yuhueng commited on
Commit
ffbee98
·
verified ·
1 Parent(s): 06c1b5d

fix: revert to non streaming

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -28,28 +28,36 @@ def inference(prompt: str, max_tokens: int = 256) -> str:
28
  )
29
 
30
  inputs = tokenizer(text, return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
 
31
 
32
- # Use TextIteratorStreamer instead of TextStreamer
33
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
34
 
35
- generation_kwargs = dict(
36
- **inputs,
37
- max_new_tokens=max_tokens,
38
- temperature=0.7,
39
- top_p=0.8,
40
- top_k=20,
41
- streamer=streamer,
42
- )
43
 
44
- # Run generation in separate thread
45
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
46
- thread.start()
47
 
48
- # Yield tokens as they come
49
- generated_text = ""
50
- for new_text in streamer:
51
- generated_text += new_text
52
- yield generated_text # yield cumulative text for Gradio
53
 
54
  demo = gr.Interface(
55
  fn=inference,
 
28
  )
29
 
30
  inputs = tokenizer(text, return_tensors="pt").to("cuda")
31
+
32
+ outputs = model.generate(
33
+ inputs,
34
+ max_new_tokens = 100, # Increase for longer outputs!
35
+ temperature = 0.7, top_p = 0.8, top_k = 20, # For non thinking
36
+ streamer = TextStreamer(tokenizer, skip_prompt = True),
37
+ )
38
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
+ # # Use TextIteratorStreamer instead of TextStreamer
41
+ # streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
42
 
43
+ # generation_kwargs = dict(
44
+ # **inputs,
45
+ # max_new_tokens=max_tokens,
46
+ # temperature=0.7,
47
+ # top_p=0.8,
48
+ # top_k=20,
49
+ # streamer=streamer,
50
+ # )
51
 
52
+ # # Run generation in separate thread
53
+ # thread = Thread(target=model.generate, kwargs=generation_kwargs)
54
+ # thread.start()
55
 
56
+ # # Yield tokens as they come
57
+ # generated_text = ""
58
+ # for new_text in streamer:
59
+ # generated_text += new_text
60
+ # yield generated_text # yield cumulative text for Gradio
61
 
62
  demo = gr.Interface(
63
  fn=inference,