OrbitMC commited on
Commit
eb4f3c5
Β·
verified Β·
1 Parent(s): b673820

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -45
app.py CHANGED
@@ -1,82 +1,152 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
  from duckduckgo_search import DDGS
5
  from threading import Thread
6
 
7
- # --- MODEL CONFIG ---
8
- MODEL_ID = "Qwen/Qwen3-0.6B" # Pure HF Datacard
 
9
 
10
- print(f"Loading model {MODEL_ID}...")
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
 
14
  device_map="auto",
15
- torch_dtype=torch.float16,
16
- low_cpu_mem_usage=True
17
  )
18
 
19
- # --- WEB SEARCH ---
20
- def search_web(query):
21
  try:
22
  with DDGS() as ddgs:
23
- results = [r for r in ddgs.text(query, max_results=3)]
24
  if not results: return ""
25
- context = "\n".join([f"Source: {r['title']}\nContent: {r['body']}" for r in results])
26
- return f"\n\nWeb Search Context:\n{context}\n"
27
- except Exception as e:
28
- print(f"Search error: {e}")
 
29
  return ""
30
 
31
- # --- INFERENCE ---
32
- def stream_response(message, history, search_enabled, temperature, max_new_tokens):
33
- # Prepare prompt
34
- context = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  if search_enabled:
36
- context = search_web(message)
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Simple Chat Template
39
- full_prompt = f"User: {message}{context}\nAssistant:"
40
 
41
- inputs = tokenizer([full_prompt], return_tensors="pt").to(model.device)
 
 
 
 
 
 
42
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
43
 
44
- generation_kwargs = dict(
45
- inputs,
46
  streamer=streamer,
47
- max_new_tokens=max_new_tokens,
48
- do_sample=True,
49
  temperature=temperature,
50
- pad_token_id=tokenizer.eos_token_id
 
 
 
51
  )
52
 
53
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
54
  thread.start()
55
 
56
- partial_text = ""
57
  for new_text in streamer:
58
- # Handle the thinking process tags if present in output
59
- new_text = new_text.replace("<think>", "πŸ’­ *Thinking:* ").replace("</think>", "\n\n---\n\n")
60
- partial_text += new_text
61
- yield partial_text
 
 
62
 
63
- # --- CLEAN UI ---
64
- with gr.Blocks(theme=gr.themes.Default(primary_hue="orange", secondary_hue="gray")) as demo:
65
- gr.Markdown("# πŸ›Έ Qwen3 Pure-Python Explorer")
66
 
67
  with gr.Row():
68
  with gr.Column(scale=4):
69
- chatbot = gr.ChatInterface(
70
- fn=stream_response,
71
- additional_inputs=[
72
- gr.Checkbox(label="🌐 Enable Web Search", value=False),
73
- gr.Slider(0.1, 1.0, 0.7, label="Temperature"),
74
- gr.Slider(128, 4096, 1024, label="Max Tokens"),
75
- ],
76
- fill_height=True
77
- )
78
 
79
- gr.Markdown("### Features:\n- βœ… **Zero C++ / Zero llama-cpp**\n- βœ… **Native HuggingFace Transformers**\n- βœ… **DuckDuckGo Integration**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if __name__ == "__main__":
82
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import torch
3
+ import re
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
5
  from duckduckgo_search import DDGS
6
  from threading import Thread
7
 
8
+ # --- MODEL SETUP ---
9
+ MODEL_ID = "Qwen/Qwen3-0.6B" # Official HF Repo
10
+ print("Loading model and tokenizer...")
11
 
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_ID,
15
+ torch_dtype="auto",
16
  device_map="auto",
17
+ trust_remote_code=True
 
18
  )
19
 
20
+ # --- SEARCH FUNCTION ---
21
+ def web_search(query):
22
  try:
23
  with DDGS() as ddgs:
24
+ results = list(ddgs.text(query, max_results=3))
25
  if not results: return ""
26
+ blob = "\n\nSearch Results:\n"
27
+ for r in results:
28
+ blob += f"- {r['title']}: {r['body']}\n"
29
+ return blob
30
+ except:
31
  return ""
32
 
