Spaces:
Sleeping
Sleeping
| import fastapi | |
| from fastapi import File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import pytesseract | |
| import re | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline | |
| import io | |
| import os | |
| import logging | |
| import google.generativeai as genai # <--- Added Import | |
| # --- Configuration --- | |
| # Ensure the TESSDATA_PREFIX environment variable is set if your Tesseract data is in a non-standard location. | |
| # os.environ["TESSDATA_PREFIX"] = "/path/to/tesseract/data" | |
| # --- Gemini Configuration --- | |
| # Make sure to set this environment variable or replace with your actual key string | |
| GEMINI_API_KEY = "AIzaSyAUFYhNpbQHshXARAclSKIKS83viTC34qc" | |
| if GEMINI_API_KEY: | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| else: | |
| logging.warning("GEMINI_API_KEY not found. Gemini fallback will be skipped/fail.") | |
| # --- Logging --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("ocr_api") | |
| # --- Model Loading --- | |
| try: | |
| logger.info("Loading TrOCR model (microsoft/trocr-base-stage1)...") | |
| # TrOCR model (Base stage1 is pre-trained on English, so it works well for English too) | |
| trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1") | |
| trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1") | |
| trocr_model.eval() | |
| logger.info("TrOCR model loaded") | |
| # MahaBERT model for text classification (Marathi specific) | |
| logger.info("Loading MahaBERT text-classification pipeline...") | |
| mahabert_model = pipeline( | |
| "text-classification", | |
| model="Abhi964/MahaPhrase_mahaBERTv2_Finetuning", | |
| tokenizer="Abhi964/MahaPhrase_mahaBERTv2_Finetuning" | |
| ) | |
| logger.info("MahaBERT pipeline ready") | |
| # Use GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| trocr_model.to(device) | |
| logger.info("Models moved to device %s", device) | |
| except Exception as e: | |
| logger.exception("Failed to load models during startup") | |
| raise | |
| # --- FastAPI App Initialization --- | |
| app = fastapi.FastAPI( | |
| docs_url="/docs", | |
| title="OCR API (Marathi & English)", | |
| description="An API to perform OCR on images for Marathi and English text.", | |
| version="1.2.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Helper Functions --- | |
| def preprocess_image(image: Image.Image): | |
| """Converts PIL image to preprocessed OpenCV format for Pytesseract.""" | |
| img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY) | |
| blurred = cv2.GaussianBlur(gray, (3, 3), 0) | |
| _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return thresh | |
| def normalize_marathi_text(text: str): | |
| """Cleans and normalizes Marathi text.""" | |
| normalized = re.sub(r'[^\u0900-\u097F\s]', '', text) | |
| normalized = re.sub(r'\s+', ' ', normalized).strip() | |
| return normalized | |
| def normalize_english_text(text: str): | |
| """Cleans and normalizes English text.""" | |
| # Remove characters that are not English letters, numbers, or standard punctuation | |
| normalized = re.sub(r'[^a-zA-Z0-9\s.,!?\'"-]', '', text) | |
| # Collapse multiple whitespace characters | |
| normalized = re.sub(r'\s+', ' ', normalized).strip() | |
| return normalized | |
| def run_trocr_inference(image: Image.Image): | |
| """Runs inference using the TrOCR model.""" | |
| img_pil_rgb = image.convert("RGB") | |
| pixel_values = trocr_processor(images=img_pil_rgb, return_tensors="pt").pixel_values.to(device) | |
| with torch.no_grad(): | |
| generated_ids = trocr_model.generate(pixel_values) | |
| generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| # --- API Endpoints --- | |
| async def marathi_ocr(file: UploadFile = File(...)): | |
| """ | |
| Extracts Marathi text. | |
| Priority 1: Gemini Multimodal API | |
| Priority 2 (Fallback): Tesseract + TrOCR + MahaBERT | |
| """ | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File provided is not an image.") | |
| try: | |
| contents = await file.read() | |
| img_pil = Image.open(io.BytesIO(contents)) | |
| # --- ATTEMPT 1: GEMINI API --- | |
| try: | |
| if not GEMINI_API_KEY: | |
| raise ValueError("Gemini API Key not configured.") | |
| logger.info("Attempting Gemini OCR...") | |
| # Using gemini-1.5-flash for speed and efficiency, or use 'gemini-1.5-pro' for higher accuracy | |
| model = genai.GenerativeModel('gemini-2.5-flash') | |
| # Prompt for the model | |
| prompt = "Extract all the Marathi text from this image. Return only the extracted text, do not add markdown or explanations." | |
| response = model.generate_content([prompt, img_pil]) | |
| gemini_text = response.text.strip() | |
| # If successful, return here | |
| logger.info("Gemini OCR successful.") | |
| return JSONResponse(content={ | |
| "language": "marathi", | |
| "filename": file.filename, | |
| "source_model": "gemini-2.5-flash", | |
| "extracted_text": gemini_text, | |
| "normalized_text": normalize_marathi_text(gemini_text) # Optional normalization | |
| }) | |
| except Exception as gemini_error: | |
| logger.error(f"Gemini OCR failed: {gemini_error}. Falling back to local models.") | |
| # Do not raise HTTPException here; just let code flow to the fallback | |
| # --- ATTEMPT 2: FALLBACK (Original Method) --- | |
| logger.info("Running Fallback OCR (Tesseract/TrOCR)...") | |
| # Preprocess and Run Tesseract (Marathi) | |
| preprocessed_img = preprocess_image(img_pil) | |
| raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="mar") | |
| normalized_text = normalize_marathi_text(raw_ocr_text) | |
| # Run Classification | |
| mahabert_result = mahabert_model(raw_ocr_text) | |
| # Run TrOCR | |
| trocr_text = run_trocr_inference(img_pil) | |
| return JSONResponse(content={ | |
| "language": "marathi", | |
| "filename": file.filename, | |
| "source_model": "local_fallback", | |
| "pytesseract_raw_text": raw_ocr_text, | |
| "normalized_text": normalized_text, | |
| "trocr_handwritten_text": trocr_text, | |
| "classification": mahabert_result | |
| }) | |
| except Exception as e: | |
| logger.error(f"Marathi OCR Error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| async def english_ocr(file: UploadFile = File(...)): | |
| """ | |
| Extracts English text. | |
| Uses Tesseract (lang='eng') and TrOCR. | |
| """ | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File provided is not an image.") | |
| try: | |
| # 1. Read Image | |
| contents = await file.read() | |
| img_pil = Image.open(io.BytesIO(contents)) | |
| # 2. Preprocess for Tesseract | |
| preprocessed_img = preprocess_image(img_pil) | |
| # 3. Run Pytesseract (English) | |
| # Note: Ensure 'eng.traineddata' is installed (usually default) | |
| raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="eng") | |
| # 4. Normalize English Text | |
| normalized_text = normalize_english_text(raw_ocr_text) | |
| # 5. Run TrOCR | |
| trocr_text = run_trocr_inference(img_pil) | |
| # 6. Compile results | |
| return JSONResponse(content={ | |
| "language": "english", | |
| "filename": file.filename, | |
| "pytesseract_raw_text": raw_ocr_text, | |
| "normalized_text": normalized_text, | |
| "trocr_text": trocr_text | |
| }) | |
| except Exception as e: | |
| logger.error(f"English OCR Error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| # --- Main entry point --- | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="127.0.0.1", port=8001, reload=True) |