Ayushnangia's picture
Update app.py
65a1226 verified
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
# Load model and tokenizer once at startup
model_name = "agokrani/phi-mini-instruct-distilled-qwq-bespoke-qwq"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
def format_messages(messages):
"""
Combine a list of messages into a single prompt.
"""
prompt = "\n".join(messages) + "\n"
return prompt
def flatten_history(history):
"""
If history is a list of pairs (or lists) like [(user, assistant), ...],
flatten it into a list of strings.
"""
flat = []
for item in history:
if isinstance(item, (list, tuple)):
for sub in item:
if sub:
flat.append(sub)
elif isinstance(item, str):
flat.append(item)
return flat
@spaces.GPU(duration=120)
def respond(message, history, max_tokens, temperature, top_p):
# Check if history is structured as pairs (user, assistant)
if history and isinstance(history[0], (list, tuple)):
flat_history = flatten_history(history)
else:
flat_history = history.copy() if history else []
# Append the new user message
flat_history.append(message)
# Build the prompt from the flattened conversation history
prompt = format_messages(flat_history)
# Tokenize the prompt and record its length
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
prompt_length = input_ids.shape[1]
# Generate the assistant response tokens
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode only the newly generated portion (excluding the prompt)
output_text = tokenizer.decode(output_ids[0][prompt_length:], skip_special_tokens=True)
# Stream the output one character at a time
partial = ""
for char in output_text:
partial += char
time.sleep(0.01) # Delay to simulate streaming
yield partial
# Update the history in its original structure
if history and isinstance(history[0], (list, tuple)):
history.append((message, output_text))
return history
else:
flat_history.append(output_text)
return flat_history
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=4096, value=2000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
)
if __name__ == "__main__":
demo.launch()