import fastapi from fastapi import File, UploadFile, HTTPException from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware import uvicorn from PIL import Image import numpy as np import cv2 import torch import pytesseract import re from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline import io import os import logging import google.generativeai as genai # <--- Added Import # --- Configuration --- # Ensure the TESSDATA_PREFIX environment variable is set if your Tesseract data is in a non-standard location. # os.environ["TESSDATA_PREFIX"] = "/path/to/tesseract/data" # --- Gemini Configuration --- # Make sure to set this environment variable or replace with your actual key string GEMINI_API_KEY = "AIzaSyAUFYhNpbQHshXARAclSKIKS83viTC34qc" if GEMINI_API_KEY: genai.configure(api_key=GEMINI_API_KEY) else: logging.warning("GEMINI_API_KEY not found. Gemini fallback will be skipped/fail.") # --- Logging --- logging.basicConfig(level=logging.INFO) logger = logging.getLogger("ocr_api") # --- Model Loading --- try: logger.info("Loading TrOCR model (microsoft/trocr-base-stage1)...") # TrOCR model (Base stage1 is pre-trained on English, so it works well for English too) trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1") trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1") trocr_model.eval() logger.info("TrOCR model loaded") # MahaBERT model for text classification (Marathi specific) logger.info("Loading MahaBERT text-classification pipeline...") mahabert_model = pipeline( "text-classification", model="Abhi964/MahaPhrase_mahaBERTv2_Finetuning", tokenizer="Abhi964/MahaPhrase_mahaBERTv2_Finetuning" ) logger.info("MahaBERT pipeline ready") # Use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") trocr_model.to(device) logger.info("Models moved to device %s", device) except Exception as e: logger.exception("Failed to load models during startup") raise # --- FastAPI App Initialization --- app = fastapi.FastAPI( docs_url="/docs", title="OCR API (Marathi & English)", description="An API to perform OCR on images for Marathi and English text.", version="1.2.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- Helper Functions --- def preprocess_image(image: Image.Image): """Converts PIL image to preprocessed OpenCV format for Pytesseract.""" img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) blurred = cv2.GaussianBlur(gray, (3, 3), 0) _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return thresh def normalize_marathi_text(text: str): """Cleans and normalizes Marathi text.""" normalized = re.sub(r'[^\u0900-\u097F\s]', '', text) normalized = re.sub(r'\s+', ' ', normalized).strip() return normalized def normalize_english_text(text: str): """Cleans and normalizes English text.""" # Remove characters that are not English letters, numbers, or standard punctuation normalized = re.sub(r'[^a-zA-Z0-9\s.,!?\'"-]', '', text) # Collapse multiple whitespace characters normalized = re.sub(r'\s+', ' ', normalized).strip() return normalized def run_trocr_inference(image: Image.Image): """Runs inference using the TrOCR model.""" img_pil_rgb = image.convert("RGB") pixel_values = trocr_processor(images=img_pil_rgb, return_tensors="pt").pixel_values.to(device) with torch.no_grad(): generated_ids = trocr_model.generate(pixel_values) generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # --- API Endpoints --- @app.post("/ocr/marathi/") async def marathi_ocr(file: UploadFile = File(...)): """ Extracts Marathi text. Priority 1: Gemini Multimodal API Priority 2 (Fallback): Tesseract + TrOCR + MahaBERT """ if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File provided is not an image.") try: contents = await file.read() img_pil = Image.open(io.BytesIO(contents)) # --- ATTEMPT 1: GEMINI API --- try: if not GEMINI_API_KEY: raise ValueError("Gemini API Key not configured.") logger.info("Attempting Gemini OCR...") # Using gemini-1.5-flash for speed and efficiency, or use 'gemini-1.5-pro' for higher accuracy model = genai.GenerativeModel('gemini-2.5-flash') # Prompt for the model prompt = "Extract all the Marathi text from this image. Return only the extracted text, do not add markdown or explanations." response = model.generate_content([prompt, img_pil]) gemini_text = response.text.strip() # If successful, return here logger.info("Gemini OCR successful.") return JSONResponse(content={ "language": "marathi", "filename": file.filename, "source_model": "gemini-2.5-flash", "extracted_text": gemini_text, "normalized_text": normalize_marathi_text(gemini_text) # Optional normalization }) except Exception as gemini_error: logger.error(f"Gemini OCR failed: {gemini_error}. Falling back to local models.") # Do not raise HTTPException here; just let code flow to the fallback # --- ATTEMPT 2: FALLBACK (Original Method) --- logger.info("Running Fallback OCR (Tesseract/TrOCR)...") # Preprocess and Run Tesseract (Marathi) preprocessed_img = preprocess_image(img_pil) raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="mar") normalized_text = normalize_marathi_text(raw_ocr_text) # Run Classification mahabert_result = mahabert_model(raw_ocr_text) # Run TrOCR trocr_text = run_trocr_inference(img_pil) return JSONResponse(content={ "language": "marathi", "filename": file.filename, "source_model": "local_fallback", "pytesseract_raw_text": raw_ocr_text, "normalized_text": normalized_text, "trocr_handwritten_text": trocr_text, "classification": mahabert_result }) except Exception as e: logger.error(f"Marathi OCR Error: {e}") raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") @app.post("/extract-english/") async def english_ocr(file: UploadFile = File(...)): """ Extracts English text. Uses Tesseract (lang='eng') and TrOCR. """ if not file.content_type.startswith("image/"): raise HTTPException(status_code=400, detail="File provided is not an image.") try: # 1. Read Image contents = await file.read() img_pil = Image.open(io.BytesIO(contents)) # 2. Preprocess for Tesseract preprocessed_img = preprocess_image(img_pil) # 3. Run Pytesseract (English) # Note: Ensure 'eng.traineddata' is installed (usually default) raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="eng") # 4. Normalize English Text normalized_text = normalize_english_text(raw_ocr_text) # 5. Run TrOCR trocr_text = run_trocr_inference(img_pil) # 6. Compile results return JSONResponse(content={ "language": "english", "filename": file.filename, "pytesseract_raw_text": raw_ocr_text, "normalized_text": normalized_text, "trocr_text": trocr_text }) except Exception as e: logger.error(f"English OCR Error: {e}") raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") # --- Main entry point --- if __name__ == "__main__": uvicorn.run("app:app", host="127.0.0.1", port=8001, reload=True)