kbsss commited on
Commit
c10f086
·
verified ·
1 Parent(s): 1787ace

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. Dockerfile +27 -0
  2. app.py +213 -0
  3. label_encoder.pkl +3 -0
  4. model.h5 +3 -0
  5. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies for OpenCV
6
+ RUN apt-get update && apt-get install -y \
7
+ libgl1 \
8
+ libglib2.0-0 \
9
+ libsm6 \
10
+ libxext6 \
11
+ libxrender-dev \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Copy requirements and install
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy app files
19
+ COPY app.py .
20
+ COPY model.h5 .
21
+ COPY label_encoder.pkl .
22
+
23
+ # Expose Gradio port
24
+ EXPOSE 7860
25
+
26
+ # Run the app
27
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Handwritten Equation Solver - API
3
+ """
4
+ import os
5
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
6
+
7
+ from fastapi import FastAPI, UploadFile, File
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.responses import JSONResponse
10
+ import cv2
11
+ import numpy as np
12
+ import re
13
+ from imutils.contours import sort_contours
14
+ import imutils
15
+ import base64
16
+ from io import BytesIO
17
+ from PIL import Image
18
+ import tensorflow as tf
19
+
20
+ tf.get_logger().setLevel('ERROR')
21
+
22
+ app = FastAPI(title="Equation Solver API")
23
+
24
+ # Enable CORS for frontend
25
+ app.add_middleware(
26
+ CORSMiddleware,
27
+ allow_origins=["*"],
28
+ allow_credentials=True,
29
+ allow_methods=["*"],
30
+ allow_headers=["*"],
31
+ )
32
+
33
+ # Load model at startup
34
+ print("Loading model...")
35
+ model = tf.keras.models.load_model('model.h5', compile=False)
36
+ print("Model loaded!")
37
+
38
+ # Label mapping
39
+ CLASSES = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "add", "div", "mul", "sub"]
40
+ SYMBOL_MAP = {'add': '+', 'sub': '-', 'mul': '×', 'div': '÷'}
41
+
42
+
43
+ def preprocess_symbol(image):
44
+ if len(image.shape) == 3:
45
+ img_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
46
+ else:
47
+ img_gray = image.copy()
48
+
49
+ threshold_img = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
50
+ threshold_img = cv2.resize(threshold_img, (32, 32))
51
+ threshold_img = threshold_img / 255.0
52
+ threshold_img = np.expand_dims(threshold_img, axis=-1)
53
+
54
+ return threshold_img
55
+
56
+
57
+ def segment_equation(image):
58
+ if len(image.shape) == 3:
59
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
60
+ else:
61
+ gray = image.copy()
62
+
63
+ binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]
64
+
65
+ cnts = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
66
+ cnts = imutils.grab_contours(cnts)
67
+
68
+ if cnts:
69
+ cnts = sort_contours(cnts, method="left-to-right")[0]
70
+
71
+ symbols = []
72
+ boxes = []
73
+
74
+ for c in cnts:
75
+ (x, y, w, h) = cv2.boundingRect(c)
76
+
77
+ if w < 10 or h < 10:
78
+ continue
79
+
80
+ padding = 5
81
+ y_start = max(0, y - padding)
82
+ y_end = min(image.shape[0], y + h + padding)
83
+ x_start = max(0, x - padding)
84
+ x_end = min(image.shape[1], x + w + padding)
85
+
86
+ symbol_img = gray[y_start:y_end, x_start:x_end]
87
+
88
+ boxes.append({"x": int(x), "y": int(y), "w": int(w), "h": int(h)})
89
+ symbols.append(symbol_img)
90
+
91
+ return boxes, symbols
92
+
93
+
94
+ def correct_symbol_by_geometry(symbol, box):
95
+ if symbol not in ['+', '-']:
96
+ return symbol
97
+
98
+ w = box["w"]
99
+ h = box["h"]
100
+ if h == 0:
101
+ return symbol
102
+
103
+ aspect_ratio = w / h
104
+
105
+ if aspect_ratio > 1.5:
106
+ return '-'
107
+ elif aspect_ratio < 1.2:
108
+ return '+'
109
+
110
+ return symbol
111
+
112
+
113
+ def solve_equation(equation_str):
114
+ try:
115
+ eq = equation_str.replace('×', '*').replace('÷', '/').replace(' ', '')
116
+ eq = eq.split('=')[0].replace('?', '')
117
+
118
+ if not re.match(r'^[\d\+\-\*/\(\)\.\s]+$', eq):
119
+ return None, "Invalid equation format"
120
+
121
+ result = eval(eq)
122
+
123
+ if isinstance(result, float) and result.is_integer():
124
+ result = int(result)
125
+
126
+ return result, None
127
+ except Exception as e:
128
+ return None, str(e)
129
+
130
+
131
+ def process_image(image_array):
132
+ if len(image_array.shape) == 3 and image_array.shape[2] == 3:
133
+ img_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
134
+ else:
135
+ img_cv = image_array
136
+
137
+ boxes, symbol_images = segment_equation(img_cv)
138
+
139
+ if not symbol_images:
140
+ return {"error": "No symbols detected in image"}
141
+
142
+ processed = [preprocess_symbol(s) for s in symbol_images]
143
+ X = np.array(processed)
144
+
145
+ predictions = model.predict(X, verbose=0)
146
+ predicted_indices = np.argmax(predictions, axis=1)
147
+
148
+ symbols = []
149
+ for i, idx in enumerate(predicted_indices):
150
+ label = CLASSES[idx]
151
+ symbol = SYMBOL_MAP.get(label, label)
152
+ if i < len(boxes):
153
+ symbol = correct_symbol_by_geometry(symbol, boxes[i])
154
+ symbols.append(symbol)
155
+
156
+ equation_str = ''.join(symbols)
157
+ result, error = solve_equation(equation_str)
158
+
159
+ return {
160
+ "equation": equation_str,
161
+ "result": result,
162
+ "symbols_count": len(symbols),
163
+ "boxes": boxes,
164
+ "error": error
165
+ }
166
+
167
+
168
+ @app.get("/")
169
+ async def root():
170
+ return {"status": "ok", "message": "Equation Solver API"}
171
+
172
+
173
+ @app.post("/api/predict")
174
+ async def predict(file: UploadFile = File(...)):
175
+ try:
176
+ contents = await file.read()
177
+ image = Image.open(BytesIO(contents))
178
+ image_array = np.array(image)
179
+
180
+ result = process_image(image_array)
181
+ return JSONResponse(content={"data": [result]})
182
+ except Exception as e:
183
+ return JSONResponse(content={"error": str(e)}, status_code=500)
184
+
185
+
186
+ @app.post("/predict")
187
+ async def predict_json(data: dict):
188
+ """Handle Gradio-style base64 image input"""
189
+ try:
190
+ if "data" not in data or not data["data"]:
191
+ return JSONResponse(content={"error": "No data provided"}, status_code=400)
192
+
193
+ image_data = data["data"][0]
194
+
195
+ # Handle base64 encoded image
196
+ if isinstance(image_data, str) and image_data.startswith("data:"):
197
+ # Remove data URL prefix
198
+ base64_str = image_data.split(",")[1]
199
+ image_bytes = base64.b64decode(base64_str)
200
+ image = Image.open(BytesIO(image_bytes))
201
+ image_array = np.array(image)
202
+ else:
203
+ return JSONResponse(content={"error": "Invalid image format"}, status_code=400)
204
+
205
+ result = process_image(image_array)
206
+ return JSONResponse(content={"data": [result]})
207
+ except Exception as e:
208
+ return JSONResponse(content={"error": str(e)}, status_code=500)
209
+
210
+
211
+ if __name__ == "__main__":
212
+ import uvicorn
213
+ uvicorn.run(app, host="0.0.0.0", port=7860)
label_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf5c1c0438546069cfc3fd908c87c38c8790a13dbd00de9ed51e59f191aad4c4
3
+ size 495
model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61e1007c0d0c2992a1aa737d72697865cedef92639a9ef9e1b04d6ecedda61fd
3
+ size 2009672
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow==2.15.0
2
+ opencv-python-headless==4.8.1.78
3
+ numpy==1.24.3
4
+ imutils==0.5.4
5
+ scikit-learn==1.3.2
6
+ fastapi==0.109.0
7
+ uvicorn==0.27.0
8
+ python-multipart==0.0.6
9
+ pillow==10.2.0