ma4389 commited on
Commit
2ac2f61
·
verified ·
1 Parent(s): 754a731

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +62 -0
  2. best_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms, models
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ # ---- 1. Hiragana Classes ----
8
+ # Replace with the exact class names from your dataset
9
+ classes = [
10
+ "aa", "chi", "ee", "fu", "ha", "he", "hi", "ho", "ii",
11
+ "ka", "ke", "ki", "ko", "ku", "ma", "me", "mi", "mo", "mu",
12
+ "na", "ne", "ni", "nn", "no", "nu", "oo",
13
+ "ra", "re", "ri", "ro", "ru", "sa", "se", "shi", "so", "su",
14
+ "ta", "te", "to", "tsu", "uu", "wa", "wo", "ya", "yo", "yu"
15
+ ]
16
+
17
+ # ---- 2. Image Transform (same as training) ----
18
+ transform = transforms.Compose([
19
+ transforms.Lambda(lambda x: x.convert('RGB')),
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.RandomRotation(10),
23
+ transforms.ColorJitter(),
24
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],
25
+ std=[0.5, 0.5, 0.5])
26
+ ])
27
+
28
+ # ---- 3. Load Model ----
29
+ device = "cuda" if torch.cuda.is_available() else "cpu"
30
+
31
+ model = models.resnet50(weights=None)
32
+ in_features = model.fc.in_features
33
+ model.fc = nn.Sequential(
34
+ nn.Linear(in_features, 512),
35
+ nn.ReLU(),
36
+ nn.Dropout(0.4),
37
+ nn.Linear(512, len(classes))
38
+ )
39
+
40
+ model.load_state_dict(torch.load("best_model.pth", map_location=device))
41
+ model.to(device)
42
+ model.eval()
43
+
44
+ # ---- 4. Prediction Function ----
45
+ def predict(image):
46
+ img = transform(image).unsqueeze(0).to(device)
47
+ with torch.no_grad():
48
+ outputs = model(img)
49
+ _, predicted = torch.max(outputs, 1)
50
+ return f"Predicted: {classes[predicted.item()]}"
51
+
52
+ # ---- 5. Gradio UI ----
53
+ interface = gr.Interface(
54
+ fn=predict,
55
+ inputs=gr.Image(type="pil"),
56
+ outputs="text",
57
+ title="Japanese Hiragana Classifier",
58
+ description="Upload an image of a handwritten Hiragana character and get its predicted syllable."
59
+ )
60
+
61
+ if __name__ == "__main__":
62
+ interface.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05f9552870c196b41cee15f2b7b175140f36b8540c0e6fc80c4c46cba2eeabe8
3
+ size 98642112
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ gradio