shayekh commited on
Commit
f3a1cdf
Β·
verified Β·
1 Parent(s): 8b261a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +179 -111
app.py CHANGED
@@ -151,7 +151,7 @@ def get_base64_image(image):
151
  return f"data:image/jpeg;base64,{img_str}"
152
 
153
  @spaces.GPU(duration=120)
154
- def extract_vocabulary(pdf_text, images, translit_lang, translit_format, target_lang, max_text_char=1500, repetition_penalty_val=1.1):
155
  """Use Transformers to extract vocabulary from text and images."""
156
  global model, processor
157
 
@@ -174,11 +174,16 @@ Return ONLY a valid JSON list of dictionaries, where each dictionary has four ke
174
  - 'transliteration' (the pronunciation transliterated into {translit_lang.upper()} script/characters, formatted as {translit_format}.{non_english})
175
  - 'translation' (the translation into {target_lang.upper()})
176
  - 'explanation' (a brief grammar or context note in {target_lang.upper()}).
177
- No markdown formatting, just raw JSON with ```json and ``` markers.
178
- CRITICAL: Do NOT provide any conversational filler, thinking steps, or reasoning. Answer quick without very long thinking. Output the JSON array IMMEDIATELY.
 
 
179
 
180
  Text:
 
 
181
  {pdf_text[:int(max_text_char)]}
 
