Spaces:
Runtime error
Runtime error
File size: 3,936 Bytes
5711651 7e876de 65a509c 5711651 506582f 7e876de 506582f 3071f4d 5711651 1441501 5711651 7e876de 5711651 65a509c 506582f 27a0eb7 5711651 506582f 7b81a62 506582f 7b81a62 506582f 7b81a62 506582f 1441501 506582f 1441501 506582f 7b81a62 65a509c 7b81a62 1441501 7b81a62 7e876de 5711651 1441501 506582f 7b81a62 506582f 1441501 65a509c 1441501 7e876de 1441501 7e876de 1441501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
import gradio as gr
import torch
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from typing import Iterator
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 2048
DESCRIPTION = """\
# DeepSeek-R1-Chat
This space demonstrates model [DeepSeek-R1](https://huggingface.co/deepseek-ai/deepseek-r1) by DeepSeek, a code model with 6XXB parameters fine-tuned for chat instructions.
**You can also try our R1 model in [official homepage](https://r1.deepseek.com/chat).**
"""
model_id = "deepseek-ai/deepseek-r1"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto" if device == "cuda" else None)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
def generate(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 2048, temperature: float = 0, top_p: float = 0, top_k: int = 50, repetition_penalty: float = 2, search_query: str = "") -> Iterator[str]:
conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
if search_query:
try:
r = requests.get(f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5)
data = r.json()
result = data.get("AbstractText", "")
if result:
conversation.append({"role": "system", "content": f"Search results for '{search_query}': {result}"})
except Exception as e:
conversation.append({"role": "system", "content": f"Search error: {e}"})
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant} for user, assistant in chat_history])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device)
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": False,
"top_p": top_p,
"top_k": top_k,
"num_beams": 1,
"repetition_penalty": repetition_penalty,
"eos_token_id": 32021,
}
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs).replace("<|EOT|>", "")
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(label="Max new tokens", minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=0.01, value=DEFAULT_MAX_NEW_TOKENS),
gr.Slider(label="Top-p (nucleus sampling)", minimum=0, maximum=1.0, step=0.01, value=0),
gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2),
gr.Textbox(label="Search Query (Optional)", placeholder="Enter search query to fetch online info", lines=1),
],
stop_btn=gr.Button("Stop"),
examples=[
["implement snake game using pygame"],
["Can you explain briefly to me what is the Python programming language?"],
["write a program to find the factorial of a number"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|