Spaces:
Sleeping
Sleeping
| """ | |
| Khmer Character Recognition App | |
| Recognizes 10 Khmer characters using a neural network model | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| import numpy as np | |
| from pathlib import Path | |
| import logging | |
| import os | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ----------------------------- | |
| # Model Definition | |
| # ----------------------------- | |
| class KhmerModel(nn.Module): | |
| """Neural network for Khmer character classification""" | |
| def __init__(self, num_classes=10): | |
| super().__init__() | |
| self.fc1 = nn.Linear(48 * 48, 392) | |
| self.fc2 = nn.Linear(392, 196) | |
| self.fc3 = nn.Linear(196, 98) | |
| self.fc4 = nn.Linear(98, num_classes) | |
| self.relu = nn.ReLU() | |
| def forward(self, x): | |
| x = self.relu(self.fc1(x)) | |
| x = self.relu(self.fc2(x)) | |
| x = self.relu(self.fc3(x)) | |
| x = self.fc4(x) | |
| return x | |
| # ----------------------------- | |
| # Configuration | |
| # ----------------------------- | |
| class Config: | |
| """Application configuration""" | |
| # Model settings | |
| IMAGE_SIZE = (48, 48) | |
| NUM_CLASSES = 10 | |
| MODEL_PATH = "khmer_model_weights.pth" | |
| # Label mappings | |
| LABEL_TO_IDX = {'CHA': 0, | |
| 'CHHA': 1, | |
| 'CHHO': 2, | |
| 'DA': 3, | |
| 'KHA': 4, | |
| 'KHO': 5, | |
| 'KO': 6, | |
| 'NA': 7, | |
| 'NGO': 8, | |
| 'TA': 9 | |
| } | |
| LABEL_TO_CHAR = { | |
| 'TA': 'α', | |
| 'NGO': 'α', | |
| 'CHA': 'α ', | |
| 'DA': 'α', | |
| 'KO': 'α', | |
| 'NA': 'α', | |
| 'KHA': 'α', | |
| 'CHHA': 'α', | |
| 'CHHO': 'α', | |
| 'KHO': 'α' | |
| } | |
| def get_idx_to_label(cls): | |
| return {v: k for k, v in cls.LABEL_TO_IDX.items()} | |
| # ----------------------------- | |
| # Model Manager | |
| # ----------------------------- | |
| class ModelManager: | |
| """Handles model loading and inference""" | |
| def __init__(self): | |
| self.device = torch.device("cpu") # Force CPU usage | |
| self.model = None | |
| self.config = Config() | |
| self.idx_to_label = self.config.get_idx_to_label() | |
| def load_model(self): | |
| """Load the trained model""" | |
| try: | |
| model_path = Path(self.config.MODEL_PATH) | |
| if not model_path.exists(): | |
| raise FileNotFoundError( | |
| f"Model file not found: {model_path}\n" | |
| f"Please ensure '{self.config.MODEL_PATH}' is in the same directory as this script." | |
| ) | |
| self.model = KhmerModel(num_classes=self.config.NUM_CLASSES) | |
| self.model.load_state_dict( | |
| torch.load(model_path, map_location=self.device, weights_only=True) | |
| ) | |
| self.model.eval() | |
| self.model.to(self.device) | |
| logger.info(f"Model loaded successfully from {model_path}") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def preprocess_image(self, img: Image.Image) -> torch.Tensor: | |
| """Preprocess image for model input using your specific processing""" | |
| # Convert to grayscale and resize to 48x48 | |
| img = img.convert("L").resize((48, 48)) | |
| # Convert to numpy array with your specific processing | |
| img_array = np.array(img, dtype=np.float32) | |
| # Your specific processing steps | |
| img_array = img_array.reshape(1, 1, 48, 48) # [batch, channel, H, W] | |
| img_tensor = torch.tensor(img_array, dtype=torch.float32) | |
| img_tensor = img_tensor.view(1, -1) # flatten to 2304 | |
| return img_tensor.to(self.device) | |
| def predict(self, img: Image.Image) -> dict: | |
| """Make prediction on image""" | |
| if self.model is None: | |
| raise RuntimeError("Model not loaded. Call load_model() first.") | |
| try: | |
| # Preprocess using your method | |
| tensor = self.preprocess_image(img) | |
| # Predict | |
| with torch.no_grad(): | |
| output = self.model(tensor) | |
| probs = F.softmax(output, dim=1) | |
| pred_idx = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0, pred_idx].item() | |
| # Get labels | |
| pred_label = self.idx_to_label[pred_idx] | |
| pred_char = self.config.LABEL_TO_CHAR[pred_label] | |
| # Get top 3 predictions | |
| top3_probs, top3_indices = torch.topk(probs[0], k=min(3, self.config.NUM_CLASSES)) | |
| top3_predictions = [] | |
| for prob, idx in zip(top3_probs, top3_indices): | |
| label = self.idx_to_label[idx.item()] | |
| char = self.config.LABEL_TO_CHAR[label] | |
| top3_predictions.append({ | |
| 'char': char, | |
| 'label': label, | |
| 'confidence': prob.item() | |
| }) | |
| return { | |
| 'predicted_char': pred_char, | |
| 'predicted_label': pred_label, | |
| 'confidence': confidence, | |
| 'top3': top3_predictions | |
| } | |
| except Exception as e: | |
| logger.error(f"Prediction error: {e}") | |
| raise | |
| # ----------------------------- | |
| # Gradio Interface Functions | |
| # ----------------------------- | |
| model_manager = ModelManager() | |
| def format_prediction_output(result: dict) -> str: | |
| """Format prediction results for display""" | |
| output = f"## Predicted Character: {result['predicted_char']}\n\n" | |
| output += f"**Romanization:** {result['predicted_label']}\n\n" | |
| output += f"**Confidence:** {result['confidence']*100:.2f}%\n\n" | |
| output += "### Top 3 Predictions:\n" | |
| for i, pred in enumerate(result['top3'], 1): | |
| output += f"{i}. {pred['char']} ({pred['label']}) - {pred['confidence']*100:.2f}%\n" | |
| return output | |
| def predict_uploaded_image(img): | |
| """Handle uploaded image prediction""" | |
| if img is None: | |
| return "β Please upload an image first!" | |
| try: | |
| result = model_manager.predict(img) | |
| return format_prediction_output(result) | |
| except Exception as e: | |
| return f"β Error during prediction: {str(e)}" | |
| def predict_drawn_image(image_dict): | |
| """Handle drawn image prediction""" | |
| if image_dict is None: | |
| return "β Please draw a character first!" | |
| try: | |
| # Gradio Sketchpad returns dict with 'background' and 'layers' | |
| # We need to composite them | |
| if isinstance(image_dict, dict): | |
| # Get the composite image | |
| composite = image_dict.get('composite') | |
| if composite is not None: | |
| img = Image.fromarray(composite) | |
| else: | |
| # Fallback: use background if composite not available | |
| background = image_dict.get('background') | |
| if background is not None: | |
| img = Image.fromarray(background) | |
| else: | |
| return "β Could not extract image from canvas!" | |
| elif isinstance(image_dict, np.ndarray): | |
| # Direct numpy array | |
| if len(image_dict.shape) == 3: | |
| if image_dict.shape[-1] == 4: | |
| image_dict = image_dict[:, :, :3] | |
| img = Image.fromarray(image_dict.astype('uint8')) | |
| else: | |
| img = Image.fromarray(image_dict.astype('uint8')) | |
| else: | |
| return "β Unexpected image format!" | |
| result = model_manager.predict(img) | |
| return format_prediction_output(result) | |
| except Exception as e: | |
| logger.error(f"Drawing prediction error: {e}") | |
| return f"β Error during prediction: {str(e)}" | |
| # ----------------------------- | |
| # Gradio App | |
| # ----------------------------- | |
| def create_app(): | |
| """Create and configure Gradio interface""" | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .gradio-button { | |
| margin: 5px; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, title="Khmer Character Recognition") as demo: | |
| gr.Markdown( | |
| """ | |
| # π€ Khmer Character Recognition | |
| This app recognizes 10 Khmer consonants using a neural network model. | |
| **Supported Characters:** | |
| - α (TA), α (NGO), α (CHA), α (DA), α (KO) | |
| - α (NA), α (KHA), α (CHHA), α (CHHO), α (KHO) | |
| """ | |
| ) | |
| with gr.Tab("π€ Upload Image"): | |
| gr.Markdown("Upload an image of a Khmer character for recognition.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input = gr.Image( | |
| type="pil", | |
| label="Upload Image", | |
| height=300 | |
| ) | |
| img_btn = gr.Button("π Predict", variant="primary", size="lg") | |
| with gr.Column(): | |
| img_output = gr.Markdown( | |
| label="Prediction Result", | |
| value="Upload an image and click Predict to see results here." | |
| ) | |
| img_btn.click( | |
| fn=predict_uploaded_image, | |
| inputs=img_input, | |
| outputs=img_output | |
| ) | |
| with gr.Tab("βοΈ Draw Character"): | |
| gr.Markdown( | |
| """ | |
| Draw a Khmer character on the canvas below. | |
| **Tips:** | |
| - Use a thick brush stroke | |
| - Draw the character as clearly as possible | |
| - Try to center the character | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| canvas_input = gr.Sketchpad( | |
| label="Draw Here", | |
| height=400, | |
| width=400, | |
| brush=gr.Brush(colors=["#000000"], color_mode="fixed") | |
| ) | |
| with gr.Row(): | |
| draw_btn = gr.Button("π Predict", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", size="lg") | |
| with gr.Column(): | |
| draw_output = gr.Markdown( | |
| label="Prediction Result", | |
| value="Draw a character and click Predict to see results here." | |
| ) | |
| draw_btn.click( | |
| fn=predict_drawn_image, | |
| inputs=canvas_input, | |
| outputs=draw_output | |
| ) | |
| clear_btn.click( | |
| fn=lambda: None, | |
| inputs=None, | |
| outputs=canvas_input | |
| ) | |
| with gr.Tab("βΉοΈ About"): | |
| gr.Markdown( | |
| """ | |
| ## About This App | |
| This application uses a neural network trained to recognize 10 Khmer consonants. | |
| ### Model Architecture | |
| - Input: 48x48 grayscale images | |
| - 4-layer fully connected neural network | |
| - Trained on handwritten Khmer characters | |
| ### Image Processing | |
| - Images are converted to grayscale | |
| - Resized to 48x48 pixels | |
| - Processed using custom preprocessing pipeline | |
| - Flattened to 2304-dimensional vectors | |
| ### How to Use | |
| 1. **Upload Image Tab**: Upload a photo or screenshot of a Khmer character | |
| 2. **Draw Character Tab**: Draw a character directly on the canvas | |
| 3. Click "Predict" to see the results | |
| ### Tips for Best Results | |
| - Use clear, well-formed characters | |
| - Ensure good contrast (dark character on light background) | |
| - Center the character in the image | |
| - Avoid cluttered backgrounds | |
| ### Technical Details | |
| - Framework: PyTorch | |
| - Interface: Gradio | |
| - Image Processing: Custom pipeline with tensor reshaping | |
| - Inference: CPU-only (no GPU required) | |
| """ | |
| ) | |
| return demo | |
| # ----------------------------- | |
| # Main Execution | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| # Load model at startup | |
| try: | |
| logger.info("Loading model...") | |
| model_manager.load_model() | |
| logger.info("Model loaded successfully!") | |
| # Create and launch the Gradio interface | |
| demo = create_app() | |
| demo.launch( | |
| server_name="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1", | |
| share=False | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start application: {e}") | |
| print(f"Error: {e}") | |
| print("Please ensure:") | |
| print("1. The model file 'khmer_model_weights.pth' exists in the model/ directory") | |
| print("2. All required packages are installed") | |
| print("3. You have proper file permissions") |