sameernotes commited on
Commit
e84c69b
·
verified ·
1 Parent(s): ccfdc8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -44
app.py CHANGED
@@ -3,36 +3,56 @@ import cv2
3
  import numpy as np
4
  import tensorflow as tf
5
  import pickle
6
- import requests
 
 
 
7
  import io
 
8
  import tempfile
9
- import sakshi_ocr
10
 
11
- # Model & Encoder URLs
12
  MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
13
  ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
 
 
 
 
 
 
 
14
 
15
- # Load model from Hugging Face
16
- @tf.function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def load_model():
18
- response = requests.get(MODEL_URL)
19
- if response.status_code == 200:
20
- with tempfile.NamedTemporaryFile(delete=False, suffix=".keras") as temp_model:
21
- temp_model.write(response.content)
22
- model = tf.keras.models.load_model(temp_model.name)
23
- return model
24
- else:
25
- raise ValueError("Failed to load model from Hugging Face.")
26
 
27
- # Load label encoder from Hugging Face
28
  def load_label_encoder():
29
- response = requests.get(ENCODER_URL)
30
- if response.status_code == 200:
31
- return pickle.loads(response.content)
32
- else:
33
- raise ValueError("Failed to load label encoder.")
34
 
35
- # Initialize model and encoder
36
  model = load_model()
37
  label_encoder = load_label_encoder()
38
 
@@ -42,34 +62,88 @@ def detect_words(image):
42
  kernel = np.ones((3,3), np.uint8)
43
  dilated = cv2.dilate(binary, kernel, iterations=2)
44
  contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
45
- word_count = sum(1 for c in contours if cv2.boundingRect(c)[2] > 10 and cv2.boundingRect(c)[3] > 10)
46
- return word_count
 
 
 
 
 
 
 
 
 
47
 
48
- # Process image and predict text
 
 
 
 
 
 
 
 
 
 
 
49
  def process_image(image):
50
- gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
51
- word_count = detect_words(gray)
52
- img_resized = cv2.resize(gray, (128, 32)) / 255.0
53
- img_input = img_resized[np.newaxis, ..., np.newaxis]
54
- pred = model.predict(img_input)
55
- pred_label_idx = np.argmax(pred)
56
- pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
57
- return f"Detected Words: {word_count}\nPredicted Text: {pred_label}"
58
 
