mrpoons-studio commited on
Commit
65a509c
·
verified ·
1 Parent(s): 27a0eb7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -27
app.py CHANGED
@@ -2,7 +2,9 @@ import os
2
  from threading import Thread
3
  from typing import Iterator
4
  import gradio as gr
5
- import spaces, torch, requests
 
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  MAX_MAX_NEW_TOKENS = 2048
@@ -16,11 +18,8 @@ This space demonstrates model [DeepSeek-R1](https://huggingface.co/deepseek-ai/d
16
 
17
  **You can also try our R1 model in [official homepage](https://r1.deepseek.com/chat).**
18
  """
19
- if not torch.cuda.is_available():
20
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
21
 
22
- if torch.cuda.is_available():
23
- model_id = "deepseek-ai/deepseek-r1"
24
  if torch.cuda.is_available():
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id, torch_dtype=torch.bfloat16, device_map="auto"
@@ -31,37 +30,56 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
  tokenizer.use_default_system_prompt = False
32
 
33
  @spaces.GPU
34
- def generate(message: str,
35
- chat_history: list[tuple[str, str]],
36
- system_prompt: str,
37
- max_new_tokens: int = 2048,
38
- temperature: float = 0,
39
- top_p: float = 0,
40
- top_k: int = 50,
41
- repetition_penalty: float = 2,
42
- search_query: str = "") -> Iterator[str]:
 
 
43
  conversation = []
44
  if system_prompt:
45
  conversation.append({"role": "system", "content": system_prompt})
46
  if search_query:
47
  try:
48
- r = requests.get(f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5)
 
 
49
  data = r.json()
50
  result = data.get("AbstractText", "")
51
  if result:
52
- conversation.append({"role": "system", "content": f"Search results for '{search_query}': {result}"})
 
 
 
 
 
53
  except Exception as e:
54
- conversation.append({"role": "system", "content": f"Search error: {e}"})
 
 
55
  for user, assistant in chat_history:
56
- conversation.extend([{"role": "user", "content": user},
57
- {"role": "assistant", "content": assistant}])
 
 
 
 
58
  conversation.append({"role": "user", "content": message})
59
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
60
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
63
  input_ids = input_ids.to(model.device)
64
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
65
  generate_kwargs = {
66
  "input_ids": input_ids,
67
  "streamer": streamer,
@@ -71,7 +89,7 @@ def generate(message: str,
71
  "top_k": top_k,
72
  "num_beams": 1,
73
  "repetition_penalty": repetition_penalty,
74
- "eos_token_id": 32021
75
  }
76
  t = Thread(target=model.generate, kwargs=generate_kwargs)
77
  t.start()
@@ -84,17 +102,35 @@ chat_interface = gr.ChatInterface(
84
  fn=generate,
85
  additional_inputs=[
86
  gr.Textbox(label="System prompt", lines=6),
87
- gr.Slider(label="Max new tokens", minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=0.01, value=DEFAULT_MAX_NEW_TOKENS),
88
- gr.Slider(label="Top-p (nucleus sampling)", minimum=0, maximum=1.0, step=0.01, value=0),
 
 
 
 
 
 
 
 
 
 
 
 
89
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
90
- gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2),
91
- gr.Textbox(label="Search Query (Optional)", placeholder="Enter search query to fetch online info", lines=1)
 
 
 
 
 
 
92
  ],
93
  stop_btn=gr.Button("Stop"),
94
  examples=[
95
  ["implement snake game using pygame"],
96
  ["Can you explain briefly to me what is the Python programming language?"],
97
- ["write a program to find the factorial of a number"]
98
  ],
99
  )
100
 
 
2
  from threading import Thread
3
  from typing import Iterator
4
  import gradio as gr
5
+ import spaces
6
+ import torch
7
+ import requests
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
 
18
 
19
  **You can also try our R1 model in [official homepage](https://r1.deepseek.com/chat).**
20
  """
 
 
21
 
22
+ model_id = "deepseek-ai/deepseek-r1"
 
23
  if torch.cuda.is_available():
24
  model = AutoModelForCausalLM.from_pretrained(
25
  model_id, torch_dtype=torch.bfloat16, device_map="auto"
 
30
  tokenizer.use_default_system_prompt = False
31
 
32
  @spaces.GPU
33
+ def generate(
34
+ message: str,
35
+ chat_history: list[tuple[str, str]],
36
+ system_prompt: str,
37
+ max_new_tokens: int = 2048,
38
+ temperature: float = 0,
39
+ top_p: float = 0,
40
+ top_k: int = 50,
41
+ repetition_penalty: float = 2,
42
+ search_query: str = "",
43
+ ) -> Iterator[str]:
44
  conversation = []
45
  if system_prompt:
46
  conversation.append({"role": "system", "content": system_prompt})
47
  if search_query:
48
  try:
49
+ r = requests.get(
50
+ f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5
51
+ )
52
  data = r.json()
53
  result = data.get("AbstractText", "")
54
  if result:
55
+ conversation.append(
56
+ {
57
+ "role": "system",
58
+ "content": f"Search results for '{search_query}': {result}",
59
+ }
60
+ )
61
  except Exception as e:
62
+ conversation.append(
63
+ {"role": "system", "content": f"Search error: {e}"}
64
+ )
65
  for user, assistant in chat_history:
66
+ conversation.extend(
67
+ [
68
+ {"role": "user", "content": user},
69
+ {"role": "assistant", "content": assistant},
70
+ ]
71
+ )
72
  conversation.append({"role": "user", "content": message})
73
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
74
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
75
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
76
+ gr.Warning(
77
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
78
+ )
79
  input_ids = input_ids.to(model.device)
80
+ streamer = TextIteratorStreamer(
81
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
82
+ )
83
  generate_kwargs = {
84
  "input_ids": input_ids,
85
  "streamer": streamer,
 
89
  "top_k": top_k,
90
  "num_beams": 1,
91
  "repetition_penalty": repetition_penalty,
92
+ "eos_token_id": 32021,
93
  }
94
  t = Thread(target=model.generate, kwargs=generate_kwargs)
95
  t.start()
 
102
  fn=generate,
103
  additional_inputs=[
104
  gr.Textbox(label="System prompt", lines=6),
105
+ gr.Slider(
106
+ label="Max new tokens",
107
+ minimum=0,
108
+ maximum=MAX_MAX_NEW_TOKENS,
109
+ step=0.01,
110
+ value=DEFAULT_MAX_NEW_TOKENS,
111
+ ),
112
+ gr.Slider(
113
+ label="Top-p (nucleus sampling)",
114
+ minimum=0,
115
+ maximum=1.0,
116
+ step=0.01,
117
+ value=0,
118
+ ),
119
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
120
+ gr.Slider(
121
+ label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2
122
+ ),
123
+ gr.Textbox(
124
+ label="Search Query (Optional)",
125
+ placeholder="Enter search query to fetch online info",
126
+ lines=1,
127
+ ),
128
  ],
129
  stop_btn=gr.Button("Stop"),
130
  examples=[
131
  ["implement snake game using pygame"],
132
  ["Can you explain briefly to me what is the Python programming language?"],
133
+ ["write a program to find the factorial of a number"],
134
  ],
135
  )
136