| | import os |
| | import torch |
| | import gradio as gr |
| | from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
| | from threading import Thread |
| |
|
| | |
| | MODEL_ID = "google/gemma-3-270m-it" |
| | HF_TOKEN = os.getenv('HF_TOKEN') |
| |
|
| | print("--- [1] Loading Assets ---") |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) |
| |
|
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | device_map="cpu", |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | token=HF_TOKEN |
| | ) |
| | print("--- [2] Model Ready ---") |
| |
|
| | def chat(message, history): |
| | |
| | inputs = tokenizer(message, return_tensors="pt").to("cpu") |
| | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| | |
| | |
| | kwargs = dict( |
| | **inputs, |
| | streamer=streamer, |
| | max_new_tokens=256, |
| | do_sample=True, |
| | temperature=0.7, |
| | ) |
| | |
| | thread = Thread(target=model.generate, kwargs=kwargs) |
| | thread.start() |
| | |
| | buffer = "" |
| | for new_text in streamer: |
| | buffer += new_text |
| | yield buffer |
| |
|
| | |
| | demo = gr.ChatInterface(fn=chat, type="messages") |
| |
|
| | if __name__ == "__main__": |
| | print("--- [3] Launching on Port 7860 ---") |
| | |
| | demo.launch(server_name="0.0.0.0", server_port=7860) |