33
+ # --- UI HELPERS ---
34
+ CSS = """
35
+ .thought-box {
36
+ background-color: rgba(255, 255, 255, 0.05);
37
+ border-left: 4px solid #facc15;
38
+ padding: 10px;
39
+ margin: 10px 0;
40
+ font-style: italic;
41
+ color: #9ca3af;
42
+ }
43
+ details summary {
44
+ cursor: pointer;
45
+ color: #facc15;
46
+ font-weight: bold;
47
+ }
48
+ """
49
+
50
+ def parse_output(text):
51
+ """Parses <think> tags into a clean UI format."""
52
+ if "<think>" in text:
53
+ parts = text.split("</think>")
54
+ if len(parts) > 1:
55
+ # Finished thinking
56
+ thought = parts[0].replace("<think>", "").strip()
57
+ answer = parts[1].strip()
58
+ return f"<details open><summary>πŸ’­ Thought Process</summary><div class='thought-box'>{thought}</div></details>\n\n{answer}"
59
+ else:
60
+ # Still thinking
61
+ thought = parts[0].replace("<think>", "").strip()
62
+ return f"<details open><summary>πŸŒ€ Thinking...</summary><div class='thought-box'>{thought}</div></details>"
63
+ return text
64
+
65
+ # --- GENERATION LOGIC ---
66
+ def chat(message, history, search_enabled, temperature, max_tokens):
67
+ # 1. Handle Web Search
68
+ search_context = ""
69
  if search_enabled:
70
+ search_context = web_search(message)
71
+
72
+ # 2. Build properly formatted prompt (Fixes AI talking to itself)
73
+ # We use the official ChatML template
74
+ conversation = []
75
+ for user_msg, assistant_msg in history:
76
+ conversation.append({"role": "user", "content": user_msg})
77
+ if assistant_msg:
78
+ # Remove UI formatting before feeding back to model
79
+ clean_assistant = re.sub(r'<details.*?</details>', '', assistant_msg, flags=re.DOTALL).strip()
80
+ conversation.append({"role": "assistant", "content": clean_assistant})
81
 
82
+ user_content = message + search_context
83
+ conversation.append({"role": "user", "content": user_content})
84
 
85
+ input_ids = tokenizer.apply_chat_template(
86
+ conversation,
87
+ add_generation_prompt=True,
88
+ return_tensors="pt"
89
+ ).to(model.device)
90
+
91
+ # 3. Streamer with stop criteria
92
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
93
 
94
+ generate_kwargs = dict(
95
+ input_ids=input_ids,
96
  streamer=streamer,
97
+ max_new_tokens=max_tokens,
 
98
  temperature=temperature,
99
+ do_sample=True,
100
+ pad_token_id=tokenizer.eos_token_id,
101
+ # Stop generating once the model tries to start a new 'User' turn
102
+ eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")]
103
  )
104
 
105
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
106
  thread.start()
107
 
108
+ buffer = ""
109
  for new_text in streamer:
110
+ # Crucial Fix: If the model generates "User:" or "<|im_start|>", stop displaying
111
+ if "User:" in new_text or "<|im_start|>" in new_text:
112
+ break
113
+
114
+ buffer += new_text
115
+ yield parse_output(buffer)
116
 
117
+ # --- GRADIO UI ---
118
+ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
119
+ gr.HTML("<h1>🧠 Qwen3 Reasoning Lab</h1>")
120
 
121
  with gr.Row():
122
  with gr.Column(scale=4):
123
+ chat_box = gr.Chatbot(height=600, label="Qwen3-0.6B")
124
+ msg_input = gr.Textbox(placeholder="Ask a logic question...", show_label=False)
 
 
 
 
 
 
 
125
 
126
+ with gr.Column(scale=1):
127
+ search_toggle = gr.Checkbox(label="🌐 Web Search (DDG)", value=False)
128
+ temp_slider = gr.Slider(0.1, 1.0, 0.7, label="Temperature")
129
+ token_slider = gr.Slider(512, 4096, 1024, label="Max Tokens")
130
+ gr.Markdown("""
131
+ ### Tips:
132
+ - **Thinking:** This model is trained for Chain-of-Thought.
133
+ - **Self-Talk Fix:** We use stop sequences to prevent the AI from acting as 'User'.
134
+ """)
135
+ clear_btn = gr.Button("πŸ—‘ Clear Chat")
136
+
137
+ # Set up logic
138
+ chat_event = msg_input.submit(
139
+ lambda x, y: (x, y + [[x, None]]),
140
+ [msg_input, chat_box],
141
+ [msg_input, chat_box],
142
+ queue=False
143
+ ).then(
144
+ chat,
145
+ [msg_input, chat_box, search_toggle, temp_slider, token_slider],
146
+ chat_box
147
+ )
148
+
149
+ clear_btn.click(lambda: None, None, chat_box, queue=False)
150
 
151
  if __name__ == "__main__":
152
  demo.launch(server_name="0.0.0.0", server_port=7860)