Geraldine commited on
Commit
136b6ad
·
verified ·
1 Parent(s): f4fb569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -19
app.py CHANGED
@@ -16,6 +16,8 @@ from PIL import Image
16
  from huggingface_hub import snapshot_download
17
 
18
  from transformers import (
 
 
19
  Qwen2VLForConditionalGeneration,
20
  Qwen3VLForConditionalGeneration,
21
  Qwen2_5_VLForConditionalGeneration,
@@ -130,14 +132,12 @@ model_v = load_model_with_attention_fallback(
130
  torch_dtype=torch.float16
131
  ).to(device).eval()
132
 
133
- MODEL_ID_Y = "rednote-hilab/dots.ocr"
134
- MODEL_PATH_Y = resolve_dots_ocr_model_path(MODEL_ID_Y)
135
- processor_y = AutoProcessor.from_pretrained(MODEL_PATH_Y, trust_remote_code=True)
136
- model_y = load_model_with_attention_fallback(
137
- AutoModelForCausalLM,
138
- MODEL_PATH_Y,
139
- trust_remote_code=True,
140
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
141
  ).to(device).eval()
142
 
143
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
@@ -168,7 +168,7 @@ model_m = load_model_with_attention_fallback(
168
 
169
  MODEL_MAP = {
170
  "Nanonets-OCR2-3B": (processor_v, model_v),
171
- "dots.OCR": (processor_y, model_y),
172
  "olmOCR-7B-0725": (processor_w, model_w),
173
  "Qwen3-VL-4B-Instruct": (processor_m, model_m),
174
  "Qwen2-VL-OCR-2B": (processor_x, model_x),
@@ -184,7 +184,7 @@ PROMPTS = {
184
  "icon": "📄"
185
  },
186
  "MARKDOWN": {
187
- "name": "Markdown Conversion",
188
  "description": "Convert document to Markdown format",
189
  "prompt": "Convert this document to Markdown. Preserve headings, lists, and formatting.",
190
  "icon": "📝"
@@ -192,7 +192,7 @@ PROMPTS = {
192
  "MARKDOWN_OCR": {
193
  "name": "Markdown OCR",
194
  "description": "Perform OCR and convert to Markdown",
195
- "prompt": "Perform OCR on this document and convert to Markdown. Preserve headings, lists, and formatting.",
196
  "icon": "🔍"
197
  },
198
  "TITLE_JSON": {
@@ -255,6 +255,30 @@ Return ONLY valid JSON with this exact structure:
255
  IMPORTANT: Return null for any field where information is NOT clearly visible.
256
  Return ONLY the JSON, no explanation.""",
257
  "icon": "📄"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  }
259
  }
260
 
@@ -264,7 +288,7 @@ image_examples = [
264
  {"query": PROMPTS["TITLE_JSON"]["prompt"], "image": "examples/ephesvt_theses_doc13.jpg", "model": "Qwen3-VL-4B-Instruct"},
265
  {"query": PROMPTS["LOCATED_TITLE_JSON"]["prompt"], "image": "examples/memoires_cridaf_doc07.jpg", "model": "Qwen2-VL-OCR-2B"},
266
  {"query": PROMPTS["GROUNDED_TITLE_JSON"]["prompt"], "image": "examples/thesefr_2015PA010690.png", "model": "Qwen2-VL-OCR-2B"},
267
- {"query": PROMPTS["FULL_SCHEMA_JSON"]["prompt"], "image": "examples/thesefr_2015PA010690.png", "model": "dots.OCR"},
268
  ]
269
 
270
 
@@ -415,6 +439,10 @@ def align_inputs_to_model_dtype(inputs, model):
415
  return inputs
416
 
417
 
 
 
 
 
418
  @spaces.GPU(duration=calc_timeout_duration)
419
  def generate_image(model_name, text, image, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_timeout):
420
  try:
@@ -425,7 +453,9 @@ def generate_image(model_name, text, image, max_new_tokens, temperature, top_p,
425
  yield "[ERROR] Please upload an image."
426
  return
427
  text = str(text or "").strip()
428
- if not text:
 
 
429
  yield "[ERROR] Please enter your OCR/query instruction."
430
  return
431
  if len(str(text)) > MAX_INPUT_TOKEN_LENGTH * 8:
@@ -434,6 +464,30 @@ def generate_image(model_name, text, image, max_new_tokens, temperature, top_p,
434
 
435
  processor, model = select_model(model_name)
436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  streamer = TextIteratorStreamer(
438
  processor.tokenizer if hasattr(processor, "tokenizer") else processor,
439
  skip_prompt=True,
@@ -724,6 +778,7 @@ footer{display:none!important}
724
  padding:10px 14px;font-family:'Inter',sans-serif;font-size:14px;color:#e4e4e7;
725
  resize:none;outline:none;min-height:100px;transition:border-color .2s;
726
  }
 
727
  .modern-textarea:focus{border-color:#ADFF2F;box-shadow:0 0 0 3px rgba(173,255,47,.14)}
728
  .modern-textarea::placeholder{color:#3f3f46}
729
  .modern-textarea.error-flash{
@@ -883,6 +938,8 @@ function init() {
883
  const btnUpload = document.getElementById('preview-upload-btn');
884
  const btnClear = document.getElementById('preview-clear-btn');
885
  const promptInput = document.getElementById('custom-query-input');
 
 
886
  const runBtnEl = document.getElementById('custom-run-btn');
887
  const outputArea = document.getElementById('custom-output-textarea');
888
  const imgStatus = document.getElementById('sb-image-status');
@@ -982,7 +1039,8 @@ function init() {
982
  if (imgStatus) imgStatus.textContent = txt;
983
  }
984
  function syncPromptToGradio() {
985
- setGradioValue('prompt-gradio-input', promptInput.value);
 
986
  }
987
  function syncModelToGradio(name) {
988
  setGradioValue('hidden-model-name', name);
@@ -1039,6 +1097,9 @@ function init() {
1039
  document.querySelectorAll('.model-tab[data-model]').forEach(btn => {
1040
  btn.classList.toggle('active', btn.getAttribute('data-model') === name);
1041
  });
 
 
 
1042
  syncModelToGradio(name);
1043
  syncPromptToGradio();
1044
  }
@@ -1105,7 +1166,9 @@ function init() {
1105
  syncSlider('custom-gpu-duration', 'gradio-gpu-duration');
1106
  function validateBeforeRun() {
1107
  const promptVal = promptInput.value.trim();
1108
- if (!imageState && !promptVal) {
 
 
1109
  showToast('Please upload an image and enter your OCR instruction', 'error');
1110
  flashPromptError();
1111
  return false;
@@ -1114,12 +1177,11 @@ function init() {
1114
  showToast('Please upload an image', 'error');
1115
  return false;
1116
  }
1117
- if (!promptVal) {
1118
  showToast('Please enter your OCR/query instruction', 'warning');
1119
  flashPromptError();
1120
  return false;
1121
  }
1122
- const currentModel = (document.querySelector('.model-tab.active') || {}).dataset?.model;
1123
  if (!currentModel) {
1124
  showToast('Please select a model', 'error');
1125
  return false;
@@ -1383,7 +1445,7 @@ with gr.Blocks() as demo:
1383
  <div class="model-tabs-bar">
1384
  {MODEL_TABS_HTML}
1385
  </div>
1386
- <div class="model-tabs-bar">
1387
  {PROMPT_TABS_HTML}
1388
  </div>
1389
  <div class="app-main-row">
@@ -1420,7 +1482,7 @@ with gr.Blocks() as demo:
1420
  </div>
1421
  </div>
1422
  <div class="app-main-right">
1423
- <div class="panel-card">
1424
  <div class="panel-card-title">OCR / Vision Instruction</div>
1425
  <div class="panel-card-body">
1426
  <label class="modern-label" for="custom-query-input">Query Input</label>
 
16
  from huggingface_hub import snapshot_download
17
 
18
  from transformers import (
19
+ LightOnOcrForConditionalGeneration,
20
+ LightOnOcrProcessor,
21
  Qwen2VLForConditionalGeneration,
22
  Qwen3VLForConditionalGeneration,
23
  Qwen2_5_VLForConditionalGeneration,
 
132
  torch_dtype=torch.float16
133
  ).to(device).eval()
134
 
135
+ MODEL_ID_Y = "lightonai/LightOnOCR-2-1B"
136
+ LIGHTON_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
137
+ processor_y = LightOnOcrProcessor.from_pretrained(MODEL_ID_Y)
138
+ model_y = LightOnOcrForConditionalGeneration.from_pretrained(
139
+ MODEL_ID_Y,
140
+ torch_dtype=LIGHTON_DTYPE,
 
 
141
  ).to(device).eval()
142
 
143
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
 
168
 
169
  MODEL_MAP = {
170
  "Nanonets-OCR2-3B": (processor_v, model_v),
171
+ "LightOnOCR-2-1B": (processor_y, model_y),
172
  "olmOCR-7B-0725": (processor_w, model_w),
173
  "Qwen3-VL-4B-Instruct": (processor_m, model_m),
174
  "Qwen2-VL-OCR-2B": (processor_x, model_x),
 
184
  "icon": "📄"
185
  },
186
  "MARKDOWN": {
187
+ "name": "Simple Markdown Conversion",
188
  "description": "Convert document to Markdown format",
189
  "prompt": "Convert this document to Markdown. Preserve headings, lists, and formatting.",
190
  "icon": "📝"
 
192
  "MARKDOWN_OCR": {
193
  "name": "Markdown OCR",
194
  "description": "Perform OCR and convert to Markdown",
195
+ "prompt": "Perform OCR including inside images and logos and convert to Markdown.",
196
  "icon": "🔍"
197
  },
198
  "TITLE_JSON": {
 
255
  IMPORTANT: Return null for any field where information is NOT clearly visible.
256
  Return ONLY the JSON, no explanation.""",
257
  "icon": "📄"
258
+ },
259
+ "NUEXTRACT_SCHEMA_JSON": {
260
+ "name": "NuExtract Json Schema",
261
+ "description": "Strict data extraction following deterministic JSON schema",
262
+ "prompt": """{
263
+ "title": "verbatim-string",
264
+ "subtitle": "verbatim-string",
265
+ "author": "verbatim-string",
266
+ "degree_type": "verbatim-string",
267
+ "discipline": [["Mathématiques", "Physique", "Autres"]],
268
+ "granting_institution": ["verbatim-string"],
269
+ "doctoral_school": ["verbatim-string"],
270
+ "co_tutelle_institutions": ["verbatim-string"],
271
+ "partner_institutions": ["verbatim-string"],
272
+ "defense_year": "integer",
273
+ "thesis_advisor": ["verbatim-string"],
274
+ "co_advisors": ["verbatim-string"],
275
+ "jury_president": "verbatim-string",
276
+ "reviewers": ["verbatim-string"],
277
+ "other_jury_members": ["verbatim-string"],
278
+ "language": "verbatim-string",
279
+ "confidence": "float"
280
+ }""",
281
+ "icon": "📄"
282
  }
283
  }
284
 
 
288
  {"query": PROMPTS["TITLE_JSON"]["prompt"], "image": "examples/ephesvt_theses_doc13.jpg", "model": "Qwen3-VL-4B-Instruct"},
289
  {"query": PROMPTS["LOCATED_TITLE_JSON"]["prompt"], "image": "examples/memoires_cridaf_doc07.jpg", "model": "Qwen2-VL-OCR-2B"},
290
  {"query": PROMPTS["GROUNDED_TITLE_JSON"]["prompt"], "image": "examples/thesefr_2015PA010690.png", "model": "Qwen2-VL-OCR-2B"},
291
+ {"query": "", "image": "examples/thesefr_2015PA010690.png", "model": "LightOnOCR-2-1B"},
292
  ]
293
 
294
 
 
439
  return inputs
440
 
441
 
442
+ def model_requires_text_prompt(model_name: str) -> bool:
443
+ return model_name != "LightOnOCR-2-1B"
444
+
445
+
446
  @spaces.GPU(duration=calc_timeout_duration)
447
  def generate_image(model_name, text, image, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_timeout):
448
  try:
 
453
  yield "[ERROR] Please upload an image."
454
  return
455
  text = str(text or "").strip()
456
+ if not model_requires_text_prompt(model_name):
457
+ text = ""
458
+ if model_requires_text_prompt(model_name) and not text:
459
  yield "[ERROR] Please enter your OCR/query instruction."
460
  return
461
  if len(str(text)) > MAX_INPUT_TOKEN_LENGTH * 8:
 
464
 
465
  processor, model = select_model(model_name)
466
 
467
+ if model_name == "LightOnOCR-2-1B":
468
+ conversation = [{"role": "user", "content": [{"type": "image", "image": image}]}]
469
+ inputs = processor.apply_chat_template(
470
+ conversation,
471
+ add_generation_prompt=True,
472
+ tokenize=True,
473
+ return_dict=True,
474
+ return_tensors="pt",
475
+ )
476
+ inputs = {
477
+ k: v.to(device=device, dtype=LIGHTON_DTYPE) if torch.is_tensor(v) and v.is_floating_point() else v.to(device)
478
+ for k, v in inputs.items()
479
+ }
480
+
481
+ output_ids = model.generate(**inputs, max_new_tokens=int(max_new_tokens))
482
+ generated_ids = output_ids[0, inputs["input_ids"].shape[1]:]
483
+ output_text = processor.decode(generated_ids, skip_special_tokens=True)
484
+
485
+ if output_text.strip():
486
+ yield output_text
487
+ else:
488
+ yield "[ERROR] No output was generated."
489
+ return
490
+
491
  streamer = TextIteratorStreamer(
492
  processor.tokenizer if hasattr(processor, "tokenizer") else processor,
493
  skip_prompt=True,
 
778
  padding:10px 14px;font-family:'Inter',sans-serif;font-size:14px;color:#e4e4e7;
779
  resize:none;outline:none;min-height:100px;transition:border-color .2s;
780
  }
781
+ .is-hidden{display:none!important}
782
  .modern-textarea:focus{border-color:#ADFF2F;box-shadow:0 0 0 3px rgba(173,255,47,.14)}
783
  .modern-textarea::placeholder{color:#3f3f46}
784
  .modern-textarea.error-flash{
 
938
  const btnUpload = document.getElementById('preview-upload-btn');
939
  const btnClear = document.getElementById('preview-clear-btn');
940
  const promptInput = document.getElementById('custom-query-input');
941
+ const promptPanel = document.getElementById('prompt-panel');
942
+ const promptTabsBar = document.getElementById('prompt-tabs-bar');
943
  const runBtnEl = document.getElementById('custom-run-btn');
944
  const outputArea = document.getElementById('custom-output-textarea');
945
  const imgStatus = document.getElementById('sb-image-status');
 
1039
  if (imgStatus) imgStatus.textContent = txt;
1040
  }
1041
  function syncPromptToGradio() {
1042
+ const activeModel = (document.querySelector('.model-tab.active') || {}).dataset?.model;
1043
+ setGradioValue('prompt-gradio-input', activeModel === 'LightOnOCR-2-1B' ? '' : promptInput.value);
1044
  }
1045
  function syncModelToGradio(name) {
1046
  setGradioValue('hidden-model-name', name);
 
1097
  document.querySelectorAll('.model-tab[data-model]').forEach(btn => {
1098
  btn.classList.toggle('active', btn.getAttribute('data-model') === name);
1099
  });
1100
+ const hidePrompt = name === 'LightOnOCR-2-1B';
1101
+ if (promptPanel) promptPanel.classList.toggle('is-hidden', hidePrompt);
1102
+ if (promptTabsBar) promptTabsBar.classList.toggle('is-hidden', hidePrompt);
1103
  syncModelToGradio(name);
1104
  syncPromptToGradio();
1105
  }
 
1166
  syncSlider('custom-gpu-duration', 'gradio-gpu-duration');
1167
  function validateBeforeRun() {
1168
  const promptVal = promptInput.value.trim();
1169
+ const currentModel = (document.querySelector('.model-tab.active') || {}).dataset?.model;
1170
+ const requiresPrompt = currentModel !== 'LightOnOCR-2-1B';
1171
+ if (!imageState && !promptVal && requiresPrompt) {
1172
  showToast('Please upload an image and enter your OCR instruction', 'error');
1173
  flashPromptError();
1174
  return false;
 
1177
  showToast('Please upload an image', 'error');
1178
  return false;
1179
  }
1180
+ if (requiresPrompt && !promptVal) {
1181
  showToast('Please enter your OCR/query instruction', 'warning');
1182
  flashPromptError();
1183
  return false;
1184
  }
 
1185
  if (!currentModel) {
1186
  showToast('Please select a model', 'error');
1187
  return false;
 
1445
  <div class="model-tabs-bar">
1446
  {MODEL_TABS_HTML}
1447
  </div>
1448
+ <div id="prompt-tabs-bar" class="model-tabs-bar">
1449
  {PROMPT_TABS_HTML}
1450
  </div>
1451
  <div class="app-main-row">
 
1482
  </div>
1483
  </div>
1484
  <div class="app-main-right">
1485
+ <div id="prompt-panel" class="panel-card">
1486
  <div class="panel-card-title">OCR / Vision Instruction</div>
1487
  <div class="panel-card-body">
1488
  <label class="modern-label" for="custom-query-input">Query Input</label>