Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import sys | |
| import cv2 | |
| import base64 | |
| import pickle | |
| import numpy as np | |
| import tensorflow as tf | |
| import matplotlib.pyplot as plt | |
| import matplotlib.font_manager as fm | |
| import tempfile | |
| import sakshi_ocr | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| # Define paths to your assets (update these if necessary) | |
| MODEL_PATH = 'hindi_ocr_model.keras' | |
| ENCODER_PATH = 'label_encoder.pkl' | |
| FONT_PATH = 'NotoSansDevanagari-Regular.ttf' | |
| # Load custom font if available | |
| if os.path.exists(FONT_PATH): | |
| fm.fontManager.addfont(FONT_PATH) | |
| plt.rcParams['font.family'] = 'Noto Sans Devanagari' | |
| else: | |
| print("Custom font not found. Using default font.") | |
| # Load the OCR model | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"Model file not found at {MODEL_PATH}") | |
| return tf.keras.models.load_model(MODEL_PATH) | |
| # Load the label encoder | |
| def load_label_encoder(): | |
| if not os.path.exists(ENCODER_PATH): | |
| raise FileNotFoundError(f"Label encoder file not found at {ENCODER_PATH}") | |
| with open(ENCODER_PATH, 'rb') as f: | |
| return pickle.load(f) | |
| # Global loading so they persist across requests | |
| model = load_model() | |
| label_encoder = load_label_encoder() | |
| # Function for word detection | |
| def detect_words(image): | |
| # Assume input is a grayscale image | |
| _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
| kernel = np.ones((3, 3), np.uint8) | |
| dilated = cv2.dilate(binary, kernel, iterations=2) | |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| word_count = 0 | |
| for contour in contours: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| if w > 10 and h > 10: | |
| cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2) | |
| word_count += 1 | |
| return word_img, word_count | |
| # Function to run Sakshi OCR and capture its output | |
| def run_sakshi_ocr(image_path): | |
| buffer = io.StringIO() | |
| old_stdout = sys.stdout | |
| sys.stdout = buffer | |
| try: | |
| sakshi_ocr.generate(image_path) | |
| finally: | |
| sys.stdout = old_stdout | |
| return buffer.getvalue() | |
| # Utility function: convert image (numpy array) to a base64 encoded string | |
| def image_to_base64(image, ext=".png"): | |
| success, encoded_image = cv2.imencode(ext, image) | |
| if not success: | |
| return None | |
| return base64.b64encode(encoded_image).decode('utf-8') | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Hindi OCR App by sakshi") | |
| async def root(): | |
| html_content = """ | |
| <html> | |
| <head> | |
| <title>Hindi OCR App by sakshi</title> | |
| </head> | |
| <body> | |
| <h1>Hindi OCR App by sakshi</h1> | |
| <form action="/predict" enctype="multipart/form-data" method="post"> | |
| <input name="file" type="file" accept="image/*"> | |
| <input type="submit" value="Upload and Predict"> | |
| </form> | |
| </body> | |
| </html> | |
| """ | |
| return HTMLResponse(content=html_content) | |
| async def predict(file: UploadFile = File(...)): | |
| # Read and decode the uploaded image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE) | |
| if img is None: | |
| raise HTTPException(status_code=400, detail="Error reading the image.") | |
| # Encode the original image to base64 for visualization | |
| original_image = image_to_base64(cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)) | |
| # Word detection | |
| word_img, word_count = detect_words(img) | |
| word_img_encoded = image_to_base64(word_img) | |
| # OCR model prediction for single word | |
| try: | |
| img_resized = cv2.resize(img, (128, 32)) | |
| img_norm = img_resized / 255.0 | |
| img_input = img_norm[np.newaxis, ..., np.newaxis] # shape: (1, 32, 128, 1) | |
| pred = model.predict(img_input) | |
| pred_label_idx = np.argmax(pred) | |
| pred_label = label_encoder.inverse_transform([pred_label_idx])[0] | |
| # Generate an image with the prediction using matplotlib | |
| fig, ax = plt.subplots() | |
| ax.imshow(img, cmap='gray') | |
| ax.set_title(f"Predicted: {pred_label}", fontsize=12) | |
| ax.axis('off') | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| buf.seek(0) | |
| pred_img_array = np.frombuffer(buf.getvalue(), np.uint8) | |
| prediction_img = cv2.imdecode(pred_img_array, cv2.IMREAD_COLOR) | |
| prediction_img_encoded = image_to_base64(prediction_img) | |
| plt.close(fig) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error in OCR model processing: {e}") | |
| # Run Sakshi OCR on the image by saving temporarily | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: | |
| cv2.imwrite(tmp_file.name, img) | |
| tmp_file_path = tmp_file.name | |
| sakshi_output = run_sakshi_ocr(tmp_file_path) | |
| os.remove(tmp_file_path) | |
| except Exception as e: | |
| sakshi_output = f"Error running Sakshi OCR: {e}" | |
| # Prepare the response | |
| response_data = { | |
| "word_count": word_count, | |
| "ocr_prediction": pred_label, | |
| "sakshi_ocr_output": sakshi_output, | |
| "original_image": original_image, | |
| "word_detected_image": word_img_encoded, | |
| "prediction_image": prediction_img_encoded | |
| } | |
| return JSONResponse(content=response_data) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |