Yifei Wang commited on
Commit ·
4e6fd45
1
Parent(s): 6d31849
fixed seed bugs
Browse files
src/numen_scriptorium/inference/qwen.py
CHANGED
|
@@ -97,10 +97,8 @@ def generate(
|
|
| 97 |
prompt = f"指令:{instruction}\n回答:"
|
| 98 |
|
| 99 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 100 |
-
generator = None
|
| 101 |
if seed is not None:
|
| 102 |
-
|
| 103 |
-
generator.manual_seed(seed)
|
| 104 |
|
| 105 |
with torch.no_grad():
|
| 106 |
outputs = model.generate(
|
|
@@ -110,7 +108,6 @@ def generate(
|
|
| 110 |
temperature=temperature,
|
| 111 |
top_p=top_p,
|
| 112 |
eos_token_id=tokenizer.eos_token_id,
|
| 113 |
-
generator=generator,
|
| 114 |
)
|
| 115 |
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 116 |
if "回答:" in text:
|
|
@@ -136,10 +133,6 @@ def stream_generate(
|
|
| 136 |
prompt = f"指令:{instruction}\n回答:"
|
| 137 |
|
| 138 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 139 |
-
generator = None
|
| 140 |
-
if seed is not None:
|
| 141 |
-
generator = torch.Generator(device=inputs["input_ids"].device)
|
| 142 |
-
generator.manual_seed(seed)
|
| 143 |
|
| 144 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 145 |
class _EventStoppingCriteria(StoppingCriteria):
|
|
@@ -158,18 +151,30 @@ def stream_generate(
|
|
| 158 |
eos_token_id=tokenizer.eos_token_id,
|
| 159 |
streamer=streamer,
|
| 160 |
)
|
| 161 |
-
if generator is not None:
|
| 162 |
-
generate_kwargs["generator"] = generator
|
| 163 |
if stop_event is not None:
|
| 164 |
generate_kwargs["stopping_criteria"] = StoppingCriteriaList([_EventStoppingCriteria(stop_event)])
|
| 165 |
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
worker.start()
|
| 168 |
for new_text in streamer:
|
| 169 |
if stop_event is not None and stop_event.is_set():
|
| 170 |
break
|
| 171 |
yield new_text
|
| 172 |
worker.join(timeout=0.5)
|
|
|
|
|
|
|
| 173 |
|
| 174 |
|
| 175 |
def get_model_device(model) -> str:
|
|
|
|
| 97 |
prompt = f"指令:{instruction}\n回答:"
|
| 98 |
|
| 99 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
| 100 |
if seed is not None:
|
| 101 |
+
set_seed(int(seed))
|
|
|
|
| 102 |
|
| 103 |
with torch.no_grad():
|
| 104 |
outputs = model.generate(
|
|
|
|
| 108 |
temperature=temperature,
|
| 109 |
top_p=top_p,
|
| 110 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
| 111 |
)
|
| 112 |
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 113 |
if "回答:" in text:
|
|
|
|
| 133 |
prompt = f"指令:{instruction}\n回答:"
|
| 134 |
|
| 135 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 138 |
class _EventStoppingCriteria(StoppingCriteria):
|
|
|
|
| 151 |
eos_token_id=tokenizer.eos_token_id,
|
| 152 |
streamer=streamer,
|
| 153 |
)
|
|
|
|
|
|
|
| 154 |
if stop_event is not None:
|
| 155 |
generate_kwargs["stopping_criteria"] = StoppingCriteriaList([_EventStoppingCriteria(stop_event)])
|
| 156 |
|
| 157 |
+
worker_error: list[Exception] = []
|
| 158 |
+
|
| 159 |
+
def _run_generate():
|
| 160 |
+
try:
|
| 161 |
+
if seed is not None:
|
| 162 |
+
set_seed(int(seed))
|
| 163 |
+
model.generate(**generate_kwargs)
|
| 164 |
+
except Exception as exc:
|
| 165 |
+
worker_error.append(exc)
|
| 166 |
+
# Ensure streamer consumer can exit even if generation fails early.
|
| 167 |
+
streamer.end()
|
| 168 |
+
|
| 169 |
+
worker = Thread(target=_run_generate)
|
| 170 |
worker.start()
|
| 171 |
for new_text in streamer:
|
| 172 |
if stop_event is not None and stop_event.is_set():
|
| 173 |
break
|
| 174 |
yield new_text
|
| 175 |
worker.join(timeout=0.5)
|
| 176 |
+
if worker_error:
|
| 177 |
+
raise worker_error[0]
|
| 178 |
|
| 179 |
|
| 180 |
def get_model_device(model) -> str:
|