# import gradio as gr # from ultralytics import YOLO # from PIL import Image, ImageDraw, ImageFont # import torch # import logging # import os # from datetime import datetime # # # ── Quiet startup ─────────────────────────────────────────────────────── # # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # # logging.getLogger('ultralytics').setLevel(logging.WARNING) # # logging.basicConfig( # # level=logging.INFO, # # format='%(asctime)s | %(level)-5s | %(message)s' # # ) # # logger = logging.getLogger(__name__) # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # logging.getLogger('ultralytics').setLevel(logging.WARNING) # # FIXED logging format: use levelname, not level # logging.basicConfig( # level=logging.INFO, # format='%(asctime)s | %(levelname)-5s | %(message)s', # ← changed level → levelname # datefmt='%Y-%m-%d %H:%M:%S' # ) # logger = logging.getLogger(__name__) # logger.info("Initializing region detector...") # device = "cuda" if torch.cuda.is_available() else "cpu" # logger.info(f"Device: {device}") # # ── Load YOLO ─────────────────────────────────────────────────────────── # try: # region_pt = 'regions.pt' # if not os.path.exists(region_pt): # for f in os.listdir('.'): # name = f.lower() # if name.endswith('.pt') and 'region' in name: # region_pt = f # break # if not os.path.exists(region_pt): # raise FileNotFoundError("No regions.pt (or similar *.pt) found in current directory") # logger.info(f"Loading model: {region_pt}") # model = YOLO(region_pt) # logger.info("Region detector loaded") # except Exception as e: # logger.error(f"Model loading failed → {e}", exc_info=True) # raise # def visualize_regions( # image, # conf_thresh: float = 0.25, # min_size: int = 60, # padding: int = 0, # show_labels: bool = True, # save_debug_crops: bool = False, # imgsz: int = 1024, # ): # start = datetime.now().strftime("%H:%M:%S") # logs = [f"[{start}] Processing started"] # if image is None: # logs.append("No image uploaded") # return None, "\n".join(logs) # # Load & convert # if isinstance(image, str): # img = Image.open(image).convert("RGB") # else: # img = image.convert("RGB") # w, h = img.size # logs.append(f"Image size: {w} × {h}") # debug_img = img.copy() # draw = ImageDraw.Draw(debug_img) # try: # # Font for drawing labels (fallback to default) # try: # font = ImageFont.truetype("arial.ttf", 18) # except: # font = ImageFont.load_default() # # ── Run detection ─────────────────────────────────────────────── # results = model( # img, # conf=conf_thresh, # imgsz=imgsz, # verbose=False # )[0] # boxes = results.boxes # logs.append(f"Detected {len(boxes)} region candidate(s)") # kept = 0 # # Sort top → bottom # if len(boxes) > 0: # ys = boxes.xyxy[:, 1].cpu().numpy() # order = ys.argsort() # for idx in order: # box = boxes[idx] # conf = float(box.conf) # if conf < conf_thresh: # continue # x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) # bw, bh = x2 - x1, y2 - y1 # if bw < min_size or bh < min_size: # continue # # Optional padding (mostly for crop saving) # px1 = max(0, x1 - padding) # py1 = max(0, y1 - padding) # px2 = min(w, x2 + padding) # py2 = min(h, y2 + padding) # # Draw box # draw.rectangle((x1, y1, x2, y2), outline="lime", width=3) # if show_labels: # label = f"conf {conf:.2f} {bw}×{bh}" # tw, th = draw.textbbox((0,0), label, font=font)[2:] # draw.rectangle( # (x1, y1 - th - 4, x1 + tw + 8, y1), # fill=(0, 180, 0, 160) # ) # draw.text((x1 + 4, y1 - th - 2), label, fill="white", font=font) # kept += 1 # # Optional: save individual crops # if save_debug_crops: # os.makedirs("debug_regions", exist_ok=True) # crop = img.crop((px1, py1, px2, py2)) # fname = f"debug_regions/r{kept:02d}_conf{conf:.2f}_{bw}x{bh}.png" # crop.save(fname) # logs.append(f"Saved crop → {fname}") # if kept == 0: # msg = f"No regions kept after filters (conf ≥ {conf_thresh}, size ≥ {min_size}px)" # logs.append(msg) # else: # logs.append(f"Visualized {kept} region(s)") # logs.append("Finished.") # return debug_img, "\n".join(logs) # except Exception as e: # logs.append(f"Error during inference: {str(e)}") # logger.exception("Inference failed") # return debug_img, "\n".join(logs) # # ── Gradio Interface ──────────────────────────────────────────────────── # demo = gr.Interface( # fn=visualize_regions, # inputs=[ # gr.Image(type="pil", label="Upload image (handwritten document)"), # gr.Slider(0.10, 0.60, step=0.02, value=0.25, label="Confidence threshold"), # gr.Slider(30, 300, step=10, value=60, label="Minimum region width/height (px)"), # gr.Slider(0, 40, step=4, value=0, label="Padding around box (for crops only)"), # gr.Checkbox(label="Draw confidence + size labels on boxes", value=True), # gr.Checkbox(label="Save individual region crops to debug_regions/", value=False), # gr.Slider(640, 1280, step=64, value=1024, label="Inference image size (imgsz)"), # ], # outputs=[ # gr.Image(label="Detected text regions (green boxes)"), # gr.Textbox(label="Log / debug info", lines=14), # ], # title="Region Detector Debug View", # description=( # "Only shows what the region YOLO model sees.\n\n" # "• Green boxes = detected text regions\n" # "• Tune confidence and min size until boxes look reasonable\n" # "• Use logs to see exact confidences and sizes\n" # "• Save crops if you want to manually check what is being detected" # ), # # theme=gr.themes.Soft(), # ← comment out or remove (moved to launch) # # allow_flagging="never", # ← remove this line completely # ) # if __name__ == "__main__": # logger.info("Launching debug interface...") # demo.launch() # import gradio as gr # from ultralytics import YOLO # from transformers import TrOCRProcessor, VisionEncoderDecoderModel # from PIL import Image, ImageDraw # import torch # import logging # import os # import warnings # import time # from datetime import datetime # # ── Suppress noisy logs ────────────────────────────────────────────────────── # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' # warnings.filterwarnings('ignore') # logging.getLogger('transformers').setLevel(logging.ERROR) # logging.getLogger('ultralytics').setLevel(logging.WARNING) # # Clean logging # logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s') # logger = logging.getLogger(__name__) # logger.info("Initializing models...") # device = "cuda" if torch.cuda.is_available() else "cpu" # logger.info(f"Device: {device}") # def load_with_retry(cls, name, token=None, retries=4, delay=6): # for attempt in range(1, retries + 1): # try: # logger.info(f"Loading {name} (attempt {attempt}/{retries})") # if "Processor" in str(cls): # return cls.from_pretrained(name, token=token) # return cls.from_pretrained(name, token=token).to(device) # except Exception as e: # logger.warning(f"Load failed: {e}") # if attempt < retries: # time.sleep(delay) # raise RuntimeError(f"Failed to load {name} after {retries} attempts") # try: # # Locate local YOLO line detection weights # line_pt = 'lines.pt' # if not os.path.exists(line_pt): # for f in os.listdir('.'): # name = f.lower() # if 'line' in name and name.endswith('.pt'): # line_pt = f # break # if not os.path.exists(line_pt): # raise FileNotFoundError("Could not find lines.pt (or similar *.pt file containing 'line' in name)") # logger.info("Loading YOLO line model...") # line_model = YOLO(line_pt) # logger.info("YOLO line model loaded") # hf_token = os.getenv("HF_TOKEN") # processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token) # trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token) # logger.info("TrOCR loaded → ready") # except Exception as e: # logger.error(f"Model loading failed: {e}", exc_info=True) # raise # def run_ocr(crop: Image.Image) -> str: # if crop.width < 20 or crop.height < 12: # return "" # pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device) # ids = trocr.generate(pixels, max_new_tokens=128) # return processor.batch_decode(ids, skip_special_tokens=True)[0].strip() # def process_document( # image, # enable_debug_crops: bool = False, # line_imgsz: int = 768, # conf_thresh: float = 0.25, # ): # start_ts = datetime.now().strftime("%H:%M:%S") # logs = [] # def log(msg: str, level: str = "INFO"): # line = f"[{start_ts}] {level:5} {msg}" # logs.append(line) # if level == "ERROR": # logger.error(msg) # else: # logger.info(msg) # log("Start processing") # if image is None: # log("No image uploaded", "ERROR") # return None, "Upload an image", "\n".join(logs) # try: # # ── Prepare ───────────────────────────────────────────────────────────── # if not isinstance(image, Image.Image): # img = Image.open(image).convert("RGB") # else: # img = image.convert("RGB") # debug_img = img.copy() # draw = ImageDraw.Draw(debug_img) # w, h = img.size # log(f"Input image: {w} × {h} px") # debug_dir = "debug_crops" # if enable_debug_crops: # os.makedirs(debug_dir, exist_ok=True) # log(f"Debug crops will be saved to {debug_dir}/") # extracted = [] # # ── Line detection on full image ──────────────────────────────────────── # # Adaptive size based on image dimensions # max_dim = max(w, h) # if max_dim > 2200: # used_sz = 1280 # elif max_dim > 1400: # used_sz = 1024 # elif max_dim < 600: # used_sz = 640 # else: # used_sz = line_imgsz # log(f"Running line detection (imgsz={used_sz}, conf≥{conf_thresh}) …") # res = line_model(img, conf=conf_thresh, imgsz=used_sz, verbose=False)[0] # boxes = res.boxes # log(f"→ Detected {len(boxes)} line candidate(s)") # if len(boxes) == 0: # msg = "No text lines detected" # log(msg, "WARNING") # return debug_img, msg, "\n".join(logs) # # Sort top → bottom # ys = boxes.xyxy[:, 1].cpu().numpy() # y_min # order = ys.argsort() # for j, idx in enumerate(order, 1): # conf = float(boxes.conf[idx]) # x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist()) # lw, lh = x2 - x1, y2 - y1 # log(f" Line {j}/{len(boxes)} conf={conf:.3f} {x1},{y1} → {x2},{y2} ({lw}×{lh})") # # Skip very small detections # if lw < 60 or lh < 20: # log(f" → skipped (too small)") # continue # draw.rectangle((x1, y1, x2, y2), outline="red", width=3) # line_crop = img.crop((x1, y1, x2, y2)) # if enable_debug_crops: # fname = f"{debug_dir}/line_{j:02d}_conf{conf:.2f}.png" # line_crop.save(fname) # text = run_ocr(line_crop) # log(f" OCR → '{text}'") # if text.strip(): # extracted.append(text) # # ── Finalize ──────────────────────────────────────────────────────────── # if not extracted: # msg = "No readable text found after OCR" # log(msg, "WARNING") # return debug_img, msg, "\n".join(logs) # log(f"Success — extracted {len(extracted)} line(s)") # if enable_debug_crops: # log(f"Debug crops saved to {debug_dir}/") # return debug_img, "\n".join(extracted), "\n".join(logs) # except Exception as e: # log(f"Processing failed: {e}", "ERROR") # logger.exception("Traceback:") # return debug_img, f"Error: {str(e)}", "\n".join(logs) # demo = gr.Interface( # fn=process_document, # inputs=[ # gr.Image(type="pil", label="Handwritten document"), # gr.Checkbox(label="Save debug crops", value=False), # gr.Slider(512, 1280, step=64, value=768, label="Line detection size (imgsz)"), # gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"), # ], # outputs=[ # gr.Image(label="Debug (red = detected text lines)"), # gr.Textbox(label="Extracted Text", lines=10), # gr.Textbox(label="Detailed Logs (copy if alignment is wrong)", lines=16), # ], # title="Handwritten Line Detection + TrOCR", # description=( # "Red boxes = text lines detected by YOLO → sent to TrOCR for recognition\n\n" # "Use **Detailed Logs** to check coordinates, sizes & confidence values if results look off." # ), # theme=gr.themes.Soft(), # flagging_mode="never", # ) # if __name__ == "__main__": # logger.info("Launching interface…") # demo.launch() # app.py - FIXED VERSION with empty crop protection import gradio as gr from ultralytics import YOLO from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image import torch import numpy as np # Load models region_model = YOLO("regions.pt") line_model = YOLO("lines.pt") processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def get_crop(image: Image.Image, result, idx: int, padding: int = 15): img_np = np.array(image) # shape: (H_full, W_full, 3) if result.masks is not None: # Get the ORIGINAL bounding box (before any upsampling) box = result.boxes.xyxy[idx].cpu().numpy().astype(int) # [x1, y1, x2, y2] x1, y1, x2, y2 = box # Get the mask – but make sure we use the mask at ORIGINAL size # In many cases masks.data[idx] is already at input resolution → we crop it directly mask = result.masks.data[idx].cpu().numpy() # shape likely (H_full, W_full) mask_bool = mask > 0.5 # Crop both image and mask using the **same box coordinates** crop_img = img_np[y1:y2, x1:x2] # shape ~ (h_box, w_box, 3) crop_mask = mask_bool[y1:y2, x1:x2] # shape ~ (h_box, w_box) if crop_img.size == 0 or crop_mask.size == 0: return None # Now apply **padding** around the cropped region h, w = crop_img.shape[:2] pad_top = min(padding, y1) pad_bottom = min(padding, img_np.shape[0] - y2) pad_left = min(padding, x1) pad_right = min(padding, img_np.shape[1] - x2) # Padded coordinates in full image y_start = y1 - pad_top y_end = y2 + pad_bottom x_start = x1 - pad_left x_end = x2 + pad_right # Extract padded crops padded_img = img_np[y_start:y_end, x_start:x_end] padded_mask = mask_bool[y_start:y_end, x_start:x_end] # Set background (outside mask) to white padded_img[~padded_mask] = 255 return Image.fromarray(padded_img) else: # Bounding box fallback (no mask) xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int) x1, y1, x2, y2 = xyxy x1 = max(0, x1 - padding) y1 = max(0, y1 - padding) x2 = min(image.width, x2 + padding) y2 = min(image.height, y2 + padding) if x2 <= x1 or y2 <= y1: return None return image.crop((x1, y1, x2, y2)) def process_image(image: Image.Image): if image is None: return "Please upload an image." results = region_model(image) region_result = results[0] if region_result.boxes is None or len(region_result.boxes) == 0: return "No text regions detected." regions_with_pos = [] for i in range(len(region_result.boxes)): y1 = region_result.boxes.xyxy[i][1].item() crop = get_crop(image, region_result, i, padding=20) if crop and crop.size[0] > 0 and crop.size[1] > 0: regions_with_pos.append((y1, crop)) if not regions_with_pos: return "No valid text regions after cropping." regions_with_pos.sort(key=lambda x: x[0]) full_text_parts = [] for region_idx, (_, region_crop) in enumerate(regions_with_pos): line_results = line_model(region_crop) line_result = line_results[0] if line_result.boxes is None or len(line_result.boxes) == 0: continue lines_with_pos = [] for j in range(len(line_result.boxes)): rel_y1 = line_result.boxes.xyxy[j][1].item() rel_x1 = line_result.boxes.xyxy[j][0].item() line_crop = get_crop(region_crop, line_result, j, padding=15) if line_crop is None or line_crop.size[0] < 10 or line_crop.size[1] < 8: # Skip tiny/invalid crops to prevent TrOCR crash # print(f"Skipped tiny line {j} in region {region_idx}") continue try: pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device) generated_ids = model.generate(pixel_values) text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() if text: # only add non-empty lines_with_pos.append((rel_y1, rel_x1, text)) except Exception as e: # Catch any remaining processing errors # print(f"TrOCR error on line {j}: {e}") continue lines_with_pos.sort(key=lambda x: (x[0], x[1])) region_text = "\n".join([item[2] for item in lines_with_pos if item[2]]) if region_text: full_text_parts.append(region_text) if not full_text_parts: return "No readable text recognized (possibly due to small/tiny lines or model limitations). Try a clearer document or larger padding." return "\n\n".join(full_text_parts) # Gradio interface demo = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="Upload handwritten document"), outputs=gr.Textbox(label="Recognized Text"), title="Handwritten Text Recognition (YOLO + TrOCR)", description="Local models: regions.pt / lines.pt + microsoft/trocr-base-handwritten. Mask-based cropping + safeguards against empty crops.", flagging_mode="never" ) if __name__ == "__main__": demo.launch()