|
|
import os |
|
|
import gradio as gr |
|
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-270m", token=HF_TOKEN) |
|
|
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", token=HF_TOKEN) |
|
|
|
|
|
|
|
|
def predict(message, history): |
|
|
chat_history = [] |
|
|
for user, bot in history: |
|
|
chat_history.append(user) |
|
|
chat_history.append(bot) |
|
|
chat_history.append(message) |
|
|
|
|
|
input_ids = tokenizer(chat_history, padding="longest", return_tensors="pt")['input_ids'] |
|
|
chat_history_ids = model.generate( |
|
|
input_ids, |
|
|
max_length=512, |
|
|
num_return_sequences=1, |
|
|
early_stopping=False, |
|
|
do_sample=True, |
|
|
) |
|
|
response = tokenizer.decode(chat_history_ids[-1], skip_special_tokens=True) |
|
|
yield response |
|
|
|
|
|
|
|
|
if __name__=="__main__": |
|
|
gr.ChatInterface(predict).queue().launch() |
|
|
|