import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import gradio as gr # ---- 1. Hiragana Classes ---- # Replace with the exact class names from your dataset classes = [ "aa", "chi", "ee", "fu", "ha", "he", "hi", "ho", "ii", "ka", "ke", "ki", "ko", "ku", "ma", "me", "mi", "mo", "mu", "na", "ne", "ni", "nn", "no", "nu", "oo", "ra", "re", "ri", "ro", "ru", "sa", "se", "shi", "so", "su", "ta", "te", "to", "tsu", "uu", "wa", "wo", "ya", "yo", "yu" ] # ---- 2. Image Transform (same as training) ---- transform = transforms.Compose([ transforms.Lambda(lambda x: x.convert('RGB')), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.RandomRotation(10), transforms.ColorJitter(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # ---- 3. Load Model ---- device = "cuda" if torch.cuda.is_available() else "cpu" model = models.resnet50(weights=None) in_features = model.fc.in_features model.fc = nn.Sequential( nn.Linear(in_features, 512), nn.ReLU(), nn.Dropout(0.4), nn.Linear(512, len(classes)) ) model.load_state_dict(torch.load("best_model.pth", map_location=device)) model.to(device) model.eval() # ---- 4. Prediction Function ---- def predict(image): img = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img) _, predicted = torch.max(outputs, 1) return f"Predicted: {classes[predicted.item()]}" # ---- 5. Gradio UI ---- interface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Japanese Hiragana Classifier", description="Upload an image of a handwritten Hiragana character and get its predicted syllable." ) if __name__ == "__main__": interface.launch()