Spaces:
Sleeping
Sleeping
File size: 4,846 Bytes
478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 fbe00ca e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b 478f262 e84c69b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import gradio as gr
import cv2
import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import sakshi_ocr
import os
import io
import sys
import tempfile
import requests
# URLs for the model and encoder hosted on Hugging Face
MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf" # Optional font
# Download model and encoder
def download_file(url, dest):
response = requests.get(url)
with open(dest, 'wb') as f:
f.write(response.content)
# Paths for local storage in Hugging Face Spaces
MODEL_PATH = "hindi_ocr_model.keras"
ENCODER_PATH = "label_encoder.pkl"
FONT_PATH = "NotoSansDevanagari-Regular.ttf"
# Download models and font if not already present
if not os.path.exists(MODEL_PATH):
download_file(MODEL_URL, MODEL_PATH)
if not os.path.exists(ENCODER_PATH):
download_file(ENCODER_URL, ENCODER_PATH)
if not os.path.exists(FONT_PATH):
download_file(FONT_URL, FONT_PATH)
# Load the custom font if available
if os.path.exists(FONT_PATH):
fm.fontManager.addfont(FONT_PATH)
plt.rcParams['font.family'] = 'Noto Sans Devanagari'
# Load the model and encoder
def load_model():
if not os.path.exists(MODEL_PATH):
return None
return tf.keras.models.load_model(MODEL_PATH)
def load_label_encoder():
if not os.path.exists(ENCODER_PATH):
return None
with open(ENCODER_PATH, 'rb') as f:
return pickle.load(f)
model = load_model()
label_encoder = load_label_encoder()
# Word detection function
def detect_words(image):
_, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
kernel = np.ones((3,3), np.uint8)
dilated = cv2.dilate(binary, kernel, iterations=2)
contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
word_count = 0
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
if w > 10 and h > 10:
cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
word_count += 1
return word_img, word_count
# Sakshi OCR output capture
def run_sakshi_ocr(image_path):
buffer = io.StringIO()
old_stdout = sys.stdout
sys.stdout = buffer
try:
sakshi_ocr.generate(image_path)
finally:
sys.stdout = old_stdout
return buffer.getvalue()
# Main OCR processing function
def process_image(image):
if image is None:
return "Error: No image provided", None, 0, "No prediction available"
# Convert PIL image to OpenCV format (grayscale)
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
# Word detection
word_detected_img, word_count = detect_words(img)
# First OCR model prediction
try:
img_resized = cv2.resize(img, (128, 32))
img_norm = img_resized / 255.0
img_input = img_norm[np.newaxis, ..., np.newaxis] # Shape: (1, 32, 128, 1)
if model is not None and label_encoder is not None:
pred = model.predict(img_input)
pred_label_idx = np.argmax(pred)
pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
# Create plot with prediction
fig, ax = plt.subplots()
ax.imshow(img, cmap='gray')
ax.set_title(f"Predicted: {pred_label}", fontsize=12)
ax.axis('off')
plt.savefig("temp_plot.png")
plt.close()
pred_image = cv2.imread("temp_plot.png")
os.remove("temp_plot.png")
else:
pred_image = None
pred_label = "Model or encoder not loaded"
except Exception as e:
pred_image = None
pred_label = f"Error: {str(e)}"
# Sakshi OCR processing
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
cv2.imwrite(tmp_file.name, img)
sakshi_output = run_sakshi_ocr(tmp_file.name)
os.remove(tmp_file.name)
return sakshi_output, word_detected_img, word_count, pred_image
# Gradio Interface
interface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="pil", label="Upload an Image"),
outputs=[
gr.Textbox(label="Sakshi OCR Output"),
gr.Image(label="Word Detection", type="numpy"),
gr.Number(label="Word Count"),
gr.Image(label="Hindi OCR Prediction", type="numpy")
],
title="Hindi OCR App by Sakshi",
description="Upload an image to perform Hindi OCR and word detection."
)
# Launch the app
interface.launch() |