File size: 2,408 Bytes
0e2e09d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import gradio as gr
import torch
from PIL import Image
from fastai.learner import load_learner

# Model loading function
def load_model():
    model_path = 'models/jimi_classifier'
    try:
        if os.path.isdir(model_path):
            learn = load_learner(model_path)
            return learn
    except Exception as e:
        print(f"Error loading model: {e}")
        
        # Fallback stub model for testing
        class StubLearner:
            def predict(self, img):
                import random
                is_jimis = random.choice([True, False])
                pred_class = 'jimis' if is_jimis else 'not_jimis'
                pred_idx = 0 if is_jimis else 1
                probs = torch.tensor([0.8, 0.2]) if is_jimis else torch.tensor([0.2, 0.8])
                return pred_class, pred_idx, probs
        
        return StubLearner()

# Prediction function
def predict_image(img):
    if img is None:
        return "Please upload an image", 0
    
    model = load_model()
    
    try:
        # Process the image
        pred_class, pred_idx, probs = model.predict(img)
        confidence = float(probs[pred_idx]) * 100
        
        result = "Jimis" if str(pred_class).lower() == "jimis" else "Not Jimis"
        return result, round(confidence, 2)
    except Exception as e:
        print(f"Error during prediction: {e}")
        import traceback
        traceback.print_exc()
        return f"Error processing image: {str(e)}", 0

# Example images for the demo
examples = [
    # You can add example image paths here if you have them
]

# Create the Gradio interface
demo = gr.Interface(
    fn=predict_image,
    inputs=gr.Image(type="pil", label="Upload an image"),
    outputs=[
        gr.Label(label="Prediction"),
        gr.Number(label="Confidence (%)")
    ],
    title="Jimis Classifier",
    description="Upload an image to check if it contains Jimis",
    examples=examples,
    article="""
    ## How it works
    
    This application uses a machine learning model trained to recognize Jimis in images.
    The model was trained on a custom dataset of Jimis and non-Jimis images using the 
    fastai library and a ResNet architecture.
    
    Simply upload an image, and the model will tell you whether it contains Jimis and 
    how confident it is about its prediction.
    """
)

# Launch the app
if __name__ == "__main__":
    demo.launch()