Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ os.system("pip install git+https://github.com/huggingface/transformers")
|
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 6 |
from threading import Thread
|
|
|
|
| 7 |
|
| 8 |
tok = AutoTokenizer.from_pretrained("distilgpt2")
|
| 9 |
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
|
@@ -16,10 +17,10 @@ early_stop_pattern = tok.eos_token
|
|
| 16 |
print(f'Early stop pattern = \"{early_stop_pattern}\"')
|
| 17 |
|
| 18 |
def generate(text = ""):
|
| 19 |
-
streamer = TextIteratorStreamer(tok)
|
| 20 |
if len(text) == 0:
|
| 21 |
text = " "
|
| 22 |
-
inputs = tok([text], return_tensors="pt")
|
| 23 |
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128)
|
| 24 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 25 |
thread.start()
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 6 |
from threading import Thread
|
| 7 |
+
import torch
|
| 8 |
|
| 9 |
tok = AutoTokenizer.from_pretrained("distilgpt2")
|
| 10 |
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
|
|
|
| 17 |
print(f'Early stop pattern = \"{early_stop_pattern}\"')
|
| 18 |
|
| 19 |
def generate(text = ""):
|
| 20 |
+
streamer = TextIteratorStreamer(tok, timeout=10.)
|
| 21 |
if len(text) == 0:
|
| 22 |
text = " "
|
| 23 |
+
inputs = tok([text], return_tensors="pt").to(device)
|
| 24 |
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128)
|
| 25 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 26 |
thread.start()
|