Spaces:
Sleeping
Sleeping
File size: 8,286 Bytes
51ba11a e3bc177 51ba11a e3bc177 2dad195 e3bc177 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a b58ecdf 806bcf2 e3bc177 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a e3bc177 51ba11a 2d1adc0 51ba11a 806bcf2 e3bc177 cf7d335 e3bc177 cf7d335 e3bc177 806bcf2 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a 806bcf2 51ba11a e3bc177 51ba11a 806bcf2 51ba11a 806bcf2 e3bc177 806bcf2 e3bc177 806bcf2 51ba11a 2d1adc0 806bcf2 51ba11a 806bcf2 51ba11a 806bcf2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | import fastapi
from fastapi import File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from PIL import Image
import numpy as np
import cv2
import torch
import pytesseract
import re
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
import io
import os
import logging
import google.generativeai as genai # <--- Added Import
# --- Configuration ---
# Ensure the TESSDATA_PREFIX environment variable is set if your Tesseract data is in a non-standard location.
# os.environ["TESSDATA_PREFIX"] = "/path/to/tesseract/data"
# --- Gemini Configuration ---
# Make sure to set this environment variable or replace with your actual key string
GEMINI_API_KEY = "AIzaSyAUFYhNpbQHshXARAclSKIKS83viTC34qc"
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
else:
logging.warning("GEMINI_API_KEY not found. Gemini fallback will be skipped/fail.")
# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ocr_api")
# --- Model Loading ---
try:
logger.info("Loading TrOCR model (microsoft/trocr-base-stage1)...")
# TrOCR model (Base stage1 is pre-trained on English, so it works well for English too)
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
trocr_model.eval()
logger.info("TrOCR model loaded")
# MahaBERT model for text classification (Marathi specific)
logger.info("Loading MahaBERT text-classification pipeline...")
mahabert_model = pipeline(
"text-classification",
model="Abhi964/MahaPhrase_mahaBERTv2_Finetuning",
tokenizer="Abhi964/MahaPhrase_mahaBERTv2_Finetuning"
)
logger.info("MahaBERT pipeline ready")
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trocr_model.to(device)
logger.info("Models moved to device %s", device)
except Exception as e:
logger.exception("Failed to load models during startup")
raise
# --- FastAPI App Initialization ---
app = fastapi.FastAPI(
docs_url="/docs",
title="OCR API (Marathi & English)",
description="An API to perform OCR on images for Marathi and English text.",
version="1.2.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Helper Functions ---
def preprocess_image(image: Image.Image):
"""Converts PIL image to preprocessed OpenCV format for Pytesseract."""
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
_, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return thresh
def normalize_marathi_text(text: str):
"""Cleans and normalizes Marathi text."""
normalized = re.sub(r'[^\u0900-\u097F\s]', '', text)
normalized = re.sub(r'\s+', ' ', normalized).strip()
return normalized
def normalize_english_text(text: str):
"""Cleans and normalizes English text."""
# Remove characters that are not English letters, numbers, or standard punctuation
normalized = re.sub(r'[^a-zA-Z0-9\s.,!?\'"-]', '', text)
# Collapse multiple whitespace characters
normalized = re.sub(r'\s+', ' ', normalized).strip()
return normalized
def run_trocr_inference(image: Image.Image):
"""Runs inference using the TrOCR model."""
img_pil_rgb = image.convert("RGB")
pixel_values = trocr_processor(images=img_pil_rgb, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values)
generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# --- API Endpoints ---
@app.post("/ocr/marathi/")
async def marathi_ocr(file: UploadFile = File(...)):
"""
Extracts Marathi text.
Priority 1: Gemini Multimodal API
Priority 2 (Fallback): Tesseract + TrOCR + MahaBERT
"""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image.")
try:
contents = await file.read()
img_pil = Image.open(io.BytesIO(contents))
# --- ATTEMPT 1: GEMINI API ---
try:
if not GEMINI_API_KEY:
raise ValueError("Gemini API Key not configured.")
logger.info("Attempting Gemini OCR...")
# Using gemini-1.5-flash for speed and efficiency, or use 'gemini-1.5-pro' for higher accuracy
model = genai.GenerativeModel('gemini-2.5-flash')
# Prompt for the model
prompt = "Extract all the Marathi text from this image. Return only the extracted text, do not add markdown or explanations."
response = model.generate_content([prompt, img_pil])
gemini_text = response.text.strip()
# If successful, return here
logger.info("Gemini OCR successful.")
return JSONResponse(content={
"language": "marathi",
"filename": file.filename,
"source_model": "gemini-2.5-flash",
"extracted_text": gemini_text,
"normalized_text": normalize_marathi_text(gemini_text) # Optional normalization
})
except Exception as gemini_error:
logger.error(f"Gemini OCR failed: {gemini_error}. Falling back to local models.")
# Do not raise HTTPException here; just let code flow to the fallback
# --- ATTEMPT 2: FALLBACK (Original Method) ---
logger.info("Running Fallback OCR (Tesseract/TrOCR)...")
# Preprocess and Run Tesseract (Marathi)
preprocessed_img = preprocess_image(img_pil)
raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="mar")
normalized_text = normalize_marathi_text(raw_ocr_text)
# Run Classification
mahabert_result = mahabert_model(raw_ocr_text)
# Run TrOCR
trocr_text = run_trocr_inference(img_pil)
return JSONResponse(content={
"language": "marathi",
"filename": file.filename,
"source_model": "local_fallback",
"pytesseract_raw_text": raw_ocr_text,
"normalized_text": normalized_text,
"trocr_handwritten_text": trocr_text,
"classification": mahabert_result
})
except Exception as e:
logger.error(f"Marathi OCR Error: {e}")
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
@app.post("/extract-english/")
async def english_ocr(file: UploadFile = File(...)):
"""
Extracts English text.
Uses Tesseract (lang='eng') and TrOCR.
"""
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="File provided is not an image.")
try:
# 1. Read Image
contents = await file.read()
img_pil = Image.open(io.BytesIO(contents))
# 2. Preprocess for Tesseract
preprocessed_img = preprocess_image(img_pil)
# 3. Run Pytesseract (English)
# Note: Ensure 'eng.traineddata' is installed (usually default)
raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="eng")
# 4. Normalize English Text
normalized_text = normalize_english_text(raw_ocr_text)
# 5. Run TrOCR
trocr_text = run_trocr_inference(img_pil)
# 6. Compile results
return JSONResponse(content={
"language": "english",
"filename": file.filename,
"pytesseract_raw_text": raw_ocr_text,
"normalized_text": normalized_text,
"trocr_text": trocr_text
})
except Exception as e:
logger.error(f"English OCR Error: {e}")
raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
# --- Main entry point ---
if __name__ == "__main__":
uvicorn.run("app:app", host="127.0.0.1", port=8001, reload=True) |