iammraat commited on
Commit
edc69a6
Β·
verified Β·
1 Parent(s): 651887a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +284 -201
app.py CHANGED
@@ -204,222 +204,305 @@
204
 
205
 
206
 
207
- import gradio as gr
208
- from ultralytics import YOLO
209
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
210
- from PIL import Image, ImageDraw
211
- import torch
212
- import logging
213
- import os
214
- import warnings
215
- import time
216
- from datetime import datetime
217
-
218
- # ── Suppress noisy logs ──────────────────────────────────────────────────────
219
- os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
220
- warnings.filterwarnings('ignore')
221
- logging.getLogger('transformers').setLevel(logging.ERROR)
222
- logging.getLogger('ultralytics').setLevel(logging.WARNING)
223
-
224
- # Clean logging
225
- logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s')
226
- logger = logging.getLogger(__name__)
227
-
228
- logger.info("Initializing models...")
229
- device = "cuda" if torch.cuda.is_available() else "cpu"
230
- logger.info(f"Device: {device}")
231
-
232
- def load_with_retry(cls, name, token=None, retries=4, delay=6):
233
- for attempt in range(1, retries + 1):
234
- try:
235
- logger.info(f"Loading {name} (attempt {attempt}/{retries})")
236
- if "Processor" in str(cls):
237
- return cls.from_pretrained(name, token=token)
238
- return cls.from_pretrained(name, token=token).to(device)
239
- except Exception as e:
240
- logger.warning(f"Load failed: {e}")
241
- if attempt < retries:
242
- time.sleep(delay)
243
- raise RuntimeError(f"Failed to load {name} after {retries} attempts")
244
-
245
-
246
- try:
247
- # Locate local YOLO line detection weights
248
- line_pt = 'lines.pt'
249
-
250
- if not os.path.exists(line_pt):
251
- for f in os.listdir('.'):
252
- name = f.lower()
253
- if 'line' in name and name.endswith('.pt'):
254
- line_pt = f
255
- break
256
-
257
- if not os.path.exists(line_pt):
258
- raise FileNotFoundError("Could not find lines.pt (or similar *.pt file containing 'line' in name)")
259
-
260
- logger.info("Loading YOLO line model...")
261
- line_model = YOLO(line_pt)
262
- logger.info("YOLO line model loaded")
263
-
264
- hf_token = os.getenv("HF_TOKEN")
265
- processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token)
266
- trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token)
267
- logger.info("TrOCR loaded β†’ ready")
268
-
269
- except Exception as e:
270
- logger.error(f"Model loading failed: {e}", exc_info=True)
271
- raise
272
-
273
-
274
- def run_ocr(crop: Image.Image) -> str:
275
- if crop.width < 20 or crop.height < 12:
276
- return ""
277
- pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device)
278
- ids = trocr.generate(pixels, max_new_tokens=128)
279
- return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
280
-
281
-
282
- def process_document(
283
- image,
284
- enable_debug_crops: bool = False,
285
- line_imgsz: int = 768,
286
- conf_thresh: float = 0.25,
287
- ):
288
- start_ts = datetime.now().strftime("%H:%M:%S")
289
- logs = []
290
-
291
- def log(msg: str, level: str = "INFO"):
292
- line = f"[{start_ts}] {level:5} {msg}"
293
- logs.append(line)
294
- if level == "ERROR":
295
- logger.error(msg)
296
- else:
297
- logger.info(msg)
298
-
299
- log("Start processing")
300
-
301
- if image is None:
302
- log("No image uploaded", "ERROR")
303
- return None, "Upload an image", "\n".join(logs)
304
-
305
- try:
306
- # ── Prepare ─────────────────────────────────────────────────────────────
307
- if not isinstance(image, Image.Image):
308
- img = Image.open(image).convert("RGB")
309
- else:
310
- img = image.convert("RGB")
311
-
312
- debug_img = img.copy()
313
- draw = ImageDraw.Draw(debug_img)
314
- w, h = img.size
315
- log(f"Input image: {w} Γ— {h} px")
316
-
317
- debug_dir = "debug_crops"
318
- if enable_debug_crops:
319
- os.makedirs(debug_dir, exist_ok=True)
320
- log(f"Debug crops will be saved to {debug_dir}/")
321
-
322
- extracted = []
323
-
324
- # ── Line detection on full image ────────────────────────────────────────
325
- # Adaptive size based on image dimensions
326
- max_dim = max(w, h)
327
- if max_dim > 2200:
328
- used_sz = 1280
329
- elif max_dim > 1400:
330
- used_sz = 1024
331
- elif max_dim < 600:
332
- used_sz = 640
333
- else:
334
- used_sz = line_imgsz
335
-
336
- log(f"Running line detection (imgsz={used_sz}, confβ‰₯{conf_thresh}) …")
337
-
338
- res = line_model(img, conf=conf_thresh, imgsz=used_sz, verbose=False)[0]
339
- boxes = res.boxes
340
-
341
- log(f"β†’ Detected {len(boxes)} line candidate(s)")
342
 
