Yifei Wang commited on
Commit
4e6fd45
·
1 Parent(s): 6d31849

fixed seed bugs

Browse files
src/numen_scriptorium/inference/qwen.py CHANGED
@@ -97,10 +97,8 @@ def generate(
97
  prompt = f"指令:{instruction}\n回答:"
98
 
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
- generator = None
101
  if seed is not None:
102
- generator = torch.Generator(device=inputs["input_ids"].device)
103
- generator.manual_seed(seed)
104
 
105
  with torch.no_grad():
106
  outputs = model.generate(
@@ -110,7 +108,6 @@ def generate(
110
  temperature=temperature,
111
  top_p=top_p,
112
  eos_token_id=tokenizer.eos_token_id,
113
- generator=generator,
114
  )
115
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
116
  if "回答:" in text:
@@ -136,10 +133,6 @@ def stream_generate(
136
  prompt = f"指令:{instruction}\n回答:"
137
 
138
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
139
- generator = None
140
- if seed is not None:
141
- generator = torch.Generator(device=inputs["input_ids"].device)
142
- generator.manual_seed(seed)
143
 
144
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
145
  class _EventStoppingCriteria(StoppingCriteria):
@@ -158,18 +151,30 @@ def stream_generate(
158
  eos_token_id=tokenizer.eos_token_id,
159
  streamer=streamer,
160
  )
161
- if generator is not None:
162
- generate_kwargs["generator"] = generator
163
  if stop_event is not None:
164
  generate_kwargs["stopping_criteria"] = StoppingCriteriaList([_EventStoppingCriteria(stop_event)])
165
 
166
- worker = Thread(target=model.generate, kwargs=generate_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
167
  worker.start()
168
  for new_text in streamer:
169
  if stop_event is not None and stop_event.is_set():
170
  break
171
  yield new_text
172
  worker.join(timeout=0.5)
 
 
173
 
174
 
175
  def get_model_device(model) -> str:
 
97
  prompt = f"指令:{instruction}\n回答:"
98
 
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
100
  if seed is not None:
101
+ set_seed(int(seed))
 
102
 
103
  with torch.no_grad():
104
  outputs = model.generate(
 
108
  temperature=temperature,
109
  top_p=top_p,
110
  eos_token_id=tokenizer.eos_token_id,
 
111
  )
112
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
113
  if "回答:" in text:
 
133
  prompt = f"指令:{instruction}\n回答:"
134
 
135
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
136
 
137
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
138
  class _EventStoppingCriteria(StoppingCriteria):
 
151
  eos_token_id=tokenizer.eos_token_id,
152
  streamer=streamer,
153
  )
 
 
154
  if stop_event is not None:
155
  generate_kwargs["stopping_criteria"] = StoppingCriteriaList([_EventStoppingCriteria(stop_event)])
156
 
157
+ worker_error: list[Exception] = []
158
+
159
+ def _run_generate():
160
+ try:
161
+ if seed is not None:
162
+ set_seed(int(seed))
163
+ model.generate(**generate_kwargs)
164
+ except Exception as exc:
165
+ worker_error.append(exc)
166
+ # Ensure streamer consumer can exit even if generation fails early.
167
+ streamer.end()
168
+
169
+ worker = Thread(target=_run_generate)
170
  worker.start()
171
  for new_text in streamer:
172
  if stop_event is not None and stop_event.is_set():
173
  break
174
  yield new_text
175
  worker.join(timeout=0.5)
176
+ if worker_error:
177
+ raise worker_error[0]
178
 
179
 
180
  def get_model_device(model) -> str: