ma4389's picture
Upload 3 files
2ac2f61 verified
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import gradio as gr
# ---- 1. Hiragana Classes ----
# Replace with the exact class names from your dataset
classes = [
"aa", "chi", "ee", "fu", "ha", "he", "hi", "ho", "ii",
"ka", "ke", "ki", "ko", "ku", "ma", "me", "mi", "mo", "mu",
"na", "ne", "ni", "nn", "no", "nu", "oo",
"ra", "re", "ri", "ro", "ru", "sa", "se", "shi", "so", "su",
"ta", "te", "to", "tsu", "uu", "wa", "wo", "ya", "yo", "yu"
]
# ---- 2. Image Transform (same as training) ----
transform = transforms.Compose([
transforms.Lambda(lambda x: x.convert('RGB')),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.RandomRotation(10),
transforms.ColorJitter(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
# ---- 3. Load Model ----
device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet50(weights=None)
in_features = model.fc.in_features
model.fc = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, len(classes))
)
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()
# ---- 4. Prediction Function ----
def predict(image):
img = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs, 1)
return f"Predicted: {classes[predicted.item()]}"
# ---- 5. Gradio UI ----
interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs="text",
title="Japanese Hiragana Classifier",
description="Upload an image of a handwritten Hiragana character and get its predicted syllable."
)
if __name__ == "__main__":
interface.launch()