Spaces:
Sleeping
Sleeping
| """ | |
| Gradio UI application for Batik Classification using VGG16 model | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms, models | |
| from PIL import Image | |
| import json | |
| import numpy as np | |
| from typing import Tuple, List | |
| # Global variables | |
| model = None | |
| class_names = [] | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| transform = None | |
| def load_model(): | |
| """Load VGG16 model and configuration""" | |
| global model, class_names, transform | |
| try: | |
| # Load model configuration | |
| with open('model_config.json', 'r') as f: | |
| config = json.load(f) | |
| num_classes = config['num_classes'] | |
| class_names = config['class_names'] | |
| image_size = config.get('image_size', 224) | |
| # Initialize VGG16 model | |
| model = models.vgg16(weights=None) | |
| # Modify classifier to match saved model architecture | |
| # The saved model has classifier.3 as output layer (111 classes) | |
| model.classifier[3] = nn.Linear(4096, num_classes) | |
| # Remove layers after classifier.3 | |
| model.classifier = nn.Sequential(*list(model.classifier.children())[:4]) | |
| # Load trained weights | |
| checkpoint = torch.load('models/vgg16_batik_best.pth', map_location=device) | |
| # Check if checkpoint is a dict with 'model_state_dict' key or direct state_dict | |
| if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: | |
| state_dict = checkpoint['model_state_dict'] | |
| else: | |
| state_dict = checkpoint | |
| # Remove '_orig_mod.' prefix if present (from torch.compile) | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith('_orig_mod.'): | |
| new_key = key.replace('_orig_mod.', '') | |
| new_state_dict[new_key] = value | |
| else: | |
| new_state_dict[key] = value | |
| model.load_state_dict(new_state_dict) | |
| model = model.to(device) | |
| model.eval() # Define image preprocessing | |
| transform = transforms.Compose([ | |
| transforms.Resize((image_size, image_size)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print(f"β Model loaded successfully on {device}") | |
| print(f"π Number of classes: {num_classes}") | |
| except Exception as e: | |
| print(f"β Error loading model: {str(e)}") | |
| raise | |
| def predict_single(image: Image.Image) -> Tuple[str, float]: | |
| """ | |
| Predict single class for an image | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Tuple of (predicted_class, confidence) | |
| """ | |
| try: | |
| # Preprocess image | |
| if image is None: | |
| return "Error: No image provided", 0.0 | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Transform and add batch dimension | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| confidence, predicted = torch.max(probabilities, 1) | |
| predicted_class = class_names[predicted.item()] | |
| confidence_score = confidence.item() * 100 # Convert to percentage | |
| return predicted_class, confidence_score | |
| except Exception as e: | |
| return f"Error: {str(e)}", 0.0 | |
| def predict_top_k(image: Image.Image, k: int = 5) -> dict: | |
| """ | |
| Predict top-k classes for an image | |
| Args: | |
| image: PIL Image | |
| k: Number of top predictions | |
| Returns: | |
| Dictionary of class names and their confidence scores | |
| """ | |
| try: | |
| # Preprocess image | |
| if image is None: | |
| return {"Error": 1.0} | |
| # Convert to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Transform and add batch dimension | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = model(input_tensor) | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| top_probs, top_indices = torch.topk(probabilities, min(k, len(class_names)), dim=1) | |
| # Format results as dictionary for Gradio | |
| results = {} | |
| for i in range(min(k, len(class_names))): | |
| class_name = class_names[top_indices[0][i].item()] | |
| confidence = top_probs[0][i].item() | |
| results[class_name] = float(confidence) | |
| return results | |
| except Exception as e: | |
| return {"Error": f"{str(e)}"} | |
| def format_prediction(image: Image.Image) -> Tuple[str, dict]: | |
| """ | |
| Format prediction output for Gradio interface | |
| Args: | |
| image: PIL Image | |
| Returns: | |
| Tuple of (formatted_text, top_k_dict) | |
| """ | |
| try: | |
| if image is None: | |
| return "β Silakan upload gambar batik terlebih dahulu", {} | |
| # Get single prediction | |
| predicted_class, confidence = predict_single(image) | |
| # Get top-5 predictions | |
| top_k_results = predict_top_k(image, k=5) | |
| # Format main result | |
| result_text = f""" | |
| ## π― Hasil Prediksi | |
| **Motif Batik:** `{predicted_class}` | |
| **Confidence:** `{confidence:.2f}%` | |
| --- | |
| ### π Top 5 Prediksi: | |
| """ | |
| for idx, (class_name, conf) in enumerate(list(top_k_results.items())[:5], 1): | |
| bar = "β" * int(conf * 20) # Simple bar visualization | |
| result_text += f"\n{idx}. **{class_name}** - {conf*100:.2f}% \n {bar}" | |
| return result_text, top_k_results | |
| except Exception as e: | |
| return f"β Error: {str(e)}", {} | |
| def get_model_info() -> str: | |
| """Get model information""" | |
| info = f""" | |
| ### π Informasi Model | |
| - **Arsitektur:** VGG16 | |
| - **Device:** {device} | |
| - **Jumlah Kelas:** {len(class_names)} | |
| - **Status:** β Model siap digunakan | |
| ### π¨ Kategori Batik: | |
| Total {len(class_names)} motif batik dari berbagai daerah di Indonesia | |
| """ | |
| return info | |
| # Load model at startup | |
| load_model() | |
| # Create Gradio interface | |
| with gr.Blocks(title="Batik Classification - VGG16", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π¨ Klasifikasi Motif Batik Indonesia | |
| ### Menggunakan Model VGG16 Deep Learning | |
| Upload gambar batik untuk mengetahui motif dan asalnya! | |
| """) | |
| with gr.Tabs(): | |
| # Tab 1: Single Prediction | |
| with gr.Tab("πΌοΈ Prediksi Tunggal"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image( | |
| type="pil", | |
| label="Upload Gambar Batik", | |
| height=400 | |
| ) | |
| predict_btn = gr.Button("π Prediksi", variant="primary", size="lg") | |
| gr.Examples( | |
| examples=[], # Add example images if available | |
| inputs=input_image, | |
| label="Contoh Gambar (jika tersedia)" | |
| ) | |
| with gr.Column(): | |
| output_text = gr.Markdown(label="Hasil Prediksi") | |
| output_label = gr.Label( | |
| label="Top 5 Prediksi", | |
| num_top_classes=5 | |
| ) | |
| predict_btn.click( | |
| fn=format_prediction, | |
| inputs=input_image, | |
| outputs=[output_text, output_label] | |
| ) | |
| # Tab 2: Batch Prediction | |
| with gr.Tab("π Prediksi Batch"): | |
| gr.Markdown("### Upload multiple gambar batik sekaligus") | |
| batch_input = gr.File( | |
| file_count="multiple", | |
| file_types=["image"], | |
| label="Upload Gambar (Multiple)" | |
| ) | |
| batch_btn = gr.Button("π Prediksi Semua", variant="primary") | |
| batch_output = gr.Dataframe( | |
| headers=["Filename", "Predicted Class", "Confidence (%)"], | |
| label="Hasil Prediksi Batch" | |
| ) | |
| def predict_batch(files): | |
| """Predict multiple images""" | |
| if files is None or len(files) == 0: | |
| return [] | |
| results = [] | |
| for file in files: | |
| try: | |
| image = Image.open(file.name) | |
| pred_class, confidence = predict_single(image) | |
| results.append([file.name.split('/')[-1], pred_class, f"{confidence:.2f}"]) | |
| except Exception as e: | |
| results.append([file.name.split('/')[-1], "Error", str(e)]) | |
| return results | |
| batch_btn.click( | |
| fn=predict_batch, | |
| inputs=batch_input, | |
| outputs=batch_output | |
| ) | |
| # Tab 3: Model Info | |
| with gr.Tab("βΉοΈ Info Model"): | |
| gr.Markdown(get_model_info()) | |
| with gr.Accordion("π Daftar Semua Kelas Batik", open=False): | |
| class_list = "\n".join([f"{i+1}. {name}" for i, name in enumerate(class_names)]) | |
| gr.Textbox( | |
| value=class_list, | |
| label=f"Total {len(class_names)} Kelas", | |
| lines=20, | |
| max_lines=30 | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π Cara Penggunaan: | |
| 1. **Prediksi Tunggal:** Upload satu gambar batik dan klik tombol Prediksi | |
| 2. **Prediksi Batch:** Upload beberapa gambar sekaligus untuk prediksi massal | |
| 3. **Info Model:** Lihat informasi lengkap tentang model dan daftar kelas | |
| ### π‘ Tips: | |
| - Gunakan gambar dengan kualitas yang baik untuk hasil terbaik | |
| - Pastikan gambar menunjukkan motif batik dengan jelas | |
| - Model mendukung format JPG, PNG, dan format gambar umum lainnya | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| try: | |
| demo.launch( | |
| server_name="127.0.0.1", | |
| server_port=7860, | |
| share=False, # Ubah ke True jika mau public link | |
| inbrowser=True, | |
| quiet=False | |
| ) | |
| except Exception as e: | |
| print(f"Error launching Gradio: {e}") | |
| # Fallback: try simpler launch | |
| demo.launch() | |