FrostIce commited on
Commit
a761dfe
·
verified ·
1 Parent(s): 4d21fdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -5,11 +5,9 @@ import re
5
  import numpy as np
6
  import json
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
8
  import gradio as gr
9
 
10
- # Применяем патч для streaming
11
- patch_streaming()
12
-
13
  # === Настройки ===
14
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
15
 
@@ -85,21 +83,25 @@ def generate_stream(prompt):
85
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
86
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
87
 
88
- def generate():
89
- with torch.no_grad():
90
- model.generate(
91
- **inputs,
92
- max_new_tokens=256,
93
- temperature=0.7,
94
- streamer=streamer,
95
- use_cache=True
96
- )
97
 
98
- thread = threading.Thread(target=generate)
 
 
 
 
99
  thread.start()
100
 
101
- for text in streamer:
102
- yield text
 
 
103
 
104
  # === Gradio ===
105
  with gr.Blocks() as demo:
 
5
  import numpy as np
6
  import json
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
8
+ from transformers_stream_generator import StreamingGenerationConfig
9
  import gradio as gr
10
 
 
 
 
11
  # === Настройки ===
12
  model_name = "Qwen/Qwen2.5-0.5B-Instruct"
13
 
 
83
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
84
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
85
 
86
+ # Используем специальную конфигурацию
87
+ generation_config = StreamingGenerationConfig(
88
+ max_new_tokens=256,
89
+ temperature=0.7,
90
+ top_p=0.9,
91
+ do_sample=True
92
+ )
 
 
93
 
94
+ thread = threading.Thread(target=model.generate, kwargs={
95
+ **inputs,
96
+ "streamer": streamer,
97
+ "generation_config": generation_config
98
+ })
99
  thread.start()
100
 
101
+ buffer = ""
102
+ for token in streamer:
103
+ buffer += token
104
+ yield buffer
105
 
106
  # === Gradio ===
107
  with gr.Blocks() as demo: