Remostart commited on
Commit
48b53a6
·
verified ·
1 Parent(s): b639043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -53
app.py CHANGED
@@ -6,41 +6,42 @@ from transformers import (
6
  AutoTokenizer,
7
  TextIteratorStreamer,
8
  StoppingCriteria,
9
- StoppingCriteriaList
10
  )
11
  from threading import Thread
12
 
 
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
 
16
  MODEL_NAME = "ubiodee/Plutus_Tutor_new"
17
 
18
- logger.info("Loading tokenizer...")
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
- logger.info("Loading model...")
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL_NAME,
23
- device_map="auto",
24
- torch_dtype=torch.float16,
25
- low_cpu_mem_usage=True
26
- ).eval()
27
-
28
- # Ensure pad/eos set sensibly
29
- if tokenizer.pad_token_id is None:
30
- tokenizer.pad_token = tokenizer.eos_token or tokenizer.pad_token or "</s>"
31
-
32
- def eos_id_candidates(tok):
33
- ids = set()
34
- for tok_str in ["</s>", "<|eot_id|>", "<|end|>", "<|im_end|>"]:
35
- tid = tok.convert_tokens_to_ids(tok_str)
36
- if tid is not None and tid != -1:
37
- ids.add(tid)
38
- if tok.eos_token_id is not None:
39
- ids.add(tok.eos_token_id)
40
- return list(ids) if ids else None
41
-
42
- EOS_IDS = eos_id_candidates(tokenizer)
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  PERSONALITY_TYPES = ["Autistic", "Dyslexic", "Expressive", "Nerd", "Visual", "Other"]
45
  PROGRAMMING_LEVELS = ["Beginner", "Intermediate", "Professional"]
46
  TOPICS = [
@@ -49,66 +50,79 @@ TOPICS = [
49
  "Smart Contracts",
50
  "Versioning in Plutus",
51
  "Monad",
52
- "Other"
53
  ]
54
 
 
55
  END_SENTINEL = "[END]"
56
 
57
  def create_prompt(personality, level, topic):
 
58
  return (
59
  f"Explain {topic} in Plutus for a {level} programmer with {personality} traits. "
60
- f"Use only basic words and clear examples. Use a physical object analogy tied to {topic}. "
61
- f"Avoid jargon like 'blockchain,' 'ledger,' 'Haskell,' 'decentralized,' 'cyber,' 'e-commerce,' "
62
- f"'formal verification,' or 'immutability.' Use short sentences (6-8 words). "
63
- f"Use exactly 3 numbered points for key ideas. Each point must have 5-10 words. "
64
- f"Bold the first word of each point. Structure the response: 2-sentence introduction, "
65
- f"3 numbered points, 1-sentence conclusion. For Autistic traits, use literal language, "
66
- f"numbered lists, and **bold key terms**. Repeat key ideas for clarity. "
67
- f"Avoid abstract terms unless concrete. Do not repeat the topic or prompt. "
68
- f"Do not ask questions. Use a direct, instructional tone. "
69
- f"End with a summary sentence on {topic}’s importance, then write {END_SENTINEL} and nothing else."
70
  )
71
 
72
- # StoppingCriteria that halts when a stop substring appears
73
  class StopOnSubstrings(StoppingCriteria):
74
  def __init__(self, tokenizer, stop_strings):
 
 
75
  self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
76
 
77
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
78
  for seq in self.stop_ids:
79
  L = len(seq)
80
- if L == 0:
81
  continue
82
  if input_ids.shape[1] >= L:
83
- if torch.equal(input_ids[0, -L:], torch.tensor(seq, device=input_ids.device)):
 
 
 
84
  return True
85
  return False
86
 
 
87
  def generate_response(personality, level, topic):
88
  try:
89
  logger.info("Processing selections...")
90
  prompt = create_prompt(personality, level, topic)
91
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
92
 
93
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
94
 
95
  stopping = StoppingCriteriaList([StopOnSubstrings(tokenizer, [END_SENTINEL])])
96
 
 
97
  generation_kwargs = {
98
  **inputs,
99
  "streamer": streamer,
100
-
101
- # <<< Tighter, safer defaults for short, structured outputs >>>
102
- "max_new_tokens": 180,
103
- "do_sample": False, # deterministic; avoids tail babble
104
- "no_repeat_ngram_size": 3,
105
- "repetition_penalty": 1.1,
106
- "eos_token_id": EOS_IDS, # list of possible EOS tokens
107
  "pad_token_id": tokenizer.pad_token_id,
108
  "stopping_criteria": stopping,
109
  "use_cache": True,
110
  }
111
 
 
112
  thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
113
  thread.start()
114
 
@@ -116,12 +130,12 @@ def generate_response(personality, level, topic):
116
  for new_text in streamer:
117
  generated_text += new_text
118
 
119
- # Defensive: cut off sentinel if it appears in the stream
120
  if END_SENTINEL in generated_text:
121
- clean = generated_text.split(END_SENTINEL)[0].rstrip()
122
- yield clean
123
  return
124
 
 
125
  yield generated_text.strip()
126
 
127
  logger.info("Response generated successfully.")
@@ -129,6 +143,7 @@ def generate_response(personality, level, topic):
129
  logger.error(f"Error during generation: {str(e)}")
130
  yield f"Error: {str(e)}"
131
 
 
132
  with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
133
  gr.Markdown("### Your Personalised Plutus Tutor")
134
  gr.Markdown("Select your personality type, programming level, and topic, then click Generate.")
@@ -138,9 +153,21 @@ with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
138
  topic = gr.Dropdown(choices=TOPICS, label="Topic", value="Introduction to Validation")
139
 
140
  generate_btn = gr.Button("Generate")
141
- output = gr.Textbox(label="Model Response", show_label=True, lines=10, placeholder="Generated content will appear here...")
142
 
143
- generate_btn.click(fn=generate_response, inputs=[personality, level, topic], outputs=output)
 
 
 
 
 
 
 
 
 
 
 
144
 
 
145
  logger.info("Launching Gradio interface...")
 
146
  demo.launch()
 
6
  AutoTokenizer,
7
  TextIteratorStreamer,
8
  StoppingCriteria,
9
+ StoppingCriteriaList,
10
  )
