| | import gradio as gr |
| | import torch |
| | import re |
| | import warnings |
| | import sys |
| | import os |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| |
|
| | |
| | warnings.filterwarnings("ignore") |
| | os.environ['PYTHONWARNINGS'] = 'ignore' |
| |
|
| | |
| | class SuppressStderr: |
| | def __enter__(self): |
| | self._original_stderr = sys.stderr |
| | sys.stderr = open(os.devnull, 'w') |
| | return self |
| | |
| | def __exit__(self, exc_type, exc_val, exc_tb): |
| | sys.stderr.close() |
| | sys.stderr = self._original_stderr |
| |
|
| | |
| | |
| | |
| | MODEL_ID = "google/gemma-3-270m" |
| |
|
| | print("Loading tokenizer...") |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| |
|
| | print("Loading model...") |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | torch_dtype=torch.float32, |
| | device_map="cpu" |
| | ) |
| | print("Model loaded successfully!") |
| |
|
| | |
| | |
| | |
| | def clean_output(text): |
| | text = text.strip() |
| | |
| | text = re.sub(r'(.{10,}?)\1+', r'\1', text) |
| | |
| | |
| | sentences = re.split(r'[.!?]\s+', text) |
| | if sentences: |
| | return sentences[0] + ('.' if not sentences[0].endswith(('.', '!', '?')) else '') |
| | return text |
| |
|
| | |
| | |
| | |
| | def chat(message, history): |
| | if not message or not message.strip(): |
| | return "Please enter a message." |
| | |
| | try: |
| | prompt = f"<bos><start_of_turn>user\n{message}\n<end_of_turn>\n<start_of_turn>model\n" |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| | |
| | with torch.no_grad(): |
| | outputs = model.generate( |
| | **inputs, |
| | max_new_tokens=150, |
| | temperature=0.7, |
| | top_p=0.9, |
| | do_sample=True, |
| | eos_token_id=tokenizer.eos_token_id, |
| | pad_token_id=tokenizer.eos_token_id, |
| | repetition_penalty=1.2 |
| | ) |
| | |
| | decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | |
| | if "model" in decoded: |
| | reply = decoded.split("model")[-1].strip() |
| | else: |
| | reply = decoded.strip() |
| | |
| | reply = clean_output(reply) |
| | |
| | return reply if reply else "I couldn't generate a response. Please try again." |
| | |
| | except Exception as e: |
| | return f"Error generating response: {str(e)}" |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# 🤖 Gemma3 270M Cloud Chat") |
| | gr.Markdown("Gemma3 270M running on Hugging Face Spaces") |
| | |
| | chatbot = gr.Chatbot(height=400) |
| | msg = gr.Textbox( |
| | label="Your message", |
| | placeholder="Type your message here...", |
| | lines=2 |
| | ) |
| | |
| | with gr.Row(): |
| | submit = gr.Button("Send", variant="primary") |
| | clear = gr.Button("Clear") |
| | |
| | gr.Markdown("### Try these examples:") |
| | with gr.Row(): |
| | example1 = gr.Button("Hi, how are you?", size="sm") |
| | example2 = gr.Button("What is AI?", size="sm") |
| | example3 = gr.Button("Write hello world in Python", size="sm") |
| | |
| | |
| | def respond(message, chat_history): |
| | bot_message = chat(message, chat_history) |
| | chat_history.append((message, bot_message)) |
| | return "", chat_history |
| | |
| | msg.submit(respond, [msg, chatbot], [msg, chatbot]) |
| | submit.click(respond, [msg, chatbot], [msg, chatbot]) |
| | clear.click(lambda: None, None, chatbot, queue=False) |
| | |
| | |
| | example1.click(lambda: "Hi, how are you?", None, msg) |
| | example2.click(lambda: "What is AI?", None, msg) |
| | example3.click(lambda: "Write hello world in Python", None, msg) |
| |
|
| | if __name__ == "__main__": |
| | import atexit |
| | |
| | |
| | def cleanup(): |
| | try: |
| | import asyncio |
| | loop = asyncio.get_event_loop() |
| | if loop.is_running(): |
| | loop.stop() |
| | except: |
| | pass |
| | |
| | atexit.register(cleanup) |
| | |
| | demo.launch( |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | share=False, |
| | quiet=True |
| | ) |