sedtha commited on
Commit
fa685f9
Β·
verified Β·
1 Parent(s): 0b4a5cb

Upload 3 files

Browse files
Files changed (3) hide show
  1. khmer_model_weights.pth +3 -0
  2. main.py +394 -131
  3. requirements.txt +11 -6
khmer_model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06d5efe7ca467186f9e4207d99d370bc27721d77504e2a973253e29170e9e309
3
+ size 4007093
main.py CHANGED
@@ -1,131 +1,394 @@
1
- # main.py
2
-
3
- import gradio as gr
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from PIL import Image
8
- import numpy as np
9
-
10
- # -----------------------------
11
- # 1. Define the model class
12
- # -----------------------------
13
- class MyModel(nn.Module):
14
- def __init__(self, num_classes=10):
15
- super().__init__()
16
- self.fc1 = nn.Linear(48*48, 392)
17
- self.fc2 = nn.Linear(392, 196)
18
- self.fc3 = nn.Linear(196, 98)
19
- self.fc4 = nn.Linear(98, num_classes)
20
- self.relu = nn.ReLU()
21
-
22
- def forward(self, x):
23
- x = self.fc1(x)
24
- x = self.relu(x)
25
- x = self.fc2(x)
26
- x = self.relu(x)
27
- x = self.fc3(x)
28
- x = self.relu(x)
29
- x = self.fc4(x)
30
- return x
31
-
32
- # -----------------------------
33
- # 2. Manual label mapping
34
- # -----------------------------
35
- label_to_idx = {
36
- 'TA': 0, # ត
37
- 'NGO': 1, # αž„
38
- 'CHA': 2, # αž…
39
- 'DA': 3, # ដ
40
- 'KO': 4, # αž€
41
- 'NA': 5, # ណ
42
- 'KHA': 6, # ខ
43
- 'CHHA': 7, # αž†
44
- 'CHHO': 8, # ឈ
45
- 'KHO': 9 # αžƒ
46
- }
47
-
48
- idx_to_label = {v: k for k, v in label_to_idx.items()}
49
-
50
- label_to_char = {
51
- 'TA': 'ត',
52
- 'NGO': 'αž„',
53
- 'CHA': 'αž…',
54
- 'DA': 'ដ',
55
- 'KO': 'αž€',
56
- 'NA': 'ណ',
57
- 'KHA': 'ខ',
58
- 'CHHA': 'αž†',
59
- 'CHHO': 'ឈ',
60
- 'KHO': 'αžƒ'
61
- }
62
-
63
- num_classes = len(label_to_idx)
64
-
65
- # -----------------------------
66
- # 3. Load model
67
- # -----------------------------
68
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
- model = MyModel(num_classes=num_classes)
70
- model.load_state_dict(torch.load(r"sedtha/khmerhandwriting", map_location=device))
71
- model.eval()
72
- model.to(device)
73
-
74
- # -----------------------------
75
- # 4. Preprocess image
76
- # -----------------------------
77
- def preprocess_image(img: Image.Image):
78
- img = img.convert("L").resize((48,48))
79
- img_array = np.array(img, dtype=np.float32)
80
- img_array = img_array.reshape(1, -1) # flatten
81
- img_array /= 255.0 # normalize
82
- tensor = torch.tensor(img_array).to(device)
83
- return tensor
84
-
85
- # -----------------------------
86
- # 5. Prediction functions
87
- # -----------------------------
88
- def predict_image(img: Image.Image):
89
- tensor = preprocess_image(img)
90
- with torch.no_grad():
91
- output = model(tensor)
92
- probs = F.softmax(output, dim=1)
93
- pred_idx = torch.argmax(probs, dim=1).item()
94
- confidence = probs[0, pred_idx].item()
95
- pred_label = idx_to_label[pred_idx]
96
- pred_char = label_to_char[pred_label]
97
- return f"Predicted: {pred_char} ({pred_label}), Confidence: {confidence*100:.2f}%"
98
-
99
- def predict_draw(image_array: np.ndarray):
100
- if image_array.shape[-1] == 3:
101
- img = Image.fromarray(image_array).convert("L")
102
- else:
103
- img = Image.fromarray(image_array.squeeze()).convert("L")
104
- return predict_image(img)
105
-
106
- # -----------------------------
107
- # 6. Gradio interface
108
- # -----------------------------
109
- def main():
110
- with gr.Blocks() as demo:
111
- gr.Markdown("## Khmer Character Recognition")
112
-
113
- with gr.Tab("Upload Image"):
114
- img_input = gr.Image(type="pil")
115
- img_output = gr.Textbox()
116
- btn = gr.Button("Predict")
117
- btn.click(predict_image, inputs=img_input, outputs=img_output)
118
-
119
- with gr.Tab("Draw Letter"):
120
- canvas_input = gr.Image(shape=(48,48), image_mode='L', invert_colors=True, source="canvas")
121
- draw_output = gr.Textbox()
122
- draw_btn = gr.Button("Predict Drawing")
123
- draw_btn.click(predict_draw, inputs=canvas_input, outputs=draw_output)
124
-
125
- demo.launch(share=True)
126
-
127
- # -----------------------------
128
- # 7. Run app
129
- # -----------------------------
130
- if __name__ == "__main__":
131
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Khmer Character Recognition App
3
+ Recognizes 10 Khmer characters using a neural network model
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from PIL import Image
11
+ import numpy as np
12
+ from pathlib import Path
13
+ import logging
14
+
15
+ # Setup logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ # -----------------------------
21
+ # Model Definition
22
+ # -----------------------------
23
+ class KhmerModel(nn.Module):
24
+ """Neural network for Khmer character classification"""
25
+
26
+ def __init__(self, num_classes=10):
27
+ super().__init__()
28
+ self.fc1 = nn.Linear(48 * 48, 392)
29
+ self.fc2 = nn.Linear(392, 196)
30
+ self.fc3 = nn.Linear(196, 98)
31
+ self.fc4 = nn.Linear(98, num_classes)
32
+ self.relu = nn.ReLU()
33
+ self.dropout = nn.Dropout(0.2)
34
+
35
+ def forward(self, x):
36
+ x = self.relu(self.fc1(x))
37
+ x = self.dropout(x)
38
+ x = self.relu(self.fc2(x))
39
+ x = self.dropout(x)
40
+ x = self.relu(self.fc3(x))
41
+ x = self.fc4(x)
42
+ return x
43
+
44
+
45
+ # -----------------------------
46
+ # Configuration
47
+ # -----------------------------
48
+ class Config:
49
+ """Application configuration"""
50
+
51
+ # Model settings
52
+ IMAGE_SIZE = (48, 48)
53
+ NUM_CLASSES = 10
54
+ MODEL_PATH = "khmer_model_weights.pth"
55
+
56
+ # Label mappings
57
+ LABEL_TO_IDX = {
58
+ 'TA': 0, # ត
59
+ 'NGO': 1, # αž„
60
+ 'CHA': 2, # αž…
61
+ 'DA': 3, # ដ
62
+ 'KO': 4, # αž€
63
+ 'NA': 5, # ណ
64
+ 'KHA': 6, # ខ
65
+ 'CHHA': 7, # αž†
66
+ 'CHHO': 8, # ឈ
67
+ 'KHO': 9 # αžƒ
68
+ }
69
+
70
+ LABEL_TO_CHAR = {
71
+ 'TA': 'ត',
72
+ 'NGO': 'αž„',
73
+ 'CHA': 'αž…',
74
+ 'DA': 'ដ',
75
+ 'KO': 'αž€',
76
+ 'NA': 'ណ',
77
+ 'KHA': 'ខ',
78
+ 'CHHA': 'αž†',
79
+ 'CHHO': 'ឈ',
80
+ 'KHO': 'αžƒ'
81
+ }
82
+
83
+ @classmethod
84
+ def get_idx_to_label(cls):
85
+ return {v: k for k, v in cls.LABEL_TO_IDX.items()}
86
+
87
+
88
+ # -----------------------------
89
+ # Model Manager
90
+ # -----------------------------
91
+ class ModelManager:
92
+ """Handles model loading and inference"""
93
+
94
+ def __init__(self):
95
+ self.device = torch.device("cpu") # Force CPU usage
96
+ self.model = None
97
+ self.config = Config()
98
+ self.idx_to_label = self.config.get_idx_to_label()
99
+
100
+ def load_model(self):
101
+ """Load the trained model"""
102
+ try:
103
+ model_path = Path(self.config.MODEL_PATH)
104
+ if not model_path.exists():
105
+ raise FileNotFoundError(
106
+ f"Model file not found: {model_path}\n"
107
+ f"Please ensure '{self.config.MODEL_PATH}' is in the same directory as this script."
108
+ )
109
+
110
+ self.model = KhmerModel(num_classes=self.config.NUM_CLASSES)
111
+ self.model.load_state_dict(
112
+ torch.load(model_path, map_location=self.device)
113
+ )
114
+ self.model.eval()
115
+ self.model.to(self.device)
116
+ logger.info(f"Model loaded successfully from {model_path}")
117
+
118
+ except Exception as e:
119
+ logger.error(f"Error loading model: {e}")
120
+ raise
121
+
122
+ def preprocess_image(self, img: Image.Image) -> torch.Tensor:
123
+ """Preprocess image for model input"""
124
+ # Convert to grayscale and resize
125
+ img = img.convert("L").resize(self.config.IMAGE_SIZE)
126
+
127
+ # Convert to numpy array and normalize
128
+ img_array = np.array(img, dtype=np.float32)
129
+ img_array = img_array.reshape(1, -1) # Flatten to (1, 2304)
130
+ img_array /= 255.0 # Normalize to [0, 1]
131
+
132
+ # Convert to tensor
133
+ tensor = torch.tensor(img_array, dtype=torch.float32).to(self.device)
134
+ return tensor
135
+
136
+ def predict(self, img: Image.Image) -> dict:
137
+ """Make prediction on image"""
138
+ if self.model is None:
139
+ raise RuntimeError("Model not loaded. Call load_model() first.")
140
+
141
+ try:
142
+ # Preprocess
143
+ tensor = self.preprocess_image(img)
144
+
145
+ # Predict
146
+ with torch.no_grad():
147
+ output = self.model(tensor)
148
+ probs = F.softmax(output, dim=1)
149
+ pred_idx = torch.argmax(probs, dim=1).item()
150
+ confidence = probs[0, pred_idx].item()
151
+
152
+ # Get labels
153
+ pred_label = self.idx_to_label[pred_idx]
154
+ pred_char = self.config.LABEL_TO_CHAR[pred_label]
155
+
156
+ # Get top 3 predictions
157
+ top3_probs, top3_indices = torch.topk(probs[0], k=min(3, self.config.NUM_CLASSES))
158
+ top3_predictions = []
159
+ for prob, idx in zip(top3_probs, top3_indices):
160
+ label = self.idx_to_label[idx.item()]
161
+ char = self.config.LABEL_TO_CHAR[label]
162
+ top3_predictions.append({
163
+ 'char': char,
164
+ 'label': label,
165
+ 'confidence': prob.item()
166
+ })
167
+
168
+ return {
169
+ 'predicted_char': pred_char,
170
+ 'predicted_label': pred_label,
171
+ 'confidence': confidence,
172
+ 'top3': top3_predictions
173
+ }
174
+
175
+ except Exception as e:
176
+ logger.error(f"Prediction error: {e}")
177
+ raise
178
+
179
+
180
+ # -----------------------------
181
+ # Gradio Interface Functions
182
+ # -----------------------------
183
+ model_manager = ModelManager()
184
+
185
+ def format_prediction_output(result: dict) -> str:
186
+ """Format prediction results for display"""
187
+ output = f"## Predicted Character: {result['predicted_char']}\n\n"
188
+ output += f"**Romanization:** {result['predicted_label']}\n\n"
189
+ output += f"**Confidence:** {result['confidence']*100:.2f}%\n\n"
190
+ output += "### Top 3 Predictions:\n"
191
+
192
+ for i, pred in enumerate(result['top3'], 1):
193
+ output += f"{i}. {pred['char']} ({pred['label']}) - {pred['confidence']*100:.2f}%\n"
194
+
195
+ return output
196
+
197
+
198
+ def predict_uploaded_image(img):
199
+ """Handle uploaded image prediction"""
200
+ if img is None:
201
+ return "❌ Please upload an image first!"
202
+
203
+ try:
204
+ result = model_manager.predict(img)
205
+ return format_prediction_output(result)
206
+ except Exception as e:
207
+ return f"❌ Error during prediction: {str(e)}"
208
+
209
+
210
+ def predict_drawn_image(image_array):
211
+ """Handle drawn image prediction"""
212
+ if image_array is None:
213
+ return "❌ Please draw a character first!"
214
+
215
+ try:
216
+ # Convert numpy array to PIL Image
217
+ if len(image_array.shape) == 3:
218
+ # Handle RGB/RGBA
219
+ if image_array.shape[-1] == 4:
220
+ image_array = image_array[:, :, :3]
221
+ img = Image.fromarray(image_array.astype('uint8')).convert("L")
222
+ else:
223
+ img = Image.fromarray(image_array.astype('uint8')).convert("L")
224
+
225
+ result = model_manager.predict(img)
226
+ return format_prediction_output(result)
227
+ except Exception as e:
228
+ return f"❌ Error during prediction: {str(e)}"
229
+
230
+
231
+ def clear_canvas():
232
+ """Clear the canvas"""
233
+ return None
234
+
235
+
236
+ # -----------------------------
237
+ # Gradio App
238
+ # -----------------------------
239
+ def create_app():
240
+ """Create and configure Gradio interface"""
241
+
242
+ # Custom CSS for better styling
243
+ custom_css = """
244
+ .gradio-container {
245
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
246
+ }
247
+ .character-display {
248
+ font-size: 72px;
249
+ text-align: center;
250
+ padding: 20px;
251
+ }
252
+ """
253
+
254
+ with gr.Blocks(css=custom_css, title="Khmer Character Recognition") as demo:
255
+ gr.Markdown(
256
+ """
257
+ # πŸ”€ Khmer Character Recognition
258
+
259
+ This app recognizes 10 Khmer consonants using a neural network model.
260
+
261
+ **Supported Characters:**
262
+ - ត (TA), αž„ (NGO), αž… (CHA), ដ (DA), αž€ (KO)
263
+ - ណ (NA), ខ (KHA), αž† (CHHA), ឈ (CHHO), αžƒ (KHO)
264
+ """
265
+ )
266
+
267
+ with gr.Tab("πŸ“€ Upload Image"):
268
+ gr.Markdown("Upload an image of a Khmer character for recognition.")
269
+
270
+ with gr.Row():
271
+ with gr.Column():
272
+ img_input = gr.Image(
273
+ type="pil",
274
+ label="Upload Image",
275
+ height=300
276
+ )
277
+ img_btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
278
+
279
+ with gr.Column():
280
+ img_output = gr.Markdown(label="Prediction Result")
281
+
282
+ img_btn.click(
283
+ fn=predict_uploaded_image,
284
+ inputs=img_input,
285
+ outputs=img_output
286
+ )
287
+
288
+ with gr.Tab("✏️ Draw Character"):
289
+ gr.Markdown(
290
+ """
291
+ Draw a Khmer character on the canvas below.
292
+
293
+ **Tips:**
294
+ - Use a thick brush stroke
295
+ - Draw the character as clearly as possible
296
+ - Try to center the character
297
+ """
298
+ )
299
+
300
+ with gr.Row():
301
+ with gr.Column():
302
+ canvas_input = gr.Image(
303
+ source="canvas",
304
+ tool="sketch",
305
+ type="numpy",
306
+ label="Draw Here",
307
+ height=400,
308
+ width=400,
309
+ invert_colors=True, # White on black
310
+ brush=gr.Brush(
311
+ default_size=8,
312
+ colors=["#FFFFFF"],
313
+ default_color="#FFFFFF"
314
+ )
315
+ )
316
+ with gr.Row():
317
+ draw_btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
318
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
319
+
320
+ with gr.Column():
321
+ draw_output = gr.Markdown(label="Prediction Result")
322
+
323
+ draw_btn.click(
324
+ fn=predict_drawn_image,
325
+ inputs=canvas_input,
326
+ outputs=draw_output
327
+ )
328
+
329
+ clear_btn.click(
330
+ fn=clear_canvas,
331
+ outputs=canvas_input
332
+ )
333
+
334
+ with gr.Tab("ℹ️ About"):
335
+ gr.Markdown(
336
+ """
337
+ ## About This App
338
+
339
+ This application uses a neural network trained to recognize 10 Khmer consonants.
340
+
341
+ ### Model Architecture
342
+ - Input: 48x48 grayscale images
343
+ - 4-layer fully connected neural network
344
+ - Trained on handwritten Khmer characters
345
+
346
+ ### How to Use
347
+ 1. **Upload Image Tab**: Upload a photo or screenshot of a Khmer character
348
+ 2. **Draw Character Tab**: Draw a character directly on the canvas
349
+ 3. Click "Predict" to see the results
350
+
351
+ ### Tips for Best Results
352
+ - Use clear, well-formed characters
353
+ - Ensure good contrast (dark character on light background or vice versa)
354
+ - Center the character in the image
355
+ - Avoid cluttered backgrounds
356
+
357
+ ### Technical Details
358
+ - Framework: PyTorch
359
+ - Interface: Gradio
360
+ - Inference: CPU-only (no GPU required)
361
+ """
362
+ )
363
+
364
+ return demo
365
+
366
+
367
+ # -----------------------------
368
+ # Main Execution
369
+ # -----------------------------
370
+ def main():
371
+ """Main application entry point"""
372
+ try:
373
+ # Load model
374
+ logger.info("Loading model...")
375
+ model_manager.load_model()
376
+ logger.info("Model loaded successfully!")
377
+
378
+ # Create and launch app
379
+ logger.info("Starting Gradio interface...")
380
+ demo = create_app()
381
+ demo.launch(
382
+ share=True,
383
+ server_name="0.0.0.0",
384
+ server_port=7860,
385
+ show_error=True
386
+ )
387
+
388
+ except Exception as e:
389
+ logger.error(f"Failed to start application: {e}")
390
+ raise
391
+
392
+
393
+ if __name__ == "__main__":
394
+ main()
requirements.txt CHANGED
@@ -1,6 +1,11 @@
1
- torch
2
- torchvision
3
- numpy
4
- pillow
5
- gradio
6
- scikit-learn
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch==2.1.0
3
+ torchvision==0.16.0
4
+ gradio==4.44.0
5
+
6
+ # Image processing
7
+ Pillow==10.1.0
8
+ numpy==1.24.3
9
+
10
+ # Optional: for better performance
11
+ --extra-index-url https://download.pytorch.org/whl/cpu