|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_REPO = "google/flan-t5-small" |
|
|
|
|
|
print("Loading FLAN-T5 Small model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_REPO) |
|
|
model.eval() |
|
|
print("Model loaded") |
|
|
|
|
|
def chat(prompt): |
|
|
if not prompt.strip(): |
|
|
return "Type a message first" |
|
|
inputs = tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=256 |
|
|
) |
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
max_length=128, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
temperature=0.7 |
|
|
) |
|
|
return tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
with gr.Blocks(title="SmallGPT CPU") as demo: |
|
|
gr.Markdown("# SmallGPT CPU Chat \nPowered by google/flan-t5-small (CPU only)") |
|
|
|
|
|
user_input = gr.Textbox(label="Your message", lines=2) |
|
|
bot_output = gr.Textbox(label="SmallGPT says", lines=4) |
|
|
|
|
|
send_btn = gr.Button("Send") |
|
|
send_btn.click(chat, inputs=user_input, outputs=bot_output) |
|
|
user_input.submit(chat, inputs=user_input, outputs=bot_output) |
|
|
|
|
|
demo.launch() |
|
|
|