File size: 6,340 Bytes
6a441bb
8b1775e
6a441bb
 
 
 
 
 
 
 
 
 
95d742a
 
6a441bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95d742a
 
 
 
 
 
 
6a441bb
95d742a
6a441bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a442896
6a441bb
 
 
 
 
 
 
 
 
a442896
8b1775e
6a441bb
 
a442896
8b1775e
a442896
 
6a441bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a442896
 
 
6a441bb
 
 
 
 
 
 
 
 
8b1775e
6a441bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0002d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
Gradio UI application for Batik Classification
Optimized for Hugging Face Spaces deployment
"""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import json
import numpy as np
from typing import Tuple, Dict
from huggingface_hub import hf_hub_download
import os

# Global variables
model = None
class_names = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = None


def load_model():
    global model, class_names, transform

    try:
        # Load model configuration
        with open('model_config.json', 'r') as f:
            config = json.load(f)

        num_classes = config['num_classes']
        class_names = config['class_names']
        image_size = config.get('image_size', 224)

        # Initialize VGG16 model
        model = models.vgg16(weights=None)
        # Modify classifier to match saved model architecture
        model.classifier[3] = nn.Linear(4096, num_classes)
        model.classifier = nn.Sequential(*list(model.classifier.children())[:4])

        # Download model from Hugging Face Hub
        print("๐Ÿ“ฅ Downloading model from Hugging Face Hub...")
        model_path = hf_hub_download(
            repo_id="RimsJ/Batik-Classifier",
            filename="vgg16_batik_best.pth"
        )

        # Load trained weights
        checkpoint = torch.load(model_path, map_location=device)

        # Extract state_dict
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint

        # Remove '_orig_mod.' prefix if present
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith('_orig_mod.'):
                new_key = key.replace('_orig_mod.', '')
                new_state_dict[new_key] = value
            else:
                new_state_dict[key] = value

        model.load_state_dict(new_state_dict)
        model = model.to(device)
        model.eval()

        # Define image preprocessing
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])

        print(f"โœ… Model loaded successfully on {device}")
        print(f"๐Ÿ“Š Number of classes: {num_classes}")

    except Exception as e:
        print(f"โŒ Error loading model: {str(e)}")
        raise


def predict_image(image):
    """
    Predict batik class from image

    Args:
        image: PIL Image

    Returns:
        Tuple of (top_k_dict, formatted_text)
    """
    global model, transform, class_names

    try:
        if image is None:
            return None, "โŒ Silakan upload gambar batik terlebih dahulu"

        if model is None:
            return None, "โŒ Model belum dimuat. Silakan refresh halaman."

        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Transform and predict
        input_tensor = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            top_probs, top_indices = torch.topk(probabilities, min(5, len(class_names)), dim=1)

        # Get top prediction
        predicted_class = class_names[top_indices[0][0].item()]
        confidence = top_probs[0][0].item() * 100

        # Format top-5 results
        results = {}
        for i in range(min(5, len(class_names))):
            class_name = class_names[top_indices[0][i].item()]
            conf = top_probs[0][i].item()
            results[class_name] = float(conf)

        # Format output text
        result_text = f"""
## ๐ŸŽฏ Hasil Prediksi

**Motif Batik:** `{predicted_class}`
**Confidence:** `{confidence:.2f}%`

---

### ๐Ÿ“Š Top 5 Prediksi:
"""

        for idx, (class_name, conf) in enumerate(list(results.items())[:5], 1):
            bar = "โ–ˆ" * int(conf * 20)
            result_text += f"\n{idx}. **{class_name}** - {conf*100:.2f}%  \n   {bar}"

        return results, result_text

    except Exception as e:
        import traceback
        traceback.print_exc()
        return None, f"โŒ Error: {str(e)}"


# Load model at startup
print("๐Ÿ”„ Loading model...")
load_model()
print("โœ… Model ready!")

# Create Gradio interface
with gr.Blocks(
    title="Batik Classification",
    theme=gr.themes.Soft(),
    css=".gradio-container {max-width: 1200px; margin: auto;}"
) as demo:

    gr.Markdown("""
    # ๐ŸŽจ Klasifikasi Motif Batik Indonesia

    Upload gambar batik untuk mengetahui motif dan asalnya!
    **Total 111 motif batik** dari berbagai daerah di Indonesia ๐Ÿ‡ฎ๐Ÿ‡ฉ
    """)

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(
                type="pil",
                label="๐Ÿ“ค Upload Gambar Batik",
                height=400
            )
            predict_btn = gr.Button(
                "๐Ÿ” Prediksi Motif Batik",
                variant="primary",
                size="lg"
            )

            gr.Markdown("""
            ### ๐Ÿ’ก Tips:
            - Gunakan gambar dengan kualitas baik
            - Pastikan motif batik terlihat jelas
            - Format: JPG, PNG, JPEG
            """)

        with gr.Column(scale=1):
            output_text = gr.Markdown(label="Hasil Prediksi")
            output_label = gr.Label(
                label="๐Ÿ“Š Confidence Score",
                num_top_classes=5
            )

    # Event handler
    predict_btn.click(
        fn=predict_image,
        inputs=input_image,
        outputs=[output_label, output_text]
    )

    gr.Markdown("""
    ---
    ### ๐Ÿ“‹ Tentang Model
    - **Dataset:** 111 Motif Batik Indonesia
    - **Kategori:** Batik dari Jawa Tengah, Jawa Timur, Jawa Barat, Bali, Jakarta, Kalimantan, Lampung

    ### ๐ŸŽจ Contoh Motif:
    Parang Kusumo, Megamendung, Kawung, Truntum, Semarangan, dan banyak lagi!

    ---
    **Made with โค๏ธ for Indonesian Batik Heritage**
    """)


# Launch
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)