""" Gradio UI application for Batik Classification Optimized for Hugging Face Spaces deployment """ import gradio as gr import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import json import numpy as np from typing import Tuple, Dict from huggingface_hub import hf_hub_download import os # Global variables model = None class_names = [] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") transform = None def load_model(): global model, class_names, transform try: # Load model configuration with open('model_config.json', 'r') as f: config = json.load(f) num_classes = config['num_classes'] class_names = config['class_names'] image_size = config.get('image_size', 224) # Initialize VGG16 model model = models.vgg16(weights=None) # Modify classifier to match saved model architecture model.classifier[3] = nn.Linear(4096, num_classes) model.classifier = nn.Sequential(*list(model.classifier.children())[:4]) # Download model from Hugging Face Hub print("📥 Downloading model from Hugging Face Hub...") model_path = hf_hub_download( repo_id="RimsJ/Batik-Classifier", filename="vgg16_batik_best.pth" ) # Load trained weights checkpoint = torch.load(model_path, map_location=device) # Extract state_dict if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint # Remove '_orig_mod.' prefix if present new_state_dict = {} for key, value in state_dict.items(): if key.startswith('_orig_mod.'): new_key = key.replace('_orig_mod.', '') new_state_dict[new_key] = value else: new_state_dict[key] = value model.load_state_dict(new_state_dict) model = model.to(device) model.eval() # Define image preprocessing transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print(f"✅ Model loaded successfully on {device}") print(f"📊 Number of classes: {num_classes}") except Exception as e: print(f"❌ Error loading model: {str(e)}") raise def predict_image(image): """ Predict batik class from image Args: image: PIL Image Returns: Tuple of (top_k_dict, formatted_text) """ global model, transform, class_names try: if image is None: return None, "❌ Silakan upload gambar batik terlebih dahulu" if model is None: return None, "❌ Model belum dimuat. Silakan refresh halaman." # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Transform and predict input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1) top_probs, top_indices = torch.topk(probabilities, min(5, len(class_names)), dim=1) # Get top prediction predicted_class = class_names[top_indices[0][0].item()] confidence = top_probs[0][0].item() * 100 # Format top-5 results results = {} for i in range(min(5, len(class_names))): class_name = class_names[top_indices[0][i].item()] conf = top_probs[0][i].item() results[class_name] = float(conf) # Format output text result_text = f""" ## 🎯 Hasil Prediksi **Motif Batik:** `{predicted_class}` **Confidence:** `{confidence:.2f}%` --- ### 📊 Top 5 Prediksi: """ for idx, (class_name, conf) in enumerate(list(results.items())[:5], 1): bar = "█" * int(conf * 20) result_text += f"\n{idx}. **{class_name}** - {conf*100:.2f}% \n {bar}" return results, result_text except Exception as e: import traceback traceback.print_exc() return None, f"❌ Error: {str(e)}" # Load model at startup print("🔄 Loading model...") load_model() print("✅ Model ready!") # Create Gradio interface with gr.Blocks( title="Batik Classification", theme=gr.themes.Soft(), css=".gradio-container {max-width: 1200px; margin: auto;}" ) as demo: gr.Markdown(""" # 🎨 Klasifikasi Motif Batik Indonesia Upload gambar batik untuk mengetahui motif dan asalnya! **Total 111 motif batik** dari berbagai daerah di Indonesia 🇮🇩 """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( type="pil", label="📤 Upload Gambar Batik", height=400 ) predict_btn = gr.Button( "🔍 Prediksi Motif Batik", variant="primary", size="lg" ) gr.Markdown(""" ### 💡 Tips: - Gunakan gambar dengan kualitas baik - Pastikan motif batik terlihat jelas - Format: JPG, PNG, JPEG """) with gr.Column(scale=1): output_text = gr.Markdown(label="Hasil Prediksi") output_label = gr.Label( label="📊 Confidence Score", num_top_classes=5 ) # Event handler predict_btn.click( fn=predict_image, inputs=input_image, outputs=[output_label, output_text] ) gr.Markdown(""" --- ### 📋 Tentang Model - **Dataset:** 111 Motif Batik Indonesia - **Kategori:** Batik dari Jawa Tengah, Jawa Timur, Jawa Barat, Bali, Jakarta, Kalimantan, Lampung ### 🎨 Contoh Motif: Parang Kusumo, Megamendung, Kawung, Truntum, Semarangan, dan banyak lagi! --- **Made with ❤️ for Indonesian Batik Heritage** """) # Launch if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)