| | import threading |
| |
|
| | import torch |
| | import gradio as gr |
| |
|
| | from transformers import AutoTokenizer |
| | from transformers import GenerationConfig |
| | from transformers import AutoModelForCausalLM |
| | from transformers import TextIteratorStreamer |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | MODEL_ID = "Qwen/Qwen3-0.6B" |
| |
|
| | |
| | SYSTEM = "You are a helpful, concise assistant." |
| |
|
| | device = ( |
| | "cuda" |
| | if torch.cuda.is_available() |
| | |
| | |
| | |
| | else "cpu" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_ID, |
| | |
| | ).to(device) |
| |
|
| | |
| | context_window = getattr(model.config, "max_position_embeddings", None) |
| | if context_window is None: |
| | context_window = getattr(tokenizer, "model_max_length", 2048) |
| |
|
| | print(f"model: {MODEL_ID}, context window: {context_window}.") |
| |
|
| |
|
| | def predict(message, history): |
| | """ |
| | Gradio ChatInterface callback. |
| | |
| | - `history` is a list of dicts with `role` and `content` (type="messages"). |
| | - We append the latest user message, then build a chat template for Qwen. |
| | """ |
| |
|
| | |
| |
|
| | |
| | conversation = history + [{"role": "user", "content": message}] |
| |
|
| | |
| | if SYSTEM: |
| | conversation = [ |
| | { |
| | "role": "system", |
| | "content": SYSTEM, |
| | }, |
| | *conversation, |
| | ] |
| |
|
| | |
| | |
| | input_text = tokenizer.apply_chat_template( |
| | conversation, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| |
|
| | inputs = tokenizer( |
| | input_text, |
| | return_tensors="pt", |
| | add_special_tokens=False, |
| | ).to(device) |
| |
|
| | |
| | input_len = inputs["input_ids"].shape[1] |
| | max_new_tokens = max(1, context_window - input_len) |
| |
|
| | |
| | |
| | |
| | streamer = TextIteratorStreamer( |
| | tokenizer, |
| | skip_prompt=True, |
| | skip_special_tokens=True, |
| | ) |
| |
|
| | generation_config = GenerationConfig.from_pretrained(MODEL_ID) |
| | generation_config.max_new_tokens = max_new_tokens |
| | |
| | model.generation_config.pad_token_id = tokenizer.eos_token_id |
| |
|
| | |
| | |
| | def _run_generation(): |
| | model.generate( |
| | **inputs, |
| | generation_config=generation_config, |
| | streamer=streamer, |
| | ) |
| |
|
| | thread = threading.Thread(target=_run_generation) |
| | thread.start() |
| |
|
| | |
| | |
| | |
| | generated = "" |
| | in_think = False |
| |
|
| | for new_text in streamer: |
| | if not new_text: |
| | continue |
| |
|
| | |
| | next_text_stripped = new_text.strip() |
| | if next_text_stripped == "<think>": |
| | generated += "<p style='color:#777; font-size: 12px; font-style:italic;'>" |
| | in_think = True |
| | continue |
| | if next_text_stripped == "</think>": |
| | generated += "</p>" |
| | in_think = False |
| | continue |
| |
|
| | generated += new_text |
| |
|
| | if in_think: |
| | |
| | yield generated + "</p>" |
| | else: |
| | |
| | yield generated |
| |
|
| | |
| | thread.join() |
| |
|
| |
|
| | demo = gr.ChatInterface( |
| | predict, |
| | api_name="chat", |
| | ) |
| |
|
| | demo.launch() |
| |
|