Gemma-3-1b-it / app.py
broadfield-dev's picture
Update app.py
e7c4a12 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from huggingface_hub import login
from os import getenv
from threading import Thread
# --- System Prompt ---
SYSTEM_MESSAGE = {"role": "system", "content": "You are a helpful and friendly AI assistant."}
if getenv("HF_TOKEN"):
login(token=getenv("HF_TOKEN"))
# --- Model and Tokenizer Loading ---
model_name = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="cpu"
)
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
gr.Markdown("# Gemma 3 1b-IT Chatbot (CPU-Optimized)")
gr.Markdown("Chat with local 1 billion parameter model ")
chatbot = gr.Chatbot(label="Gemma 3 Chat", bubble_full_width=False, height=600, render_markdown=True)
with gr.Row():
text_input = gr.Textbox(label="Your message", scale=4, placeholder="Type your message here...", container=False)
submit_button = gr.Button("Send", variant="primary", scale=1, min_width=150)
with gr.Accordion("Advanced Settings", open=False):
# Increased the max length default to better accommodate code.
max_length_slider = gr.Slider(minimum=50, maximum=4096, value=1024, label="Max New Tokens")
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p")
top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
def format_chat_history(chat_history, new_message):
messages = [SYSTEM_MESSAGE]
for user_msg, model_msg in chat_history:
messages.append({"role": "user", "content": user_msg})
if model_msg:
messages.append({"role": "assistant", "content": model_msg})
messages.append({"role": "user", "content": new_message})
return messages
def respond(message, chat_history, max_length, temperature, top_p, top_k):
messages = format_chat_history(chat_history, message)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cpu")
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=False,
clean_up_tokenization_spaces=True
)
# Generation parameters
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"streamer": streamer,
"max_new_tokens": int(max_length),
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k),
"do_sample": True,
"eos_token_id": tokenizer.eos_token_id
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
chat_history.append((message, ""))
accumulated_text = ""
for new_text in streamer:
if new_text != tokenizer.eos_token:
accumulated_text += new_text
chat_history[-1] = (message, accumulated_text)
yield chat_history
else: return
submit_action = (respond, [text_input, chatbot, max_length_slider, temperature_slider, top_p_slider, top_k_slider], chatbot)
text_input.submit(*submit_action).then(lambda: gr.update(value=""), outputs=text_input)
submit_button.click(*submit_action).then(lambda: gr.update(value=""), outputs=text_input)
demo.launch(debug=True)