File size: 1,820 Bytes
0117c54
8187f33
0117c54
8187f33
 
0117c54
8187f33
 
0117c54
8187f33
0117c54
 
 
 
 
 
 
 
8187f33
0117c54
8187f33
 
 
 
 
 
0117c54
 
 
8187f33
 
 
0117c54
8187f33
 
 
 
0117c54
 
8187f33
 
0117c54
8187f33
0117c54
8187f33
0117c54
8187f33
0117c54
 
 
8187f33
 
 
 
0117c54
 
 
8187f33
0117c54
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load DeepSeek-R1 model and tokenizer
MODEL_NAME = "deepseek-ai/DeepSeek-R1"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Function to handle chat interactions
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    # Construct messages format
    messages = [{"role": "system", "content": system_message}]
    
    for user_input, bot_response in history:
        if user_input:
            messages.append({"role": "user", "content": user_input})
        if bot_response:
            messages.append({"role": "assistant", "content": bot_response})

    messages.append({"role": "user", "content": message})

    # Tokenize input
    input_text = "\n".join([msg["content"] for msg in messages])
    inputs = tokenizer(input_text, return_tensors="pt")

    # Generate response
    output = model.generate(
        **inputs,
        max_length=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True
    )

    response_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return response_text

# Gradio Chat Interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI assistant.", label="System Message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()