""" Khmer Character Recognition App Recognizes 10 Khmer characters using a neural network model """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np from pathlib import Path import logging import os # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ----------------------------- # Model Definition # ----------------------------- class KhmerModel(nn.Module): """Neural network for Khmer character classification""" def __init__(self, num_classes=10): super().__init__() self.fc1 = nn.Linear(48 * 48, 392) self.fc2 = nn.Linear(392, 196) self.fc3 = nn.Linear(196, 98) self.fc4 = nn.Linear(98, num_classes) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.relu(self.fc3(x)) x = self.fc4(x) return x # ----------------------------- # Configuration # ----------------------------- class Config: """Application configuration""" # Model settings IMAGE_SIZE = (48, 48) NUM_CLASSES = 10 MODEL_PATH = "khmer_model_weights.pth" # Label mappings LABEL_TO_IDX = {'CHA': 0, 'CHHA': 1, 'CHHO': 2, 'DA': 3, 'KHA': 4, 'KHO': 5, 'KO': 6, 'NA': 7, 'NGO': 8, 'TA': 9 } LABEL_TO_CHAR = { 'TA': 'ត', 'NGO': 'ង', 'CHA': 'ច', 'DA': 'ដ', 'KO': 'ក', 'NA': 'ណ', 'KHA': 'ខ', 'CHHA': 'ឆ', 'CHHO': 'ឈ', 'KHO': 'ឃ' } @classmethod def get_idx_to_label(cls): return {v: k for k, v in cls.LABEL_TO_IDX.items()} # ----------------------------- # Model Manager # ----------------------------- class ModelManager: """Handles model loading and inference""" def __init__(self): self.device = torch.device("cpu") # Force CPU usage self.model = None self.config = Config() self.idx_to_label = self.config.get_idx_to_label() def load_model(self): """Load the trained model""" try: model_path = Path(self.config.MODEL_PATH) if not model_path.exists(): raise FileNotFoundError( f"Model file not found: {model_path}\n" f"Please ensure '{self.config.MODEL_PATH}' is in the same directory as this script." ) self.model = KhmerModel(num_classes=self.config.NUM_CLASSES) self.model.load_state_dict( torch.load(model_path, map_location=self.device, weights_only=True) ) self.model.eval() self.model.to(self.device) logger.info(f"Model loaded successfully from {model_path}") except Exception as e: logger.error(f"Error loading model: {e}") raise def preprocess_image(self, img: Image.Image) -> torch.Tensor: """Preprocess image for model input using your specific processing""" # Convert to grayscale and resize to 48x48 img = img.convert("L").resize((48, 48)) # Convert to numpy array with your specific processing img_array = np.array(img, dtype=np.float32) # Your specific processing steps img_array = img_array.reshape(1, 1, 48, 48) # [batch, channel, H, W] img_tensor = torch.tensor(img_array, dtype=torch.float32) img_tensor = img_tensor.view(1, -1) # flatten to 2304 return img_tensor.to(self.device) def predict(self, img: Image.Image) -> dict: """Make prediction on image""" if self.model is None: raise RuntimeError("Model not loaded. Call load_model() first.") try: # Preprocess using your method tensor = self.preprocess_image(img) # Predict with torch.no_grad(): output = self.model(tensor) probs = F.softmax(output, dim=1) pred_idx = torch.argmax(probs, dim=1).item() confidence = probs[0, pred_idx].item() # Get labels pred_label = self.idx_to_label[pred_idx] pred_char = self.config.LABEL_TO_CHAR[pred_label] # Get top 3 predictions top3_probs, top3_indices = torch.topk(probs[0], k=min(3, self.config.NUM_CLASSES)) top3_predictions = [] for prob, idx in zip(top3_probs, top3_indices): label = self.idx_to_label[idx.item()] char = self.config.LABEL_TO_CHAR[label] top3_predictions.append({ 'char': char, 'label': label, 'confidence': prob.item() }) return { 'predicted_char': pred_char, 'predicted_label': pred_label, 'confidence': confidence, 'top3': top3_predictions } except Exception as e: logger.error(f"Prediction error: {e}") raise # ----------------------------- # Gradio Interface Functions # ----------------------------- model_manager = ModelManager() def format_prediction_output(result: dict) -> str: """Format prediction results for display""" output = f"## Predicted Character: {result['predicted_char']}\n\n" output += f"**Romanization:** {result['predicted_label']}\n\n" output += f"**Confidence:** {result['confidence']*100:.2f}%\n\n" output += "### Top 3 Predictions:\n" for i, pred in enumerate(result['top3'], 1): output += f"{i}. {pred['char']} ({pred['label']}) - {pred['confidence']*100:.2f}%\n" return output def predict_uploaded_image(img): """Handle uploaded image prediction""" if img is None: return "❌ Please upload an image first!" try: result = model_manager.predict(img) return format_prediction_output(result) except Exception as e: return f"❌ Error during prediction: {str(e)}" def predict_drawn_image(image_dict): """Handle drawn image prediction""" if image_dict is None: return "❌ Please draw a character first!" try: # Gradio Sketchpad returns dict with 'background' and 'layers' # We need to composite them if isinstance(image_dict, dict): # Get the composite image composite = image_dict.get('composite') if composite is not None: img = Image.fromarray(composite) else: # Fallback: use background if composite not available background = image_dict.get('background') if background is not None: img = Image.fromarray(background) else: return "❌ Could not extract image from canvas!" elif isinstance(image_dict, np.ndarray): # Direct numpy array if len(image_dict.shape) == 3: if image_dict.shape[-1] == 4: image_dict = image_dict[:, :, :3] img = Image.fromarray(image_dict.astype('uint8')) else: img = Image.fromarray(image_dict.astype('uint8')) else: return "❌ Unexpected image format!" result = model_manager.predict(img) return format_prediction_output(result) except Exception as e: logger.error(f"Drawing prediction error: {e}") return f"❌ Error during prediction: {str(e)}" # ----------------------------- # Gradio App # ----------------------------- def create_app(): """Create and configure Gradio interface""" # Custom CSS for better styling custom_css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .gradio-button { margin: 5px; } """ with gr.Blocks(css=custom_css, title="Khmer Character Recognition") as demo: gr.Markdown( """ # 🔤 Khmer Character Recognition This app recognizes 10 Khmer consonants using a neural network model. **Supported Characters:** - ត (TA), ង (NGO), ច (CHA), ដ (DA), ក (KO) - ណ (NA), ខ (KHA), ឆ (CHHA), ឈ (CHHO), ឃ (KHO) """ ) with gr.Tab("📤 Upload Image"): gr.Markdown("Upload an image of a Khmer character for recognition.") with gr.Row(): with gr.Column(): img_input = gr.Image( type="pil", label="Upload Image", height=300 ) img_btn = gr.Button("🔍 Predict", variant="primary", size="lg") with gr.Column(): img_output = gr.Markdown( label="Prediction Result", value="Upload an image and click Predict to see results here." ) img_btn.click( fn=predict_uploaded_image, inputs=img_input, outputs=img_output ) with gr.Tab("✏️ Draw Character"): gr.Markdown( """ Draw a Khmer character on the canvas below. **Tips:** - Use a thick brush stroke - Draw the character as clearly as possible - Try to center the character """ ) with gr.Row(): with gr.Column(): canvas_input = gr.Sketchpad( label="Draw Here", height=400, width=400, brush=gr.Brush(colors=["#000000"], color_mode="fixed") ) with gr.Row(): draw_btn = gr.Button("🔍 Predict", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Clear", size="lg") with gr.Column(): draw_output = gr.Markdown( label="Prediction Result", value="Draw a character and click Predict to see results here." ) draw_btn.click( fn=predict_drawn_image, inputs=canvas_input, outputs=draw_output ) clear_btn.click( fn=lambda: None, inputs=None, outputs=canvas_input ) with gr.Tab("ℹ️ About"): gr.Markdown( """ ## About This App This application uses a neural network trained to recognize 10 Khmer consonants. ### Model Architecture - Input: 48x48 grayscale images - 4-layer fully connected neural network - Trained on handwritten Khmer characters ### Image Processing - Images are converted to grayscale - Resized to 48x48 pixels - Processed using custom preprocessing pipeline - Flattened to 2304-dimensional vectors ### How to Use 1. **Upload Image Tab**: Upload a photo or screenshot of a Khmer character 2. **Draw Character Tab**: Draw a character directly on the canvas 3. Click "Predict" to see the results ### Tips for Best Results - Use clear, well-formed characters - Ensure good contrast (dark character on light background) - Center the character in the image - Avoid cluttered backgrounds ### Technical Details - Framework: PyTorch - Interface: Gradio - Image Processing: Custom pipeline with tensor reshaping - Inference: CPU-only (no GPU required) """ ) return demo # ----------------------------- # Main Execution # ----------------------------- if __name__ == "__main__": # Load model at startup try: logger.info("Loading model...") model_manager.load_model() logger.info("Model loaded successfully!") # Create and launch the Gradio interface demo = create_app() demo.launch( server_name="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1", share=False ) except Exception as e: logger.error(f"Failed to start application: {e}") print(f"Error: {e}") print("Please ensure:") print("1. The model file 'khmer_model_weights.pth' exists in the model/ directory") print("2. All required packages are installed") print("3. You have proper file permissions")