from fastapi import FastAPI, UploadFile, File, Form from doctr.io import DocumentFile from doctr.models import ocr_predictor from PIL import Image, ImageOps, ImageEnhance, ImageFilter import tempfile import io import os import numpy as np import cv2 app = FastAPI() print("Loading docTR OCR model...") doctr_model = ocr_predictor(pretrained=True) print("docTR ready.") print("Loading TrOCR handwriting model...") from transformers import TrOCRProcessor, VisionEncoderDecoderModel trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten") trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten") print("TrOCR ready.") def fix_image_orientation(img: Image.Image) -> Image.Image: """Fix EXIF rotation AND force portrait if landscape""" try: img = ImageOps.exif_transpose(img) except Exception: pass # ✅ Force portrait — landscape photos rotated 90° w, h = img.size if w > h: img = img.rotate(90, expand=True) return img def deskew(img: Image.Image) -> Image.Image: """Auto-correct small tilts using OpenCV""" try: cv_img = np.array(img) gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY) coords = np.column_stack(np.where(gray < 200)) if len(coords) > 0: angle = cv2.minAreaRect(coords)[-1] if angle < -45: angle = 90 + angle if abs(angle) > 0.5: (h, w) = cv_img.shape[:2] center = (w // 2, h // 2) M = cv2.getRotationMatrix2D(center, angle, 1.0) cv_img = cv2.warpAffine(cv_img, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE) return Image.fromarray(cv_img) except Exception: return img def enhance_image(img: Image.Image) -> Image.Image: img = deskew(img) img = ImageEnhance.Contrast(img).enhance(1.4) img = ImageEnhance.Sharpness(img).enhance(2.0) img = ImageEnhance.Brightness(img).enhance(1.1) return img def enhance_for_handwriting(img: Image.Image) -> Image.Image: img = deskew(img) img = img.convert("L") img = ImageEnhance.Contrast(img).enhance(2.0) img = ImageEnhance.Sharpness(img).enhance(2.5) img = img.filter(ImageFilter.SHARPEN) img = img.convert("RGB") return img def run_trocr(img: Image.Image) -> str: """Run Microsoft TrOCR for handwritten images — processes in chunks""" import torch width, height = img.size chunk_height = 100 results = [] for y in range(0, height, chunk_height): crop = img.crop((0, y, width, min(y + chunk_height, height))) if crop.size[1] < 20: continue pixel_values = trocr_processor(images=crop, return_tensors="pt").pixel_values with torch.no_grad(): generated_ids = trocr_model.generate(pixel_values) text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() if text: results.append(text) return "\n".join(results) @app.get("/") def root(): return {"status": "OCR running"} @app.post("/ocr") async def ocr_images( file: UploadFile = File(...), mode: str = Form("print") ): try: contents = await file.read() pil_image = Image.open(io.BytesIO(contents)).convert("RGB") pil_image = fix_image_orientation(pil_image) # ✅ fixes landscape + EXIF # ── HANDWRITING MODE ────────────────────────────────────── if mode == "handwrite": pil_image = enhance_for_handwriting(pil_image) text = run_trocr(pil_image) # ── PRINT MODE ──────────────────────────────────────────── elif mode == "print": pil_image = enhance_image(pil_image) with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: pil_image.save(tmp.name, quality=95) temp_path = tmp.name doc = DocumentFile.from_images([temp_path]) result = doctr_model(doc) text = result.render().strip() os.remove(temp_path) # ── AUTO MODE ──────────────────────────────────────────── else: pil_enhanced = enhance_image(pil_image.copy()) with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: pil_enhanced.save(tmp.name, quality=95) temp_path = tmp.name doc = DocumentFile.from_images([temp_path]) result = doctr_model(doc) text = result.render().strip() os.remove(temp_path) if len(text) < 150: pil_hw = enhance_for_handwriting(pil_image.copy()) text = run_trocr(pil_hw) if not text or len(text) < 10: return { "success": False, "error": "No text found in image. Try a clearer photo with visible text." } return {"success": True, "text": text} except Exception as e: return {"success": False, "error": str(e)}