File size: 5,204 Bytes
acf0628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import threading

import torch
import gradio as gr

from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import AutoModelForCausalLM
from transformers import TextIteratorStreamer
# from transformers import BitsAndBytesConfig

# BEWARE: this app will only work with 'chat' models (that have a
#         `.chat_template` in their `tokenizer` – you can check that
#         Qwen3-06B has one: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/tokenizer_config.json)
#         Also, note that there is a mechanism to detect 'thinking' tokens and
#         displaying them differently, but if the chosen model outputs them in
#         a different format than <think></think>, then that won't work, and
#         you need to study the model output and change the checks accordingly!
# MODEL_ID = "google/gemma-3-270m-it"
MODEL_ID = "Qwen/Qwen3-0.6B"

# The overall 'directive' for our bot, see below
SYSTEM = "You are a helpful, concise assistant."

device = (
    "cuda"
    if torch.cuda.is_available()
    # note: models using bfloat16 aren't compatible with MPS
    # else "mps"
    # if torch.backends.mps.is_available()
    else "cpu"
)

# Theoretically, you can reduce the memory footprint and increase the speed of
# your model by loading it quantized, but that means making sure bitsandbytes
# is installed (with pip only), and my tests haven't led to conclusive results
# quantization_config = BitsAndBytesConfig(load_in_8bit=True)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    # quantization_config=quantization_config
).to(device)

# Context window from model config (fallback if missing)
context_window = getattr(model.config, "max_position_embeddings", None)
if context_window is None:
    context_window = getattr(tokenizer, "model_max_length", 2048)

print(f"model: {MODEL_ID}, context window: {context_window}.")


def predict(message, history):
    """
    Gradio ChatInterface callback.

    - `history` is a list of dicts with `role` and `content` (type="messages").
    - We append the latest user message, then build a chat template for Qwen.
    """

    # print(history)

    # Make sure we don't mutate Gradio's history list in-place
    conversation = history + [{"role": "user", "content": message}]

    # Optionally prepend a system prompt; this also helps some Qwen templates.
    if SYSTEM:
        conversation = [
            {
                "role": "system",
                "content": SYSTEM,
            },
            *conversation,
        ]

    # Use Qwen's chat template and add a generation prompt so the model knows
    # it should now produce the assistant's reply.
    input_text = tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        add_special_tokens=False,
    ).to(device)

    # Set max_new_tokens to fill remaining context
    input_len = inputs["input_ids"].shape[1]
    max_new_tokens = max(1, context_window - input_len)

    # Set up a text streamer so we can yield partial generations
    # token-by-token (or small chunks), while the model runs in a
    # background thread.
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    generation_config = GenerationConfig.from_pretrained(MODEL_ID)
    generation_config.max_new_tokens = max_new_tokens
    # suppressing a pesky warning (https://stackoverflow.com/a/71397707)
    model.generation_config.pad_token_id = tokenizer.eos_token_id

    # Run generation in a separate thread so that we can iterate over
    # the streamer in this function and yield updates to Gradio.
    def _run_generation():
        model.generate(
            **inputs,
            generation_config=generation_config,
            streamer=streamer,
        )

    thread = threading.Thread(target=_run_generation)
    thread.start()

    # Streamed parsing of the `<think>...</think>` block.
    # As soon as we see `<think>` in the stream, we start treating
    # everything that follows as "reasoning" until we encounter `</think>`.
    generated = ""
    in_think = False

    for new_text in streamer:
        if not new_text:
            continue

        # Wrap thinking in a p with dedicated html
        next_text_stripped = new_text.strip()
        if next_text_stripped == "<think>":
            generated += "<p style='color:#777; font-size: 12px; font-style:italic;'>"
            in_think = True
            continue
        if next_text_stripped == "</think>":
            generated += "</p>"
            in_think = False
            continue

        generated += new_text

        if in_think:
            # If within thinking tags, temporarily close the div for coherence
            yield generated + "</p>"
        else:
            # The thinking is over, the tag is closed
            yield generated

    # Ensure the generation thread is finished before returning.
    thread.join()


demo = gr.ChatInterface(
    predict,
    api_name="chat",
)

demo.launch()