sameernotes commited on
Commit
214c905
·
verified ·
1 Parent(s): fc14550

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +156 -107
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,66 +1,101 @@
1
- import os
2
- import io
3
- import sys
4
  import cv2
5
- import base64
6
- import pickle
7
  import numpy as np
8
  import tensorflow as tf
 
9
  import matplotlib.pyplot as plt
10
  import matplotlib.font_manager as fm
11
- import tempfile
12
  import sakshi_ocr
 
 
 
 
 
 
 
 
 
13
 
14
- from fastapi import FastAPI, File, UploadFile, HTTPException
15
- from fastapi.responses import HTMLResponse, JSONResponse
 
 
 
 
 
 
 
 
16
 
17
- # Define paths to your assets (update these if necessary)
18
- MODEL_PATH = 'hindi_ocr_model.keras'
19
- ENCODER_PATH = 'label_encoder.pkl'
20
- FONT_PATH = 'NotoSansDevanagari-Regular.ttf'
 
21
 
22
- # Load custom font if available
23
- if os.path.exists(FONT_PATH):
24
- fm.fontManager.addfont(FONT_PATH)
25
- plt.rcParams['font.family'] = 'Noto Sans Devanagari'
26
- else:
27
- print("Custom font not found. Using default font.")
28
 
29
- # Load the OCR model
 
 
 
 
 
 
30
  def load_model():
31
  if not os.path.exists(MODEL_PATH):
32
- raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
33
  return tf.keras.models.load_model(MODEL_PATH)
34
 
35
- # Load the label encoder
36
  def load_label_encoder():
37
  if not os.path.exists(ENCODER_PATH):
38
- raise FileNotFoundError(f"Label encoder file not found at {ENCODER_PATH}")
39
  with open(ENCODER_PATH, 'rb') as f:
40
  return pickle.load(f)
41
 
42
- # Global loading so they persist across requests
43
- model = load_model()
44
- label_encoder = load_label_encoder()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Function for word detection
47
  def detect_words(image):
48
- # Assume input is a grayscale image
49
  _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
50
- kernel = np.ones((3, 3), np.uint8)
51
  dilated = cv2.dilate(binary, kernel, iterations=2)
52
  contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
 
54
  word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
55
  word_count = 0
 
56
  for contour in contours:
57
  x, y, w, h = cv2.boundingRect(contour)
58
  if w > 10 and h > 10:
59
  cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
60
  word_count += 1
 
61
  return word_img, word_count
62
 
63
- # Function to run Sakshi OCR and capture its output
64
  def run_sakshi_ocr(image_path):
65
  buffer = io.StringIO()
66
  old_stdout = sys.stdout
@@ -71,96 +106,110 @@ def run_sakshi_ocr(image_path):
71
  sys.stdout = old_stdout
72
  return buffer.getvalue()
73
 
74
- # Utility function: convert image (numpy array) to a base64 encoded string
75
- def image_to_base64(image, ext=".png"):
76
- success, encoded_image = cv2.imencode(ext, image)
77
- if not success:
78
- return None
79
- return base64.b64encode(encoded_image).decode('utf-8')
80
-
81
- # Initialize FastAPI app
82
- app = FastAPI(title="Hindi OCR App by sakshi")
83
-
84
- @app.get("/", response_class=HTMLResponse)
85
- async def root():
86
- html_content = """
87
- <html>
88
- <head>
89
- <title>Hindi OCR App by sakshi</title>
90
- </head>
91
- <body>
92
- <h1>Hindi OCR App by sakshi</h1>
93
- <form action="/predict" enctype="multipart/form-data" method="post">
94
- <input name="file" type="file" accept="image/*">
95
- <input type="submit" value="Upload and Predict">
96
- </form>
97
- </body>
98
- </html>
99
- """
100
- return HTMLResponse(content=html_content)
101
-
102
- @app.post("/predict")
103
- async def predict(file: UploadFile = File(...)):
104
- # Read and decode the uploaded image
105
- contents = await file.read()
106
- nparr = np.frombuffer(contents, np.uint8)
107
- img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
108
- if img is None:
109
- raise HTTPException(status_code=400, detail="Error reading the image.")
110
-
111
- # Encode the original image to base64 for visualization
112
- original_image = image_to_base64(cv2.cvtColor(img, cv2.COLOR_GRAY2BGR))
113
 
114
  # Word detection
115
- word_img, word_count = detect_words(img)
116
- word_img_encoded = image_to_base64(word_img)
 
117
 
118
- # OCR model prediction for single word
119
  try:
120
  img_resized = cv2.resize(img, (128, 32))
121
  img_norm = img_resized / 255.0
122
- img_input = img_norm[np.newaxis, ..., np.newaxis] # shape: (1, 32, 128, 1)
123
- pred = model.predict(img_input)
124
- pred_label_idx = np.argmax(pred)
125
- pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
126
 