343
- if len(boxes) == 0:
344
- msg = "No text lines detected"
345
- log(msg, "WARNING")
346
- return debug_img, msg, "\n".join(logs)
347
-
348
- # Sort top β†’ bottom
349
- ys = boxes.xyxy[:, 1].cpu().numpy() # y_min
350
- order = ys.argsort()
351
-
352
- for j, idx in enumerate(order, 1):
353
- conf = float(boxes.conf[idx])
354
- x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist())
355
 
356
- lw, lh = x2 - x1, y2 - y1
357
- log(f" Line {j}/{len(boxes)} conf={conf:.3f} {x1},{y1} β†’ {x2},{y2} ({lw}Γ—{lh})")
358
 
359
- # Skip very small detections
360
- if lw < 60 or lh < 20:
361
- log(f" β†’ skipped (too small)")
362
- continue
363
 
364
- draw.rectangle((x1, y1, x2, y2), outline="red", width=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
- line_crop = img.crop((x1, y1, x2, y2))
 
 
367
 
368
- if enable_debug_crops:
369
- fname = f"{debug_dir}/line_{j:02d}_conf{conf:.2f}.png"
370
- line_crop.save(fname)
 
 
 
 
 
 
 
 
 
371
 
372
- text = run_ocr(line_crop)
373
- log(f" OCR β†’ '{text}'")
374
-
375
- if text.strip():
376
- extracted.append(text)
377
 
378
- # ── Finalize ────────────────────────────────────────────────────────────
379
- if not extracted:
380
- msg = "No readable text found after OCR"
381
- log(msg, "WARNING")
382
- return debug_img, msg, "\n".join(logs)
383
-
384
- log(f"Success β€” extracted {len(extracted)} line(s)")
385
- if enable_debug_crops:
386
- log(f"Debug crops saved to {debug_dir}/")
387
 
388
- return debug_img, "\n".join(extracted), "\n".join(logs)
 
 
 
 
 
389
 
390
- except Exception as e:
391
- log(f"Processing failed: {e}", "ERROR")
392
- logger.exception("Traceback:")
393
- return debug_img, f"Error: {str(e)}", "\n".join(logs)
394
 
 
 
 
395
 
396
- demo = gr.Interface(
397
- fn=process_document,
398
- inputs=[
399
- gr.Image(type="pil", label="Handwritten document"),
400
- gr.Checkbox(label="Save debug crops", value=False),
401
- gr.Slider(512, 1280, step=64, value=768, label="Line detection size (imgsz)"),
402
- gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"),
403
- ],
404
- outputs=[
405
- gr.Image(label="Debug (red = detected text lines)"),
406
- gr.Textbox(label="Extracted Text", lines=10),
407
- gr.Textbox(label="Detailed Logs (copy if alignment is wrong)", lines=16),
408
- ],
409
- title="Handwritten Line Detection + TrOCR",
410
- description=(
411
- "Red boxes = text lines detected by YOLO β†’ sent to TrOCR for recognition\n\n"
412
- "Use **Detailed Logs** to check coordinates, sizes & confidence values if results look off."
413
- ),
414
- theme=gr.themes.Soft(),
415
- flagging_mode="never",
416
- )
417
 
418
- if __name__ == "__main__":
419
- logger.info("Launching interface…")
420
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
 
422
 
 
423
 
 
 
 
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
 
204
 
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
 
 
209
 
 
 
 
 
210
 
211
+ # import gradio as gr
212
+ # from ultralytics import YOLO
213
+ # from transformers import TrOCRProcessor, VisionEncoderDecoderModel
214
+ # from PIL import Image, ImageDraw
215
+ # import torch
216
+ # import logging
217
+ # import os
218
+ # import warnings
219
+ # import time
220
+ # from datetime import datetime
221
+
222
+ # # ── Suppress noisy logs ──────────────────────────────────────────────────────
223
+ # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
224
+ # warnings.filterwarnings('ignore')
225
+ # logging.getLogger('transformers').setLevel(logging.ERROR)
226
+ # logging.getLogger('ultralytics').setLevel(logging.WARNING)
227
+
228
+ # # Clean logging
229
+ # logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s')
230
+ # logger = logging.getLogger(__name__)
231
 
232
+ # logger.info("Initializing models...")
233
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
234
+ # logger.info(f"Device: {device}")
235
 
236
+ # def load_with_retry(cls, name, token=None, retries=4, delay=6):
237
+ # for attempt in range(1, retries + 1):
238
+ # try:
239
+ # logger.info(f"Loading {name} (attempt {attempt}/{retries})")
240
+ # if "Processor" in str(cls):
241
+ # return cls.from_pretrained(name, token=token)
242
+ # return cls.from_pretrained(name, token=token).to(device)
243
+ # except Exception as e:
244
+ # logger.warning(f"Load failed: {e}")
245
+ # if attempt < retries:
246
+ # time.sleep(delay)
247
+ # raise RuntimeError(f"Failed to load {name} after {retries} attempts")
248
 
 
 
 
 
 
249
 
250
+ # try:
251
+ # # Locate local YOLO line detection weights
252
+ # line_pt = 'lines.pt'
 
 
 
 
 
 
253
 
254
+ # if not os.path.exists(line_pt):
255
+ # for f in os.listdir('.'):
256
+ # name = f.lower()
257
+ # if 'line' in name and name.endswith('.pt'):
258
+ # line_pt = f
259
+ # break
260
 
261
+ # if not os.path.exists(line_pt):
262
+ # raise FileNotFoundError("Could not find lines.pt (or similar *.pt file containing 'line' in name)")
 
 
263
 
264
+ # logger.info("Loading YOLO line model...")
265
+ # line_model = YOLO(line_pt)
266
+ # logger.info("YOLO line model loaded")
267
 
268
+ # hf_token = os.getenv("HF_TOKEN")
269
+ # processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token)
270
+ # trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token)
271
+ # logger.info("TrOCR loaded β†’ ready")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
+ # except Exception as e:
274
+ # logger.error(f"Model loading failed: {e}", exc_info=True)
275
+ # raise
276
+
277
+
278
+ # def run_ocr(crop: Image.Image) -> str:
279
+ # if crop.width < 20 or crop.height < 12:
280
+ # return ""
281
+ # pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device)
282
+ # ids = trocr.generate(pixels, max_new_tokens=128)
283
+ # return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
284
+
285
+
286
+ # def process_document(
287
+ # image,
288
+ # enable_debug_crops: bool = False,
289
+ # line_imgsz: int = 768,
290
+ # conf_thresh: float = 0.25,
291
+ # ):
292
+ # start_ts = datetime.now().strftime("%H:%M:%S")
293
+ # logs = []
294
+
295
+ # def log(msg: str, level: str = "INFO"):
296
+ # line = f"[{start_ts}] {level:5} {msg}"
297
+ # logs.append(line)
298
+ # if level == "ERROR":
299
+ # logger.error(msg)
300
+ # else:
301
+ # logger.info(msg)
302
+
303
+ # log("Start processing")
304
+
305
+ # if image is None:
306
+ # log("No image uploaded", "ERROR")
307
+ # return None, "Upload an image", "\n".join(logs)
308
+
309
+ # try:
310
+ # # ── Prepare ─────────────────────────────────────────────────────────────
311
+ # if not isinstance(image, Image.Image):
312
+ # img = Image.open(image).convert("RGB")
313
+ # else:
314
+ # img = image.convert("RGB")
315
+
316
+ # debug_img = img.copy()
317
+ # draw = ImageDraw.Draw(debug_img)
318
+ # w, h = img.size
319
+ # log(f"Input image: {w} Γ— {h} px")
320
+
321
+ # debug_dir = "debug_crops"
322
+ # if enable_debug_crops:
323
+ # os.makedirs(debug_dir, exist_ok=True)
324
+ # log(f"Debug crops will be saved to {debug_dir}/")
325
+
326
+ # extracted = []
327
+
328
+ # # ── Line detection on full image ────────────────────────────────────────
329
+ # # Adaptive size based on image dimensions
330
+ # max_dim = max(w, h)
331
+ # if max_dim > 2200:
332
+ # used_sz = 1280
333
+ # elif max_dim > 1400:
334
+ # used_sz = 1024
335
+ # elif max_dim < 600:
336
+ # used_sz = 640
337
+ # else:
338
+ # used_sz = line_imgsz
339
+
340
+ # log(f"Running line detection (imgsz={used_sz}, confβ‰₯{conf_thresh}) …")
341
+
342
+ # res = line_model(img, conf=conf_thresh, imgsz=used_sz, verbose=False)[0]
343
+ # boxes = res.boxes
344
+
345
+ # log(f"β†’ Detected {len(boxes)} line candidate(s)")
346
+
347
+ # if len(boxes) == 0:
348
+ # msg = "No text lines detected"
349
+ # log(msg, "WARNING")
350
+ # return debug_img, msg, "\n".join(logs)
351
+
352
+ # # Sort top β†’ bottom
353
+ # ys = boxes.xyxy[:, 1].cpu().numpy() # y_min
354
+ # order = ys.argsort()
355
+
356
+ # for j, idx in enumerate(order, 1):
357
+ # conf = float(boxes.conf[idx])
358
+ # x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist())
359
+
360
+ # lw, lh = x2 - x1, y2 - y1
361
+ # log(f" Line {j}/{len(boxes)} conf={conf:.3f} {x1},{y1} β†’ {x2},{y2} ({lw}Γ—{lh})")
362
+
363
+ # # Skip very small detections
364
+ # if lw < 60 or lh < 20:
365
+ # log(f" β†’ skipped (too small)")
366
+ # continue
367
 
368
+ # draw.rectangle((x1, y1, x2, y2), outline="red", width=3)
369
 
370
+ # line_crop = img.crop((x1, y1, x2, y2))
371
 
372
+ # if enable_debug_crops:
373
+ # fname = f"{debug_dir}/line_{j:02d}_conf{conf:.2f}.png"
374
+ # line_crop.save(fname)
375
 
376
+ # text = run_ocr(line_crop)
377
+ # log(f" OCR β†’ '{text}'")
378
+
379
+ # if text.strip():
380
+ # extracted.append(text)
381
+
382
+ # # ── Finalize ────────────────────────────────────────────────────────────
383
+ # if not extracted:
384
+ # msg = "No readable text found after OCR"
385
+ # log(msg, "WARNING")
386
+ # return debug_img, msg, "\n".join(logs)
387
+
388
+ # log(f"Success β€” extracted {len(extracted)} line(s)")
389
+ # if enable_debug_crops:
390
+ # log(f"Debug crops saved to {debug_dir}/")
391
+
392
+ # return debug_img, "\n".join(extracted), "\n".join(logs)
393
+
394
+ # except Exception as e:
395
+ # log(f"Processing failed: {e}", "ERROR")
396
+ # logger.exception("Traceback:")
397
+ # return debug_img, f"Error: {str(e)}", "\n".join(logs)
398
+
399
+
400
+ # demo = gr.Interface(
401
+ # fn=process_document,
402
+ # inputs=[
403
+ # gr.Image(type="pil", label="Handwritten document"),
404
+ # gr.Checkbox(label="Save debug crops", value=False),
405
+ # gr.Slider(512, 1280, step=64, value=768, label="Line detection size (imgsz)"),
406
+ # gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"),
407
+ # ],
408
+ # outputs=[
409
+ # gr.Image(label="Debug (red = detected text lines)"),
410
+ # gr.Textbox(label="Extracted Text", lines=10),
411
+ # gr.Textbox(label="Detailed Logs (copy if alignment is wrong)", lines=16),
412
+ # ],
413
+ # title="Handwritten Line Detection + TrOCR",
414
+ # description=(
415
+ # "Red boxes = text lines detected by YOLO β†’ sent to TrOCR for recognition\n\n"
416
+ # "Use **Detailed Logs** to check coordinates, sizes & confidence values if results look off."
417
+ # ),
418
+ # theme=gr.themes.Soft(),
419
+ # flagging_mode="never",
420
+ # )
421
+
422
+ # if __name__ == "__main__":
423
+ # logger.info("Launching interface…")
424
+ # demo.launch()
425
+
426
+
427
+
428
+
429
+
430
+
431
+
432
+
433
+
434
+
435
+
436
+
437
+
438
+
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
+ # app.py for Hugging Face Space
447
+ # This script creates a Gradio demo for HTR using Riksarkivet YOLO models for region and line detection,
448
+ # and Microsoft's TrOCR for text recognition.
449
+
450
+ import gradio as gr
451
+ from ultralytics import YOLO
452
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
453
+ from PIL import Image
454
+ import torch
455
+
456
+ # Load models (this will download from HF if not cached)
457
+ region_model = YOLO("Riksarkivet/yolov9-regions-1")
458
+ line_model = YOLO("Riksarkivet/yolov9-lines-within-regions-1")
459
+ trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
460
+ trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
461
+
462
+ def process_image(image):
463
+ # Step 1: Detect text regions
464
+ region_results = region_model(image)
465
+ texts = []
466
+
467
+ if not region_results or not region_results[0].boxes:
468
+ return "No regions detected."
469
+
470
+ for region in region_results[0].boxes:
471
+ # Extract bounding box (x1, y1, x2, y2)
472
+ x1, y1, x2, y2 = map(int, region.xyxy[0])
473
+ region_crop = image.crop((x1, y1, x2, y2))
474
+
475
+ # Step 2: Detect lines within the region
476
+ line_results = line_model(region_crop)
477
+
478
+ if not line_results or not line_results[0].boxes:
479
+ texts.append("No lines detected in this region.")
480
+ continue
481
+
482
+ region_texts = []
483
+ for line in sorted(line_results[0].boxes, key=lambda b: b.xyxy[0][1]): # Sort by y-coordinate (top to bottom)
484
+ lx1, ly1, lx2, ly2 = map(int, line.xyxy[0])
485
+ line_crop = region_crop.crop((lx1, ly1, lx2, ly2))
486
+
487
+ # Step 3: Recognize text with TrOCR
488
+ pixel_values = trocr_processor(images=line_crop, return_tensors="pt").pixel_values
489
+ generated_ids = trocr_model.generate(pixel_values)
490
+ text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
491
+ region_texts.append(text)
492
+
493
+ texts.append(" ".join(region_texts)) # Join lines in region with space or \n as needed
494
+
495
+ return "\n\n".join(texts) # Separate regions with double newline
496
+
497
+ # Gradio interface
498
+ demo = gr.Interface(
499
+ fn=process_image,
500
+ inputs=gr.Image(type="pil"),
501
+ outputs="text",
502
+ title="HTR Demo with YOLO Detection and TrOCR Recognition",
503
+ description="Upload an image of a handwritten document. The app will detect regions, then lines, and recognize text using Microsoft's TrOCR."
504
+ )
505
+
506
+ if __name__ == "__main__":
507
+ demo.launch()
508