Spaces:
Sleeping
Sleeping
| """ | |
| Gradio UI application for Batik Classification | |
| Optimized for Hugging Face Spaces deployment | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import json | |
| import numpy as np | |
| from typing import Tuple, Dict | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # Global variables | |
| model = None | |
| class_names = [] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform = None | |
| def load_model(): | |
| global model, class_names, transform | |
| try: | |
| # Load model configuration | |
| with open('model_config.json', 'r') as f: | |
| config = json.load(f) | |
| num_classes = config['num_classes'] | |
| class_names = config['class_names'] | |
| image_size = config.get('image_size', 224) | |
| # Initialize VGG16 model | |
| model = models.vgg16(weights=None) | |
| # Modify classifier to match saved model architecture | |
| model.classifier[3] = nn.Linear(4096, num_classes) | |
| model.classifier = nn.Sequential(*list(model.classifier.children())[:4]) | |
| # Download model from Hugging Face Hub | |
| print("๐ฅ Downloading model from Hugging Face Hub...") | |
| model_path = hf_hub_download( | |
| repo_id="RimsJ/Batik-Classifier", | |
| filename="vgg16_batik_best.pth" | |
| ) | |
| # Load trained weights | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # Extract state_dict | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| # Remove '_orig_mod.' prefix if present | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('_orig_mod.'): | |
| new_key = key.replace('_orig_mod.', '') | |
| new_state_dict[new_key] = value | |
| else: | |
| new_state_dict[key] = value | |
| model.load_state_dict(new_state_dict) | |
| model = model.to(device) | |
| model.eval() | |
| # Define image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print(f"โ Model loaded successfully on {device}") | |
| print(f"๐ Number of classes: {num_classes}") | |
| except Exception as e: | |
| print(f"โ Error loading model: {str(e)}") | |
| raise | |
| def predict_image(image): | |
| """ | |
| Predict batik class from image | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Tuple of (top_k_dict, formatted_text) | |
| """ | |
| global model, transform, class_names | |
| try: | |
| if image is None: | |
| return None, "โ Silakan upload gambar batik terlebih dahulu" | |
| if model is None: | |
| return None, "โ Model belum dimuat. Silakan refresh halaman." | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Transform and predict | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| top_probs, top_indices = torch.topk(probabilities, min(5, len(class_names)), dim=1) | |
| # Get top prediction | |
| predicted_class = class_names[top_indices[0][0].item()] | |
| confidence = top_probs[0][0].item() * 100 | |
| # Format top-5 results | |
| results = {} | |
| for i in range(min(5, len(class_names))): | |
| class_name = class_names[top_indices[0][i].item()] | |
| conf = top_probs[0][i].item() | |
| results[class_name] = float(conf) | |
| # Format output text | |
| result_text = f""" | |
| ## ๐ฏ Hasil Prediksi | |
| **Motif Batik:** `{predicted_class}` | |
| **Confidence:** `{confidence:.2f}%` | |
| --- | |
| ### ๐ Top 5 Prediksi: | |
| """ | |
| for idx, (class_name, conf) in enumerate(list(results.items())[:5], 1): | |
| bar = "โ" * int(conf * 20) | |
| result_text += f"\n{idx}. **{class_name}** - {conf*100:.2f}% \n {bar}" | |
| return results, result_text | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, f"โ Error: {str(e)}" | |
| # Load model at startup | |
| print("๐ Loading model...") | |
| load_model() | |
| print("โ Model ready!") | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="Batik Classification", | |
| theme=gr.themes.Soft(), | |
| css=".gradio-container {max-width: 1200px; margin: auto;}" | |
| ) as demo: | |
| gr.Markdown(""" | |
| # ๐จ Klasifikasi Motif Batik Indonesia | |
| Upload gambar batik untuk mengetahui motif dan asalnya! | |
| **Total 111 motif batik** dari berbagai daerah di Indonesia ๐ฎ๐ฉ | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="๐ค Upload Gambar Batik", | |
| height=400 | |
| ) | |
| predict_btn = gr.Button( | |
| "๐ Prediksi Motif Batik", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown(""" | |
| ### ๐ก Tips: | |
| - Gunakan gambar dengan kualitas baik | |
| - Pastikan motif batik terlihat jelas | |
| - Format: JPG, PNG, JPEG | |
| """) | |
| with gr.Column(scale=1): | |
| output_text = gr.Markdown(label="Hasil Prediksi") | |
| output_label = gr.Label( | |
| label="๐ Confidence Score", | |
| num_top_classes=5 | |
| ) | |
| # Event handler | |
| predict_btn.click( | |
| fn=predict_image, | |
| inputs=input_image, | |
| outputs=[output_label, output_text] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### ๐ Tentang Model | |
| - **Dataset:** 111 Motif Batik Indonesia | |
| - **Kategori:** Batik dari Jawa Tengah, Jawa Timur, Jawa Barat, Bali, Jakarta, Kalimantan, Lampung | |
| ### ๐จ Contoh Motif: | |
| Parang Kusumo, Megamendung, Kawung, Truntum, Semarangan, dan banyak lagi! | |
| --- | |
| **Made with โค๏ธ for Indonesian Batik Heritage** | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |