jchwenger commited on
Commit
acf0628
·
1 Parent(s): 81df2d4

app | import from dmlcp

Browse files
Files changed (1) hide show
  1. app.py +159 -0
app.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+
3
+ import torch
4
+ import gradio as gr
5
+
6
+ from transformers import AutoTokenizer
7
+ from transformers import GenerationConfig
8
+ from transformers import AutoModelForCausalLM
9
+ from transformers import TextIteratorStreamer
10
+ # from transformers import BitsAndBytesConfig
11
+
12
+ # BEWARE: this app will only work with 'chat' models (that have a
13
+ # `.chat_template` in their `tokenizer` – you can check that
14
+ # Qwen3-06B has one: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/tokenizer_config.json)
15
+ # Also, note that there is a mechanism to detect 'thinking' tokens and
16
+ # displaying them differently, but if the chosen model outputs them in
17
+ # a different format than <think></think>, then that won't work, and
18
+ # you need to study the model output and change the checks accordingly!
19
+ # MODEL_ID = "google/gemma-3-270m-it"
20
+ MODEL_ID = "Qwen/Qwen3-0.6B"
21
+
22
+ # The overall 'directive' for our bot, see below
23
+ SYSTEM = "You are a helpful, concise assistant."
24
+
25
+ device = (
26
+ "cuda"
27
+ if torch.cuda.is_available()
28
+ # note: models using bfloat16 aren't compatible with MPS
29
+ # else "mps"
30
+ # if torch.backends.mps.is_available()
31
+ else "cpu"
32
+ )
33
+
34
+ # Theoretically, you can reduce the memory footprint and increase the speed of
35
+ # your model by loading it quantized, but that means making sure bitsandbytes
36
+ # is installed (with pip only), and my tests haven't led to conclusive results
37
+ # quantization_config = BitsAndBytesConfig(load_in_8bit=True)
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ MODEL_ID,
42
+ # quantization_config=quantization_config
43
+ ).to(device)
44
+
45
+ # Context window from model config (fallback if missing)
46
+ context_window = getattr(model.config, "max_position_embeddings", None)
47
+ if context_window is None:
48
+ context_window = getattr(tokenizer, "model_max_length", 2048)
49
+
50
+ print(f"model: {MODEL_ID}, context window: {context_window}.")
51
+
52
+
53
+ def predict(message, history):
54
+ """
55
+ Gradio ChatInterface callback.
56
+
57
+ - `history` is a list of dicts with `role` and `content` (type="messages").
58
+ - We append the latest user message, then build a chat template for Qwen.
59
+ """
60
+
61
+ # print(history)
62
+
63
+ # Make sure we don't mutate Gradio's history list in-place
64
+ conversation = history + [{"role": "user", "content": message}]
65
+
66
+ # Optionally prepend a system prompt; this also helps some Qwen templates.
67
+ if SYSTEM:
68
+ conversation = [
69
+ {
70
+ "role": "system",
71
+ "content": SYSTEM,
72
+ },
73
+ *conversation,
74
+ ]
75
+
76
+ # Use Qwen's chat template and add a generation prompt so the model knows
77
+ # it should now produce the assistant's reply.
78
+ input_text = tokenizer.apply_chat_template(
79
+ conversation,
80
+ tokenize=False,
81
+ add_generation_prompt=True,
82
+ )
83
+
84
+ inputs = tokenizer(
85
+ input_text,
86
+ return_tensors="pt",
87
+ add_special_tokens=False,
88
+ ).to(device)
89
+
90
+ # Set max_new_tokens to fill remaining context
91
+ input_len = inputs["input_ids"].shape[1]
92
+ max_new_tokens = max(1, context_window - input_len)
93
+
94
+ # Set up a text streamer so we can yield partial generations
95
+ # token-by-token (or small chunks), while the model runs in a
96
+ # background thread.
97
+ streamer = TextIteratorStreamer(
98
+ tokenizer,
99
+ skip_prompt=True,
100
+ skip_special_tokens=True,
101
+ )
102
+
103
+ generation_config = GenerationConfig.from_pretrained(MODEL_ID)
104
+ generation_config.max_new_tokens = max_new_tokens
105
+ # suppressing a pesky warning (https://stackoverflow.com/a/71397707)
106
+ model.generation_config.pad_token_id = tokenizer.eos_token_id
107
+
108
+ # Run generation in a separate thread so that we can iterate over
109
+ # the streamer in this function and yield updates to Gradio.
110
+ def _run_generation():
111
+ model.generate(
112
+ **inputs,
113
+ generation_config=generation_config,
114
+ streamer=streamer,
115
+ )
116
+
117
+ thread = threading.Thread(target=_run_generation)
118
+ thread.start()
119
+
120
+ # Streamed parsing of the `<think>...</think>` block.
121
+ # As soon as we see `<think>` in the stream, we start treating
122
+ # everything that follows as "reasoning" until we encounter `</think>`.
123
+ generated = ""
124
+ in_think = False
125
+
126
+ for new_text in streamer:
127
+ if not new_text:
128
+ continue
129
+
130
+ # Wrap thinking in a p with dedicated html
131
+ next_text_stripped = new_text.strip()
132
+ if next_text_stripped == "<think>":
133
+ generated += "<p style='color:#777; font-size: 12px; font-style:italic;'>"
134
+ in_think = True
135
+ continue
136
+ if next_text_stripped == "</think>":
137
+ generated += "</p>"
138
+ in_think = False
139
+ continue
140
+
141
+ generated += new_text
142
+
143
+ if in_think:
144
+ # If within thinking tags, temporarily close the div for coherence
145
+ yield generated + "</p>"
146
+ else:
147
+ # The thinking is over, the tag is closed
148
+ yield generated
149
+
150
+ # Ensure the generation thread is finished before returning.
151
+ thread.join()
152
+
153
+
154
+ demo = gr.ChatInterface(
155
+ predict,
156
+ api_name="chat",
157
+ )
158
+
159
+ demo.launch()