File size: 2,914 Bytes
49ca365 0eceec2 49ca365 6ae5907 49ca365 ed22894 49ca365 ed22894 6ae5907 65a1226 ed22894 6ae5907 49ca365 6ae5907 49ca365 6ae5907 49ca365 6ae5907 49ca365 4911f6b 49ca365 6ae5907 49ca365 6ae5907 49ca365 ed22894 49ca365 65a1226 49ca365 ed22894 49ca365 ed22894 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time
# Load model and tokenizer once at startup
model_name = "agokrani/phi-mini-instruct-distilled-qwq-bespoke-qwq"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
def format_messages(messages):
"""
Combine a list of messages into a single prompt.
"""
prompt = "\n".join(messages) + "\n"
return prompt
def flatten_history(history):
"""
If history is a list of pairs (or lists) like [(user, assistant), ...],
flatten it into a list of strings.
"""
flat = []
for item in history:
if isinstance(item, (list, tuple)):
for sub in item:
if sub:
flat.append(sub)
elif isinstance(item, str):
flat.append(item)
return flat
@spaces.GPU(duration=120)
def respond(message, history, max_tokens, temperature, top_p):
# Check if history is structured as pairs (user, assistant)
if history and isinstance(history[0], (list, tuple)):
flat_history = flatten_history(history)
else:
flat_history = history.copy() if history else []
# Append the new user message
flat_history.append(message)
# Build the prompt from the flattened conversation history
prompt = format_messages(flat_history)
# Tokenize the prompt and record its length
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
prompt_length = input_ids.shape[1]
# Generate the assistant response tokens
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Decode only the newly generated portion (excluding the prompt)
output_text = tokenizer.decode(output_ids[0][prompt_length:], skip_special_tokens=True)
# Stream the output one character at a time
partial = ""
for char in output_text:
partial += char
time.sleep(0.01) # Delay to simulate streaming
yield partial
# Update the history in its original structure
if history and isinstance(history[0], (list, tuple)):
history.append((message, output_text))
return history
else:
flat_history.append(output_text)
return flat_history
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Slider(minimum=1, maximum=4096, value=2000, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
)
if __name__ == "__main__":
demo.launch()
|