182
  """
183
 
184
  # DEBUG: Log prompt text
@@ -209,6 +214,9 @@ Text:
209
  try:
210
  model.to("cuda")
211
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
212
  inputs = processor(
213
  text=[text],
214
  images=pil_images if pil_images else None,
@@ -216,131 +224,105 @@ Text:
216
  padding=True
217
  ).to("cuda")
218
 
219
- global_stop_thinking[0] = False
220
- global_kill_threads[0] = False
221
- print(f"[STOP-THINK] Flags RESET. stop_thinking={global_stop_thinking[0]}, kill={global_kill_threads[0]}")
222
-
223
  from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
224
  from threading import Thread
 
225
 
226
- class StopThinkingCriteria(StoppingCriteria):
 
 
227
  def __call__(self, input_ids, scores, **kwargs):
228
- val = global_stop_thinking[0] or global_kill_threads[0]
229
- if val:
230
- print(f"[STOP-THINK] Criteria returning True! stop={global_stop_thinking[0]} kill={global_kill_threads[0]}")
231
- return val
232
-
233
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
234
- generation_kwargs = dict(
235
- **inputs,
236
- streamer=streamer,
237
- max_new_tokens=2048*16,
238
- do_sample=True,
239
- repetition_penalty=repetition_penalty_val,
240
- stopping_criteria=StoppingCriteriaList([StopThinkingCriteria()])
241
- )
242
-
243
- if len(images) > 0:
244
- generation_kwargs.update(dict(temperature=0.6, top_p=0.95, top_k=20, min_p=0.0))
245
- else:
246
- generation_kwargs.update(dict(temperature=1.0, top_p=0.95, top_k=20, min_p=0.0))
247
-
248
- generation_result = []
249
- def generate_and_capture(**kwargs):
250
  try:
251
- out = model.generate(**kwargs)
252
- generation_result.append(out)
253
  except Exception as e:
254
  import traceback
255
- print(f"\n[THREAD1 ERROR] model.generate crashed: {e}")
256
  traceback.print_exc()
 
 
 
 
 
257
 
258
- thread = Thread(target=generate_and_capture, kwargs=generation_kwargs)
 
 
 
259
  thread.start()
260
 
261
- output_text = ""
262
  for new_text in streamer:
263
  output_text += new_text
264
  yield output_text, None
265
-
266
- thread.join()
267
-
268
- if global_kill_threads[0]:
269
- yield output_text + "\n\n[Generation completely stopped by user.]", None
270
- return
271
-
272
- if global_stop_thinking[0]:
273
- global_stop_thinking[0] = False
274
- print(f"[STOP-THINK] INJECTION PATH entered. Reset flag to: {global_stop_thinking[0]}")
275
-
276
- # Inject the closure of thinking and start of JSON
277
- injection_text = "\n</think>\n\n```json\n[\n"
278
- output_text += injection_text
279
- yield output_text, None
280
-
281
- # Restart generation with updated context
282
- generated_ids = generation_result[0]
283
- injection_ids = processor.tokenizer(injection_text, return_tensors="pt", add_special_tokens=False).input_ids.to("cuda")
284
- new_input_ids = torch.cat([generated_ids, injection_ids], dim=-1)
285
-
286
- # Update attention mask
287
- new_mask = torch.cat([
288
- inputs["attention_mask"],
289
- torch.ones((1, new_input_ids.shape[1] - inputs["attention_mask"].shape[1]), dtype=inputs["attention_mask"].dtype, device="cuda")
290
- ], dim=-1)
291
-
292
- new_inputs = {
293
- "input_ids": new_input_ids,
294
- "attention_mask": new_mask
295
- }
296
 
297
- # Carry over only the visual features; discard stale keys like input_token_type or rope_deltas
298
- keys_to_keep = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"]
299
- for k in keys_to_keep:
300
- if k in inputs:
301
- new_inputs[k] = inputs[k]
302
-
303
- class KillCriteria(StoppingCriteria):
304
- def __call__(self, input_ids, scores, **kwargs):
305
- return global_kill_threads[0]
306
-
307
- new_streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
308
- new_generation_kwargs = dict(
309
- **new_inputs,
310
- streamer=new_streamer,
311
- max_new_tokens=2048*16,
312
- do_sample=True,
313
- repetition_penalty=repetition_penalty_val,
314
- stopping_criteria=StoppingCriteriaList([KillCriteria()])
315
- )
316
-
317
- if len(images) > 0:
318
- new_generation_kwargs.update(dict(temperature=0.6, top_p=0.95, top_k=20, min_p=0.0))
319
- else:
320
- new_generation_kwargs.update(dict(temperature=1.0, top_p=0.95, top_k=20, min_p=0.0))
321
 
322
- def thread2_target(**kwargs):
323
- try:
324
- model.generate(**kwargs)
325
- except Exception as e:
326
- import traceback
327
- print(f"\n[THREAD2 ERROR] model.generate crashed: {e}")
328
- traceback.print_exc()
329
- finally:
330
- # Always unblock the streamer to prevent Gradio UI from hanging permanently
331
  try:
332
- new_streamer.end()
333
- except Exception:
334
- pass
335
-
336
- thread2 = Thread(target=thread2_target, kwargs=new_generation_kwargs)
337
- thread2.start()
338
-
339
- for new_text in new_streamer:
340
- output_text += new_text
 
 
 
341
  yield output_text, None
342
 
343
- thread2.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  # DEBUG: Log raw output text
346
  with open("log/debug_vlm_output.txt", "w", encoding="utf-8") as f:
@@ -483,7 +465,7 @@ def process_pdf(pdf_file, url_input, translit_lang, translit_format, target_lang
483
 
484
  is_url = bool(url_input and url_input.strip())
485
  if pdf_file is None and not is_url:
486
- yield "<p>Please upload a PDF or enter a URL.</p>", None, None, ""
487
  return
488
 
489
  if is_url:
@@ -890,6 +872,86 @@ def get_example_pdf():
890
  print(f"Failed to download example PDF: {e}")
891
  return file_path if os.path.exists(file_path) else None
892
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
893
  def create_demo():
894
  example_pdf = get_example_pdf()
895
 
@@ -1004,7 +1066,11 @@ def create_demo():
1004
  with gr.Column(scale=1):
1005
  # url_input = gr.Textbox(label="Enter a Website URL 🌐", placeholder=r"e.g. https://storykorean.com/stories?level=beginner&story=tiger", value=r"https://storykorean.com/stories?level=beginner&story=tiger")
1006
  # https://www.bbc.com/korean/articles/c5yz89k5dw0o
1007
- url_input = gr.Textbox(label="Enter a Website URL 🌐", placeholder=r"e.g. https://www.koreanstudyjunkie.com/post/korean-reading-exercise-for-all-levels-beginner-intermediate-advanced", value=r"https://www.koreanstudyjunkie.com/post/korean-reading-exercise-for-all-levels-beginner-intermediate-advanced")
 
 
 
 
1008
 
1009
  pdf_input = gr.File(label="Or Upload Book PDF πŸ“š", file_types=[".pdf"], value=example_pdf)
1010
 
@@ -1051,6 +1117,7 @@ def create_demo():
1051
  )
1052
 
1053
  stop_thinking_btn.click(fn=set_stop_thinking, inputs=None, outputs=stop_thinking_btn, queue=False)
 
1054
  stop_btn.click(fn=set_kill_threads, inputs=None, outputs=stop_btn, queue=False).then(fn=None, inputs=None, outputs=None, cancels=[generate_event])
1055
 
1056
  # Force autoscroll using Custom JS
@@ -1262,3 +1329,4 @@ if __name__ == "__main__":
1262
  demo.launch(server_name="0.0.0.0", server_port=7865)
1263
 
1264
 
 
 
151
  return f"data:image/jpeg;base64,{img_str}"
152
 
153
  @spaces.GPU(duration=120)
154
+ def extract_vocabulary(pdf_text, images, translit_lang, translit_format, target_lang, max_text_char=1500, repetition_penalty_val=1.1, partial_assistant_text=None):
155
  """Use Transformers to extract vocabulary from text and images."""
156
  global model, processor
157
 
 
174
  - 'transliteration' (the pronunciation transliterated into {translit_lang.upper()} script/characters, formatted as {translit_format}.{non_english})
175
  - 'translation' (the translation into {target_lang.upper()})
176
  - 'explanation' (a brief grammar or context note in {target_lang.upper()}).
177
+
178
+ Just output raw JSON with ```json and ``` markers, as the user will load in python.
179
+
180
+ CRITICAL: Answer quick without very long thinking. Output the JSON array IMMEDIATELY.
181
 
182
  Text:
183
+
184
+ <scrpated-content>
185
  {pdf_text[:int(max_text_char)]}
186
+ </scrpated-content>
187
  """
