File size: 4,437 Bytes
585b80d
 
20122ba
21a0a83
0659245
 
20122ba
b69009c
0659245
54c46af
0659245
 
 
 
 
 
 
 
 
 
 
 
21a0a83
20122ba
 
 
aa2761b
b69009c
54c46af
20122ba
54c46af
 
20122ba
 
54c46af
20122ba
 
54c46af
b69009c
20122ba
21a0a83
20122ba
efb0bd0
 
54c46af
21a0a83
 
 
54c46af
 
 
efb0bd0
b69009c
20122ba
 
 
 
54c46af
 
 
21a0a83
54c46af
21a0a83
54c46af
21a0a83
 
 
 
54c46af
21a0a83
 
 
 
 
 
 
 
 
54c46af
 
 
 
 
 
 
21a0a83
 
54c46af
 
21a0a83
54c46af
0ded455
20122ba
54c46af
20122ba
54c46af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20122ba
 
0659245
 
 
 
 
 
 
 
 
 
 
 
 
 
21a0a83
 
 
0659245
 
21a0a83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import torch
import re
import warnings
import sys
import os
from transformers import AutoTokenizer, AutoModelForCausalLM

# Suppress all warnings including asyncio
warnings.filterwarnings("ignore")
os.environ['PYTHONWARNINGS'] = 'ignore'

# Redirect stderr to suppress asyncio exceptions
class SuppressStderr:
    def __enter__(self):
        self._original_stderr = sys.stderr
        sys.stderr = open(os.devnull, 'w')
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stderr.close()
        sys.stderr = self._original_stderr

# ======================
# Load model
# ======================
MODEL_ID = "google/gemma-3-270m"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    device_map="cpu"
)
print("Model loaded successfully!")

# ======================
# Clean output
# ======================
def clean_output(text):
    text = text.strip()
    # Remove repeated patterns
    text = re.sub(r'(.{10,}?)\1+', r'\1', text)
    
    # Find first complete sentence
    sentences = re.split(r'[.!?]\s+', text)
    if sentences:
        return sentences[0] + ('.' if not sentences[0].endswith(('.', '!', '?')) else '')
    return text

# ======================
# Chat function
# ======================
def chat(message, history):
    if not message or not message.strip():
        return "Please enter a message."
    
    try:
        prompt = f"<bos><start_of_turn>user\n{message}\n<end_of_turn>\n<start_of_turn>model\n"
        
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=150,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
                repetition_penalty=1.2
            )
        
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract model response
        if "model" in decoded:
            reply = decoded.split("model")[-1].strip()
        else:
            reply = decoded.strip()
            
        reply = clean_output(reply)
        
        return reply if reply else "I couldn't generate a response. Please try again."
        
    except Exception as e:
        return f"Error generating response: {str(e)}"

# ======================
# UI with proper examples
# ======================
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🤖 Gemma3 270M Cloud Chat")
    gr.Markdown("Gemma3 270M running on Hugging Face Spaces")
    
    chatbot = gr.Chatbot(height=400)
    msg = gr.Textbox(
        label="Your message",
        placeholder="Type your message here...",
        lines=2
    )
    
    with gr.Row():
        submit = gr.Button("Send", variant="primary")
        clear = gr.Button("Clear")
    
    gr.Markdown("### Try these examples:")
    with gr.Row():
        example1 = gr.Button("Hi, how are you?", size="sm")
        example2 = gr.Button("What is AI?", size="sm")
        example3 = gr.Button("Write hello world in Python", size="sm")
    
    # Chat interaction
    def respond(message, chat_history):
        bot_message = chat(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history
    
    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    submit.click(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)
    
    # Example buttons
    example1.click(lambda: "Hi, how are you?", None, msg)
    example2.click(lambda: "What is AI?", None, msg)
    example3.click(lambda: "Write hello world in Python", None, msg)

if __name__ == "__main__":
    import atexit
    
    # Clean exit handler
    def cleanup():
        try:
            import asyncio
            loop = asyncio.get_event_loop()
            if loop.is_running():
                loop.stop()
        except:
            pass
    
    atexit.register(cleanup)
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        quiet=True  # Suppress Gradio startup messages
    )