|
|
import gradio as gr |
|
|
import spaces |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import time |
|
|
|
|
|
|
|
|
model_name = "agokrani/phi-mini-instruct-distilled-qwq-bespoke-qwq" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") |
|
|
|
|
|
def format_messages(messages): |
|
|
""" |
|
|
Combine a list of messages into a single prompt. |
|
|
""" |
|
|
prompt = "\n".join(messages) + "\n" |
|
|
return prompt |
|
|
|
|
|
def flatten_history(history): |
|
|
""" |
|
|
If history is a list of pairs (or lists) like [(user, assistant), ...], |
|
|
flatten it into a list of strings. |
|
|
""" |
|
|
flat = [] |
|
|
for item in history: |
|
|
if isinstance(item, (list, tuple)): |
|
|
for sub in item: |
|
|
if sub: |
|
|
flat.append(sub) |
|
|
elif isinstance(item, str): |
|
|
flat.append(item) |
|
|
return flat |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def respond(message, history, max_tokens, temperature, top_p): |
|
|
|
|
|
if history and isinstance(history[0], (list, tuple)): |
|
|
flat_history = flatten_history(history) |
|
|
else: |
|
|
flat_history = history.copy() if history else [] |
|
|
|
|
|
|
|
|
flat_history.append(message) |
|
|
|
|
|
|
|
|
prompt = format_messages(flat_history) |
|
|
|
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device) |
|
|
prompt_length = input_ids.shape[1] |
|
|
|
|
|
|
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
output_text = tokenizer.decode(output_ids[0][prompt_length:], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
partial = "" |
|
|
for char in output_text: |
|
|
partial += char |
|
|
time.sleep(0.01) |
|
|
yield partial |
|
|
|
|
|
|
|
|
if history and isinstance(history[0], (list, tuple)): |
|
|
history.append((message, output_text)) |
|
|
return history |
|
|
else: |
|
|
flat_history.append(output_text) |
|
|
return flat_history |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=respond, |
|
|
additional_inputs=[ |
|
|
gr.Slider(minimum=1, maximum=4096, value=2000, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") |
|
|
], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|