11
  from threading import Thread
12
 
13
+ # ---------------- Logging ----------------
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
16
 
17
+ # ---------------- Model & Tokenizer ----------------
18
  MODEL_NAME = "ubiodee/Plutus_Tutor_new"
19
 
20
+ try:
21
+ logger.info("Loading tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ logger.info("Loading model...")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ MODEL_NAME,
27
+ device_map="auto",
28
+ torch_dtype=torch.float16,
29
+ low_cpu_mem_usage=True,
30
+ )
31
+ model.eval()
32
+
33
+ # Make sure pad/eos are sensible to avoid warnings/crashes
34
+ if tokenizer.pad_token_id is None:
35
+ if tokenizer.eos_token_id is not None:
36
+ tokenizer.pad_token = tokenizer.eos_token
37
+ else:
38
+ tokenizer.add_special_tokens({"pad_token": "</s>"})
39
+ logger.info("Model and tokenizer loaded successfully.")
40
+ except Exception as e:
41
+ logger.error(f"Error loading model or tokenizer: {str(e)}")
42
+ raise
43
+
44
+ # ---------------- UI Options ----------------
45
  PERSONALITY_TYPES = ["Autistic", "Dyslexic", "Expressive", "Nerd", "Visual", "Other"]
46
  PROGRAMMING_LEVELS = ["Beginner", "Intermediate", "Professional"]
47
  TOPICS = [
 
50
  "Smart Contracts",
51
  "Versioning in Plutus",
52
  "Monad",
53
+ "Other",
54
  ]
55
 
56
+ # ---------------- Prompting ----------------
57
  END_SENTINEL = "[END]"
58
 
59
  def create_prompt(personality, level, topic):
60
+ # Keep your structure & tone, add explicit end signal
61
  return (
62
  f"Explain {topic} in Plutus for a {level} programmer with {personality} traits. "
63
+ f"Use only basic words and clear examples. Use a physical object analogy (e.g., a lock or checklist) tied to {topic}. "
64
+ f"Avoid jargon like 'blockchain,' 'ledger,' 'Haskell,' 'decentralized,' 'cyber,' 'e-commerce,' 'formal verification,' or 'immutability.' "
65
+ f"Use short sentences (6-8 words). Use exactly 3 numbered points for key ideas. Each point must have 5-10 words. "
66
+ f"Bold the first word of each point. Structure the response: 2-sentence introduction, 3 numbered points, 1-sentence conclusion. "
67
+ f"For Autistic traits, use literal language, numbered lists, and **bold key terms**. Repeat key ideas for clarity. "
68
+ f"Avoid abstract terms unless concrete. Do not repeat the topic or prompt. Do not simulate a conversation, ask questions, or discuss unrelated topics. "
69
+ f"Use a direct, instructional tone without 'I' or 'we'. "
70
+ f"End with a summary sentence on {topic}'s importance, then write {END_SENTINEL} and nothing else."
 
 
71
  )
