zjy2001 commited on
Commit
5a0b36f
·
verified ·
1 Parent(s): 5be7173

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +168 -0
README.md CHANGED
@@ -26,9 +26,177 @@ Z1: Efficient Test-time Scaling with Code
26
  <!-- <a href="#%EF%B8%8F-citation">Citation</a> -->
27
  </p>
28
 
 
 
29
  ## Model Details
30
  To begin with the shifted thinking mode, please refer to https://github.com/efficientscaling/Z1.
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  ## Evaluation
34
 
 
26
  <!-- <a href="#%EF%B8%8F-citation">Citation</a> -->
27
  </p>
28
 
29
+
30
+
31
  ## Model Details
32
  To begin with the shifted thinking mode, please refer to https://github.com/efficientscaling/Z1.
33
 
34
+ ## Gradio Demo
35
+
36
+ ```python
37
+ import copy
38
+ from typing import List
39
+ from dataclasses import dataclass
40
+
41
+ import gradio as gr
42
+ from vllm import LLM, SamplingParams
43
+ from transformers import AutoTokenizer
44
+
45
+ BOX=r"\boxed{}"
46
+ ANSWER_WITH_BOX=f"\n\nI overthought it, the final answer in {BOX} should be:\n\n"
47
+ ANSWER_WITHOUT_BOX=f"\n\nI overthought it, the final answer should be:\n\n"
48
+
49
+ model_name = "efficientscaling/Z1-7B"
50
+
51
+ @dataclass
52
+ class ThinkingLLM(LLM):
53
+
54
+ def __init__(self, *args, **kwargs):
55
+ """
56
+ Initialize the ThinkingLLM class.
57
+
58
+ Args:
59
+ max_tokens_thinking (int): Maximum budget in terms of tokens.
60
+ *args, **kwargs: Additional arguments passed to the parent LLM class.
61
+ """
62
+ super().__init__(*args, **kwargs)
63
+
64
+ def thinking_generate(self, prompts: List[str], sampling_params: SamplingParams = None, max_tokens_for_thinking: int = None):
65
+ """
66
+ Generate text with a specified budget.
67
+
68
+ Args:
69
+ prompt (str): The input prompt for the LLM.
70
+ sampling_params (SamplingParams): A SamplingParams object to configure generation.
71
+ budget (int): The maximum budget for generation (e.g., token limit).
72
+ If None, defaults to the instance's max_budget.
73
+
74
+ Returns:
75
+ str: The generated text within the budget.
76
+ """
77
+
78
+ # If no SamplingParams is provided, create a default one
79
+ if sampling_params is None:
80
+ raise ValueError("Sampling_params can't be None!")
81
+ else:
82
+ all_max_tokens = sampling_params.max_tokens
83
+ # Override the max_tokens in the provided SamplingParams with the budget
84
+ sampling_params.max_tokens = max_tokens_for_thinking
85
+ print(f"All tokens: {all_max_tokens}")
86
+ print(f"Tokens for thinking: {max_tokens_for_thinking}")
87
+
88
+ trajectories = self.generate(prompts, sampling_params)
89
+
90
+ rethinking_str = ANSWER_WITHOUT_BOX
91
+ sampling_params.max_tokens = all_max_tokens
92
+
93
+ answers = copy.deepcopy(trajectories)
94
+
95
+ unfinished_id = []
96
+ thinking_token = 0
97
+ new_prompts = []
98
+
99
+ for id, traj in enumerate(trajectories):
100
+ if traj.outputs[0].finish_reason == 'length':
101
+ unfinished_id.append(id)
102
+ new_prompts.append(prompts[id] + traj.outputs[0].text + rethinking_str)
103
+ thinking_token += len(traj.outputs[0].token_ids)
104
+
105
+ avg_thinking_token = thinking_token / len(prompts)
106
+
107
+ if new_prompts:
108
+ print(new_prompts[0])
109
+
110
+ o = self.generate(
111
+ new_prompts,
112
+ sampling_params=sampling_params,
113
+ )
114
+
115
+ for i, uid in enumerate(unfinished_id):
116
+ answers[uid] = o[i]
117
+
118
+ return new_prompts, answers
119
+
120
+
121
+ def generate_text(prompt, max_tokens, max_tokens_for_thinking, temperature, top_p):
122
+
123
+ sampling_params = SamplingParams(
124
+ temperature=temperature,
125
+ max_tokens=max_tokens,
126
+ top_p=top_p,
127
+ skip_special_tokens=False,
128
+ )
129
+
130
+ trajectories, outputs = llm.thinking_generate(prompt, sampling_params, max_tokens_for_thinking=max_tokens_for_thinking)
131
+
132
+ return trajectories[0] + '\n\n' + outputs[0].outputs[0].text if trajectories else outputs[0].outputs[0].text
133
+
134
+
135
+ llm = ThinkingLLM(
136
+ model=model_name,
137
+ tensor_parallel_size=1,
138
+ gpu_memory_utilization=0.96,
139
+ )
140
+
141
+
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown("# Reason with shifted thinking")
144
+
145
+ with gr.Row():
146
+ with gr.Column():
147
+ prompt_input = gr.Textbox(
148
+ label="Prompt",
149
+ placeholder="Input",
150
+ lines=5,
151
+ )
152
+ max_tokens_for_thinking_input = gr.Slider(
153
+ label="shifted_thinking_window_size",
154
+ minimum=1,
155
+ maximum=32786,
156
+ value=4000,
157
+ step=1,
158
+ )
159
+ max_tokens_input = gr.Slider(
160
+ label="all_max_tokens",
161
+ minimum=1,
162
+ maximum=32786,
163
+ value=32786,
164
+ step=1,
165
+ )
166
+ temperature_input = gr.Slider(
167
+ label="Temperature",
168
+ minimum=00,
169
+ maximum=2.0,
170
+ value=0,
171
+ step=0.1,
172
+ )
173
+ top_p_input = gr.Slider(
174
+ label="Top-p",
175
+ minimum=0.0,
176
+ maximum=1.0,
177
+ value=1,
178
+ step=0.01,
179
+ )
180
+ generate_button = gr.Button("Generate")
181
+
182
+ with gr.Column():
183
+ output_text = gr.Textbox(
184
+ label="Shifted Thinking Window",
185
+ placeholder="Text is here...",
186
+ lines=10,
187
+ )
188
+
189
+
190
+ generate_button.click(
191
+ fn=generate_text,
192
+ inputs=[prompt_input, max_tokens_for_thinking_input,max_tokens_input, temperature_input, top_p_input],
193
+ outputs=output_text,
194
+ )
195
+
196
+
197
+ if __name__ == "__main__":
198
+ demo.launch()
199
+ ```
200
 
201
  ## Evaluation
202