Marathi_OCR / app.py
OakMajesty's picture
Update app.py
cf7d335 verified
import fastapi
from fastapi import File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from PIL import Image
import numpy as np
import cv2
import torch
import pytesseract
import re
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
import io
import os
import logging
import google.generativeai as genai # <--- Added Import
# --- Configuration ---
# Ensure the TESSDATA_PREFIX environment variable is set if your Tesseract data is in a non-standard location.
# os.environ["TESSDATA_PREFIX"] = "/path/to/tesseract/data"
# --- Gemini Configuration ---
# Make sure to set this environment variable or replace with your actual key string
GEMINI_API_KEY = "AIzaSyAUFYhNpbQHshXARAclSKIKS83viTC34qc"
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
else:
logging.warning("GEMINI_API_KEY not found. Gemini fallback will be skipped/fail.")
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ocr_api")
# --- Model Loading ---
try:
logger.info("Loading TrOCR model (microsoft/trocr-base-stage1)...")
# TrOCR model (Base stage1 is pre-trained on English, so it works well for English too)
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
trocr_model.eval()
logger.info("TrOCR model loaded")
# MahaBERT model for text classification (Marathi specific)
logger.info("Loading MahaBERT text-classification pipeline...")
mahabert_model = pipeline(
"text-classification",
model="Abhi964/MahaPhrase_mahaBERTv2_Finetuning",
tokenizer="Abhi964/MahaPhrase_mahaBERTv2_Finetuning"
)
logger.info("MahaBERT pipeline ready")
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trocr_model.to(device)
logger.info("Models moved to device %s", device)
except Exception as e:
logger.exception("Failed to load models during startup")
raise
# --- FastAPI App Initialization ---
app = fastapi.FastAPI(
docs_url="/docs",
title="OCR API (Marathi & English)",
description="An API to perform OCR on images for Marathi and English text.",
version="1.2.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Helper Functions ---
def preprocess_image(image: Image.Image):
"""Converts PIL image to preprocessed OpenCV format for Pytesseract."""
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return thresh
def normalize_marathi_text(text: str):
"""Cleans and normalizes Marathi text."""
normalized = re.sub(r'[^\u0900-\u097F\s]', '', text)
normalized = re.sub(r'\s+', ' ', normalized).strip()
return normalized
def normalize_english_text(text: str):
"""Cleans and normalizes English text."""
# Remove characters that are not English letters, numbers, or standard punctuation
normalized = re.sub(r'[^a-zA-Z0-9\s.,!?\'"-]', '', text)
# Collapse multiple whitespace characters
normalized = re.sub(r'\s+', ' ', normalized).strip()
return normalized
def run_trocr_inference(image: Image.Image):
"""Runs inference using the TrOCR model."""
img_pil_rgb = image.convert("RGB")
pixel_values = trocr_processor(images=img_pil_rgb, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values)
generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# --- API Endpoints ---
@app.post("/ocr/marathi/")
async def marathi_ocr(file: UploadFile = File(...)):
"""
Extracts Marathi text.
Priority 1: Gemini Multimodal API
Priority 2 (Fallback): Tesseract + TrOCR + MahaBERT
"""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image.")
try:
contents = await file.read()
img_pil = Image.open(io.BytesIO(contents))
# --- ATTEMPT 1: GEMINI API ---
try:
if not GEMINI_API_KEY:
raise ValueError("Gemini API Key not configured.")
logger.info("Attempting Gemini OCR...")
# Using gemini-1.5-flash for speed and efficiency, or use 'gemini-1.5-pro' for higher accuracy
model = genai.GenerativeModel('gemini-2.5-flash')
# Prompt for the model
prompt = "Extract all the Marathi text from this image. Return only the extracted text, do not add markdown or explanations."
response = model.generate_content([prompt, img_pil])
gemini_text = response.text.strip()
# If successful, return here
logger.info("Gemini OCR successful.")
return JSONResponse(content={
"language": "marathi",
"filename": file.filename,
"source_model": "gemini-2.5-flash",
"extracted_text": gemini_text,
"normalized_text": normalize_marathi_text(gemini_text) # Optional normalization
})
except Exception as gemini_error:
logger.error(f"Gemini OCR failed: {gemini_error}. Falling back to local models.")
# Do not raise HTTPException here; just let code flow to the fallback
# --- ATTEMPT 2: FALLBACK (Original Method) ---
logger.info("Running Fallback OCR (Tesseract/TrOCR)...")
# Preprocess and Run Tesseract (Marathi)
preprocessed_img = preprocess_image(img_pil)
raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="mar")
normalized_text = normalize_marathi_text(raw_ocr_text)
# Run Classification
mahabert_result = mahabert_model(raw_ocr_text)
# Run TrOCR
trocr_text = run_trocr_inference(img_pil)
return JSONResponse(content={
"language": "marathi",
"filename": file.filename,
"source_model": "local_fallback",
"pytesseract_raw_text": raw_ocr_text,
"normalized_text": normalized_text,
"trocr_handwritten_text": trocr_text,
"classification": mahabert_result
})
except Exception as e:
logger.error(f"Marathi OCR Error: {e}")
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/extract-english/")
async def english_ocr(file: UploadFile = File(...)):
"""
Extracts English text.
Uses Tesseract (lang='eng') and TrOCR.
"""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image.")
try:
# 1. Read Image
contents = await file.read()
img_pil = Image.open(io.BytesIO(contents))
# 2. Preprocess for Tesseract
preprocessed_img = preprocess_image(img_pil)
# 3. Run Pytesseract (English)
# Note: Ensure 'eng.traineddata' is installed (usually default)
raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="eng")
# 4. Normalize English Text
normalized_text = normalize_english_text(raw_ocr_text)
# 5. Run TrOCR
trocr_text = run_trocr_inference(img_pil)
# 6. Compile results
return JSONResponse(content={
"language": "english",
"filename": file.filename,
"pytesseract_raw_text": raw_ocr_text,
"normalized_text": normalized_text,
"trocr_text": trocr_text
})
except Exception as e:
logger.error(f"English OCR Error: {e}")
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
# --- Main entry point ---
if __name__ == "__main__":
uvicorn.run("app:app", host="127.0.0.1", port=8001, reload=True)