72
 
73
+ # ---------------- Stopping on substring ----------------
74
  class StopOnSubstrings(StoppingCriteria):
75
  def __init__(self, tokenizer, stop_strings):
76
+ self.tokenizer = tokenizer
77
+ # Pre-tokenize stop strings for fast suffix checks
78
  self.stop_ids = [tokenizer.encode(s, add_special_tokens=False) for s in stop_strings]
79
 
80
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
81
+ # Stop if any stop_ids match the suffix of the generated sequence
82
  for seq in self.stop_ids:
83
  L = len(seq)
84
+ if L == 0:
85
  continue
86
  if input_ids.shape[1] >= L:
87
+ if torch.equal(
88
+ input_ids[0, -L:],
89
+ torch.tensor(seq, device=input_ids.device),
90
+ ):
91
  return True
92
  return False
93
 
94
+ # ---------------- Generation (STREAMING) ----------------
95
  def generate_response(personality, level, topic):
96
  try:
97
  logger.info("Processing selections...")
98
  prompt = create_prompt(personality, level, topic)
99
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
100
 
101
+ # Keep streamer + background thread approach (as in your working version)
102
+ streamer = TextIteratorStreamer(
103
+ tokenizer,
104
+ skip_prompt=True,
105
+ skip_special_tokens=True,
106
+ timeout=0.02, # flush small chunks quickly
107
+ )
108
 
109
  stopping = StoppingCriteriaList([StopOnSubstrings(tokenizer, [END_SENTINEL])])
110
 
111
+ # Tighter, deterministic decoding to avoid trailing garbage
112
  generation_kwargs = {
113
  **inputs,
114
  "streamer": streamer,
115
+ "max_new_tokens": 200, # your format fits well under this
116
+ "do_sample": False, # deterministic; helps finish cleanly
117
+ "no_repeat_ngram_size": 3, # avoid loops
118
+ "repetition_penalty": 1.1, # gentle anti-babble
119
+ "eos_token_id": tokenizer.eos_token_id,
 
 
120
  "pad_token_id": tokenizer.pad_token_id,
121
  "stopping_criteria": stopping,
122
  "use_cache": True,
123
  }
124
 
125
+ # Run generation in a separate thread so we can iterate the streamer
126
  thread = Thread(target=model.generate, kwargs=generation_kwargs, daemon=True)
127
  thread.start()
128
 
 
130
  for new_text in streamer:
131
  generated_text += new_text
132
 
133
+ # Hard stop if sentinel appears; strip it from output
134
  if END_SENTINEL in generated_text:
135
+ yield generated_text.split(END_SENTINEL)[0].rstrip()
 
136
  return
137
 
138
+ # Stream progressively (exactly like your earlier working version)
139
  yield generated_text.strip()
140
 
141
  logger.info("Response generated successfully.")
 
143
  logger.error(f"Error during generation: {str(e)}")
144
  yield f"Error: {str(e)}"
145
 
146
+ # ---------------- Gradio UI ----------------
147
  with gr.Blocks(title="Cardano Plutus AI Assistant") as demo:
148
  gr.Markdown("### Your Personalised Plutus Tutor")
149
  gr.Markdown("Select your personality type, programming level, and topic, then click Generate.")
 
153
  topic = gr.Dropdown(choices=TOPICS, label="Topic", value="Introduction to Validation")
154
 
155
  generate_btn = gr.Button("Generate")
 
156
 
157
+ output = gr.Textbox(
158
+ label="Model Response",
159
+ show_label=True,
160
+ lines=10,
161
+ placeholder="Generated content will appear here...",
162
+ )
163
+
164
+ generate_btn.click(
165
+ fn=generate_response,
166
+ inputs=[personality, level, topic],
167
+ outputs=output,
168
+ )
169
 
170
+ # Ensure true streaming in Gradio
171
  logger.info("Launching Gradio interface...")
172
+ demo.queue(concurrency_count=1, max_size=20)
173
  demo.launch()