File size: 5,204 Bytes
acf0628 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | import threading
import torch
import gradio as gr
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers import AutoModelForCausalLM
from transformers import TextIteratorStreamer
# from transformers import BitsAndBytesConfig
# BEWARE: this app will only work with 'chat' models (that have a
# `.chat_template` in their `tokenizer` – you can check that
# Qwen3-06B has one: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/tokenizer_config.json)
# Also, note that there is a mechanism to detect 'thinking' tokens and
# displaying them differently, but if the chosen model outputs them in
# a different format than <think></think>, then that won't work, and
# you need to study the model output and change the checks accordingly!
# MODEL_ID = "google/gemma-3-270m-it"
MODEL_ID = "Qwen/Qwen3-0.6B"
# The overall 'directive' for our bot, see below
SYSTEM = "You are a helpful, concise assistant."
device = (
"cuda"
if torch.cuda.is_available()
# note: models using bfloat16 aren't compatible with MPS
# else "mps"
# if torch.backends.mps.is_available()
else "cpu"
)
# Theoretically, you can reduce the memory footprint and increase the speed of
# your model by loading it quantized, but that means making sure bitsandbytes
# is installed (with pip only), and my tests haven't led to conclusive results
# quantization_config = BitsAndBytesConfig(load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
# quantization_config=quantization_config
).to(device)
# Context window from model config (fallback if missing)
context_window = getattr(model.config, "max_position_embeddings", None)
if context_window is None:
context_window = getattr(tokenizer, "model_max_length", 2048)
print(f"model: {MODEL_ID}, context window: {context_window}.")
def predict(message, history):
"""
Gradio ChatInterface callback.
- `history` is a list of dicts with `role` and `content` (type="messages").
- We append the latest user message, then build a chat template for Qwen.
"""
# print(history)
# Make sure we don't mutate Gradio's history list in-place
conversation = history + [{"role": "user", "content": message}]
# Optionally prepend a system prompt; this also helps some Qwen templates.
if SYSTEM:
conversation = [
{
"role": "system",
"content": SYSTEM,
},
*conversation,
]
# Use Qwen's chat template and add a generation prompt so the model knows
# it should now produce the assistant's reply.
input_text = tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(
input_text,
return_tensors="pt",
add_special_tokens=False,
).to(device)
# Set max_new_tokens to fill remaining context
input_len = inputs["input_ids"].shape[1]
max_new_tokens = max(1, context_window - input_len)
# Set up a text streamer so we can yield partial generations
# token-by-token (or small chunks), while the model runs in a
# background thread.
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generation_config = GenerationConfig.from_pretrained(MODEL_ID)
generation_config.max_new_tokens = max_new_tokens
# suppressing a pesky warning (https://stackoverflow.com/a/71397707)
model.generation_config.pad_token_id = tokenizer.eos_token_id
# Run generation in a separate thread so that we can iterate over
# the streamer in this function and yield updates to Gradio.
def _run_generation():
model.generate(
**inputs,
generation_config=generation_config,
streamer=streamer,
)
thread = threading.Thread(target=_run_generation)
thread.start()
# Streamed parsing of the `<think>...</think>` block.
# As soon as we see `<think>` in the stream, we start treating
# everything that follows as "reasoning" until we encounter `</think>`.
generated = ""
in_think = False
for new_text in streamer:
if not new_text:
continue
# Wrap thinking in a p with dedicated html
next_text_stripped = new_text.strip()
if next_text_stripped == "<think>":
generated += "<p style='color:#777; font-size: 12px; font-style:italic;'>"
in_think = True
continue
if next_text_stripped == "</think>":
generated += "</p>"
in_think = False
continue
generated += new_text
if in_think:
# If within thinking tags, temporarily close the div for coherence
yield generated + "</p>"
else:
# The thinking is over, the tag is closed
yield generated
# Ensure the generation thread is finished before returning.
thread.join()
demo = gr.ChatInterface(
predict,
api_name="chat",
)
demo.launch()
|