import threading import torch import gradio as gr from transformers import AutoTokenizer from transformers import GenerationConfig from transformers import AutoModelForCausalLM from transformers import TextIteratorStreamer # from transformers import BitsAndBytesConfig # BEWARE: this app will only work with 'chat' models (that have a # `.chat_template` in their `tokenizer` – you can check that # Qwen3-06B has one: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/tokenizer_config.json) # Also, note that there is a mechanism to detect 'thinking' tokens and # displaying them differently, but if the chosen model outputs them in # a different format than , then that won't work, and # you need to study the model output and change the checks accordingly! # MODEL_ID = "google/gemma-3-270m-it" MODEL_ID = "Qwen/Qwen3-0.6B" # The overall 'directive' for our bot, see below SYSTEM = "You are a helpful, concise assistant." device = ( "cuda" if torch.cuda.is_available() # note: models using bfloat16 aren't compatible with MPS # else "mps" # if torch.backends.mps.is_available() else "cpu" ) # Theoretically, you can reduce the memory footprint and increase the speed of # your model by loading it quantized, but that means making sure bitsandbytes # is installed (with pip only), and my tests haven't led to conclusive results # quantization_config = BitsAndBytesConfig(load_in_8bit=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, # quantization_config=quantization_config ).to(device) # Context window from model config (fallback if missing) context_window = getattr(model.config, "max_position_embeddings", None) if context_window is None: context_window = getattr(tokenizer, "model_max_length", 2048) print(f"model: {MODEL_ID}, context window: {context_window}.") def predict(message, history): """ Gradio ChatInterface callback. - `history` is a list of dicts with `role` and `content` (type="messages"). - We append the latest user message, then build a chat template for Qwen. """ # print(history) # Make sure we don't mutate Gradio's history list in-place conversation = history + [{"role": "user", "content": message}] # Optionally prepend a system prompt; this also helps some Qwen templates. if SYSTEM: conversation = [ { "role": "system", "content": SYSTEM, }, *conversation, ] # Use Qwen's chat template and add a generation prompt so the model knows # it should now produce the assistant's reply. input_text = tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer( input_text, return_tensors="pt", add_special_tokens=False, ).to(device) # Set max_new_tokens to fill remaining context input_len = inputs["input_ids"].shape[1] max_new_tokens = max(1, context_window - input_len) # Set up a text streamer so we can yield partial generations # token-by-token (or small chunks), while the model runs in a # background thread. streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generation_config = GenerationConfig.from_pretrained(MODEL_ID) generation_config.max_new_tokens = max_new_tokens # suppressing a pesky warning (https://stackoverflow.com/a/71397707) model.generation_config.pad_token_id = tokenizer.eos_token_id # Run generation in a separate thread so that we can iterate over # the streamer in this function and yield updates to Gradio. def _run_generation(): model.generate( **inputs, generation_config=generation_config, streamer=streamer, ) thread = threading.Thread(target=_run_generation) thread.start() # Streamed parsing of the `...` block. # As soon as we see `` in the stream, we start treating # everything that follows as "reasoning" until we encounter ``. generated = "" in_think = False for new_text in streamer: if not new_text: continue # Wrap thinking in a p with dedicated html next_text_stripped = new_text.strip() if next_text_stripped == "": generated += "

" in_think = True continue if next_text_stripped == "": generated += "

" in_think = False continue generated += new_text if in_think: # If within thinking tags, temporarily close the div for coherence yield generated + "

" else: # The thinking is over, the tag is closed yield generated # Ensure the generation thread is finished before returning. thread.join() demo = gr.ChatInterface( predict, api_name="chat", ) demo.launch()