188
 
189
  # DEBUG: Log prompt text
 
214
  try:
215
  model.to("cuda")
216
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
217
+ if partial_assistant_text:
218
+ text += partial_assistant_text + "\n</think>\n\n```json\n[\n"
219
+
220
  inputs = processor(
221
  text=[text],
222
  images=pil_images if pil_images else None,
 
224
  padding=True
225
  ).to("cuda")
226
 
 
 
 
 
227
  from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
228
  from threading import Thread
229
+ import queue
230
 
231
+ local_stop = [False]
232
+
233
+ class LocalKillCriteria(StoppingCriteria):
234
  def __call__(self, input_ids, scores, **kwargs):
235
+ return local_stop[0] or global_kill_threads[0]
236
+
237
+ def run_generation(cur_inputs, cur_streamer, cur_local_stop):
238
+ """Run model.generate in a thread, always calling streamer.end() on exit."""
239
+ kill_criteria = StoppingCriteriaList([LocalKillCriteria()])
240
+ gen_kwargs = dict(
241
+ **cur_inputs,
242
+ streamer=cur_streamer,
243
+ max_new_tokens=2048*16,
244
+ do_sample=True,
245
+ repetition_penalty=repetition_penalty_val,
246
+ stopping_criteria=kill_criteria
247
+ )
248
+ if len(images) > 0:
249
+ gen_kwargs.update(dict(temperature=0.6, top_p=0.95, top_k=20, min_p=0.0))
250
+ else:
251
+ gen_kwargs.update(dict(temperature=1.0, top_p=0.95, top_k=20, min_p=0.0))
 
 
 
 
 
252
  try:
253
+ model.generate(**gen_kwargs)
 
254
  except Exception as e:
255
  import traceback
256
+ print(f"\n[THREAD ERROR] model.generate crashed: {e}")
257
  traceback.print_exc()
258
+ finally:
259
+ try:
260
+ cur_streamer.end()
261
+ except Exception:
262
+ pass
263
 
