| # import gradio as gr | |
| # from ultralytics import YOLO | |
| # from PIL import Image, ImageDraw, ImageFont | |
| # import torch | |
| # import logging | |
| # import os | |
| # from datetime import datetime | |
| # # # ββ Quiet startup βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' | |
| # # logging.getLogger('ultralytics').setLevel(logging.WARNING) | |
| # # logging.basicConfig( | |
| # # level=logging.INFO, | |
| # # format='%(asctime)s | %(level)-5s | %(message)s' | |
| # # ) | |
| # # logger = logging.getLogger(__name__) | |
| # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' | |
| # logging.getLogger('ultralytics').setLevel(logging.WARNING) | |
| # # FIXED logging format: use levelname, not level | |
| # logging.basicConfig( | |
| # level=logging.INFO, | |
| # format='%(asctime)s | %(levelname)-5s | %(message)s', # β changed level β levelname | |
| # datefmt='%Y-%m-%d %H:%M:%S' | |
| # ) | |
| # logger = logging.getLogger(__name__) | |
| # logger.info("Initializing region detector...") | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # logger.info(f"Device: {device}") | |
| # # ββ Load YOLO βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # try: | |
| # region_pt = 'regions.pt' | |
| # if not os.path.exists(region_pt): | |
| # for f in os.listdir('.'): | |
| # name = f.lower() | |
| # if name.endswith('.pt') and 'region' in name: | |
| # region_pt = f | |
| # break | |
| # if not os.path.exists(region_pt): | |
| # raise FileNotFoundError("No regions.pt (or similar *.pt) found in current directory") | |
| # logger.info(f"Loading model: {region_pt}") | |
| # model = YOLO(region_pt) | |
| # logger.info("Region detector loaded") | |
| # except Exception as e: | |
| # logger.error(f"Model loading failed β {e}", exc_info=True) | |
| # raise | |
| # def visualize_regions( | |
| # image, | |
| # conf_thresh: float = 0.25, | |
| # min_size: int = 60, | |
| # padding: int = 0, | |
| # show_labels: bool = True, | |
| # save_debug_crops: bool = False, | |
| # imgsz: int = 1024, | |
| # ): | |
| # start = datetime.now().strftime("%H:%M:%S") | |
| # logs = [f"[{start}] Processing started"] | |
| # if image is None: | |
| # logs.append("No image uploaded") | |
| # return None, "\n".join(logs) | |
| # # Load & convert | |
| # if isinstance(image, str): | |
| # img = Image.open(image).convert("RGB") | |
| # else: | |
| # img = image.convert("RGB") | |
| # w, h = img.size | |
| # logs.append(f"Image size: {w} Γ {h}") | |
| # debug_img = img.copy() | |
| # draw = ImageDraw.Draw(debug_img) | |
| # try: | |
| # # Font for drawing labels (fallback to default) | |
| # try: | |
| # font = ImageFont.truetype("arial.ttf", 18) | |
| # except: | |
| # font = ImageFont.load_default() | |
| # # ββ Run detection βββββββββββββββββββββββββββββββββββββββββββββββ | |
| # results = model( | |
| # img, | |
| # conf=conf_thresh, | |
| # imgsz=imgsz, | |
| # verbose=False | |
| # )[0] | |
| # boxes = results.boxes | |
| # logs.append(f"Detected {len(boxes)} region candidate(s)") | |
| # kept = 0 | |
| # # Sort top β bottom | |
| # if len(boxes) > 0: | |
| # ys = boxes.xyxy[:, 1].cpu().numpy() | |
| # order = ys.argsort() | |
| # for idx in order: | |
| # box = boxes[idx] | |
| # conf = float(box.conf) | |
| # if conf < conf_thresh: | |
| # continue | |
| # x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
| # bw, bh = x2 - x1, y2 - y1 | |
| # if bw < min_size or bh < min_size: | |
| # continue | |
| # # Optional padding (mostly for crop saving) | |
| # px1 = max(0, x1 - padding) | |
| # py1 = max(0, y1 - padding) | |
| # px2 = min(w, x2 + padding) | |
| # py2 = min(h, y2 + padding) | |
| # # Draw box | |
| # draw.rectangle((x1, y1, x2, y2), outline="lime", width=3) | |
| # if show_labels: | |
| # label = f"conf {conf:.2f} {bw}Γ{bh}" | |
| # tw, th = draw.textbbox((0,0), label, font=font)[2:] | |
| # draw.rectangle( | |
| # (x1, y1 - th - 4, x1 + tw + 8, y1), | |
| # fill=(0, 180, 0, 160) | |
| # ) | |
| # draw.text((x1 + 4, y1 - th - 2), label, fill="white", font=font) | |
| # kept += 1 | |
| # # Optional: save individual crops | |
| # if save_debug_crops: | |
| # os.makedirs("debug_regions", exist_ok=True) | |
| # crop = img.crop((px1, py1, px2, py2)) | |
| # fname = f"debug_regions/r{kept:02d}_conf{conf:.2f}_{bw}x{bh}.png" | |
| # crop.save(fname) | |
| # logs.append(f"Saved crop β {fname}") | |
| # if kept == 0: | |
| # msg = f"No regions kept after filters (conf β₯ {conf_thresh}, size β₯ {min_size}px)" | |
| # logs.append(msg) | |
| # else: | |
| # logs.append(f"Visualized {kept} region(s)") | |
| # logs.append("Finished.") | |
| # return debug_img, "\n".join(logs) | |
| # except Exception as e: | |
| # logs.append(f"Error during inference: {str(e)}") | |
| # logger.exception("Inference failed") | |
| # return debug_img, "\n".join(logs) | |
| # # ββ Gradio Interface ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # demo = gr.Interface( | |
| # fn=visualize_regions, | |
| # inputs=[ | |
| # gr.Image(type="pil", label="Upload image (handwritten document)"), | |
| # gr.Slider(0.10, 0.60, step=0.02, value=0.25, label="Confidence threshold"), | |
| # gr.Slider(30, 300, step=10, value=60, label="Minimum region width/height (px)"), | |
| # gr.Slider(0, 40, step=4, value=0, label="Padding around box (for crops only)"), | |
| # gr.Checkbox(label="Draw confidence + size labels on boxes", value=True), | |
| # gr.Checkbox(label="Save individual region crops to debug_regions/", value=False), | |
| # gr.Slider(640, 1280, step=64, value=1024, label="Inference image size (imgsz)"), | |
| # ], | |
| # outputs=[ | |
| # gr.Image(label="Detected text regions (green boxes)"), | |
| # gr.Textbox(label="Log / debug info", lines=14), | |
| # ], | |
| # title="Region Detector Debug View", | |
| # description=( | |
| # "Only shows what the region YOLO model sees.\n\n" | |
| # "β’ Green boxes = detected text regions\n" | |
| # "β’ Tune confidence and min size until boxes look reasonable\n" | |
| # "β’ Use logs to see exact confidences and sizes\n" | |
| # "β’ Save crops if you want to manually check what is being detected" | |
| # ), | |
| # # theme=gr.themes.Soft(), # β comment out or remove (moved to launch) | |
| # # allow_flagging="never", # β remove this line completely | |
| # ) | |
| # if __name__ == "__main__": | |
| # logger.info("Launching debug interface...") | |
| # demo.launch() | |
| # import gradio as gr | |
| # from ultralytics import YOLO | |
| # from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| # from PIL import Image, ImageDraw | |
| # import torch | |
| # import logging | |
| # import os | |
| # import warnings | |
| # import time | |
| # from datetime import datetime | |
| # # ββ Suppress noisy logs ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' | |
| # warnings.filterwarnings('ignore') | |
| # logging.getLogger('transformers').setLevel(logging.ERROR) | |
| # logging.getLogger('ultralytics').setLevel(logging.WARNING) | |
| # # Clean logging | |
| # logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-5s | %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| # logger.info("Initializing models...") | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # logger.info(f"Device: {device}") | |
| # def load_with_retry(cls, name, token=None, retries=4, delay=6): | |
| # for attempt in range(1, retries + 1): | |
| # try: | |
| # logger.info(f"Loading {name} (attempt {attempt}/{retries})") | |
| # if "Processor" in str(cls): | |
| # return cls.from_pretrained(name, token=token) | |
| # return cls.from_pretrained(name, token=token).to(device) | |
| # except Exception as e: | |
| # logger.warning(f"Load failed: {e}") | |
| # if attempt < retries: | |
| # time.sleep(delay) | |
| # raise RuntimeError(f"Failed to load {name} after {retries} attempts") | |
| # try: | |
| # # Locate local YOLO line detection weights | |
| # line_pt = 'lines.pt' | |
| # if not os.path.exists(line_pt): | |
| # for f in os.listdir('.'): | |
| # name = f.lower() | |
| # if 'line' in name and name.endswith('.pt'): | |
| # line_pt = f | |
| # break | |
| # if not os.path.exists(line_pt): | |
| # raise FileNotFoundError("Could not find lines.pt (or similar *.pt file containing 'line' in name)") | |
| # logger.info("Loading YOLO line model...") | |
| # line_model = YOLO(line_pt) | |
| # logger.info("YOLO line model loaded") | |
| # hf_token = os.getenv("HF_TOKEN") | |
| # processor = load_with_retry(TrOCRProcessor, "microsoft/trocr-base-handwritten", hf_token) | |
| # trocr = load_with_retry(VisionEncoderDecoderModel, "microsoft/trocr-base-handwritten", hf_token) | |
| # logger.info("TrOCR loaded β ready") | |
| # except Exception as e: | |
| # logger.error(f"Model loading failed: {e}", exc_info=True) | |
| # raise | |
| # def run_ocr(crop: Image.Image) -> str: | |
| # if crop.width < 20 or crop.height < 12: | |
| # return "" | |
| # pixels = processor(images=crop, return_tensors="pt").pixel_values.to(device) | |
| # ids = trocr.generate(pixels, max_new_tokens=128) | |
| # return processor.batch_decode(ids, skip_special_tokens=True)[0].strip() | |
| # def process_document( | |
| # image, | |
| # enable_debug_crops: bool = False, | |
| # line_imgsz: int = 768, | |
| # conf_thresh: float = 0.25, | |
| # ): | |
| # start_ts = datetime.now().strftime("%H:%M:%S") | |
| # logs = [] | |
| # def log(msg: str, level: str = "INFO"): | |
| # line = f"[{start_ts}] {level:5} {msg}" | |
| # logs.append(line) | |
| # if level == "ERROR": | |
| # logger.error(msg) | |
| # else: | |
| # logger.info(msg) | |
| # log("Start processing") | |
| # if image is None: | |
| # log("No image uploaded", "ERROR") | |
| # return None, "Upload an image", "\n".join(logs) | |
| # try: | |
| # # ββ Prepare βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # if not isinstance(image, Image.Image): | |
| # img = Image.open(image).convert("RGB") | |
| # else: | |
| # img = image.convert("RGB") | |
| # debug_img = img.copy() | |
| # draw = ImageDraw.Draw(debug_img) | |
| # w, h = img.size | |
| # log(f"Input image: {w} Γ {h} px") | |
| # debug_dir = "debug_crops" | |
| # if enable_debug_crops: | |
| # os.makedirs(debug_dir, exist_ok=True) | |
| # log(f"Debug crops will be saved to {debug_dir}/") | |
| # extracted = [] | |
| # # ββ Line detection on full image ββββββββββββββββββββββββββββββββββββββββ | |
| # # Adaptive size based on image dimensions | |
| # max_dim = max(w, h) | |
| # if max_dim > 2200: | |
| # used_sz = 1280 | |
| # elif max_dim > 1400: | |
| # used_sz = 1024 | |
| # elif max_dim < 600: | |
| # used_sz = 640 | |
| # else: | |
| # used_sz = line_imgsz | |
| # log(f"Running line detection (imgsz={used_sz}, confβ₯{conf_thresh}) β¦") | |
| # res = line_model(img, conf=conf_thresh, imgsz=used_sz, verbose=False)[0] | |
| # boxes = res.boxes | |
| # log(f"β Detected {len(boxes)} line candidate(s)") | |
| # if len(boxes) == 0: | |
| # msg = "No text lines detected" | |
| # log(msg, "WARNING") | |
| # return debug_img, msg, "\n".join(logs) | |
| # # Sort top β bottom | |
| # ys = boxes.xyxy[:, 1].cpu().numpy() # y_min | |
| # order = ys.argsort() | |
| # for j, idx in enumerate(order, 1): | |
| # conf = float(boxes.conf[idx]) | |
| # x1, y1, x2, y2 = map(round, boxes.xyxy[idx].cpu().tolist()) | |
| # lw, lh = x2 - x1, y2 - y1 | |
| # log(f" Line {j}/{len(boxes)} conf={conf:.3f} {x1},{y1} β {x2},{y2} ({lw}Γ{lh})") | |
| # # Skip very small detections | |
| # if lw < 60 or lh < 20: | |
| # log(f" β skipped (too small)") | |
| # continue | |
| # draw.rectangle((x1, y1, x2, y2), outline="red", width=3) | |
| # line_crop = img.crop((x1, y1, x2, y2)) | |
| # if enable_debug_crops: | |
| # fname = f"{debug_dir}/line_{j:02d}_conf{conf:.2f}.png" | |
| # line_crop.save(fname) | |
| # text = run_ocr(line_crop) | |
| # log(f" OCR β '{text}'") | |
| # if text.strip(): | |
| # extracted.append(text) | |
| # # ββ Finalize ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # if not extracted: | |
| # msg = "No readable text found after OCR" | |
| # log(msg, "WARNING") | |
| # return debug_img, msg, "\n".join(logs) | |
| # log(f"Success β extracted {len(extracted)} line(s)") | |
| # if enable_debug_crops: | |
| # log(f"Debug crops saved to {debug_dir}/") | |
| # return debug_img, "\n".join(extracted), "\n".join(logs) | |
| # except Exception as e: | |
| # log(f"Processing failed: {e}", "ERROR") | |
| # logger.exception("Traceback:") | |
| # return debug_img, f"Error: {str(e)}", "\n".join(logs) | |
| # demo = gr.Interface( | |
| # fn=process_document, | |
| # inputs=[ | |
| # gr.Image(type="pil", label="Handwritten document"), | |
| # gr.Checkbox(label="Save debug crops", value=False), | |
| # gr.Slider(512, 1280, step=64, value=768, label="Line detection size (imgsz)"), | |
| # gr.Slider(0.15, 0.5, step=0.05, value=0.25, label="Confidence threshold"), | |
| # ], | |
| # outputs=[ | |
| # gr.Image(label="Debug (red = detected text lines)"), | |
| # gr.Textbox(label="Extracted Text", lines=10), | |
| # gr.Textbox(label="Detailed Logs (copy if alignment is wrong)", lines=16), | |
| # ], | |
| # title="Handwritten Line Detection + TrOCR", | |
| # description=( | |
| # "Red boxes = text lines detected by YOLO β sent to TrOCR for recognition\n\n" | |
| # "Use **Detailed Logs** to check coordinates, sizes & confidence values if results look off." | |
| # ), | |
| # theme=gr.themes.Soft(), | |
| # flagging_mode="never", | |
| # ) | |
| # if __name__ == "__main__": | |
| # logger.info("Launching interfaceβ¦") | |
| # demo.launch() | |
| # app.py - FIXED VERSION with empty crop protection | |
| import gradio as gr | |
| from ultralytics import YOLO | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| # Load models | |
| region_model = YOLO("regions.pt") | |
| line_model = YOLO("lines.pt") | |
| processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") | |
| model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| def get_crop(image: Image.Image, result, idx: int, padding: int = 15): | |
| img_np = np.array(image) # shape: (H_full, W_full, 3) | |
| if result.masks is not None: | |
| # Get the ORIGINAL bounding box (before any upsampling) | |
| box = result.boxes.xyxy[idx].cpu().numpy().astype(int) # [x1, y1, x2, y2] | |
| x1, y1, x2, y2 = box | |
| # Get the mask β but make sure we use the mask at ORIGINAL size | |
| # In many cases masks.data[idx] is already at input resolution β we crop it directly | |
| mask = result.masks.data[idx].cpu().numpy() # shape likely (H_full, W_full) | |
| mask_bool = mask > 0.5 | |
| # Crop both image and mask using the **same box coordinates** | |
| crop_img = img_np[y1:y2, x1:x2] # shape ~ (h_box, w_box, 3) | |
| crop_mask = mask_bool[y1:y2, x1:x2] # shape ~ (h_box, w_box) | |
| if crop_img.size == 0 or crop_mask.size == 0: | |
| return None | |
| # Now apply **padding** around the cropped region | |
| h, w = crop_img.shape[:2] | |
| pad_top = min(padding, y1) | |
| pad_bottom = min(padding, img_np.shape[0] - y2) | |
| pad_left = min(padding, x1) | |
| pad_right = min(padding, img_np.shape[1] - x2) | |
| # Padded coordinates in full image | |
| y_start = y1 - pad_top | |
| y_end = y2 + pad_bottom | |
| x_start = x1 - pad_left | |
| x_end = x2 + pad_right | |
| # Extract padded crops | |
| padded_img = img_np[y_start:y_end, x_start:x_end] | |
| padded_mask = mask_bool[y_start:y_end, x_start:x_end] | |
| # Set background (outside mask) to white | |
| padded_img[~padded_mask] = 255 | |
| return Image.fromarray(padded_img) | |
| else: | |
| # Bounding box fallback (no mask) | |
| xyxy = result.boxes.xyxy[idx].cpu().numpy().astype(int) | |
| x1, y1, x2, y2 = xyxy | |
| x1 = max(0, x1 - padding) | |
| y1 = max(0, y1 - padding) | |
| x2 = min(image.width, x2 + padding) | |
| y2 = min(image.height, y2 + padding) | |
| if x2 <= x1 or y2 <= y1: | |
| return None | |
| return image.crop((x1, y1, x2, y2)) | |
| def process_image(image: Image.Image): | |
| if image is None: | |
| return "Please upload an image." | |
| results = region_model(image) | |
| region_result = results[0] | |
| if region_result.boxes is None or len(region_result.boxes) == 0: | |
| return "No text regions detected." | |
| regions_with_pos = [] | |
| for i in range(len(region_result.boxes)): | |
| y1 = region_result.boxes.xyxy[i][1].item() | |
| crop = get_crop(image, region_result, i, padding=20) | |
| if crop and crop.size[0] > 0 and crop.size[1] > 0: | |
| regions_with_pos.append((y1, crop)) | |
| if not regions_with_pos: | |
| return "No valid text regions after cropping." | |
| regions_with_pos.sort(key=lambda x: x[0]) | |
| full_text_parts = [] | |
| for region_idx, (_, region_crop) in enumerate(regions_with_pos): | |
| line_results = line_model(region_crop) | |
| line_result = line_results[0] | |
| if line_result.boxes is None or len(line_result.boxes) == 0: | |
| continue | |
| lines_with_pos = [] | |
| for j in range(len(line_result.boxes)): | |
| rel_y1 = line_result.boxes.xyxy[j][1].item() | |
| rel_x1 = line_result.boxes.xyxy[j][0].item() | |
| line_crop = get_crop(region_crop, line_result, j, padding=15) | |
| if line_crop is None or line_crop.size[0] < 10 or line_crop.size[1] < 8: | |
| # Skip tiny/invalid crops to prevent TrOCR crash | |
| # print(f"Skipped tiny line {j} in region {region_idx}") | |
| continue | |
| try: | |
| pixel_values = processor(line_crop, return_tensors="pt").pixel_values.to(device) | |
| generated_ids = model.generate(pixel_values) | |
| text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| if text: # only add non-empty | |
| lines_with_pos.append((rel_y1, rel_x1, text)) | |
| except Exception as e: | |
| # Catch any remaining processing errors | |
| # print(f"TrOCR error on line {j}: {e}") | |
| continue | |
| lines_with_pos.sort(key=lambda x: (x[0], x[1])) | |
| region_text = "\n".join([item[2] for item in lines_with_pos if item[2]]) | |
| if region_text: | |
| full_text_parts.append(region_text) | |
| if not full_text_parts: | |
| return "No readable text recognized (possibly due to small/tiny lines or model limitations). Try a clearer document or larger padding." | |
| return "\n\n".join(full_text_parts) | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", label="Upload handwritten document"), | |
| outputs=gr.Textbox(label="Recognized Text"), | |
| title="Handwritten Text Recognition (YOLO + TrOCR)", | |
| description="Local models: regions.pt / lines.pt + microsoft/trocr-base-handwritten. Mask-based cropping + safeguards against empty crops.", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |