sedtha's picture
Update app.py
8a064f7 verified
"""
Khmer Character Recognition App
Recognizes 10 Khmer characters using a neural network model
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
from pathlib import Path
import logging
import os
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# -----------------------------
# Model Definition
# -----------------------------
class KhmerModel(nn.Module):
"""Neural network for Khmer character classification"""
def __init__(self, num_classes=10):
super().__init__()
self.fc1 = nn.Linear(48 * 48, 392)
self.fc2 = nn.Linear(392, 196)
self.fc3 = nn.Linear(196, 98)
self.fc4 = nn.Linear(98, num_classes)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.fc4(x)
return x
# -----------------------------
# Configuration
# -----------------------------
class Config:
"""Application configuration"""
# Model settings
IMAGE_SIZE = (48, 48)
NUM_CLASSES = 10
MODEL_PATH = "khmer_model_weights.pth"
# Label mappings
LABEL_TO_IDX = {'CHA': 0,
'CHHA': 1,
'CHHO': 2,
'DA': 3,
'KHA': 4,
'KHO': 5,
'KO': 6,
'NA': 7,
'NGO': 8,
'TA': 9
}
LABEL_TO_CHAR = {
'TA': 'ត',
'NGO': 'αž„',
'CHA': 'αž…',
'DA': 'ដ',
'KO': 'αž€',
'NA': 'ណ',
'KHA': 'ខ',
'CHHA': 'αž†',
'CHHO': 'ឈ',
'KHO': 'αžƒ'
}
@classmethod
def get_idx_to_label(cls):
return {v: k for k, v in cls.LABEL_TO_IDX.items()}
# -----------------------------
# Model Manager
# -----------------------------
class ModelManager:
"""Handles model loading and inference"""
def __init__(self):
self.device = torch.device("cpu") # Force CPU usage
self.model = None
self.config = Config()
self.idx_to_label = self.config.get_idx_to_label()
def load_model(self):
"""Load the trained model"""
try:
model_path = Path(self.config.MODEL_PATH)
if not model_path.exists():
raise FileNotFoundError(
f"Model file not found: {model_path}\n"
f"Please ensure '{self.config.MODEL_PATH}' is in the same directory as this script."
)
self.model = KhmerModel(num_classes=self.config.NUM_CLASSES)
self.model.load_state_dict(
torch.load(model_path, map_location=self.device, weights_only=True)
)
self.model.eval()
self.model.to(self.device)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def preprocess_image(self, img: Image.Image) -> torch.Tensor:
"""Preprocess image for model input using your specific processing"""
# Convert to grayscale and resize to 48x48
img = img.convert("L").resize((48, 48))
# Convert to numpy array with your specific processing
img_array = np.array(img, dtype=np.float32)
# Your specific processing steps
img_array = img_array.reshape(1, 1, 48, 48) # [batch, channel, H, W]
img_tensor = torch.tensor(img_array, dtype=torch.float32)
img_tensor = img_tensor.view(1, -1) # flatten to 2304
return img_tensor.to(self.device)
def predict(self, img: Image.Image) -> dict:
"""Make prediction on image"""
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
try:
# Preprocess using your method
tensor = self.preprocess_image(img)
# Predict
with torch.no_grad():
output = self.model(tensor)
probs = F.softmax(output, dim=1)
pred_idx = torch.argmax(probs, dim=1).item()
confidence = probs[0, pred_idx].item()
# Get labels
pred_label = self.idx_to_label[pred_idx]
pred_char = self.config.LABEL_TO_CHAR[pred_label]
# Get top 3 predictions
top3_probs, top3_indices = torch.topk(probs[0], k=min(3, self.config.NUM_CLASSES))
top3_predictions = []
for prob, idx in zip(top3_probs, top3_indices):
label = self.idx_to_label[idx.item()]
char = self.config.LABEL_TO_CHAR[label]
top3_predictions.append({
'char': char,
'label': label,
'confidence': prob.item()
})
return {
'predicted_char': pred_char,
'predicted_label': pred_label,
'confidence': confidence,
'top3': top3_predictions
}
except Exception as e:
logger.error(f"Prediction error: {e}")
raise
# -----------------------------
# Gradio Interface Functions
# -----------------------------
model_manager = ModelManager()
def format_prediction_output(result: dict) -> str:
"""Format prediction results for display"""
output = f"## Predicted Character: {result['predicted_char']}\n\n"
output += f"**Romanization:** {result['predicted_label']}\n\n"
output += f"**Confidence:** {result['confidence']*100:.2f}%\n\n"
output += "### Top 3 Predictions:\n"
for i, pred in enumerate(result['top3'], 1):
output += f"{i}. {pred['char']} ({pred['label']}) - {pred['confidence']*100:.2f}%\n"
return output
def predict_uploaded_image(img):
"""Handle uploaded image prediction"""
if img is None:
return "❌ Please upload an image first!"
try:
result = model_manager.predict(img)
return format_prediction_output(result)
except Exception as e:
return f"❌ Error during prediction: {str(e)}"
def predict_drawn_image(image_dict):
"""Handle drawn image prediction"""
if image_dict is None:
return "❌ Please draw a character first!"
try:
# Gradio Sketchpad returns dict with 'background' and 'layers'
# We need to composite them
if isinstance(image_dict, dict):
# Get the composite image
composite = image_dict.get('composite')
if composite is not None:
img = Image.fromarray(composite)
else:
# Fallback: use background if composite not available
background = image_dict.get('background')
if background is not None:
img = Image.fromarray(background)
else:
return "❌ Could not extract image from canvas!"
elif isinstance(image_dict, np.ndarray):
# Direct numpy array
if len(image_dict.shape) == 3:
if image_dict.shape[-1] == 4:
image_dict = image_dict[:, :, :3]
img = Image.fromarray(image_dict.astype('uint8'))
else:
img = Image.fromarray(image_dict.astype('uint8'))
else:
return "❌ Unexpected image format!"
result = model_manager.predict(img)
return format_prediction_output(result)
except Exception as e:
logger.error(f"Drawing prediction error: {e}")
return f"❌ Error during prediction: {str(e)}"
# -----------------------------
# Gradio App
# -----------------------------
def create_app():
"""Create and configure Gradio interface"""
# Custom CSS for better styling
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.gradio-button {
margin: 5px;
}
"""
with gr.Blocks(css=custom_css, title="Khmer Character Recognition") as demo:
gr.Markdown(
"""
# πŸ”€ Khmer Character Recognition
This app recognizes 10 Khmer consonants using a neural network model.
**Supported Characters:**
- ត (TA), αž„ (NGO), αž… (CHA), ដ (DA), αž€ (KO)
- ណ (NA), ខ (KHA), αž† (CHHA), ឈ (CHHO), αžƒ (KHO)
"""
)
with gr.Tab("πŸ“€ Upload Image"):
gr.Markdown("Upload an image of a Khmer character for recognition.")
with gr.Row():
with gr.Column():
img_input = gr.Image(
type="pil",
label="Upload Image",
height=300
)
img_btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
with gr.Column():
img_output = gr.Markdown(
label="Prediction Result",
value="Upload an image and click Predict to see results here."
)
img_btn.click(
fn=predict_uploaded_image,
inputs=img_input,
outputs=img_output
)
with gr.Tab("✏️ Draw Character"):
gr.Markdown(
"""
Draw a Khmer character on the canvas below.
**Tips:**
- Use a thick brush stroke
- Draw the character as clearly as possible
- Try to center the character
"""
)
with gr.Row():
with gr.Column():
canvas_input = gr.Sketchpad(
label="Draw Here",
height=400,
width=400,
brush=gr.Brush(colors=["#000000"], color_mode="fixed")
)
with gr.Row():
draw_btn = gr.Button("πŸ” Predict", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clear", size="lg")
with gr.Column():
draw_output = gr.Markdown(
label="Prediction Result",
value="Draw a character and click Predict to see results here."
)
draw_btn.click(
fn=predict_drawn_image,
inputs=canvas_input,
outputs=draw_output
)
clear_btn.click(
fn=lambda: None,
inputs=None,
outputs=canvas_input
)
with gr.Tab("ℹ️ About"):
gr.Markdown(
"""
## About This App
This application uses a neural network trained to recognize 10 Khmer consonants.
### Model Architecture
- Input: 48x48 grayscale images
- 4-layer fully connected neural network
- Trained on handwritten Khmer characters
### Image Processing
- Images are converted to grayscale
- Resized to 48x48 pixels
- Processed using custom preprocessing pipeline
- Flattened to 2304-dimensional vectors
### How to Use
1. **Upload Image Tab**: Upload a photo or screenshot of a Khmer character
2. **Draw Character Tab**: Draw a character directly on the canvas
3. Click "Predict" to see the results
### Tips for Best Results
- Use clear, well-formed characters
- Ensure good contrast (dark character on light background)
- Center the character in the image
- Avoid cluttered backgrounds
### Technical Details
- Framework: PyTorch
- Interface: Gradio
- Image Processing: Custom pipeline with tensor reshaping
- Inference: CPU-only (no GPU required)
"""
)
return demo
# -----------------------------
# Main Execution
# -----------------------------
if __name__ == "__main__":
# Load model at startup
try:
logger.info("Loading model...")
model_manager.load_model()
logger.info("Model loaded successfully!")
# Create and launch the Gradio interface
demo = create_app()
demo.launch(
server_name="0.0.0.0" if "SPACE_ID" in os.environ else "127.0.0.1",
share=False
)
except Exception as e:
logger.error(f"Failed to start application: {e}")
print(f"Error: {e}")
print("Please ensure:")
print("1. The model file 'khmer_model_weights.pth' exists in the model/ directory")
print("2. All required packages are installed")
print("3. You have proper file permissions")