prithivMLmods commited on
Commit
34a5235
·
verified ·
1 Parent(s): b49d414

update app

Browse files
Files changed (1) hide show
  1. app.py +144 -71
app.py CHANGED
@@ -222,67 +222,112 @@ def calc_timeout_image(model_name, text, image, max_new_tokens, temperature, top
222
 
223
  @spaces.GPU(duration=calc_timeout_image)
224
  def generate_image(model_name, text, image, max_new_tokens=1024, temperature=0.6, top_p=0.9, top_k=50, repetition_penalty=1.2, gpu_timeout=60):
225
- if not model_name or model_name not in MODEL_MAP:
226
- raise gr.Error("Please select a valid model.")
227
- if image is None:
228
- raise gr.Error("Please upload an image.")
229
- if not text or not str(text).strip():
230
- raise gr.Error("Please enter your OCR/query instruction.")
231
- if len(str(text)) > MAX_INPUT_TOKEN_LENGTH * 8:
232
- raise gr.Error("Query is too long. Please shorten your input.")
233
-
234
- processor, model = MODEL_MAP[model_name]
235
- images = [image]
236
-
237
- if model_name == "SmolDocling-256M-preview":
238
- if "OTSL" in text or "code" in text:
239
- images = [add_random_padding(img) for img in images]
240
- if "OCR at text at" in text or "Identify element" in text or "formula" in text:
241
- text = normalize_values(text, target_max=500)
242
-
243
- messages = [{
244
- "role": "user",
245
- "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
246
- }]
247
-
248
- prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
249
- inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
250
-
251
- streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
252
- generation_kwargs = {
253
- **inputs,
254
- "streamer": streamer,
255
- "max_new_tokens": int(max_new_tokens),
256
- "temperature": float(temperature),
257
- "top_p": float(top_p),
258
- "top_k": int(top_k),
259
- "repetition_penalty": float(repetition_penalty),
260
- }
 
261
 
262
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
263
- thread.start()
264
 
265
- buffer = ""
266
- for new_text in streamer:
267
- buffer += new_text.replace("<|im_end|>", "")
268
- yield buffer
269
-
270
- if model_name == "SmolDocling-256M-preview":
271
- cleaned_output = buffer.replace("<end_of_utterance>", "").strip()
272
- if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
273
- if "<chart>" in cleaned_output:
274
- cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
275
- cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
276
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
277
- doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
278
- markdown_output = doc.export_to_markdown()
279
- yield markdown_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  else:
281
- yield cleaned_output
 
282
 
283
- gc.collect()
284
- if torch.cuda.is_available():
285
- torch.cuda.empty_cache()
 
 
 
286
 
287
 
288
  def noop():
@@ -669,8 +714,16 @@ function init() {
669
  const sb = document.getElementById('sb-run-state');
670
  if (sb) sb.textContent = 'Done';
671
  }
 
 
 
 
 
 
 
672
  window.__showLoader = showLoader;
673
  window.__hideLoader = hideLoader;
 
674
 
675
  function flashPromptError() {
676
  promptInput.classList.add('error-flash');
@@ -845,7 +898,12 @@ function init() {
845
  showLoader();
846
  setTimeout(() => {
847
  const gradioBtn = document.getElementById('gradio-run-btn');
848
- if (!gradioBtn) return;
 
 
 
 
 
849
  const btn = gradioBtn.querySelector('button');
850
  if (btn) btn.click(); else gradioBtn.click();
851
  }, 180);
@@ -961,6 +1019,10 @@ function watchOutputs() {
961
 
962
  let lastText = '';
963
 
 
 
 
 
964
  function syncOutput() {
965
  const el = resultContainer.querySelector('textarea') || resultContainer.querySelector('input');
966
  if (!el) return;
@@ -969,7 +1031,15 @@ function watchOutputs() {
969
  lastText = val;
970
  outArea.value = val;
971
  outArea.scrollTop = outArea.scrollHeight;
972
- if (window.__hideLoader && val.trim()) window.__hideLoader();
 
 
 
 
 
 
 
 
973
  }
974
  }
975
 
@@ -1178,18 +1248,21 @@ with gr.Blocks() as demo:
1178
  return None
1179
 
1180
  def run_ocr(model_name, text, image_b64, max_new_tokens_v, temperature_v, top_p_v, top_k_v, repetition_penalty_v, gpu_timeout_v):
1181
- image = b64_to_pil(image_b64)
1182
- yield from generate_image(
1183
- model_name=model_name,
1184
- text=text,
1185
- image=image,
1186
- max_new_tokens=max_new_tokens_v,
1187
- temperature=temperature_v,
1188
- top_p=top_p_v,
1189
- top_k=top_k_v,
1190
- repetition_penalty=repetition_penalty_v,
1191
- gpu_timeout=gpu_timeout_v,
1192
- )
 
 
 
1193
 
1194
  demo.load(fn=noop, inputs=None, outputs=None, js=gallery_js)
1195
  demo.load(fn=noop, inputs=None, outputs=None, js=wire_outputs_js)
 
222
 
223
  @spaces.GPU(duration=calc_timeout_image)
224
  def generate_image(model_name, text, image, max_new_tokens=1024, temperature=0.6, top_p=0.9, top_k=50, repetition_penalty=1.2, gpu_timeout=60):
225
+ buffer = ""
226
+ try:
227
+ if not model_name or model_name not in MODEL_MAP:
228
+ yield "[ERROR] Please select a valid model."
229
+ return
230
+ if image is None:
231
+ yield "[ERROR] Please upload an image."
232
+ return
233
+ if not text or not str(text).strip():
234
+ yield "[ERROR] Please enter your OCR/query instruction."
235
+ return
236
+ if len(str(text)) > MAX_INPUT_TOKEN_LENGTH * 8:
237
+ yield "[ERROR] Query is too long. Please shorten your input."
238
+ return
239
+
240
+ processor, model = MODEL_MAP[model_name]
241
+ images = [image]
242
+
243
+ if model_name == "SmolDocling-256M-preview":
244
+ if "OTSL" in text or "code" in text:
245
+ images = [add_random_padding(img) for img in images]
246
+ if "OCR at text at" in text or "Identify element" in text or "formula" in text:
247
+ text = normalize_values(text, target_max=500)
248
+
249
+ messages = [{
250
+ "role": "user",
251
+ "content": [{"type": "image"} for _ in images] + [{"type": "text", "text": text}]
252
+ }]
253
+
254
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
255
+ inputs = processor(text=prompt, images=images, return_tensors="pt").to(device)
256
+
257
+ streamer = TextIteratorStreamer(
258
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor,
259
+ skip_prompt=True,
260
+ skip_special_tokens=True
261
+ )
262
 
263
+ generation_error = {"error": None}
 
264
 
265
+ generation_kwargs = {
266
+ **inputs,
267
+ "streamer": streamer,
268
+ "max_new_tokens": int(max_new_tokens),
269
+ "temperature": float(temperature),
270
+ "top_p": float(top_p),
271
+ "top_k": int(top_k),
272
+ "repetition_penalty": float(repetition_penalty),
273
+ }
274
+
275
+ def _run_generation():
276
+ try:
277
+ model.generate(**generation_kwargs)
278
+ except Exception as e:
279
+ generation_error["error"] = e
280
+ try:
281
+ streamer.end()
282
+ except Exception:
283
+ pass
284
+
285
+ thread = Thread(target=_run_generation, daemon=True)
286
+ thread.start()
287
+
288
+ for new_text in streamer:
289
+ buffer += new_text.replace("<|im_end|>", "")
290
+ yield buffer
291
+
292
+ thread.join(timeout=1.0)
293
+
294
+ if generation_error["error"] is not None:
295
+ err_msg = f"[ERROR] Inference failed: {str(generation_error['error'])}"
296
+ if buffer.strip():
297
+ yield buffer + "\n\n" + err_msg
298
+ else:
299
+ yield err_msg
300
+ return
301
+
302
+ if model_name == "SmolDocling-256M-preview":
303
+ cleaned_output = buffer.replace("<end_of_utterance>", "").strip()
304
+ if any(tag in cleaned_output for tag in ["<doctag>", "<otsl>", "<code>", "<chart>", "<formula>"]):
305
+ try:
306
+ if "<chart>" in cleaned_output:
307
+ cleaned_output = cleaned_output.replace("<chart>", "<otsl>").replace("</chart>", "</otsl>")
308
+ cleaned_output = re.sub(r'(<loc_500>)(?!.*<loc_500>)<[^>]+>', r'\1', cleaned_output)
309
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs([cleaned_output], images)
310
+ doc = DoclingDocument.load_from_doctags(doctags_doc, document_name="Document")
311
+ markdown_output = doc.export_to_markdown()
312
+ yield markdown_output
313
+ except Exception as e:
314
+ yield f"[ERROR] Post-processing failed: {str(e)}"
315
+ return
316
+ else:
317
+ if cleaned_output.strip():
318
+ yield cleaned_output
319
+ else:
320
+ yield "[ERROR] No output was generated."
321
  else:
322
+ if not buffer.strip():
323
+ yield "[ERROR] No output was generated."
324
 
325
+ except Exception as e:
326
+ yield f"[ERROR] {str(e)}"
327
+ finally:
328
+ gc.collect()
329
+ if torch.cuda.is_available():
330
+ torch.cuda.empty_cache()
331
 
332
 
333
  def noop():
 
714
  const sb = document.getElementById('sb-run-state');
715
  if (sb) sb.textContent = 'Done';
716
  }
717
+ function setRunErrorState() {
718
+ const l = document.getElementById('output-loader');
719
+ if (l) l.classList.remove('active');
720
+ const sb = document.getElementById('sb-run-state');
721
+ if (sb) sb.textContent = 'Error';
722
+ }
723
+
724
  window.__showLoader = showLoader;
725
  window.__hideLoader = hideLoader;
726
+ window.__setRunErrorState = setRunErrorState;
727
 
728
  function flashPromptError() {
729
  promptInput.classList.add('error-flash');
 
898
  showLoader();
899
  setTimeout(() => {
900
  const gradioBtn = document.getElementById('gradio-run-btn');
901
+ if (!gradioBtn) {
902
+ setRunErrorState();
903
+ if (outputArea) outputArea.value = '[ERROR] Run button not found.';
904
+ showToast('Run button not found', 'error');
905
+ return;
906
+ }
907
  const btn = gradioBtn.querySelector('button');
908
  if (btn) btn.click(); else gradioBtn.click();
909
  }, 180);
 
1019
 
1020
  let lastText = '';
1021
 
1022
+ function isErrorText(val) {
1023
+ return typeof val === 'string' && val.trim().startsWith('[ERROR]');
1024
+ }
1025
+
1026
  function syncOutput() {
1027
  const el = resultContainer.querySelector('textarea') || resultContainer.querySelector('input');
1028
  if (!el) return;
 
1031
  lastText = val;
1032
  outArea.value = val;
1033
  outArea.scrollTop = outArea.scrollHeight;
1034
+
1035
+ if (val.trim()) {
1036
+ if (isErrorText(val)) {
1037
+ if (window.__setRunErrorState) window.__setRunErrorState();
1038
+ if (window.__showToast) window.__showToast('OCR failed', 'error');
1039
+ } else {
1040
+ if (window.__hideLoader) window.__hideLoader();
1041
+ }
1042
+ }
1043
  }
1044
  }
1045
 
 
1248
  return None
1249
 
1250
  def run_ocr(model_name, text, image_b64, max_new_tokens_v, temperature_v, top_p_v, top_k_v, repetition_penalty_v, gpu_timeout_v):
1251
+ try:
1252
+ image = b64_to_pil(image_b64)
1253
+ yield from generate_image(
1254
+ model_name=model_name,
1255
+ text=text,
1256
+ image=image,
1257
+ max_new_tokens=max_new_tokens_v,
1258
+ temperature=temperature_v,
1259
+ top_p=top_p_v,
1260
+ top_k=top_k_v,
1261
+ repetition_penalty=repetition_penalty_v,
1262
+ gpu_timeout=gpu_timeout_v,
1263
+ )
1264
+ except Exception as e:
1265
+ yield f"[ERROR] {str(e)}"
1266
 
1267
  demo.load(fn=noop, inputs=None, outputs=None, js=gallery_js)
1268
  demo.load(fn=noop, inputs=None, outputs=None, js=wire_outputs_js)