File size: 4,846 Bytes
478f262
 
 
 
 
e84c69b
 
 
 
478f262
e84c69b
478f262
e84c69b
478f262
e84c69b
478f262
 
fbe00ca
e84c69b
 
 
 
 
 
478f262
e84c69b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478f262
e84c69b
 
 
478f262
 
e84c69b
 
 
 
478f262
 
 
 
 
 
 
 
 
 
e84c69b
 
 
 
 
 
 
 
 
 
 
478f262
e84c69b
 
 
 
 
 
 
 
 
 
 
 
478f262
e84c69b
 
478f262
e84c69b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478f262
e84c69b
 
 
 
 
478f262
 
e84c69b
 
 
 
 
 
 
 
 
 
 
 
478f262
e84c69b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import gradio as gr
import cv2
import numpy as np
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import sakshi_ocr
import os
import io
import sys
import tempfile
import requests

# URLs for the model and encoder hosted on Hugging Face
MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"  # Optional font

# Download model and encoder
def download_file(url, dest):
    response = requests.get(url)
    with open(dest, 'wb') as f:
        f.write(response.content)

# Paths for local storage in Hugging Face Spaces
MODEL_PATH = "hindi_ocr_model.keras"
ENCODER_PATH = "label_encoder.pkl"
FONT_PATH = "NotoSansDevanagari-Regular.ttf"

# Download models and font if not already present
if not os.path.exists(MODEL_PATH):
    download_file(MODEL_URL, MODEL_PATH)
if not os.path.exists(ENCODER_PATH):
    download_file(ENCODER_URL, ENCODER_PATH)
if not os.path.exists(FONT_PATH):
    download_file(FONT_URL, FONT_PATH)

# Load the custom font if available
if os.path.exists(FONT_PATH):
    fm.fontManager.addfont(FONT_PATH)
    plt.rcParams['font.family'] = 'Noto Sans Devanagari'

# Load the model and encoder
def load_model():
    if not os.path.exists(MODEL_PATH):
        return None
    return tf.keras.models.load_model(MODEL_PATH)

def load_label_encoder():
    if not os.path.exists(ENCODER_PATH):
        return None
    with open(ENCODER_PATH, 'rb') as f:
        return pickle.load(f)

model = load_model()
label_encoder = load_label_encoder()

# Word detection function
def detect_words(image):
    _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    kernel = np.ones((3,3), np.uint8)
    dilated = cv2.dilate(binary, kernel, iterations=2)
    contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
    word_count = 0
    
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        if w > 10 and h > 10:
            cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
            word_count += 1
    
    return word_img, word_count

# Sakshi OCR output capture
def run_sakshi_ocr(image_path):
    buffer = io.StringIO()
    old_stdout = sys.stdout
    sys.stdout = buffer
    try:
        sakshi_ocr.generate(image_path)
    finally:
        sys.stdout = old_stdout
    return buffer.getvalue()

# Main OCR processing function
def process_image(image):
    if image is None:
        return "Error: No image provided", None, 0, "No prediction available"

    # Convert PIL image to OpenCV format (grayscale)
    img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
    
    # Word detection
    word_detected_img, word_count = detect_words(img)
    
    # First OCR model prediction
    try:
        img_resized = cv2.resize(img, (128, 32))
        img_norm = img_resized / 255.0
        img_input = img_norm[np.newaxis, ..., np.newaxis]  # Shape: (1, 32, 128, 1)
        
        if model is not None and label_encoder is not None:
            pred = model.predict(img_input)
            pred_label_idx = np.argmax(pred)
            pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
            
            # Create plot with prediction
            fig, ax = plt.subplots()
            ax.imshow(img, cmap='gray')
            ax.set_title(f"Predicted: {pred_label}", fontsize=12)
            ax.axis('off')
            plt.savefig("temp_plot.png")
            plt.close()
            pred_image = cv2.imread("temp_plot.png")
            os.remove("temp_plot.png")
        else:
            pred_image = None
            pred_label = "Model or encoder not loaded"
    except Exception as e:
        pred_image = None
        pred_label = f"Error: {str(e)}"
    
    # Sakshi OCR processing
    with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
        cv2.imwrite(tmp_file.name, img)
        sakshi_output = run_sakshi_ocr(tmp_file.name)
        os.remove(tmp_file.name)
    
    return sakshi_output, word_detected_img, word_count, pred_image

# Gradio Interface
interface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="Upload an Image"),
    outputs=[
        gr.Textbox(label="Sakshi OCR Output"),
        gr.Image(label="Word Detection", type="numpy"),
        gr.Number(label="Word Count"),
        gr.Image(label="Hindi OCR Prediction", type="numpy")
    ],
    title="Hindi OCR App by Sakshi",
    description="Upload an image to perform Hindi OCR and word detection."
)

# Launch the app
interface.launch()