Shining-Data commited on
Commit
5cf5d21
·
verified ·
1 Parent(s): 5ba955b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -42
app.py CHANGED
@@ -7,14 +7,19 @@ from datetime import datetime
7
  import re # for parsing <think> blocks
8
  import gradio as gr
9
  import torch
10
- from transformers import pipeline, TextIteratorStreamer
11
- from transformers import AutoTokenizer
12
  from duckduckgo_search import DDGS
13
  # import spaces # Import spaces early to enable ZeroGPU support
14
 
15
  # Optional: Disable GPU visibility if you wish to force CPU usage
16
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
17
 
 
 
 
 
 
18
  # ------------------------------
19
  # Global Cancellation Event
20
  # ------------------------------
@@ -43,30 +48,13 @@ def load_pipeline(model_name):
43
  return PIPELINES[model_name]
44
  repo = MODELS[model_name]["repo_id"]
45
  tokenizer = AutoTokenizer.from_pretrained(repo)
46
- for dtype in (torch.bfloat16, torch.float16, torch.float32):
47
- try:
48
- pipe = pipeline(
49
- task="text-generation",
50
- model=repo,
51
- tokenizer=tokenizer,
52
- trust_remote_code=True,
53
- torch_dtype=dtype,
54
- device=-1 # CPU only # device_map="auto"
55
- )
56
- PIPELINES[model_name] = pipe
57
- return pipe
58
- except Exception:
59
- continue
60
- # Final fallback
61
- pipe = pipeline(
62
- task="text-generation",
63
- model=repo,
64
- tokenizer=tokenizer,
65
- trust_remote_code=True,
66
- device=-1 # CPU only # device_map="auto"
67
- )
68
- PIPELINES[model_name] = pipe
69
- return pipe
70
 
71
 
72
  def retrieve_context(query, max_results=6, max_chars=600):
@@ -153,19 +141,21 @@ def chat_response(user_msg, chat_history, system_prompt,
153
  streamer = TextIteratorStreamer(pipe.tokenizer,
154
  skip_prompt=True,
155
  skip_special_tokens=True)
156
- gen_thread = threading.Thread(
157
- target=pipe,
158
- args=(prompt,),
159
- kwargs={
160
- 'max_new_tokens': max_tokens,
161
- 'temperature': temperature,
162
- 'top_k': top_k,
163
- 'top_p': top_p,
164
- 'repetition_penalty': repeat_penalty,
165
- 'streamer': streamer,
166
- 'return_full_text': False,
167
- }
168
- )
 
 
169
  gen_thread.start()
170
 
171
  # Buffers for thought vs answer
@@ -253,11 +243,11 @@ with gr.Blocks(title="Yee R1 Demo") as demo:
253
  with gr.Row():
254
  with gr.Column(scale=3):
255
  model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])
256
- search_chk = gr.Checkbox(label="Enable Web Search", value=True)
257
  sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
258
  gr.Markdown("### Generation Parameters")
259
- max_tok = gr.Slider(64, 16384, value=2048, step=32, label="Max Tokens")
260
- temp = gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature")
261
  k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
262
  p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
263
  rp = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
 
7
  import re # for parsing <think> blocks
8
  import gradio as gr
9
  import torch
10
+ from transformers import TextIteratorStreamer
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
  from duckduckgo_search import DDGS
13
  # import spaces # Import spaces early to enable ZeroGPU support
14
 
15
  # Optional: Disable GPU visibility if you wish to force CPU usage
16
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
17
 
18
+ if torch.cuda.is_available():
19
+ device = "auto"
20
+ else:
21
+ device = "cpu"
22
+
23
  # ------------------------------
24
  # Global Cancellation Event
25
  # ------------------------------
 
48
  return PIPELINES[model_name]
49
  repo = MODELS[model_name]["repo_id"]
50
  tokenizer = AutoTokenizer.from_pretrained(repo)
51
+ model = AutoModelForCausalLM.from_pretrained(
52
+ repo,
53
+ device_map=device,
54
+ trust_remote_code=True,
55
+ )
56
+ PIPELINES[model_name] = {"tokenizer": tokenizer, "model": model}
57
+ return PIPELINES[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  def retrieve_context(query, max_results=6, max_chars=600):
 
141
  streamer = TextIteratorStreamer(pipe.tokenizer,
142
  skip_prompt=True,
143
  skip_special_tokens=True)
144
+ generation_config = dict(
145
+ temperature=temperature,
146
+ top_k=top_k,
147
+ top_p=top_p,
148
+ max_new_tokens=max_tokens,
149
+ do_sample=True,
150
+ repetition_penalty=repeat_penalty,
151
+ streamer=streamer,
152
+ )
153
+ inputs = pipe["tokenizer"](prompt, return_tensors="pt")
154
+ if device == "auto":
155
+ input_ids = inputs["input_ids"].cuda()
156
+ else:
157
+ input_ids = inputs["input_ids"]
158
+ gen_thread = threading.Thread(target=lambda: pipe["model"].generate(input_ids=input_ids, **generation_config))
159
  gen_thread.start()
160
 
161
  # Buffers for thought vs answer
 
243
  with gr.Row():
244
  with gr.Column(scale=3):
245
  model_dd = gr.Dropdown(label="Select Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0])
246
+ search_chk = gr.Checkbox(label="Enable Web Search", value=False)
247
  sys_prompt = gr.Textbox(label="System Prompt", lines=3, value=update_default_prompt(search_chk.value))
248
  gr.Markdown("### Generation Parameters")
249
+ max_tok = gr.Slider(64, 16384, value=4096, step=32, label="Max Tokens")
250
+ temp = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="Temperature")
251
  k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
252
  p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
253
  rp = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")