127
- # Generate an image with the prediction using matplotlib
128
- fig, ax = plt.subplots()
129
- ax.imshow(img, cmap='gray')
130
- ax.set_title(f"Predicted: {pred_label}", fontsize=12)
131
- ax.axis('off')
132
- buf = io.BytesIO()
133
- plt.savefig(buf, format="png")
134
- buf.seek(0)
135
- pred_img_array = np.frombuffer(buf.getvalue(), np.uint8)
136
- prediction_img = cv2.imdecode(pred_img_array, cv2.IMREAD_COLOR)
137
- prediction_img_encoded = image_to_base64(prediction_img)
138
- plt.close(fig)
 
 
 
 
139
  except Exception as e:
140
- raise HTTPException(status_code=500, detail=f"Error in OCR model processing: {e}")
 
141
 
142
- # Run Sakshi OCR on the image by saving temporarily
143
- try:
144
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
145
- cv2.imwrite(tmp_file.name, img)
146
- tmp_file_path = tmp_file.name
147
- sakshi_output = run_sakshi_ocr(tmp_file_path)
148
- os.remove(tmp_file_path)
149
- except Exception as e:
150
- sakshi_output = f"Error running Sakshi OCR: {e}"
151
 
152
- # Prepare the response
153
- response_data = {
 
154
  "word_count": word_count,
155
- "ocr_prediction": pred_label,
156
- "sakshi_ocr_output": sakshi_output,
157
- "original_image": original_image,
158
- "word_detected_image": word_img_encoded,
159
- "prediction_image": prediction_img_encoded
160
  }
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- return JSONResponse(content=response_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
 
164
  if __name__ == "__main__":
165
- import uvicorn
166
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import FileResponse, JSONResponse
3
+ from pydantic import BaseModel
4
  import cv2
 
 
5
  import numpy as np
6
  import tensorflow as tf
7
+ import pickle
8
  import matplotlib.pyplot as plt
9
  import matplotlib.font_manager as fm
 
10
  import sakshi_ocr
11
+ import os
12
+ import io
13
+ import sys
14
+ import tempfile
15
+ import requests
16
+ from PIL import Image
17
+ import uvicorn
18
+ import shutil
19
+ from pathlib import Path
20
 
21
+ app = FastAPI(
22
+ title="Hindi OCR API",
23
+ description="API for Hindi OCR and word detection",
24
+ version="1.0.0"
25
+ )
26
+
27
+ # URLs for the model and encoder hosted on Hugging Face
28
+ MODEL_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/hindi_ocr_model.keras"
29
+ ENCODER_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/label_encoder.pkl"
30
+ FONT_URL = "https://huggingface.co/sameernotes/hindi-ocr/resolve/main/NotoSansDevanagari-Regular.ttf"
31
 
32
+ # Paths for local storage
33
+ MODEL_PATH = "hindi_ocr_model.keras"
34
+ ENCODER_PATH = "label_encoder.pkl"
35
+ FONT_PATH = "NotoSansDevanagari-Regular.ttf"
36
+ OUTPUT_DIR = "output"
37
 
38
+ # Create output directory if it doesn't exist
39
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
 
 
 
 
40
 
41
+ # Download model and encoder
42
+ def download_file(url, dest):
43
+ response = requests.get(url)
44
+ with open(dest, 'wb') as f:
45
+ f.write(response.content)
46
+
47
+ # Load the model and encoder
48
  def load_model():
49
  if not os.path.exists(MODEL_PATH):
50
+ return None
51
  return tf.keras.models.load_model(MODEL_PATH)
52
 
 
53
  def load_label_encoder():
54
  if not os.path.exists(ENCODER_PATH):
55
+ return None
56
  with open(ENCODER_PATH, 'rb') as f:
57
  return pickle.load(f)
58
 
59
+ # Download required files on startup
60
+ @app.on_event("startup")
61
+ async def startup_event():
62
+ # Download models and font if not already present
63
+ if not os.path.exists(MODEL_PATH):
64
+ download_file(MODEL_URL, MODEL_PATH)
65
+ if not os.path.exists(ENCODER_PATH):
66
+ download_file(ENCODER_URL, ENCODER_PATH)
67
+ if not os.path.exists(FONT_PATH):
68
+ download_file(FONT_URL, FONT_PATH)
69
+
70
+ # Load the custom font if available
71
+ if os.path.exists(FONT_PATH):
72
+ fm.fontManager.addfont(FONT_PATH)
73
+ plt.rcParams['font.family'] = 'Noto Sans Devanagari'
74
+
75
+ # Initialize global variables
76
+ global model, label_encoder
77
+ model = load_model()
78
+ label_encoder = load_label_encoder()
79
 
80
+ # Word detection function
81
  def detect_words(image):
 
82
  _, binary = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
83
+ kernel = np.ones((3,3), np.uint8)
84
  dilated = cv2.dilate(binary, kernel, iterations=2)
85
  contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
86
 
87
  word_img = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
88
  word_count = 0
89
+
90
  for contour in contours:
91
  x, y, w, h = cv2.boundingRect(contour)
92
  if w > 10 and h > 10:
93
  cv2.rectangle(word_img, (x, y), (x+w, y+h), (0, 255, 0), 2)
94
  word_count += 1
95
+
96
  return word_img, word_count
97
 
98
+ # Sakshi OCR output capture
99
  def run_sakshi_ocr(image_path):
100
  buffer = io.StringIO()
101
  old_stdout = sys.stdout
 
106
  sys.stdout = old_stdout
107
  return buffer.getvalue()
108
 
109
+ # Main OCR processing function
110
+ def process_image(image_array):
111
+ # Convert image array to grayscale
112
+ img = cv2.cvtColor(image_array, cv2.COLOR_RGB2GRAY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Word detection
115
+ word_detected_img, word_count = detect_words(img)
116
+ word_detection_path = os.path.join(OUTPUT_DIR, "word_detection.png")
117
+ cv2.imwrite(word_detection_path, word_detected_img)
118
 
119
+ # First OCR model prediction
120
  try:
121
  img_resized = cv2.resize(img, (128, 32))
122
  img_norm = img_resized / 255.0
123
+ img_input = img_norm[np.newaxis, ..., np.newaxis] # Shape: (1, 32, 128, 1)
 
 
 
124
 
125
+ if model is not None and label_encoder is not None:
126
+ pred = model.predict(img_input)
127
+ pred_label_idx = np.argmax(pred)
128
+ pred_label = label_encoder.inverse_transform([pred_label_idx])[0]
129
+
130
+ # Create plot with prediction
131
+ fig, ax = plt.subplots()
132
+ ax.imshow(img, cmap='gray')
133
+ ax.set_title(f"Predicted: {pred_label}", fontsize=12)
134
+ ax.axis('off')
135
+ pred_path = os.path.join(OUTPUT_DIR, "prediction.png")
136
+ plt.savefig(pred_path)
137
+ plt.close()
138
+ else:
139
+ pred_path = None
140
+ pred_label = "Model or encoder not loaded"
141
  except Exception as e:
142
+ pred_path = None
143
+ pred_label = f"Error: {str(e)}"
144
 
145
+ # Sakshi OCR processing
146
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
147
+ cv2.imwrite(tmp_file.name, img)
148
+ sakshi_output = run_sakshi_ocr(tmp_file.name)
149
+ os.remove(tmp_file.name)
 
 
 
 
150
 
151
+ return {
152
+ "sakshi_output": sakshi_output,
153
+ "word_detection_path": word_detection_path,
154
  "word_count": word_count,
155
+ "prediction_path": pred_path,
156
+ "prediction_label": pred_label
 
 
 
157
  }
158
+
159
+ class OCRResponse(BaseModel):
160
+ sakshi_output: str
161
+ word_count: int
162
+ prediction_label: str
163
+
164
+ @app.post("/process/", response_model=OCRResponse)
165
+ async def process(file: UploadFile = File(...)):
166
+ # Check if the file is an image
167
+ if not file.content_type.startswith("image/"):
168
+ raise HTTPException(status_code=400, detail="File must be an image")
169
 
170
+ # Create a temporary file to save the uploaded image
171
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
172
+ try:
173
+ # Save the uploaded file
174
+ with temp_file as f:
175
+ shutil.copyfileobj(file.file, f)
176
+
177
+ # Open and process the image
178
+ image = Image.open(temp_file.name)
179
+ image_array = np.array(image)
180
+ result = process_image(image_array)
181
+
182
+ return OCRResponse(
183
+ sakshi_output=result["sakshi_output"],
184
+ word_count=result["word_count"],
185
+ prediction_label=result["prediction_label"]
186
+ )
187
+ except Exception as e:
188
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
189
+ finally:
190
+ # Clean up the temporary file
191
+ os.unlink(temp_file.name)
192
+
193
+ @app.get("/word-detection/")
194
+ async def get_word_detection():
195
+ """Return the word detection image."""
196
+ word_detection_path = Path(OUTPUT_DIR) / "word_detection.png"
197
+ if not word_detection_path.exists():
198
+ raise HTTPException(status_code=404, detail="Word detection image not found. Process an image first.")
199
+ return FileResponse(word_detection_path)
200
+
201
+ @app.get("/prediction/")
202
+ async def get_prediction():
203
+ """Return the prediction image."""
204
+ prediction_path = Path(OUTPUT_DIR) / "prediction.png"
205
+ if not prediction_path.exists():
206
+ raise HTTPException(status_code=404, detail="Prediction image not found. Process an image first.")
207
+ return FileResponse(prediction_path)
208
+
209
+ @app.get("/")
210
+ async def root():
211
+ return {"message": "Hindi OCR API is running. Use POST /process/ to analyze images."}
212
 
213
+ # For local testing
214
  if __name__ == "__main__":
215
+ uvicorn.run(app, host="0.0.0.0", port=8000)
 
requirements.txt CHANGED
@@ -7,4 +7,6 @@ opencv-python
7
  matplotlib
8
  scikit-learn
9
  python-multipart
10
- sakshi-ocr
 
 
 
7
  matplotlib
8
  scikit-learn
9
  python-multipart
10
+ sakshi-ocr
11
+ pydantic
12
+ requests