File size: 8,286 Bytes
51ba11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3bc177
51ba11a
 
 
 
 
e3bc177
 
2dad195
e3bc177
 
 
 
 
 
51ba11a
 
806bcf2
51ba11a
 
 
 
806bcf2
51ba11a
 
 
 
 
806bcf2
51ba11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b58ecdf
806bcf2
 
e3bc177
51ba11a
 
 
 
 
 
 
 
 
 
806bcf2
51ba11a
 
 
 
 
 
 
 
 
 
 
 
806bcf2
 
 
 
 
 
 
 
51ba11a
 
 
 
 
 
 
 
 
 
 
 
806bcf2
51ba11a
806bcf2
 
51ba11a
e3bc177
 
 
51ba11a
 
 
 
2d1adc0
51ba11a
 
806bcf2
e3bc177
 
 
 
 
 
 
cf7d335
e3bc177
 
 
 
 
 
 
 
 
 
 
 
cf7d335
e3bc177
 
 
 
 
 
 
 
 
 
 
806bcf2
51ba11a
 
 
 
806bcf2
51ba11a
 
806bcf2
51ba11a
 
 
806bcf2
51ba11a
e3bc177
51ba11a
806bcf2
51ba11a
806bcf2
 
 
 
 
 
 
 
 
 
e3bc177
806bcf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3bc177
806bcf2
 
 
 
 
 
51ba11a
 
2d1adc0
806bcf2
 
51ba11a
806bcf2
51ba11a
806bcf2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import fastapi
from fastapi import File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from PIL import Image
import numpy as np
import cv2
import torch
import pytesseract
import re
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, pipeline
import io
import os
import logging
import google.generativeai as genai  # <--- Added Import

# --- Configuration ---
# Ensure the TESSDATA_PREFIX environment variable is set if your Tesseract data is in a non-standard location.
# os.environ["TESSDATA_PREFIX"] = "/path/to/tesseract/data"

# --- Gemini Configuration ---
# Make sure to set this environment variable or replace with your actual key string
GEMINI_API_KEY = "AIzaSyAUFYhNpbQHshXARAclSKIKS83viTC34qc"

if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
else:
    logging.warning("GEMINI_API_KEY not found. Gemini fallback will be skipped/fail.")

# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ocr_api")

# --- Model Loading ---
try:
    logger.info("Loading TrOCR model (microsoft/trocr-base-stage1)...")
    # TrOCR model (Base stage1 is pre-trained on English, so it works well for English too)
    trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
    trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1")
    trocr_model.eval()
    logger.info("TrOCR model loaded")

    # MahaBERT model for text classification (Marathi specific)
    logger.info("Loading MahaBERT text-classification pipeline...")
    mahabert_model = pipeline(
        "text-classification",
        model="Abhi964/MahaPhrase_mahaBERTv2_Finetuning",
        tokenizer="Abhi964/MahaPhrase_mahaBERTv2_Finetuning"
    )
    logger.info("MahaBERT pipeline ready")

    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    trocr_model.to(device)
    logger.info("Models moved to device %s", device)

except Exception as e:
    logger.exception("Failed to load models during startup")
    raise

