mrpoons-studio commited on
Commit
506582f
·
verified ·
1 Parent(s): 65a509c

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -74
app.py CHANGED
@@ -1,15 +1,14 @@
1
  import os
2
- from threading import Thread
3
- from typing import Iterator
4
  import gradio as gr
5
- import spaces
6
  import torch
7
  import requests
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
9
 
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 2048
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))
13
 
14
  DESCRIPTION = """\
15
  # DeepSeek-R1-Chat
@@ -20,66 +19,29 @@ This space demonstrates model [DeepSeek-R1](https://huggingface.co/deepseek-ai/d
20
  """
21
 
22
  model_id = "deepseek-ai/deepseek-r1"
23
- if torch.cuda.is_available():
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_id, torch_dtype=torch.bfloat16, device_map="auto"
26
- )
27
- else:
28
- model = AutoModelForCausalLM.from_pretrained(model_id)
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
  tokenizer.use_default_system_prompt = False
31
 
32
- @spaces.GPU
33
- def generate(
34
- message: str,
35
- chat_history: list[tuple[str, str]],
36
- system_prompt: str,
37
- max_new_tokens: int = 2048,
38
- temperature: float = 0,
39
- top_p: float = 0,
40
- top_k: int = 50,
41
- repetition_penalty: float = 2,
42
- search_query: str = "",
43
- ) -> Iterator[str]:
44
- conversation = []
45
- if system_prompt:
46
- conversation.append({"role": "system", "content": system_prompt})
47
  if search_query:
48
  try:
49
- r = requests.get(
50
- f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5
51
- )
52
  data = r.json()
53
  result = data.get("AbstractText", "")
54
  if result:
55
- conversation.append(
56
- {
57
- "role": "system",
58
- "content": f"Search results for '{search_query}': {result}",
59
- }
60
- )
61
  except Exception as e:
62
- conversation.append(
63
- {"role": "system", "content": f"Search error: {e}"}
64
- )
65
- for user, assistant in chat_history:
66
- conversation.extend(
67
- [
68
- {"role": "user", "content": user},
69
- {"role": "assistant", "content": assistant},
70
- ]
71
- )
72
  conversation.append({"role": "user", "content": message})
73
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
74
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
75
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
76
- gr.Warning(
77
- f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
78
- )
79
- input_ids = input_ids.to(model.device)
80
- streamer = TextIteratorStreamer(
81
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
82
- )
83
  generate_kwargs = {
84
  "input_ids": input_ids,
85
  "streamer": streamer,
@@ -102,29 +64,11 @@ chat_interface = gr.ChatInterface(
102
  fn=generate,
103
  additional_inputs=[
104
  gr.Textbox(label="System prompt", lines=6),
105
- gr.Slider(
106
- label="Max new tokens",
107
- minimum=0,
108
- maximum=MAX_MAX_NEW_TOKENS,
109
- step=0.01,
110
- value=DEFAULT_MAX_NEW_TOKENS,
111
- ),
112
- gr.Slider(
113
- label="Top-p (nucleus sampling)",
114
- minimum=0,
115
- maximum=1.0,
116
- step=0.01,
117
- value=0,
118
- ),
119
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
120
- gr.Slider(
121
- label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2
122
- ),
123
- gr.Textbox(
124
- label="Search Query (Optional)",
125
- placeholder="Enter search query to fetch online info",
126
- lines=1,
127
- ),
128
  ],
129
  stop_btn=gr.Button("Stop"),
130
  examples=[
 
1
  import os
 
 
2
  import gradio as gr
 
3
  import torch
4
  import requests
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ from threading import Thread
7
+ from typing import Iterator
8
 
9
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 2048
 
12
 
13
  DESCRIPTION = """\
14
  # DeepSeek-R1-Chat
 
19
  """
20
 
21
  model_id = "deepseek-ai/deepseek-r1"
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto" if device == "cuda" else None)
 
 
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  tokenizer.use_default_system_prompt = False
26
 
27
+ def generate(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = 2048, temperature: float = 0, top_p: float = 0, top_k: int = 50, repetition_penalty: float = 2, search_query: str = "") -> Iterator[str]:
28
+ conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if search_query:
30
  try:
31
+ r = requests.get(f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5)
 
 
32
  data = r.json()
33
  result = data.get("AbstractText", "")
34
  if result:
35
+ conversation.append({"role": "system", "content": f"Search results for '{search_query}': {result}"})
 
 
 
 
 
36
  except Exception as e:
37
+ conversation.append({"role": "system", "content": f"Search error: {e}"})
38
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant} for user, assistant in chat_history])
 
 
 
 
 
 
 
 
39
  conversation.append({"role": "user", "content": message})
40
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(device)
41
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
42
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
43
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
44
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
45
  generate_kwargs = {
46
  "input_ids": input_ids,
47
  "streamer": streamer,
 
64
  fn=generate,
65
  additional_inputs=[
66
  gr.Textbox(label="System prompt", lines=6),
67
+ gr.Slider(label="Max new tokens", minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=0.01, value=DEFAULT_MAX_NEW_TOKENS),
68
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0, maximum=1.0, step=0.01, value=0),
 
 
 
 
 
 
 
 
 
 
 
 
69
  gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
70
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2),
71
+ gr.Textbox(label="Search Query (Optional)", placeholder="Enter search query to fetch online info", lines=1),
 
 
 
 
 
 
72
  ],
73
  stop_btn=gr.Button("Stop"),
74
  examples=[