Spaces:
Running
Running
| from fastapi import FastAPI, UploadFile, File, Form | |
| from doctr.io import DocumentFile | |
| from doctr.models import ocr_predictor | |
| from PIL import Image, ImageOps, ImageEnhance, ImageFilter | |
| import tempfile | |
| import io | |
| import os | |
| import numpy as np | |
| import cv2 | |
| app = FastAPI() | |
| print("Loading docTR OCR model...") | |
| doctr_model = ocr_predictor(pretrained=True) | |
| print("docTR ready.") | |
| print("Loading TrOCR handwriting model...") | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
| trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten") | |
| trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten") | |
| print("TrOCR ready.") | |
| def fix_image_orientation(img: Image.Image) -> Image.Image: | |
| """Fix EXIF rotation AND force portrait if landscape""" | |
| try: | |
| img = ImageOps.exif_transpose(img) | |
| except Exception: | |
| pass | |
| # β Force portrait β landscape photos rotated 90Β° | |
| w, h = img.size | |
| if w > h: | |
| img = img.rotate(90, expand=True) | |
| return img | |
| def deskew(img: Image.Image) -> Image.Image: | |
| """Auto-correct small tilts using OpenCV""" | |
| try: | |
| cv_img = np.array(img) | |
| gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY) | |
| coords = np.column_stack(np.where(gray < 200)) | |
| if len(coords) > 0: | |
| angle = cv2.minAreaRect(coords)[-1] | |
| if angle < -45: | |
| angle = 90 + angle | |
| if abs(angle) > 0.5: | |
| (h, w) = cv_img.shape[:2] | |
| center = (w // 2, h // 2) | |
| M = cv2.getRotationMatrix2D(center, angle, 1.0) | |
| cv_img = cv2.warpAffine(cv_img, M, (w, h), | |
| flags=cv2.INTER_CUBIC, | |
| borderMode=cv2.BORDER_REPLICATE) | |
| return Image.fromarray(cv_img) | |
| except Exception: | |
| return img | |
| def enhance_image(img: Image.Image) -> Image.Image: | |
| img = deskew(img) | |
| img = ImageEnhance.Contrast(img).enhance(1.4) | |
| img = ImageEnhance.Sharpness(img).enhance(2.0) | |
| img = ImageEnhance.Brightness(img).enhance(1.1) | |
| return img | |
| def enhance_for_handwriting(img: Image.Image) -> Image.Image: | |
| img = deskew(img) | |
| img = img.convert("L") | |
| img = ImageEnhance.Contrast(img).enhance(2.0) | |
| img = ImageEnhance.Sharpness(img).enhance(2.5) | |
| img = img.filter(ImageFilter.SHARPEN) | |
| img = img.convert("RGB") | |
| return img | |
| def run_trocr(img: Image.Image) -> str: | |
| """Run Microsoft TrOCR for handwritten images β processes in chunks""" | |
| import torch | |
| width, height = img.size | |
| chunk_height = 100 | |
| results = [] | |
| for y in range(0, height, chunk_height): | |
| crop = img.crop((0, y, width, min(y + chunk_height, height))) | |
| if crop.size[1] < 20: | |
| continue | |
| pixel_values = trocr_processor(images=crop, return_tensors="pt").pixel_values | |
| with torch.no_grad(): | |
| generated_ids = trocr_model.generate(pixel_values) | |
| text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
| if text: | |
| results.append(text) | |
| return "\n".join(results) | |
| def root(): | |
| return {"status": "OCR running"} | |
| async def ocr_images( | |
| file: UploadFile = File(...), | |
| mode: str = Form("print") | |
| ): | |
| try: | |
| contents = await file.read() | |
| pil_image = Image.open(io.BytesIO(contents)).convert("RGB") | |
| pil_image = fix_image_orientation(pil_image) # β fixes landscape + EXIF | |
| # ββ HANDWRITING MODE ββββββββββββββββββββββββββββββββββββββ | |
| if mode == "handwrite": | |
| pil_image = enhance_for_handwriting(pil_image) | |
| text = run_trocr(pil_image) | |
| # ββ PRINT MODE ββββββββββββββββββββββββββββββββββββββββββββ | |
| elif mode == "print": | |
| pil_image = enhance_image(pil_image) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| pil_image.save(tmp.name, quality=95) | |
| temp_path = tmp.name | |
| doc = DocumentFile.from_images([temp_path]) | |
| result = doctr_model(doc) | |
| text = result.render().strip() | |
| os.remove(temp_path) | |
| # ββ AUTO MODE ββββββββββββββββββββββββββββββββββββββββββββ | |
| else: | |
| pil_enhanced = enhance_image(pil_image.copy()) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp: | |
| pil_enhanced.save(tmp.name, quality=95) | |
| temp_path = tmp.name | |
| doc = DocumentFile.from_images([temp_path]) | |
| result = doctr_model(doc) | |
| text = result.render().strip() | |
| os.remove(temp_path) | |
| if len(text) < 150: | |
| pil_hw = enhance_for_handwriting(pil_image.copy()) | |
| text = run_trocr(pil_hw) | |
| if not text or len(text) < 10: | |
| return { | |
| "success": False, | |
| "error": "No text found in image. Try a clearer photo with visible text." | |
| } | |
| return {"success": True, "text": text} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |