File size: 5,165 Bytes
7ae216c
b673820
eb4f3c5
b673820
7ae216c
b673820
 
eb4f3c5
 
 
b673820
eb4f3c5
b673820
 
eb4f3c5
b673820
eb4f3c5
e7b5bc2
7ae216c
eb4f3c5
 
7ae216c
 
eb4f3c5
b673820
eb4f3c5
 
 
 
 
b673820
7ae216c
eb4f3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ae216c
eb4f3c5
 
 
 
 
 
 
 
 
 
 
b673820
eb4f3c5
 
b673820
eb4f3c5
 
 
 
 
 
 
b673820
 
eb4f3c5
 
b673820
eb4f3c5
b673820
eb4f3c5
 
 
 
b673820
e7b5bc2
eb4f3c5
b673820
e7b5bc2
eb4f3c5
b673820
eb4f3c5
 
 
 
 
 
e7b5bc2
eb4f3c5
 
 
e7b5bc2
7ae216c
e7b5bc2
eb4f3c5
 
b673820
eb4f3c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ae216c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import gradio as gr
import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from duckduckgo_search import DDGS
from threading import Thread

# --- MODEL SETUP ---
MODEL_ID = "Qwen/Qwen3-0.6B" # Official HF Repo
print("Loading model and tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True
)

# --- SEARCH FUNCTION ---
def web_search(query):
    try:
        with DDGS() as ddgs:
            results = list(ddgs.text(query, max_results=3))
            if not results: return ""
            blob = "\n\nSearch Results:\n"
            for r in results:
                blob += f"- {r['title']}: {r['body']}\n"
            return blob
    except:
        return ""

# --- UI HELPERS ---
CSS = """
.thought-box {
    background-color: rgba(255, 255, 255, 0.05);
    border-left: 4px solid #facc15;
    padding: 10px;
    margin: 10px 0;
    font-style: italic;
    color: #9ca3af;
}
details summary {
    cursor: pointer;
    color: #facc15;
    font-weight: bold;
}
"""

def parse_output(text):
    """Parses <think> tags into a clean UI format."""
    if "<think>" in text:
        parts = text.split("</think>")
        if len(parts) > 1:
            # Finished thinking
            thought = parts[0].replace("<think>", "").strip()
            answer = parts[1].strip()
            return f"<details open><summary>πŸ’­ Thought Process</summary><div class='thought-box'>{thought}</div></details>\n\n{answer}"
        else:
            # Still thinking
            thought = parts[0].replace("<think>", "").strip()
            return f"<details open><summary>πŸŒ€ Thinking...</summary><div class='thought-box'>{thought}</div></details>"
    return text

# --- GENERATION LOGIC ---
def chat(message, history, search_enabled, temperature, max_tokens):
    # 1. Handle Web Search
    search_context = ""
    if search_enabled:
        search_context = web_search(message)
    
    # 2. Build properly formatted prompt (Fixes AI talking to itself)
    # We use the official ChatML template
    conversation = []
    for user_msg, assistant_msg in history:
        conversation.append({"role": "user", "content": user_msg})
        if assistant_msg:
            # Remove UI formatting before feeding back to model
            clean_assistant = re.sub(r'<details.*?</details>', '', assistant_msg, flags=re.DOTALL).strip()
            conversation.append({"role": "assistant", "content": clean_assistant})
    
    user_content = message + search_context
    conversation.append({"role": "user", "content": user_content})
    
    input_ids = tokenizer.apply_chat_template(
        conversation, 
        add_generation_prompt=True, 
        return_tensors="pt"
    ).to(model.device)

    # 3. Streamer with stop criteria
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        # Stop generating once the model tries to start a new 'User' turn
        eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|im_end|>")]
    )

    thread = Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    buffer = ""
    for new_text in streamer:
        # Crucial Fix: If the model generates "User:" or "<|im_start|>", stop displaying
        if "User:" in new_text or "<|im_start|>" in new_text:
            break
            
        buffer += new_text
        yield parse_output(buffer)

# --- GRADIO UI ---
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    gr.HTML("<h1>🧠 Qwen3 Reasoning Lab</h1>")
    
    with gr.Row():
        with gr.Column(scale=4):
            chat_box = gr.Chatbot(height=600, label="Qwen3-0.6B")
            msg_input = gr.Textbox(placeholder="Ask a logic question...", show_label=False)
        
        with gr.Column(scale=1):
            search_toggle = gr.Checkbox(label="🌐 Web Search (DDG)", value=False)
            temp_slider = gr.Slider(0.1, 1.0, 0.7, label="Temperature")
            token_slider = gr.Slider(512, 4096, 1024, label="Max Tokens")
            gr.Markdown("""
            ### Tips:
            - **Thinking:** This model is trained for Chain-of-Thought.
            - **Self-Talk Fix:** We use stop sequences to prevent the AI from acting as 'User'.
            """)
            clear_btn = gr.Button("πŸ—‘ Clear Chat")

    # Set up logic
    chat_event = msg_input.submit(
        lambda x, y: (x, y + [[x, None]]), 
        [msg_input, chat_box], 
        [msg_input, chat_box], 
        queue=False
    ).then(
        chat, 
        [msg_input, chat_box, search_toggle, temp_slider, token_slider], 
        chat_box
    )
    
    clear_btn.click(lambda: None, None, chat_box, queue=False)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)