Remostart commited on
Commit
41206c6
·
verified ·
1 Parent(s): 4125ab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -19
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- import logging
4
  from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
@@ -22,18 +22,27 @@ try:
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
23
 
24
  logger.info("Loading model...")
 
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
  MODEL_NAME,
27
  device_map="auto",
28
- dtype=torch.float16, # transformers now prefers `dtype` over `torch_dtype`
29
  low_cpu_mem_usage=True,
30
- ).eval()
 
31
 
 
 
32
  if tokenizer.pad_token_id is None:
33
- if tokenizer.eos_token_id is not None:
34
  tokenizer.pad_token = tokenizer.eos_token
35
  else:
36
  tokenizer.add_special_tokens({"pad_token": "</s>"})
 
 
 
 
37
  logger.info("Model and tokenizer loaded successfully.")
38
  except Exception as e:
39
  logger.error(f"Error loading model or tokenizer: {str(e)}")
@@ -67,7 +76,7 @@ def create_prompt(personality, level, topic):
67
  f"End with a summary sentence on {topic}'s importance, then write {END_SENTINEL} and nothing else."
68
  )
69
 
70
- # ---------------- Stopping on substring ----------------
71
  class StopOnSubstrings(StoppingCriteria):
72
  def __init__(self, tokenizer, stop_strings):
73
  self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
@@ -87,11 +96,12 @@ def generate_response(personality, level, topic):
87
  prompt = create_prompt(personality, level, topic)
88
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
89
 
 
90
  streamer = TextIteratorStreamer(
91
  tokenizer,
92
  skip_prompt=True,
93
  skip_special_tokens=True,
94
- timeout=0.02, # flush small chunks quickly
95
  )
96
 
97
  stopping = StoppingCriteriaList([StopOnSubstrings(tokenizer, [END_SENTINEL])])
@@ -99,31 +109,39 @@ def generate_response(personality, level, topic):
99
  generation_kwargs = {
100
  **inputs,
101
  "streamer": streamer,
102
- "max_new_tokens": 200,
103
- "do_sample": False,
104
- "no_repeat_ngram_size": 3,
105
- "repetition_penalty": 1.1,
106
- "eos_token_id": tokenizer.eos_token_id,
107
  "pad_token_id": tokenizer.pad_token_id,
108
  "stopping_criteria": stopping,
109
  "use_cache": True,
110
  }
111
 
 
 
 
 
112
  thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
113
  thread.start()
114
 
115
  generated_text = ""
116
  for new_text in streamer:
117
  generated_text += new_text
 
 
118
  if END_SENTINEL in generated_text:
119
  yield generated_text.split(END_SENTINEL)[0].rstrip()
120
  return
 
121
  yield generated_text.strip()
122
 
123
  logger.info("Response generated successfully.")
124
- except Exception as e:
125
- logger.error(f"Error during generation: {str(e)}")
126
- yield f"Error: {str(e)}"
 
 
127
 
128
  # ---------------- Gradio UI ----------------
129
  with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
@@ -143,16 +161,13 @@ with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
143
  placeholder="Generated content will appear here...",
144
  )
145
 
146
- # (Optional) Per-event concurrency control in Gradio 4+
147
  generate_btn.click(
148
  fn=generate_response,
149
  inputs=[personality, level, topic],
150
  outputs=output,
151
- concurrency_limit=1, # <- replaces old global concurrency_count
152
  )
153
 
154
  logger.info("Launching Gradio interface...")
155
-
156
- # ✅ Gradio 4+ queue config (no more `concurrency_count`)
157
- demo.queue(default_concurrency_limit=1, max_size=20)
158
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import logging, traceback
4
  from transformers import (
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
 
22
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
23
 
24
  logger.info("Loading model...")
25
+ has_cuda = torch.cuda.is_available()
26
+ dtype = torch.float16 if has_cuda else torch.float32 # safer on CPU
27
  model = AutoModelForCausalLM.from_pretrained(
28
  MODEL_NAME,
29
  device_map="auto",
30
+ torch_dtype=dtype,
31
  low_cpu_mem_usage=True,
32
+ )
33
+ model.eval()
34
 
35
+ # Ensure pad/eos are sensible; if we add a token, resize embeddings
36
+ added = False
37
  if tokenizer.pad_token_id is None:
38
+ if tokenizer.eos_token is not None:
39
  tokenizer.pad_token = tokenizer.eos_token
40
  else:
41
  tokenizer.add_special_tokens({"pad_token": "</s>"})
42
+ added = True
43
+ if added:
44
+ model.resize_token_embeddings(len(tokenizer))
45
+
46
  logger.info("Model and tokenizer loaded successfully.")
47
  except Exception as e:
48
  logger.error(f"Error loading model or tokenizer: {str(e)}")
 
76
  f"End with a summary sentence on {topic}'s importance, then write {END_SENTINEL} and nothing else."
77
  )
78
 
79
+ # ---------------- Stop on substring ----------------
80
  class StopOnSubstrings(StoppingCriteria):
81
  def __init__(self, tokenizer, stop_strings):
82
  self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
 
96
  prompt = create_prompt(personality, level, topic)
97
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
98
 
99
+ # Keep your original streaming pattern; avoid version-sensitive args
100
  streamer = TextIteratorStreamer(
101
  tokenizer,
102
  skip_prompt=True,
103
  skip_special_tokens=True,
104
+ # no timeout arg (some Gradio/HF versions don't support it)
105
  )
106
 
107
  stopping = StoppingCriteriaList([StopOnSubstrings(tokenizer, [END_SENTINEL])])
 
109
  generation_kwargs = {
110
  **inputs,
111
  "streamer": streamer,
112
+ "max_new_tokens": 200, # fits your format comfortably
113
+ "do_sample": False, # deterministic to avoid tail babble
114
+ "no_repeat_ngram_size": 3, # loop guard
115
+ "repetition_penalty": 1.1, # mild anti-babble
 
116
  "pad_token_id": tokenizer.pad_token_id,
117
  "stopping_criteria": stopping,
118
  "use_cache": True,
119
  }
120
 
121
+ # Only pass eos_token_id if it exists (avoid None issues)
122
+ if tokenizer.eos_token_id is not None:
123
+ generation_kwargs["eos_token_id"] = tokenizer.eos_token_id
124
+
125
  thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
126
  thread.start()
127
 
128
  generated_text = ""
129
  for new_text in streamer:
130
  generated_text += new_text
131
+
132
+ # Hard stop the moment we see the sentinel
133
  if END_SENTINEL in generated_text:
134
  yield generated_text.split(END_SENTINEL)[0].rstrip()
135
  return
136
+
137
  yield generated_text.strip()
138
 
139
  logger.info("Response generated successfully.")
140
+ except Exception:
141
+ err = traceback.format_exc()
142
+ logger.error(err)
143
+ # Show full traceback in UI for quick debugging
144
+ yield "Error:\n" + err
145
 
146
  # ---------------- Gradio UI ----------------
147
  with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
 
161
  placeholder="Generated content will appear here...",
162
  )
163
 
 
164
  generate_btn.click(
165
  fn=generate_response,
166
  inputs=[personality, level, topic],
167
  outputs=output,
 
168
  )
169
 
170
  logger.info("Launching Gradio interface...")
171
+ # Keep it version-agnostic: enable queueing without extra args
172
+ demo.queue()
 
173
  demo.launch()