264
+ output_text = partial_assistant_text + "\n</think>\n\n```json\n[\n" if partial_assistant_text else ""
265
+
266
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
267
+ thread = Thread(target=run_generation, args=(inputs, streamer, local_stop))
268
  thread.start()
269
 
270
+ force_triggered = False
271
  for new_text in streamer:
272
  output_text += new_text
273
  yield output_text, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ # Check if user clicked "Stop thinking"
276
+ if global_stop_thinking[0] and not force_triggered:
277
+ force_triggered = True
278
+ print("[STOP-THINK] Flag detected inside streamer loop! Killing current generation...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
+ # 1. Kill the current generation thread
281
+ local_stop[0] = True
282
+ # Drain queue so the thread can exit
283
+ while not streamer.text_queue.empty():
 
 
 
 
 
284
  try:
285
+ streamer.text_queue.get_nowait()
286
+ except queue.Empty:
287
+ break
288
+ thread.join(timeout=5)
289
+ print("[STOP-THINK] Old thread joined. Starting forced JSON generation...")
290
+
291
+ # 2. Reset flags
292
+ global_stop_thinking[0] = False
293
+ local_stop[0] = False
294
+
295
+ # 3. Append the think-closing + JSON prefix
296
+ output_text += "\n</think>\n\n```json\n[\n"
297
  yield output_text, None
298
 
299
+ # 4. Build new prompt with partial assistant text
300
+ text2 = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
301
+ text2 += output_text
302
+ inputs2 = processor(
303
+ text=[text2],
304
+ images=pil_images if pil_images else None,
305
+ return_tensors="pt",
306
+ padding=True
307
+ ).to("cuda")
308
+
309
+ # 5. Start new generation thread
310
+ streamer2 = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
311
+ thread2 = Thread(target=run_generation, args=(inputs2, streamer2, local_stop))
312
+ thread2.start()
313
+
314
+ for new_text2 in streamer2:
315
+ output_text += new_text2
316
+ yield output_text, None
317
+
318
+ thread2.join(timeout=10)
319
+ break # Exit the outer streamer loop
320
+
321
+ if not force_triggered:
322
+ thread.join()
323
+
324
+ # Reset flag in case it was set but generation finished naturally
325
+ global_stop_thinking[0] = False
326
 
327
  # DEBUG: Log raw output text
328
  with open("log/debug_vlm_output.txt", "w", encoding="utf-8") as f:
 
465
 
466
  is_url = bool(url_input and url_input.strip())
467
  if pdf_file is None and not is_url:
468
+ yield "<p>Please upload a PDF or enter a URL.</p>", None, None, "", "", []
469
  return
470
 
471
  if is_url:
 
872
  print(f"Failed to download example PDF: {e}")
873
  return file_path if os.path.exists(file_path) else None
874
 
875
+ @spaces.GPU(duration=120)
876
+ def process_pdf_force(partial_text, pdf_file, url_input, translit_lang, translit_format, target_lang, max_text_char, repetition_penalty_val, last_source_state, last_korean_words_state):
877
+ """Force JSON generation using the current partial stream_box text."""
878
+ is_url = bool(url_input and url_input.strip())
879
+
880
+ current_source_hash = ""
881
+ if is_url:
882
+ current_source_hash = "url:" + url_input.strip()
883
+ elif pdf_file is not None:
884
+ import hashlib
885
+ with open(pdf_file.name, "rb") as f:
886
+ current_source_hash = "pdf:" + hashlib.md5(f.read()).hexdigest()
887
+
888
+ try:
889
+ if is_url:
890
+ progress(0, desc="Fetching Website...")
891
+ content_text, images = extract_website_content(url_input.strip())
892
+ else:
893
+ progress(0, desc="Reading PDF...")
894
+ content_text, images = extract_pdf_content(pdf_file.name)
895
+
896
+ if not content_text.strip() and not images:
897
+ yield "<p>No content found.</p>", current_source_hash, None, partial_text, "", []
898
+ return
899
+ except Exception as e:
900
+ yield f"<p>Error reading content: {e}</p>", None, None, partial_text, "", []
901
+ return
902
+
903
+ vocab_list = []
904
+ stream_text = partial_text
905
+
906
+ progress(0.2, desc="Extracting vocabulary (Forced JSON)...")
907
+ for stream_t, v_list in extract_vocabulary(content_text, images, translit_lang, translit_format, target_lang, max_text_char, repetition_penalty_val, partial_assistant_text=partial_text):
908
+ stream_text = stream_t
909
+ if v_list is not None:
910
+ vocab_list = v_list
911
+ yield "", current_source_hash, None, stream_text, content_text, images
912
+
913
+ if not vocab_list:
914
+ yield "<p>Failed to parse forced JSON.</p>", current_source_hash, None, stream_text, content_text, images
915
+ return
916
+
917
+ progress(0.6, desc="Generating TTS audio...")
918
+ for i, item in enumerate(vocab_list):
919
+ korean = item.get("korean", "")
920
+ if korean and tts is not None:
921
+ progress(0.6 + 0.3 * (i / len(vocab_list)), desc=f"Generating audio {i+1}/{len(vocab_list)}...")
922
+ try:
923
+ wav, dur = tts.synthesize(
924
+ korean, voice_style=voice_style, lang="ko",
925
+ total_steps=12,
926
+ speed=0.7,
927
+ )
928
+ import numpy as np
929
+ import soundfile as sf
930
+
931
+ audio_path = f"audio_{i}.wav"
932
+ sf.write(audio_path, wav, 24000)
933
+ item["audio_uri"] = numpy_to_base64_audio(wav, tts.sample_rate)
934
+ except Exception as e:
935
+ print(f"Failed to generate audio for {korean}: {e}")
936
+ item["audio_uri"] = None
937
+
938
+ progress(1.0, desc="Rendering flashcards...")
939
+
940
+ from jinja2 import Environment, BaseLoader
941
+ import json
942
+
943
+ env = Environment(loader=BaseLoader())
944
+ template = env.from_string(html_template)
945
+ html_output = template.render(
946
+ vocab_list=vocab_list,
947
+ translit_lang=translit_lang,
948
+ target_lang=target_lang
949
+ )
950
+
951
+ safe_srcdoc = html_output.replace('"', '&quot;')
952
+ yield f'<iframe srcdoc="{safe_srcdoc}" style="width: 100%; height: 650px; border: none; overflow-y: auto;"></iframe>', current_source_hash, vocab_list, stream_text, content_text, images
953
+
954
+
955
  def create_demo():
956
  example_pdf = get_example_pdf()
957
 
 
1066
  with gr.Column(scale=1):
1067
  # url_input = gr.Textbox(label="Enter a Website URL 🌐", placeholder=r"e.g. https://storykorean.com/stories?level=beginner&story=tiger", value=r"https://storykorean.com/stories?level=beginner&story=tiger")
1068
  # https://www.bbc.com/korean/articles/c5yz89k5dw0o
1069
+ # https://www.bbc.com/korean/articles/cn0p7rkvxdgo
1070
+ # https://www.koreanstudyjunkie.com/post/korean-reading-exercise-for-all-levels-beginner-intermediate-advanced
1071
+ url_input = gr.Textbox(label="Enter a Website URL 🌐",
1072
+ placeholder=r"e.g. # https://www.bbc.com/korean/articles/cn0p7rkvxdgo",
1073
+ value=r"https://www.bbc.com/korean/articles/cn0p7rkvxdgo")
1074
 
1075
  pdf_input = gr.File(label="Or Upload Book PDF πŸ“š", file_types=[".pdf"], value=example_pdf)
1076
 
 
1117
  )
1118
 
1119
  stop_thinking_btn.click(fn=set_stop_thinking, inputs=None, outputs=stop_thinking_btn, queue=False)
1120
+
1121
  stop_btn.click(fn=set_kill_threads, inputs=None, outputs=stop_btn, queue=False).then(fn=None, inputs=None, outputs=None, cancels=[generate_event])
1122
 
1123
  # Force autoscroll using Custom JS
 
1329
  demo.launch(server_name="0.0.0.0", server_port=7865)
1330
 
1331
 
1332
+