test / app.py
iammraat's picture
Update app.py
e6dcbb9 verified
# import gradio as gr
# from ultralytics import YOLO
# from PIL import Image, ImageDraw, ImageFont
# import torch
# import logging
# import os
# from datetime import datetime
# # # ── Quiet startup ───────────────────────────────────────────────────────
# # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
# # logging.getLogger('ultralytics').setLevel(logging.WARNING)
# # logging.basicConfig(
# # level=logging.INFO,
# # format='%(asctime)s | %(level)-5s | %(message)s'
# # )
# # logger = logging.getLogger(__name__)
# os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
# logging.getLogger('ultralytics').setLevel(logging.WARNING)
# # FIXED logging format: use levelname, not level
# logging.basicConfig(
# level=logging.INFO,
# format='%(asctime)s | %(levelname)-5s | %(message)s', # ← changed level β†’ levelname
# datefmt='%Y-%m-%d %H:%M:%S'
# )
# logger = logging.getLogger(__name__)
# logger.info("Initializing region detector...")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# logger.info(f"Device: {device}")
# # ── Load YOLO ───────────────────────────────────────────────────────────
# try:
# region_pt = 'regions.pt'
# if not os.path.exists(region_pt):
# for f in os.listdir('.'):
# name = f.lower()
# if name.endswith('.pt') and 'region' in name:
# region_pt = f
# break
# if not os.path.exists(region_pt):
# raise FileNotFoundError("No regions.pt (or similar *.pt) found in current directory")
# logger.info(f"Loading model: {region_pt}")
# model = YOLO(region_pt)
# logger.info("Region detector loaded")
# except Exception as e:
# logger.error(f"Model loading failed β†’ {e}", exc_info=True)
# raise
# def visualize_regions(
# image,
# conf_thresh: float = 0.25,
# min_size: int = 60,
# padding: int = 0,
# show_labels: bool = True,
# save_debug_crops: bool = False,
# imgsz: int = 1024,
# ):
# start = datetime.now().strftime("%H:%M:%S")
# logs = [f"[{start}] Processing started"]
# if image is None:
# logs.append("No image uploaded")
# return None, "\n".join(logs)
# # Load & convert
# if isinstance(image, str):
# img = Image.open(image).convert("RGB")
# else:
# img = image.convert("RGB")
# w, h = img.size
# logs.append(f"Image size: {w} Γ— {h}")
# debug_img = img.copy()
# draw = ImageDraw.Draw(debug_img)
# try:
# # Font for drawing labels (fallback to default)
# try:
# font = ImageFont.truetype("arial.ttf", 18)
# except:
# font = ImageFont.load_default()
# # ── Run detection ───────────────────────────────────────────────
# results = model(
# img,
# conf=conf_thresh,
# imgsz=imgsz,
# verbose=False
# )[0]
# boxes = results.boxes
# logs.append(f"Detected {len(boxes)} region candidate(s)")
# kept = 0
# # Sort top β†’ bottom
# if len(boxes) > 0:
# ys = boxes.xyxy[:, 1].cpu().numpy()
# order = ys.argsort()
# for idx in order:
# box = boxes[idx]
# conf = float(box.conf)
# if conf < conf_thresh:
# continue
# x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
# bw, bh = x2 - x1, y2 - y1
# if bw < min_size or bh < min_size:
# continue
# # Optional padding (mostly for crop saving)
# px1 = max(0, x1 - padding)
# py1 = max(0, y1 - padding)
# px2 = min(w, x2 + padding)
# py2 = min(h, y2 + padding)
# # Draw box
# draw.rectangle((x1, y1, x2, y2), outline="lime", width=3)
# if show_labels:
# label = f"conf {conf:.2f} {bw}Γ—{bh}"
# tw, th = draw.textbbox((0,0), label, font=font)[2:]
# draw.rectangle(
# (x1, y1 - th - 4, x1 + tw + 8, y1),
# fill=(0, 180, 0, 160)
# )
# draw.text((x1 + 4, y1 - th - 2), label, fill="white", font=font)
# kept += 1
# # Optional: save individual crops
# if save_debug_crops:
# os.makedirs("debug_regions", exist_ok=True)
# crop = img.crop((px1, py1, px2, py2))
# fname = f"debug_regions/r{kept:02d}_conf{conf:.2f}_{bw}x{bh}.png"
# crop.save(fname)
# logs.append(f"Saved crop β†’ {fname}")
# if kept == 0:
# msg = f"No regions kept after filters (conf β‰₯ {conf_thresh}, size β‰₯ {min_size}px)"
# logs.append(msg)
# else:
# logs.append(f"Visualized {kept} region(s)")
# logs.append("Finished.")
# return debug_img, "\n".join(logs)
# except Exception as e:
# logs.append(f"Error during inference: {str(e)}")
# logger.exception("Inference failed")
# return debug_img, "\n".join(logs)
# # ── Gradio Interface ────────────────────────────────────────────────────
# demo = gr.Interface(
# fn=visualize_regions,
# inputs=[
# gr.Image(type="pil", label="Upload image (handwritten document)"),
# gr.Slider(0.10, 0.60, step=0.02, value=0.25, label="Confidence threshold"),
# gr.Slider(30, 300, step=10, value=60, label="Minimum region width/height (px)"),
# gr.Slider(0, 40, step=4, value=0, label="Padding around box (for crops only)"),
# gr.Checkbox(label="Draw confidence + size labels on boxes", value=True),
# gr.Checkbox(label="Save individual region crops to debug_regions/", value=False),
# gr.Slider(640, 1280, step=64, value=1024, label="Inference image size (imgsz)"),
# ],
# outputs=[
# gr.Image(label="Detected text regions (green boxes)"),
# gr.Textbox(label="Log / debug info", lines=14),
# ],
# title="Region Detector Debug View",
# description=(
# "Only shows what the region YOLO model sees.\n\n"
# "β€’ Green boxes = detected text regions\n"
# "β€’ Tune confidence and min size until boxes look reasonable\n"
# "β€’ Use logs to see exact confidences and sizes\n"
# "β€’ Save crops if you want to manually check what is being detected"
# ),
# # theme=gr.themes.Soft(), # ← comment out or remove (moved to launch)
# # allow_flagging="never", # ← remove this line completely
# )
# if __name__ == "__main__":
# logger.info("Launching debug interface...")
# demo.launch()
# import gradio as gr
# from ultralytics import YOLO
# from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# from PIL import Image, ImageDraw
# import torch
# import logging
# import os
# import warnings
# import time
# from datetime import datetime
# # ── Suppress noisy logs ──────────────────────────────────────────────────────
# os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
# warnings.filterwarnings('ignore')
# logging.getLogger('transformers').setLevel(logging.ERROR)
# logging.getLogger('ultralytics').setLevel(logging.WARNING)
# # Clean logging
# logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s')
# logger = logging.getLogger(__name__)
# logger.info("Initializing models...")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# logger.info(f"Device: {device}")
# def load_with_retry(cls, name, token=None, retries=4, delay=6):
# for attempt in range(1, retries + 1):
# try:
# logger.info(f"Loading {name} (attempt {attempt}/{retries})")
# if "Processor" in str(cls):
# return cls.from_pretrained(name, token=token)
# return cls.from_pretrained(name, token=token).to(device)
# except Exception as e:
# logger.warning(f"Load failed: {e}")
# if attempt < retries:
# time.sleep(delay)
# raise RuntimeError(f"Failed to load {name} after {retries} attempts")
# try:
# # Locate local YOLO line detection weights
# line_pt = 'lines.pt'
# if not os.path.exists(line_pt):
# for f in os.listdir('.'):
# name = f.lower()
# if 'line' in name and name.endswith('.pt'):
# line_pt = f
# break
# if not os.path.exists(line_pt):
# raise FileNotFoundError("Could not find lines.pt (or similar *.pt file containing 'line' in name)")
# logger.info("Loading YOLO line model...")
# line_model = YOLO(line_pt)
# logger.info("YOLO line model loaded")
# hf_token = os.getenv("HF_TOKEN")
# processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token)
# trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token)
# logger.info("TrOCR loaded β†’ ready")
# except Exception as e:
# logger.error(f"Model loading failed: {e}", exc_info=True)
# raise
# def run_ocr(crop: Image.Image) -> str:
# if crop.width < 20 or crop.height < 12:
# return ""
# pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device)
# ids = trocr.generate(pixels, max_new_tokens=128)
# return processor.batch_decode(ids, skip_special_tokens=True)[0].strip()
# def process_document(
# image,
# enable_debug_crops: bool = False,
# line_imgsz: int = 768,
# conf_thresh: float = 0.25,
# ):
# start_ts = datetime.now().strftime("%H:%M:%S")
# logs = []
# def log(msg: str, level: str = "INFO"):
# line = f"[{start_ts}] {level:5} {msg}"
# logs.append(line)
# if level == "ERROR":
# logger.error(msg)
# else:
# logger.info(msg)
# log("Start processing")
# if image is None:
# log("No image uploaded", "ERROR")
# return None, "Upload an image", "\n".join(logs)
# try:
# # ── Prepare ─────────────────────────────────────────────────────────────
# if not isinstance(image, Image.Image):
# img = Image.open(image).convert("RGB")
# else:
# img = image.convert("RGB")
# debug_img = img.copy()
# draw = ImageDraw.Draw(debug_img)
# w, h = img.size
# log(f"Input image: {w} Γ— {h} px")
# debug_dir = "debug_crops"
# if enable_debug_crops:
# os.makedirs(debug_dir, exist_ok=True)
# log(f"Debug crops will be saved to {debug_dir}/")
# extracted = []
# # ── Line detection on full image ────────────────────────────────────────
# # Adaptive size based on image dimensions
# max_dim = max(w, h)
# if max_dim > 2200:
# used_sz = 1280
# elif max_dim > 1400:
# used_sz = 1024
# elif max_dim < 600:
# used_sz = 640
# else:
# used_sz = line_imgsz
# log(f"Running line detection (imgsz={used_sz}, confβ‰₯{conf_thresh}) …")
# res = line_model(img, conf=conf_thresh, imgsz=used_sz, verbose=False)[0]
# boxes = res.boxes
# log(f"β†’ Detected {len(boxes)} line candidate(s)")
# if len(boxes) == 0:
# msg = "No text lines detected"
# log(msg, "WARNING")
# return debug_img, msg, "\n".join(logs)
# # Sort top β†’ bottom
# ys = boxes.xyxy[:, 1].cpu().numpy() # y_min
# order = ys.argsort()
# for j, idx in enumerate(order, 1):
# conf = float(boxes.conf[idx])
# x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist())
# lw, lh = x2 - x1, y2 - y1
# log(f" Line {j}/{len(boxes)} conf={conf:.3f} {x1},{y1} β†’ {x2},{y2} ({lw}Γ—{lh})")
# # Skip very small detections
# if lw < 60 or lh < 20:
# log(f" β†’ skipped (too small)")
# continue
# draw.rectangle((x1, y1, x2, y2), outline="red", width=3)
# line_crop = img.crop((x1, y1, x2, y2))
# if enable_debug_crops:
# fname = f"{debug_dir}/line_{j:02d}_conf{conf:.2f}.png"
# line_crop.save(fname)
# text = run_ocr(line_crop)
# log(f" OCR β†’ '{text}'")
# if text.strip():
# extracted.append(text)
# # ── Finalize ────────────────────────────────────────────────────────────
# if not extracted:
# msg = "No readable text found after OCR"
# log(msg, "WARNING")
# return debug_img, msg, "\n".join(logs)
# log(f"Success β€” extracted {len(extracted)} line(s)")
# if enable_debug_crops:
# log(f"Debug crops saved to {debug_dir}/")
# return debug_img, "\n".join(extracted), "\n".join(logs)
# except Exception as e:
# log(f"Processing failed: {e}", "ERROR")
# logger.exception("Traceback:")
# return debug_img, f"Error: {str(e)}", "\n".join(logs)
# demo = gr.Interface(
# fn=process_document,
# inputs=[
# gr.Image(type="pil", label="Handwritten document"),
# gr.Checkbox(label="Save debug crops", value=False),
# gr.Slider(512, 1280, step=64, value=768, label="Line detection size (imgsz)"),
# gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"),
# ],
# outputs=[
# gr.Image(label="Debug (red = detected text lines)"),
# gr.Textbox(label="Extracted Text", lines=10),
# gr.Textbox(label="Detailed Logs (copy if alignment is wrong)", lines=16),
# ],
# title="Handwritten Line Detection + TrOCR",
# description=(
# "Red boxes = text lines detected by YOLO β†’ sent to TrOCR for recognition\n\n"
# "Use **Detailed Logs** to check coordinates, sizes & confidence values if results look off."
# ),
# theme=gr.themes.Soft(),
# flagging_mode="never",
# )
# if __name__ == "__main__":
# logger.info("Launching interface…")
# demo.launch()
# app.py - FIXED VERSION with empty crop protection
import gradio as gr
from ultralytics import YOLO
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import torch
import numpy as np
# Load models
region_model = YOLO("regions.pt")
line_model = YOLO("lines.pt")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def get_crop(image: Image.Image, result, idx: int, padding: int = 15):
img_np = np.array(image) # shape: (H_full, W_full, 3)
if result.masks is not None:
# Get the ORIGINAL bounding box (before any upsampling)
box = result.boxes.xyxy[idx].cpu().numpy().astype(int) # [x1, y1, x2, y2]
x1, y1, x2, y2 = box
# Get the mask – but make sure we use the mask at ORIGINAL size
# In many cases masks.data[idx] is already at input resolution β†’ we crop it directly
mask = result.masks.data[idx].cpu().numpy() # shape likely (H_full, W_full)
mask_bool = mask > 0.5
# Crop both image and mask using the **same box coordinates**
crop_img = img_np[y1:y2, x1:x2] # shape ~ (h_box, w_box, 3)
crop_mask = mask_bool[y1:y2, x1:x2] # shape ~ (h_box, w_box)
if crop_img.size == 0 or crop_mask.size == 0:
return None
# Now apply **padding** around the cropped region
h, w = crop_img.shape[:2]
pad_top = min(padding, y1)
pad_bottom = min(padding, img_np.shape[0] - y2)
pad_left = min(padding, x1)
pad_right = min(padding, img_np.shape[1] - x2)
# Padded coordinates in full image
y_start = y1 - pad_top
y_end = y2 + pad_bottom
x_start = x1 - pad_left
x_end = x2 + pad_right
# Extract padded crops
padded_img = img_np[y_start:y_end, x_start:x_end]
padded_mask = mask_bool[y_start:y_end, x_start:x_end]
# Set background (outside mask) to white
padded_img[~padded_mask] = 255
return Image.fromarray(padded_img)
else:
# Bounding box fallback (no mask)
xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int)
x1, y1, x2, y2 = xyxy
x1 = max(0, x1 - padding)
y1 = max(0, y1 - padding)
x2 = min(image.width, x2 + padding)
y2 = min(image.height, y2 + padding)
if x2 <= x1 or y2 <= y1:
return None
return image.crop((x1, y1, x2, y2))
def process_image(image: Image.Image):
if image is None:
return "Please upload an image."
results = region_model(image)
region_result = results[0]
if region_result.boxes is None or len(region_result.boxes) == 0:
return "No text regions detected."
regions_with_pos = []
for i in range(len(region_result.boxes)):
y1 = region_result.boxes.xyxy[i][1].item()
crop = get_crop(image, region_result, i, padding=20)
if crop and crop.size[0] > 0 and crop.size[1] > 0:
regions_with_pos.append((y1, crop))
if not regions_with_pos:
return "No valid text regions after cropping."
regions_with_pos.sort(key=lambda x: x[0])
full_text_parts = []
for region_idx, (_, region_crop) in enumerate(regions_with_pos):
line_results = line_model(region_crop)
line_result = line_results[0]
if line_result.boxes is None or len(line_result.boxes) == 0:
continue
lines_with_pos = []
for j in range(len(line_result.boxes)):
rel_y1 = line_result.boxes.xyxy[j][1].item()
rel_x1 = line_result.boxes.xyxy[j][0].item()
line_crop = get_crop(region_crop, line_result, j, padding=15)
if line_crop is None or line_crop.size[0] < 10 or line_crop.size[1] < 8:
# Skip tiny/invalid crops to prevent TrOCR crash
# print(f"Skipped tiny line {j} in region {region_idx}")
continue
try:
pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device)
generated_ids = model.generate(pixel_values)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if text: # only add non-empty
lines_with_pos.append((rel_y1, rel_x1, text))
except Exception as e:
# Catch any remaining processing errors
# print(f"TrOCR error on line {j}: {e}")
continue
lines_with_pos.sort(key=lambda x: (x[0], x[1]))
region_text = "\n".join([item[2] for item in lines_with_pos if item[2]])
if region_text:
full_text_parts.append(region_text)
if not full_text_parts:
return "No readable text recognized (possibly due to small/tiny lines or model limitations). Try a clearer document or larger padding."
return "\n\n".join(full_text_parts)
# Gradio interface
demo = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload handwritten document"),
outputs=gr.Textbox(label="Recognized Text"),
title="Handwritten Text Recognition (YOLO + TrOCR)",
description="Local models: regions.pt / lines.pt + microsoft/trocr-base-handwritten. Mask-based cropping + safeguards against empty crops.",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch()