File size: 1,039 Bytes
ead9956
2b90e82
 
 
 
 
ead9956
ac7f408
ead9956
 
2b90e82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import os
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

HF_TOKEN = os.getenv("HF_TOKEN")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-270m", token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m", token=HF_TOKEN)


def predict(message, history):
    chat_history = []
    for user, bot in history:
        chat_history.append(user)
        chat_history.append(bot)
    chat_history.append(message)

    input_ids = tokenizer(chat_history, padding="longest", return_tensors="pt")['input_ids']
    chat_history_ids = model.generate(
        input_ids,
        max_length=512,
        num_return_sequences=1,  # Number of response sequences to generate
        early_stopping=False,
        do_sample=True,
    )
    response = tokenizer.decode(chat_history_ids[-1], skip_special_tokens=True)
    yield response


if __name__=="__main__":
    gr.ChatInterface(predict).queue().launch()