import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import torch from huggingface_hub import login from os import getenv from threading import Thread # --- System Prompt --- SYSTEM_MESSAGE = {"role": "system", "content": "You are a helpful and friendly AI assistant."} if getenv("HF_TOKEN"): login(token=getenv("HF_TOKEN")) # --- Model and Tokenizer Loading --- model_name = "google/gemma-3-1b-it" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cpu" ) # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo: gr.Markdown("# Gemma 3 1b-IT Chatbot (CPU-Optimized)") gr.Markdown("Chat with local 1 billion parameter model ") chatbot = gr.Chatbot(label="Gemma 3 Chat", bubble_full_width=False, height=600, render_markdown=True) with gr.Row(): text_input = gr.Textbox(label="Your message", scale=4, placeholder="Type your message here...", container=False) submit_button = gr.Button("Send", variant="primary", scale=1, min_width=150) with gr.Accordion("Advanced Settings", open=False): # Increased the max length default to better accommodate code. max_length_slider = gr.Slider(minimum=50, maximum=4096, value=1024, label="Max New Tokens") temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Temperature") top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p") top_k_slider = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k") def format_chat_history(chat_history, new_message): messages = [SYSTEM_MESSAGE] for user_msg, model_msg in chat_history: messages.append({"role": "user", "content": user_msg}) if model_msg: messages.append({"role": "assistant", "content": model_msg}) messages.append({"role": "user", "content": new_message}) return messages def respond(message, chat_history, max_length, temperature, top_p, top_k): messages = format_chat_history(chat_history, message) prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to("cpu") streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=False, clean_up_tokenization_spaces=True ) # Generation parameters generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "streamer": streamer, "max_new_tokens": int(max_length), "temperature": float(temperature), "top_p": float(top_p), "top_k": int(top_k), "do_sample": True, "eos_token_id": tokenizer.eos_token_id } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() chat_history.append((message, "")) accumulated_text = "" for new_text in streamer: if new_text != tokenizer.eos_token: accumulated_text += new_text chat_history[-1] = (message, accumulated_text) yield chat_history else: return submit_action = (respond, [text_input, chatbot, max_length_slider, temperature_slider, top_p_slider, top_k_slider], chatbot) text_input.submit(*submit_action).then(lambda: gr.update(value=""), outputs=text_input) submit_button.click(*submit_action).then(lambda: gr.update(value=""), outputs=text_input) demo.launch(debug=True)