File size: 3,936 Bytes
5711651
7e876de
65a509c
 
5711651
506582f
 
7e876de
506582f
3071f4d
 
5711651
 
 
 
1441501
5711651
 
7e876de
5711651
65a509c
506582f
 
27a0eb7
 
5711651
506582f
 
7b81a62
 
506582f
7b81a62
 
 
506582f
7b81a62
506582f
 
1441501
506582f
1441501
 
506582f
 
7b81a62
 
 
 
 
 
 
 
 
65a509c
7b81a62
1441501
 
 
 
 
7b81a62
7e876de
5711651
1441501
 
 
506582f
 
7b81a62
506582f
 
1441501
 
 
 
 
65a509c
1441501
7e876de
 
1441501
 
 
7e876de
 
1441501
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import gradio as gr
import torch
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator

MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 2048

DESCRIPTION = """\
# DeepSeek-R1-Chat

This space demonstrates model [DeepSeek-R1](https://huggingface.co/deepseek-ai/deepseek-r1) by DeepSeek, a code model with 6XXB parameters fine-tuned for chat instructions.

**You can also try our R1 model in [official homepage](https://r1.deepseek.com/chat).**
"""

model_id = "deepseek-ai/deepseek-r1"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto" if device == "cuda" else None)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False

def generate(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 2048, temperature: float = 0, top_p: float = 0, top_k: int = 50, repetition_penalty: float = 2, search_query: str = "") -> Iterator[str]:
    conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
    if search_query:
        try:
            r = requests.get(f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5)
            data = r.json()
            result = data.get("AbstractText", "")
            if result:
                conversation.append({"role": "system", "content": f"Search results for '{search_query}': {result}"})
        except Exception as e:
            conversation.append({"role": "system", "content": f"Search error: {e}"})
    conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant} for user, assistant in chat_history])
    conversation.append({"role": "user", "content": message})
    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device)
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = {
        "input_ids": input_ids,
        "streamer": streamer,
        "max_new_tokens": max_new_tokens,
        "do_sample": False,
        "top_p": top_p,
        "top_k": top_k,
        "num_beams": 1,
        "repetition_penalty": repetition_penalty,
        "eos_token_id": 32021,
    }
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs).replace("<|EOT|>", "")

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(label="Max new tokens", minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=0.01, value=DEFAULT_MAX_NEW_TOKENS),
        gr.Slider(label="Top-p (nucleus sampling)", minimum=0, maximum=1.0, step=0.01, value=0),
        gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
        gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2),
        gr.Textbox(label="Search Query (Optional)", placeholder="Enter search query to fetch online info", lines=1),
    ],
    stop_btn=gr.Button("Stop"),
    examples=[
        ["implement snake game using pygame"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["write a program to find the factorial of a number"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()