Shining-Data commited on
Commit
bdac4d5
·
verified ·
1 Parent(s): 5c14066

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -34
app.py CHANGED
@@ -10,7 +10,7 @@ import re # for parsing <think> blocks
10
  import gradio as gr
11
  import torch
12
  from transformers import pipeline, TextIteratorStreamer
13
- from transformers import AutoTokenizer
14
  from duckduckgo_search import DDGS
15
 
16
  from transformers import modeling_utils
@@ -55,53 +55,34 @@ def load_pipeline(model_name):
55
  return PIPELINES[model_name]
56
  repo = MODELS[model_name]["repo_id"]
57
  if model_name == "secgpt-mini":
58
- tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, subfolder="models")
 
59
  else:
60
- tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
 
61
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
62
  try:
63
- if model_name == "secgpt-mini":
64
- pipe = pipeline(
65
  task="text-generation",
66
- model=repo,
67
  tokenizer=tokenizer,
68
  trust_remote_code=True,
69
  torch_dtype=dtype,
70
  device_map=device,
71
- subfolder="models"
72
- )
73
- else:
74
- pipe = pipeline(
75
- task="text-generation",
76
- model=repo,
77
- tokenizer=tokenizer,
78
- trust_remote_code=True,
79
- torch_dtype=device,
80
- device_map="auto",
81
  )
82
  PIPELINES[model_name] = pipe
83
  return pipe
84
  except Exception:
85
  continue
86
  # Final fallback
87
- if model_name == "secgpt-mini":
88
- pipe = pipeline(
89
- task="text-generation",
90
- model=repo,
91
- tokenizer=tokenizer,
92
- trust_remote_code=True,
93
- torch_dtype=dtype,
94
- device_map=device,
95
- subfolder="models"
96
- )
97
- else:
98
- pipe = pipeline(
99
  task="text-generation",
100
- model=repo,
101
  tokenizer=tokenizer,
102
  trust_remote_code=True,
103
- device_map=device
104
- )
 
105
  PIPELINES[model_name] = pipe
106
  return pipe
107
 
@@ -290,9 +271,9 @@ with gr.Blocks(title="Yee R1 Demo") as demo:
290
  gr.Markdown("### Generation Parameters")
291
  max_tok = gr.Slider(64, 16384, value=4096, step=32, label="Max Tokens")
292
  temp = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="Temperature")
293
- k = gr.Slider(1, 100, value=40, step=1, label="Top-K")
294
- p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
295
- rp = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
296
  gr.Markdown("### Web Search Settings")
297
  mr = gr.Number(value=6, precision=0, label="Max Results")
298
  mc = gr.Number(value=600, precision=0, label="Max Chars/Result")
 
10
  import gradio as gr
11
  import torch
12
  from transformers import pipeline, TextIteratorStreamer
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
  from duckduckgo_search import DDGS
15
 
16
  from transformers import modeling_utils
 
55
  return PIPELINES[model_name]
56
  repo = MODELS[model_name]["repo_id"]
57
  if model_name == "secgpt-mini":
58
+ tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, device_map=device, subfolder="models")
59
+ model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True, device_map=device, subfolder="models")
60
  else:
61
+ tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, device_map=device)
62
+ model = AutoModelForCausalLM.from_pretrained(repo, trust_remote_code=True, device_map=device)
63
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
64
  try:
65
+ pipe = pipeline(
 
66
  task="text-generation",
67
+ model=model,
68
  tokenizer=tokenizer,
69
  trust_remote_code=True,
70
  torch_dtype=dtype,
71
  device_map=device,
 
 
 
 
 
 
 
 
 
 
72
  )
73
  PIPELINES[model_name] = pipe
74
  return pipe
75
  except Exception:
76
  continue
77
  # Final fallback
78
+ pipe = pipeline(
 
 
 
 
 
 
 
 
 
 
 
79
  task="text-generation",
80
+ model=model,
81
  tokenizer=tokenizer,
82
  trust_remote_code=True,
83
+ torch_dtype=dtype,
84
+ device_map=device,
85
+ )
86
  PIPELINES[model_name] = pipe
87
  return pipe
88
 
 
271
  gr.Markdown("### Generation Parameters")
272
  max_tok = gr.Slider(64, 16384, value=4096, step=32, label="Max Tokens")
273
  temp = gr.Slider(0.1, 2.0, value=0.6, step=0.1, label="Temperature")
274
+ k = gr.Slider(1, 100, value=20, step=1, label="Top-K")
275
+ p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P")
276
+ rp = gr.Slider(1.0, 2.0, value=1.0, step=0.1, label="Repetition Penalty")
277
  gr.Markdown("### Web Search Settings")
278
  mr = gr.Number(value=6, precision=0, label="Max Results")
279
  mc = gr.Number(value=600, precision=0, label="Max Chars/Result")