import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer import torch import time # Load model and tokenizer once at startup model_name = "agokrani/phi-mini-instruct-distilled-qwq-bespoke-qwq" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") def format_messages(messages): """ Combine a list of messages into a single prompt. """ prompt = "\n".join(messages) + "\n" return prompt def flatten_history(history): """ If history is a list of pairs (or lists) like [(user, assistant), ...], flatten it into a list of strings. """ flat = [] for item in history: if isinstance(item, (list, tuple)): for sub in item: if sub: flat.append(sub) elif isinstance(item, str): flat.append(item) return flat @spaces.GPU(duration=120) def respond(message, history, max_tokens, temperature, top_p): # Check if history is structured as pairs (user, assistant) if history and isinstance(history[0], (list, tuple)): flat_history = flatten_history(history) else: flat_history = history.copy() if history else [] # Append the new user message flat_history.append(message) # Build the prompt from the flattened conversation history prompt = format_messages(flat_history) # Tokenize the prompt and record its length input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) prompt_length = input_ids.shape[1] # Generate the assistant response tokens output_ids = model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode only the newly generated portion (excluding the prompt) output_text = tokenizer.decode(output_ids[0][prompt_length:], skip_special_tokens=True) # Stream the output one character at a time partial = "" for char in output_text: partial += char time.sleep(0.01) # Delay to simulate streaming yield partial # Update the history in its original structure if history and isinstance(history[0], (list, tuple)): history.append((message, output_text)) return history else: flat_history.append(output_text) return flat_history demo = gr.ChatInterface( fn=respond, additional_inputs=[ gr.Slider(minimum=1, maximum=4096, value=2000, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") ], ) if __name__ == "__main__": demo.launch()