exam-notes-ocr / main.py
Ayan8901's picture
Update main.py
2fc8916 verified
from fastapi import FastAPI, UploadFile, File, Form
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
import tempfile
import io
import os
import numpy as np
import cv2
app = FastAPI()
print("Loading docTR OCR model...")
doctr_model = ocr_predictor(pretrained=True)
print("docTR ready.")
print("Loading TrOCR handwriting model...")
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
print("TrOCR ready.")
def fix_image_orientation(img: Image.Image) -> Image.Image:
"""Fix EXIF rotation AND force portrait if landscape"""
try:
img = ImageOps.exif_transpose(img)
except Exception:
pass
# βœ… Force portrait β€” landscape photos rotated 90Β°
w, h = img.size
if w > h:
img = img.rotate(90, expand=True)
return img
def deskew(img: Image.Image) -> Image.Image:
"""Auto-correct small tilts using OpenCV"""
try:
cv_img = np.array(img)
gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
coords = np.column_stack(np.where(gray < 200))
if len(coords) > 0:
angle = cv2.minAreaRect(coords)[-1]
if angle < -45:
angle = 90 + angle
if abs(angle) > 0.5:
(h, w) = cv_img.shape[:2]
center = (w // 2, h // 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
cv_img = cv2.warpAffine(cv_img, M, (w, h),
flags=cv2.INTER_CUBIC,
borderMode=cv2.BORDER_REPLICATE)
return Image.fromarray(cv_img)
except Exception:
return img
def enhance_image(img: Image.Image) -> Image.Image:
img = deskew(img)
img = ImageEnhance.Contrast(img).enhance(1.4)
img = ImageEnhance.Sharpness(img).enhance(2.0)
img = ImageEnhance.Brightness(img).enhance(1.1)
return img
def enhance_for_handwriting(img: Image.Image) -> Image.Image:
img = deskew(img)
img = img.convert("L")
img = ImageEnhance.Contrast(img).enhance(2.0)
img = ImageEnhance.Sharpness(img).enhance(2.5)
img = img.filter(ImageFilter.SHARPEN)
img = img.convert("RGB")
return img
def run_trocr(img: Image.Image) -> str:
"""Run Microsoft TrOCR for handwritten images β€” processes in chunks"""
import torch
width, height = img.size
chunk_height = 100
results = []
for y in range(0, height, chunk_height):
crop = img.crop((0, y, width, min(y + chunk_height, height)))
if crop.size[1] < 20:
continue
pixel_values = trocr_processor(images=crop, return_tensors="pt").pixel_values
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values)
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
if text:
results.append(text)
return "\n".join(results)
@app.get("/")
def root():
return {"status": "OCR running"}
@app.post("/ocr")
async def ocr_images(
file: UploadFile = File(...),
mode: str = Form("print")
):
try:
contents = await file.read()
pil_image = Image.open(io.BytesIO(contents)).convert("RGB")
pil_image = fix_image_orientation(pil_image) # βœ… fixes landscape + EXIF
# ── HANDWRITING MODE ──────────────────────────────────────
if mode == "handwrite":
pil_image = enhance_for_handwriting(pil_image)
text = run_trocr(pil_image)
# ── PRINT MODE ────────────────────────────────────────────
elif mode == "print":
pil_image = enhance_image(pil_image)
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
pil_image.save(tmp.name, quality=95)
temp_path = tmp.name
doc = DocumentFile.from_images([temp_path])
result = doctr_model(doc)
text = result.render().strip()
os.remove(temp_path)
# ── AUTO MODE ────────────────────────────────────────────
else:
pil_enhanced = enhance_image(pil_image.copy())
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
pil_enhanced.save(tmp.name, quality=95)
temp_path = tmp.name
doc = DocumentFile.from_images([temp_path])
result = doctr_model(doc)
text = result.render().strip()
os.remove(temp_path)
if len(text) < 150:
pil_hw = enhance_for_handwriting(pil_image.copy())
text = run_trocr(pil_hw)
if not text or len(text) < 10:
return {
"success": False,
"error": "No text found in image. Try a clearer photo with visible text."
}
return {"success": True, "text": text}
except Exception as e:
return {"success": False, "error": str(e)}