File size: 2,914 Bytes
49ca365
 
 
 
 
 
 
0eceec2
49ca365
 
 
 
 
6ae5907
49ca365
ed22894
49ca365
ed22894
6ae5907
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65a1226
ed22894
6ae5907
 
 
 
 
49ca365
6ae5907
 
49ca365
6ae5907
 
 
 
49ca365
 
 
6ae5907
49ca365
 
4911f6b
49ca365
 
 
 
 
 
6ae5907
49ca365
 
6ae5907
 
 
 
 
 
 
 
 
 
 
 
 
 
49ca365
 
ed22894
49ca365
65a1226
49ca365
ed22894
49ca365
 
 
 
ed22894
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

# Load model and tokenizer once at startup
model_name = "agokrani/phi-mini-instruct-distilled-qwq-bespoke-qwq"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

def format_messages(messages):
    """
    Combine a list of messages into a single prompt.
    """
    prompt = "\n".join(messages) + "\n"
    return prompt

def flatten_history(history):
    """
    If history is a list of pairs (or lists) like [(user, assistant), ...],
    flatten it into a list of strings.
    """
    flat = []
    for item in history:
        if isinstance(item, (list, tuple)):
            for sub in item:
                if sub:
                    flat.append(sub)
        elif isinstance(item, str):
            flat.append(item)
    return flat

@spaces.GPU(duration=120)
def respond(message, history, max_tokens, temperature, top_p):
    # Check if history is structured as pairs (user, assistant)
    if history and isinstance(history[0], (list, tuple)):
        flat_history = flatten_history(history)
    else:
        flat_history = history.copy() if history else []
    
    # Append the new user message
    flat_history.append(message)
    
    # Build the prompt from the flattened conversation history
    prompt = format_messages(flat_history)
    
    # Tokenize the prompt and record its length
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    prompt_length = input_ids.shape[1]
    
    # Generate the assistant response tokens
    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode only the newly generated portion (excluding the prompt)
    output_text = tokenizer.decode(output_ids[0][prompt_length:], skip_special_tokens=True)
    
    # Stream the output one character at a time
    partial = ""
    for char in output_text:
        partial += char
        time.sleep(0.01)  # Delay to simulate streaming
        yield partial
    
    # Update the history in its original structure
    if history and isinstance(history[0], (list, tuple)):
        history.append((message, output_text))
        return history
    else:
        flat_history.append(output_text)
        return flat_history

demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[
        gr.Slider(minimum=1, maximum=4096, value=2000, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
    ],
)

if __name__ == "__main__":
    demo.launch()