59
- # Sakshi OCR function
60
- def run_sakshi_ocr(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
62
- cv2.imwrite(tmp_file.name, image)
63
- output = io.StringIO()
64
- sakshi_ocr.generate(tmp_file.name, output)
65
- return output.getvalue()
 
66
 
67
  # Gradio Interface
68
- def ocr_pipeline(image):
69
- text_prediction = process_image(image)
70
- sakshi_output = run_sakshi_ocr(image)
71
- return f"{text_prediction}\n\nSakshi OCR Output:\n{sakshi_output}"
72
-
73
- demo = gr.Interface(fn=ocr_pipeline, inputs=gr.Image(type="numpy"), outputs="text")
 
 
 
 
 
 
74
 
75
- demo.launch()
 
 
3
  import numpy as np
4
  import tensorflow as tf
5
  import pickle
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib.font_manager as fm
8
+ import sakshi_ocr
9
+ import os
10
  import io
11
+ import sys
12
  import tempfile
13
+ import requests
14
 
15
+ # URLs for the model and encoder hosted on Hugging Face
16
  MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
17
  ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
18
+ FONT_URL = "https://noto-website-2.storage.googleapis.com/pkgs/NotoSansDevanagari-Regular.ttf" # Optional font
19
+
20
+ # Download model and encoder
21
+ def download_file(url, dest):
22
+ response = requests.get(url)
23
+ with open(dest, 'wb') as f:
24
+ f.write(response.content)
25
 
26
+ # Paths for local storage in Hugging Face Spaces
27
+ MODEL_PATH = "hindi_ocr_model.keras"
28
+ ENCODER_PATH = "label_encoder.pkl"
29
+ FONT_PATH = "NotoSansDevanagari-Regular.ttf"
30
+
31
+ # Download models and font if not already present
32
+ if not os.path.exists(MODEL_PATH):
33
+ download_file(MODEL_URL, MODEL_PATH)
34
+ if not os.path.exists(ENCODER_PATH):
35
+ download_file(ENCODER_URL, ENCODER_PATH)
36
+ if not os.path.exists(FONT_PATH):
37
+ download_file(FONT_URL, FONT_PATH)
38
+
39
+ # Load the custom font if available
40
+ if os.path.exists(FONT_PATH):
41
+ fm.fontManager.addfont(FONT_PATH)
42
+ plt.rcParams['font.family'] = 'Noto Sans Devanagari'
43
+
44
+ # Load the model and encoder
45
  def load_model():
46
+ if not os.path.exists(MODEL_PATH):
47
+ return None
48
+ return tf.keras.models.load_model(MODEL_PATH)
 
 
 
 
 
49
 
 
50
  def load_label_encoder():
51
+ if not os.path.exists(ENCODER_PATH):
52
+ return None
53
+ with open(ENCODER_PATH, 'rb') as f:
54
+ return pickle.load(f)
 
55
 
 
56
  model = load_model()
57
  label_encoder = load_label_encoder()
58
 
 
62
  kernel = np.ones((3,3), np.uint8)
63
  dilated = cv2.dilate(binary, kernel, iterations=2)
64
  contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
65
+
66
+ word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
67
+ word_count = 0
68
+
69
+ for contour in contours:
70
+ x, y, w, h = cv2.boundingRect(contour)
71
+ if w > 10 and h > 10:
72
+ cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
73
+ word_count += 1
74
+
75
+ return word_img, word_count
76
 
77
+ # Sakshi OCR output capture
78
+ def run_sakshi_ocr(image_path):
79
+ buffer = io.StringIO()
80
+ old_stdout = sys.stdout
81
+ sys.stdout = buffer
82
+ try:
83
+ sakshi_ocr.generate(image_path)
84
+ finally:
85
+ sys.stdout = old_stdout
86
+ return buffer.getvalue()
87
+
88
+ # Main OCR processing function
89
  def process_image(image):
90
+ if image is None:
91
+ return "Error: No image provided", None, 0, "No prediction available"
 
 
 
 
 
 
92
 
93
+ # Convert PIL image to OpenCV format (grayscale)
94
+ img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
95
+
96
+ # Word detection
97
+ word_detected_img, word_count = detect_words(img)
98
+
99
+ # First OCR model prediction
100
+ try:
101
+ img_resized = cv2.resize(img, (128, 32))
102
+ img_norm = img_resized / 255.0
103
+ img_input = img_norm[np.newaxis, ..., np.newaxis] # Shape: (1, 32, 128, 1)
104
+
105
+ if model is not None and label_encoder is not None:
106
+ pred = model.predict(img_input)
107
+ pred_label_idx = np.argmax(pred)
108
+ pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
109
+
110
+ # Create plot with prediction
111
+ fig, ax = plt.subplots()
112
+ ax.imshow(img, cmap='gray')
113
+ ax.set_title(f"Predicted: {pred_label}", fontsize=12)
114
+ ax.axis('off')
115
+ plt.savefig("temp_plot.png")
116
+ plt.close()
117
+ pred_image = cv2.imread("temp_plot.png")
118
+ os.remove("temp_plot.png")
119
+ else:
120
+ pred_image = None
121
+ pred_label = "Model or encoder not loaded"
122
+ except Exception as e:
123
+ pred_image = None
124
+ pred_label = f"Error: {str(e)}"
125
+
126
+ # Sakshi OCR processing
127
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
128
+ cv2.imwrite(tmp_file.name, img)
129
+ sakshi_output = run_sakshi_ocr(tmp_file.name)
130
+ os.remove(tmp_file.name)
131
+
132
+ return sakshi_output, word_detected_img, word_count, pred_image
133
 
134
  # Gradio Interface
135
+ interface = gr.Interface(
136
+ fn=process_image,
137
+ inputs=gr.Image(type="pil", label="Upload an Image"),
138
+ outputs=[
139
+ gr.Textbox(label="Sakshi OCR Output"),
140
+ gr.Image(label="Word Detection", type="numpy"),
141
+ gr.Number(label="Word Count"),
142
+ gr.Image(label="Hindi OCR Prediction", type="numpy")
143
+ ],
144
+ title="Hindi OCR App by Sakshi",
145
+ description="Upload an image to perform Hindi OCR and word detection."
146
+ )
147
 
148
+ # Launch the app
149
+ interface.launch()