# --- FastAPI App Initialization ---
app = fastapi.FastAPI(
    docs_url="/docs",
    title="OCR API (Marathi & English)",
    description="An API to perform OCR on images for Marathi and English text.",
    version="1.2.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Helper Functions ---

def preprocess_image(image: Image.Image):
    """Converts PIL image to preprocessed OpenCV format for Pytesseract."""
    img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    gray = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
    blurred = cv2.GaussianBlur(gray, (3, 3), 0)
    _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return thresh

def normalize_marathi_text(text: str):
    """Cleans and normalizes Marathi text."""
    normalized = re.sub(r'[^\u0900-\u097F\s]', '', text)
    normalized = re.sub(r'\s+', ' ', normalized).strip()
    return normalized

def normalize_english_text(text: str):
    """Cleans and normalizes English text."""
    # Remove characters that are not English letters, numbers, or standard punctuation
    normalized = re.sub(r'[^a-zA-Z0-9\s.,!?\'"-]', '', text)
    # Collapse multiple whitespace characters
    normalized = re.sub(r'\s+', ' ', normalized).strip()
    return normalized

def run_trocr_inference(image: Image.Image):
    """Runs inference using the TrOCR model."""
    img_pil_rgb = image.convert("RGB")
    pixel_values = trocr_processor(images=img_pil_rgb, return_tensors="pt").pixel_values.to(device)
    with torch.no_grad():
        generated_ids = trocr_model.generate(pixel_values)
    generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

# --- API Endpoints ---

@app.post("/ocr/marathi/")
async def marathi_ocr(file: UploadFile = File(...)):
    """
    Extracts Marathi text.
    Priority 1: Gemini Multimodal API
    Priority 2 (Fallback): Tesseract + TrOCR + MahaBERT
    """
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File provided is not an image.")

    try:
        contents = await file.read()
        img_pil = Image.open(io.BytesIO(contents))
        
        # --- ATTEMPT 1: GEMINI API ---
        try:
            if not GEMINI_API_KEY:
                raise ValueError("Gemini API Key not configured.")

            logger.info("Attempting Gemini OCR...")
            # Using gemini-1.5-flash for speed and efficiency, or use 'gemini-1.5-pro' for higher accuracy
            model = genai.GenerativeModel('gemini-2.5-flash')
            
            # Prompt for the model
            prompt = "Extract all the Marathi text from this image. Return only the extracted text, do not add markdown or explanations."
            
            response = model.generate_content([prompt, img_pil])
            gemini_text = response.text.strip()
            
            # If successful, return here
            logger.info("Gemini OCR successful.")
            return JSONResponse(content={
                "language": "marathi",
                "filename": file.filename,
                "source_model": "gemini-2.5-flash",
                "extracted_text": gemini_text,
                "normalized_text": normalize_marathi_text(gemini_text) # Optional normalization
            })

        except Exception as gemini_error:
            logger.error(f"Gemini OCR failed: {gemini_error}. Falling back to local models.")
            # Do not raise HTTPException here; just let code flow to the fallback

        # --- ATTEMPT 2: FALLBACK (Original Method) ---
        logger.info("Running Fallback OCR (Tesseract/TrOCR)...")
        
        # Preprocess and Run Tesseract (Marathi)
        preprocessed_img = preprocess_image(img_pil)
        raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="mar")
        normalized_text = normalize_marathi_text(raw_ocr_text)

        # Run Classification
        mahabert_result = mahabert_model(raw_ocr_text)

        # Run TrOCR
        trocr_text = run_trocr_inference(img_pil)

        return JSONResponse(content={
            "language": "marathi",
            "filename": file.filename,
            "source_model": "local_fallback",
            "pytesseract_raw_text": raw_ocr_text,
            "normalized_text": normalized_text,
            "trocr_handwritten_text": trocr_text,
            "classification": mahabert_result
        })

    except Exception as e:
        logger.error(f"Marathi OCR Error: {e}")
        raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")

@app.post("/extract-english/")
async def english_ocr(file: UploadFile = File(...)):
    """
    Extracts English text.
    Uses Tesseract (lang='eng') and TrOCR.
    """
    if not file.content_type.startswith("image/"):
        raise HTTPException(status_code=400, detail="File provided is not an image.")

    try:
        # 1. Read Image
        contents = await file.read()
        img_pil = Image.open(io.BytesIO(contents))

        # 2. Preprocess for Tesseract
        preprocessed_img = preprocess_image(img_pil)

        # 3. Run Pytesseract (English)
        # Note: Ensure 'eng.traineddata' is installed (usually default)
        raw_ocr_text = pytesseract.image_to_string(preprocessed_img, lang="eng")

        # 4. Normalize English Text
        normalized_text = normalize_english_text(raw_ocr_text)

        # 5. Run TrOCR 
        trocr_text = run_trocr_inference(img_pil)

        # 6. Compile results
        return JSONResponse(content={
            "language": "english",
            "filename": file.filename,
            "pytesseract_raw_text": raw_ocr_text,
            "normalized_text": normalized_text,
            "trocr_text": trocr_text
        })

    except Exception as e:
        logger.error(f"English OCR Error: {e}")
        raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")

# --- Main entry point ---
if __name__ == "__main__":
    uvicorn.run("app:app", host="127.0.0.1", port=8001, reload=True)