mrpoons-studio commited on
Commit
7b81a62
·
verified ·
1 Parent(s): 79739eb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +39 -77
  2. requirements.txt +4 -4
app.py CHANGED
@@ -1,16 +1,12 @@
1
  import os
2
-
3
  from threading import Thread
4
  from typing import Iterator
5
-
6
  import gradio as gr
7
- import spaces
8
- import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  MAX_MAX_NEW_TOKENS = 2048
12
  DEFAULT_MAX_NEW_TOKENS = 2048
13
- total_count=0
14
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))
15
 
16
  DESCRIPTION = """\
@@ -20,114 +16,80 @@ This space demonstrates model [DeepSeek-R1](https://huggingface.co/deepseek-ai/d
20
 
21
  **You can also try our R1 model in [official homepage](https://r1.deepseek.com/chat).**
22
  """
23
-
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
 
27
-
28
  if torch.cuda.is_available():
29
  model_id = "deepseek-ai/deepseek-r1"
30
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
31
  tokenizer = AutoTokenizer.from_pretrained(model_id)
32
  tokenizer.use_default_system_prompt = False
33
-
34
-
35
 
36
  @spaces.GPU
37
- def generate(
38
- message: str,
39
- chat_history: list[tuple[str, str]],
40
- system_prompt: str,
41
- max_new_tokens: int = 2048,
42
- temperature: float = 0,
43
- top_p: float = 0,
44
- top_k: int = 50,
45
- repetition_penalty: float = 2,
46
- ) -> Iterator[str]:
47
- global total_count
48
- total_count += 1
49
- print(total_count)
50
- os.system("nvidia-smi")
51
  conversation = []
52
  if system_prompt:
53
  conversation.append({"role": "system", "content": system_prompt})
 
 
 
 
 
 
 
 
 
54
  for user, assistant in chat_history:
55
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
56
  conversation.append({"role": "user", "content": message})
57
-
58
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
59
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
60
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
61
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
62
  input_ids = input_ids.to(model.device)
63
-
64
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
65
- generate_kwargs = dict(
66
- {"input_ids": input_ids},
67
- streamer=streamer,
68
- max_new_tokens=max_new_tokens,
69
- do_sample=False,
70
- top_p=top_p,
71
- top_k=top_k,
72
- num_beams=1,
73
- # temperature=temperature,
74
- repetition_penalty=repetition_penalty,
75
- eos_token_id=32021
76
- )
77
  t = Thread(target=model.generate, kwargs=generate_kwargs)
78
  t.start()
79
-
80
  outputs = []
81
  for text in streamer:
82
  outputs.append(text)
83
- yield "".join(outputs).replace("<|EOT|>","")
84
-
85
 
86
  chat_interface = gr.ChatInterface(
87
  fn=generate,
88
  additional_inputs=[
89
  gr.Textbox(label="System prompt", lines=6),
90
- gr.Slider(
91
- label="Max new tokens",
92
- minimum=0,
93
- maximum=MAX_MAX_NEW_TOKENS,
94
- step=0.01,
95
- value=DEFAULT_MAX_NEW_TOKENS,
96
- ),
97
- # gr.Slider(
98
- # label="Temperature",
99
- # minimum=0,
100
- # maximum=4.0,
101
- # step=0.01,
102
- # value=0,
103
- # ),
104
- gr.Slider(
105
- label="Top-p (nucleus sampling)",
106
- minimum=0,
107
- maximum=1.0,
108
- step=0.01,
109
- value=0,
110
- ),
111
- gr.Slider(
112
- label="Top-k",
113
- minimum=1,
114
- maximum=1000,
115
- step=0.01,
116
- value=50,
117
- ),
118
- gr.Slider(
119
- label="Repetition penalty",
120
- minimum=1.0,
121
- maximum=2.0,
122
- step=0.01,
123
- value=2,
124
- ),
125
  ],
126
  stop_btn=gr.Button("Stop"),
127
  examples=[
128
  ["implement snake game using pygame"],
129
  ["Can you explain briefly to me what is the Python programming language?"],
130
- ["write a program to find the factorial of a number"],
131
  ],
132
  )
133
 
 
1
  import os
 
2
  from threading import Thread
3
  from typing import Iterator
 
4
  import gradio as gr
5
+ import spaces, torch, requests
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
 
