slm / app.py
OrbitMC's picture
Update app.py
eb4f3c5 verified
import gradio as gr
import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from duckduckgo_search import DDGS
from threading import Thread
# --- MODEL SETUP ---
MODEL_ID = "Qwen/Qwen3-0.6B" # Official HF Repo
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
# --- SEARCH FUNCTION ---
def web_search(query):
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=3))
if not results: return ""
blob = "\n\nSearch Results:\n"
for r in results:
blob += f"- {r['title']}: {r['body']}\n"
return blob
except:
return ""
# --- UI HELPERS ---
CSS = """
.thought-box {
background-color: rgba(255, 255, 255, 0.05);
border-left: 4px solid #facc15;
padding: 10px;
margin: 10px 0;
font-style: italic;
color: #9ca3af;
}
details summary {
cursor: pointer;
color: #facc15;
font-weight: bold;
}
"""
def parse_output(text):
"""Parses <think> tags into a clean UI format."""
if "<think>" in text:
parts = text.split("</think>")
if len(parts) > 1:
# Finished thinking
thought = parts[0].replace("<think>", "").strip()
answer = parts[1].strip()
return f"<details open><summary>πŸ’­ Thought Process</summary><div class='thought-box'>{thought}</div></details>\n\n{answer}"
else:
# Still thinking
thought = parts[0].replace("<think>", "").strip()
return f"<details open><summary>πŸŒ€ Thinking...</summary><div class='thought-box'>{thought}</div></details>"
return text
# --- GENERATION LOGIC ---
def chat(message, history, search_enabled, temperature, max_tokens):
# 1. Handle Web Search
search_context = ""
if search_enabled:
search_context = web_search(message)
# 2. Build properly formatted prompt (Fixes AI talking to itself)
# We use the official ChatML template
conversation = []
for user_msg, assistant_msg in history:
conversation.append({"role": "user", "content": user_msg})
if assistant_msg:
# Remove UI formatting before feeding back to model
clean_assistant = re.sub(r'<details.*?</details>', '', assistant_msg, flags=re.DOTALL).strip()
conversation.append({"role": "assistant", "content": clean_assistant})
user_content = message + search_context
conversation.append({"role": "user", "content": user_content})
input_ids = tokenizer.apply_chat_template(
conversation,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# 3. Streamer with stop criteria
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
# Stop generating once the model tries to start a new 'User' turn
eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")]
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
# Crucial Fix: If the model generates "User:" or "<|im_start|>", stop displaying
if "User:" in new_text or "<|im_start|>" in new_text:
break
buffer += new_text
yield parse_output(buffer)
# --- GRADIO UI ---
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
gr.HTML("<h1>🧠 Qwen3 Reasoning Lab</h1>")
with gr.Row():
with gr.Column(scale=4):
chat_box = gr.Chatbot(height=600, label="Qwen3-0.6B")
msg_input = gr.Textbox(placeholder="Ask a logic question...", show_label=False)
with gr.Column(scale=1):
search_toggle = gr.Checkbox(label="🌐 Web Search (DDG)", value=False)
temp_slider = gr.Slider(0.1, 1.0, 0.7, label="Temperature")
token_slider = gr.Slider(512, 4096, 1024, label="Max Tokens")
gr.Markdown("""
### Tips:
- **Thinking:** This model is trained for Chain-of-Thought.
- **Self-Talk Fix:** We use stop sequences to prevent the AI from acting as 'User'.
""")
clear_btn = gr.Button("πŸ—‘ Clear Chat")
# Set up logic
chat_event = msg_input.submit(
lambda x, y: (x, y + [[x, None]]),
[msg_input, chat_box],
[msg_input, chat_box],
queue=False
).then(
chat,
[msg_input, chat_box, search_toggle, temp_slider, token_slider],
chat_box
)
clear_btn.click(lambda: None, None, chat_box, queue=False)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)