Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel | |
| import cv2 | |
| import numpy as np | |
| import tensorflow as tf | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| import matplotlib.font_manager as fm | |
| # import sakshi_ocr | |
| import os | |
| import io | |
| import sys | |
| import tempfile | |
| import requests | |
| from PIL import Image | |
| import uvicorn | |
| import shutil | |
| from pathlib import Path | |
| import pytext_ocr | |
| app = FastAPI( | |
| title="Hindi OCR API", | |
| description="API for Hindi OCR and word detection", | |
| version="1.0.0" | |
| ) | |
| # URLs for the model and encoder hosted on Hugging Face | |
| MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras" | |
| ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl" | |
| FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf" | |
| # Paths for local storage | |
| MODEL_PATH = os.path.join(tempfile.gettempdir(), "hindi_ocr_model.keras") | |
| ENCODER_PATH = os.path.join(tempfile.gettempdir(), "label_encoder.pkl") | |
| FONT_PATH = os.path.join(tempfile.gettempdir(), "NotoSansDevanagari-Regular.ttf") | |
| # Use a temporary directory for outputs | |
| OUTPUT_DIR = tempfile.mkdtemp() | |
| # Download model and encoder | |
| def download_file(url, dest): | |
| response = requests.get(url) | |
| with open(dest, 'wb') as f: | |
| f.write(response.content) | |
| # Load the model and encoder | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| return None | |
| return tf.keras.models.load_model(MODEL_PATH) | |
| def load_label_encoder(): | |
| if not os.path.exists(ENCODER_PATH): | |
| return None | |
| with open(ENCODER_PATH, 'rb') as f: | |
| return pickle.load(f) | |
| # Set up global variables | |
| model = None | |
| label_encoder = None | |
| # Download required files on startup | |
| async def startup_event(): | |
| # Download models and font if not already present | |
| if not os.path.exists(MODEL_PATH): | |
| download_file(MODEL_URL, MODEL_PATH) | |
| if not os.path.exists(ENCODER_PATH): | |
| download_file(ENCODER_URL, ENCODER_PATH) | |
| if not os.path.exists(FONT_PATH): | |
| download_file(FONT_URL, FONT_PATH) | |
| # Load the custom font if available | |
| if os.path.exists(FONT_PATH): | |
| fm.fontManager.addfont(FONT_PATH) | |
| plt.rcParams['font.family'] = 'Noto Sans Devanagari' | |
| # Initialize global variables | |
| global model, label_encoder | |
| model = load_model() | |
| label_encoder = load_label_encoder() | |
| # Word detection function | |
| def detect_words(image): | |
| _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
| kernel = np.ones((3,3), np.uint8) | |
| dilated = cv2.dilate(binary, kernel, iterations=2) | |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| word_count = 0 | |
| for contour in contours: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| if w > 10 and h > 10: | |
| cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2) | |
| word_count += 1 | |
| return word_img, word_count | |
| # Sakshi OCR output capture | |
| def run_sakshi_ocr(image_path): | |
| buffer = io.StringIO() | |
| old_stdout = sys.stdout | |
| sys.stdout = buffer | |
| try: | |
| sakshi_ocr.generate(image_path) | |
| finally: | |
| sys.stdout = old_stdout | |
| return buffer.getvalue() | |
| # File storage for session | |
| session_files = {} | |
| # Main OCR processing function | |
| def process_image(image_array): | |
| # Convert image array to grayscale | |
| img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY) | |
| # Word detection | |
| word_detected_img, word_count = detect_words(img) | |
| word_detection_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name | |
| cv2.imwrite(word_detection_path, word_detected_img) | |
| # Store the file path in our session dict | |
| session_files['word_detection'] = word_detection_path | |
| # First OCR model prediction | |
| pred_path = None | |
| try: | |
| img_resized = cv2.resize(img, (128, 32)) | |
| img_norm = img_resized / 255.0 | |
| img_input = img_norm[np.newaxis, ..., np.newaxis] # Shape: (1, 32, 128, 1) | |
| if model is not None and label_encoder is not None: | |
| pred = model.predict(img_input) | |
| pred_label_idx = np.argmax(pred) | |
| pred_label = label_encoder.inverse_transform([pred_label_idx])[0] | |
| # Create plot with prediction | |
| fig, ax = plt.subplots() | |
| ax.imshow(img, cmap='gray') | |
| ax.set_title(f"Predicted: {pred_label}", fontsize=12) | |
| ax.axis('off') | |
| pred_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png").name | |
| plt.savefig(pred_path) | |
| plt.close() | |
| # Store the file path in our session dict | |
| session_files['prediction'] = pred_path | |
| else: | |
| pred_label = "Model or encoder not loaded" | |
| except Exception as e: | |
| pred_label = f"Error: {str(e)}" | |
| # Sakshi OCR processing | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: | |
| cv2.imwrite(tmp_file.name, img) | |
| sakshi_output = run_sakshi_ocr(tmp_file.name) | |
| os.unlink(tmp_file.name) | |
| return { | |
| "sakshi_output": sakshi_output, | |
| "word_detection_path": word_detection_path if 'word_detection' in session_files else None, | |
| "word_count": word_count, | |
| "prediction_path": pred_path if 'prediction' in session_files else None, | |
| "prediction_label": pred_label | |
| } | |
| class OCRResponse(BaseModel): | |
| sakshi_output: str | |
| word_count: int | |
| prediction_label: str | |
| async def process(file: UploadFile = File(...)): | |
| # Check if the file is an image | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Clean up previous session files | |
| for key, filepath in session_files.items(): | |
| if os.path.exists(filepath): | |
| try: | |
| os.unlink(filepath) | |
| except: | |
| pass | |
| session_files.clear() | |
| # Create a temporary file to save the uploaded image | |
| temp_file = tempfile.NamedTemporaryFile(delete=False) | |
| try: | |
| # Save the uploaded file | |
| with temp_file as f: | |
| shutil.copyfileobj(file.file, f) | |
| # Open and process the image | |
| image = Image.open(temp_file.name) | |
| image_array = np.array(image) | |
| result = process_image(image_array) | |
| return OCRResponse( | |
| sakshi_output=result["sakshi_output"], | |
| word_count=result["word_count"], | |
| prediction_label=result["prediction_label"] | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| finally: | |
| # Clean up the temporary file | |
| os.unlink(temp_file.name) | |
| async def get_word_detection(): | |
| """Return the word detection image.""" | |
| if 'word_detection' not in session_files or not os.path.exists(session_files['word_detection']): | |
| raise HTTPException(status_code=404, detail="Word detection image not found. Process an image first.") | |
| return FileResponse(session_files['word_detection']) | |
| async def get_prediction(): | |
| """Return the prediction image.""" | |
| if 'prediction' not in session_files or not os.path.exists(session_files['prediction']): | |
| raise HTTPException(status_code=404, detail="Prediction image not found. Process an image first.") | |
| return FileResponse(session_files['prediction']) | |
| async def root(): | |
| return {"message": "Hindi OCR API is running. Use POST /process/ to analyze images."} | |
| # For local testing | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |