Shining-Data commited on
Commit
24fbba8
·
verified ·
1 Parent(s): 053d245

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -69
app.py CHANGED
@@ -9,13 +9,14 @@ from datetime import datetime
9
  import re # for parsing <think> blocks
10
  import gradio as gr
11
  import torch
12
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
13
  from duckduckgo_search import DDGS
14
 
15
  from transformers import modeling_utils
16
  if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
17
  modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']
18
-
19
  # import spaces # Import spaces early to enable ZeroGPU support
20
 
21
  # Optional: Disable GPU visibility if you wish to force CPU usage
@@ -44,69 +45,66 @@ MODELS = {
44
  # Global cache for pipelines to avoid re-loading.
45
  PIPELINES = {}
46
 
47
- class TextIterStreamer:
48
- def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
49
- self.tokenizer = tokenizer
50
- self.skip_prompt = skip_prompt
51
- self.skip_special_tokens = skip_special_tokens
52
- self.tokens = []
53
- self.text_queue = Queue()
54
- # self.text_queue = []
55
- self.next_tokens_are_prompt = True
56
-
57
- def put(self, value):
58
- if self.skip_prompt and self.next_tokens_are_prompt:
59
- self.next_tokens_are_prompt = False
60
- else:
61
- if len(value.shape) > 1:
62
- value = value[0]
63
- self.tokens.extend(value.tolist())
64
- word = self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)
65
- # self.text_queue.append(word)
66
- self.text_queue.put(word)
67
-
68
- def end(self):
69
- # self.text_queue.append(None)
70
- self.text_queue.put(None)
71
-
72
- def __iter__(self):
73
- return self
74
-
75
- def __next__(self):
76
- value = self.text_queue.get()
77
- if value is None:
78
- raise StopIteration()
79
- else:
80
- return value
81
-
82
-
83
  def load_pipeline(model_name):
84
  """
85
  Load and cache a transformers pipeline for text generation.
86
  Tries bfloat16, falls back to float16 or float32 if unsupported.
87
  """
88
  global PIPELINES
89
-
90
- if model_name in PIPELINES.keys():
91
  return PIPELINES[model_name]
92
  repo = MODELS[model_name]["repo_id"]
93
  if model_name == "secgpt-mini":
 
 
94
  tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, subfolder="models")
95
- model = AutoModelForCausalLM.from_pretrained(
96
- repo,
97
- device_map=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  trust_remote_code=True,
99
- subfolder="models",
100
- )
101
- else:
102
- tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
103
- model = AutoModelForCausalLM.from_pretrained(
104
- repo,
105
  device_map=device,
106
- trust_remote_code=True,
107
- )
108
- PIPELINES[model_name] = {"tokenizer": tokenizer, "model": model}
109
- return {"tokenizer": tokenizer, "model": model}
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  def retrieve_context(query, max_results=6, max_chars=600):
@@ -182,26 +180,24 @@ def chat_response(user_msg, chat_history, system_prompt,
182
  enriched = system_prompt
183
 
184
  pipe = load_pipeline(model_name)
185
- prompt = format_conversation(history, enriched, pipe["tokenizer"])
186
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
187
- streamer = TextIterStreamer(pipe["tokenizer"],
188
  skip_prompt=True,
189
  skip_special_tokens=True)
190
- generation_config = dict(
191
- temperature=temperature,
192
- top_k=top_k,
193
- top_p=top_p,
194
- max_new_tokens=max_tokens,
195
- do_sample=True,
196
- repetition_penalty=repeat_penalty,
197
- streamer=streamer,
 
 
 
 
198
  )
199
- inputs = pipe["tokenizer"](prompt, return_tensors="pt")
200
- if device == "auto":
201
- input_ids = inputs["input_ids"].cuda()
202
- else:
203
- input_ids = inputs["input_ids"]
204
- gen_thread = Thread(target=lambda: pipe["model"].generate(input_ids=input_ids, **generation_config))
205
  gen_thread.start()
206
 
207
  # Buffers for thought vs answer
 
9
  import re # for parsing <think> blocks
10
  import gradio as gr
11
  import torch
12
+ from transformers import pipeline, TextIteratorStreamer
13
+ from transformers import AutoTokenizer
14
  from duckduckgo_search import DDGS
15
 
16
  from transformers import modeling_utils
17
  if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
18
  modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none","colwise",'rowwise']
19
+
20
  # import spaces # Import spaces early to enable ZeroGPU support
21
 
22
  # Optional: Disable GPU visibility if you wish to force CPU usage
 
45
  # Global cache for pipelines to avoid re-loading.
46
  PIPELINES = {}
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def load_pipeline(model_name):
49
  """
50
  Load and cache a transformers pipeline for text generation.
51
  Tries bfloat16, falls back to float16 or float32 if unsupported.
52
  """
53
  global PIPELINES
54
+ if model_name in PIPELINES:
 
55
  return PIPELINES[model_name]
56
  repo = MODELS[model_name]["repo_id"]
57
  if model_name == "secgpt-mini":
58
+ tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
59
+ else:
60
  tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, subfolder="models")
61
+ for dtype in (torch.bfloat16, torch.float16, torch.float32):
62
+ try:
63
+ if model_name == "secgpt-mini":
64
+ pipe = pipeline(
65
+ task="text-generation",
66
+ model=repo,
67
+ tokenizer=tokenizer,
68
+ trust_remote_code=True,
69
+ torch_dtype=dtype,
70
+ device_map=device,
71
+ subfolder="models"
72
+ )
73
+ else:
74
+ pipe = pipeline(
75
+ task="text-generation",
76
+ model=repo,
77
+ tokenizer=tokenizer,
78
+ trust_remote_code=True,
79
+ torch_dtype=device,
80
+ device_map="auto",
81
+ )
82
+ PIPELINES[model_name] = pipe
83
+ return pipe
84
+ except Exception:
85
+ continue
86
+ # Final fallback
87
+ if model_name == "secgpt-mini":
88
+ pipe = pipeline(
89
+ task="text-generation",
90
+ model=repo,
91
+ tokenizer=tokenizer,
92
  trust_remote_code=True,
93
+ torch_dtype=dtype,
 
 
 
 
 
94
  device_map=device,
95
+ subfolder="models"
96
+ )
97
+ else:
98
+ pipe = pipeline(
99
+ task="text-generation",
100
+ model=repo,
101
+ tokenizer=tokenizer,
102
+ trust_remote_code=True,
103
+ device_map=device
104
+ )
105
+ PIPELINES[model_name] = pipe
106
+ return pipe
107
+
108
 
109
 
110
  def retrieve_context(query, max_results=6, max_chars=600):
 
180
  enriched = system_prompt
181
 
182
  pipe = load_pipeline(model_name)
183
+ prompt = format_conversation(history, enriched, pipe.tokenizer)
184
  prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
185
+ streamer = TextIteratorStreamer(pipe.tokenizer,
186
  skip_prompt=True,
187
  skip_special_tokens=True)
188
+ gen_thread = Thread(
189
+ target=pipe,
190
+ args=(prompt,),
191
+ kwargs={
192
+ 'max_new_tokens': max_tokens,
193
+ 'temperature': temperature,
194
+ 'top_k': top_k,
195
+ 'top_p': top_p,
196
+ 'repetition_penalty': repeat_penalty,
197
+ 'streamer': streamer,
198
+ 'return_full_text': False,
199
+ }
200
  )
 
 
 
 
 
 
201
  gen_thread.start()
202
 
203
  # Buffers for thought vs answer