mewton's picture
Update app.py
9744ded verified
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
# Definisikan ulang model sesuai dengan struktur aslinya
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 16 * 16, 128), # Pastikan ukuran input sesuai
nn.ReLU(),
nn.Linear(128, 6) # Output layer untuk 6 kelas
)
def forward(self, x):
x = self.conv_layers(x)
x = self.fc_layers(x)
return x
# Inisialisasi model
model = SimpleCNN()
# Load model dengan error handling
try:
model.load_state_dict(torch.load("model_deri.pth", map_location=torch.device("cpu")), strict=False)
model.eval()
print("βœ… Model berhasil dimuat!")
except Exception as e:
print(f"❌ Error loading model: {e}")
# Kelas mapping
class_mapping = {0: 'Bu dian', 1: 'Deri', 2: 'Putra', 3: 'Unknown', 4: 'Uqi', 5: 'Uwa'}
# Transformasi gambar
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def predict(image):
image = transform(image).unsqueeze(0) # Tambah batch dimension
with torch.no_grad():
output = model(image)
probabilities = torch.nn.functional.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0, predicted_class].item() * 100 # Konversi ke persen
return f"Predicted: {class_mapping[predicted_class]} (Confidence: {confidence:.2f}%)"
# Buat UI Gradio
iface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text")
# Jalankan aplikasi
if __name__ == "__main__":
try:
iface.launch()
except Exception as e:
print(f"❌ Gradio error: {e}")