8
  MAX_MAX_NEW_TOKENS = 2048
9
  DEFAULT_MAX_NEW_TOKENS = 2048
 
10
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "128000"))
11
 
12
  DESCRIPTION = """\
 
16
 
17
  **You can also try our R1 model in [official homepage](https://r1.deepseek.com/chat).**
18
  """
 
19
  if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
21
 
 
22
  if torch.cuda.is_available():
23
  model_id = "deepseek-ai/deepseek-r1"
24
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
25
  tokenizer = AutoTokenizer.from_pretrained(model_id)
26
  tokenizer.use_default_system_prompt = False
 
 
27
 
28
  @spaces.GPU
29
+ def generate(message: str,
30
+ chat_history: list[tuple[str, str]],
31
+ system_prompt: str,
32
+ max_new_tokens: int = 2048,
33
+ temperature: float = 0,
34
+ top_p: float = 0,
35
+ top_k: int = 50,
36
+ repetition_penalty: float = 2,
37
+ search_query: str = "") -> Iterator[str]:
 
 
 
 
 
38
  conversation = []
39
  if system_prompt:
40
  conversation.append({"role": "system", "content": system_prompt})
41
+ if search_query:
42
+ try:
43
+ r = requests.get(f"https://api.duckduckgo.com/?q={search_query}&format=json", timeout=5)
44
+ data = r.json()
45
+ result = data.get("AbstractText", "")
46
+ if result:
47
+ conversation.append({"role": "system", "content": f"Search results for '{search_query}': {result}"})
48
+ except Exception as e:
49
+ conversation.append({"role": "system", "content": f"Search error: {e}"})
50
  for user, assistant in chat_history:
51
+ conversation.extend([{"role": "user", "content": user},
52
+ {"role": "assistant", "content": assistant}])
53
  conversation.append({"role": "user", "content": message})
 
54
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
55
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
  input_ids = input_ids.to(model.device)
 
59
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
60
+ generate_kwargs = {
61
+ "input_ids": input_ids,
62
+ "streamer": streamer,
63
+ "max_new_tokens": max_new_tokens,
64
+ "do_sample": False,
65
+ "top_p": top_p,
66
+ "top_k": top_k,
67
+ "num_beams": 1,
68
+ "repetition_penalty": repetition_penalty,
69
+ "eos_token_id": 32021
70
+ }
 
71
  t = Thread(target=model.generate, kwargs=generate_kwargs)
72
  t.start()
 
73
  outputs = []
74
  for text in streamer:
75
  outputs.append(text)
76
+ yield "".join(outputs).replace("<|EOT|>", "")
 
77
 
78
  chat_interface = gr.ChatInterface(
79
  fn=generate,
80
  additional_inputs=[
81
  gr.Textbox(label="System prompt", lines=6),
82
+ gr.Slider(label="Max new tokens", minimum=0, maximum=MAX_MAX_NEW_TOKENS, step=0.01, value=DEFAULT_MAX_NEW_TOKENS),
83
+ gr.Slider(label="Top-p (nucleus sampling)", minimum=0, maximum=1.0, step=0.01, value=0),
84
+ gr.Slider(label="Top-k", minimum=1, maximum=1000, step=0.01, value=50),
85
+ gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.01, value=2),
86
+ gr.Textbox(label="Search Query (Optional)", placeholder="Enter search query to fetch online info", lines=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  ],
88
  stop_btn=gr.Button("Stop"),
89
  examples=[
90
  ["implement snake game using pygame"],
91
  ["Can you explain briefly to me what is the Python programming language?"],
92
+ ["write a program to find the factorial of a number"]
93
  ],
94
  )
95
 
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
- accelerate==0.23.0
2
  bitsandbytes==0.41.1
3
- gradio==3.48.0
4
  protobuf==3.20.3
5
  scipy==1.11.2
6
  sentencepiece==0.1.99
7
  spaces==0.16.1
8
- torch==2.0.0
9
- transformers==4.34.0
 
1
+ accelerate==0.23.2
2
  bitsandbytes==0.41.1
3
+ gradio==3.50.1
4
  protobuf==3.20.3
5
  scipy==1.11.2
6
  sentencepiece==0.1.99
7
  spaces==0.16.1
8
+ torch==2.0.1
9
+ transformers==4.35.1