Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import tensorflow as tf | |
| import pickle | |
| import matplotlib.pyplot as plt | |
| import matplotlib.font_manager as fm | |
| import sakshi_ocr | |
| import os | |
| import io | |
| import sys | |
| import tempfile | |
| import requests | |
| # URLs for the model and encoder hosted on Hugging Face | |
| MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras" | |
| ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl" | |
| FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf" # Optional font | |
| # Download model and encoder | |
| def download_file(url, dest): | |
| response = requests.get(url) | |
| with open(dest, 'wb') as f: | |
| f.write(response.content) | |
| # Paths for local storage in Hugging Face Spaces | |
| MODEL_PATH = "hindi_ocr_model.keras" | |
| ENCODER_PATH = "label_encoder.pkl" | |
| FONT_PATH = "NotoSansDevanagari-Regular.ttf" | |
| # Download models and font if not already present | |
| if not os.path.exists(MODEL_PATH): | |
| download_file(MODEL_URL, MODEL_PATH) | |
| if not os.path.exists(ENCODER_PATH): | |
| download_file(ENCODER_URL, ENCODER_PATH) | |
| if not os.path.exists(FONT_PATH): | |
| download_file(FONT_URL, FONT_PATH) | |
| # Load the custom font if available | |
| if os.path.exists(FONT_PATH): | |
| fm.fontManager.addfont(FONT_PATH) | |
| plt.rcParams['font.family'] = 'Noto Sans Devanagari' | |
| # Load the model and encoder | |
| def load_model(): | |
| if not os.path.exists(MODEL_PATH): | |
| return None | |
| return tf.keras.models.load_model(MODEL_PATH) | |
| def load_label_encoder(): | |
| if not os.path.exists(ENCODER_PATH): | |
| return None | |
| with open(ENCODER_PATH, 'rb') as f: | |
| return pickle.load(f) | |
| model = load_model() | |
| label_encoder = load_label_encoder() | |
| # Word detection function | |
| def detect_words(image): | |
| _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) | |
| kernel = np.ones((3,3), np.uint8) | |
| dilated = cv2.dilate(binary, kernel, iterations=2) | |
| contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) | |
| word_count = 0 | |
| for contour in contours: | |
| x, y, w, h = cv2.boundingRect(contour) | |
| if w > 10 and h > 10: | |
| cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2) | |
| word_count += 1 | |
| return word_img, word_count | |
| # Sakshi OCR output capture | |
| def run_sakshi_ocr(image_path): | |
| buffer = io.StringIO() | |
| old_stdout = sys.stdout | |
| sys.stdout = buffer | |
| try: | |
| sakshi_ocr.generate(image_path) | |
| finally: | |
| sys.stdout = old_stdout | |
| return buffer.getvalue() | |
| # Main OCR processing function | |
| def process_image(image): | |
| if image is None: | |
| return "Error: No image provided", None, 0, "No prediction available" | |
| # Convert PIL image to OpenCV format (grayscale) | |
| img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) | |
| # Word detection | |
| word_detected_img, word_count = detect_words(img) | |
| # First OCR model prediction | |
| try: | |
| img_resized = cv2.resize(img, (128, 32)) | |
| img_norm = img_resized / 255.0 | |
| img_input = img_norm[np.newaxis, ..., np.newaxis] # Shape: (1, 32, 128, 1) | |
| if model is not None and label_encoder is not None: | |
| pred = model.predict(img_input) | |
| pred_label_idx = np.argmax(pred) | |
| pred_label = label_encoder.inverse_transform([pred_label_idx])[0] | |
| # Create plot with prediction | |
| fig, ax = plt.subplots() | |
| ax.imshow(img, cmap='gray') | |
| ax.set_title(f"Predicted: {pred_label}", fontsize=12) | |
| ax.axis('off') | |
| plt.savefig("temp_plot.png") | |
| plt.close() | |
| pred_image = cv2.imread("temp_plot.png") | |
| os.remove("temp_plot.png") | |
| else: | |
| pred_image = None | |
| pred_label = "Model or encoder not loaded" | |
| except Exception as e: | |
| pred_image = None | |
| pred_label = f"Error: {str(e)}" | |
| # Sakshi OCR processing | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: | |
| cv2.imwrite(tmp_file.name, img) | |
| sakshi_output = run_sakshi_ocr(tmp_file.name) | |
| os.remove(tmp_file.name) | |
| return sakshi_output, word_detected_img, word_count, pred_image | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(type="pil", label="Upload an Image"), | |
| outputs=[ | |
| gr.Textbox(label="Sakshi OCR Output"), | |
| gr.Image(label="Word Detection", type="numpy"), | |
| gr.Number(label="Word Count"), | |
| gr.Image(label="Hindi OCR Prediction", type="numpy") | |
| ], | |
| title="Hindi OCR App by Sakshi", | |
| description="Upload an image to perform Hindi OCR and word detection." | |
| ) | |
| # Launch the app | |
| interface.launch() |