jimapi / app.py
nikosteskos's picture
Rename app_gradio.py to app.py
43e8ffb verified
import os
import gradio as gr
import torch
from PIL import Image
from fastai.learner import load_learner
# Model loading function
def load_model():
model_path = 'models/jimi_classifier'
try:
if os.path.isdir(model_path):
learn = load_learner(model_path)
return learn
except Exception as e:
print(f"Error loading model: {e}")
# Fallback stub model for testing
class StubLearner:
def predict(self, img):
import random
is_jimis = random.choice([True, False])
pred_class = 'jimis' if is_jimis else 'not_jimis'
pred_idx = 0 if is_jimis else 1
probs = torch.tensor([0.8, 0.2]) if is_jimis else torch.tensor([0.2, 0.8])
return pred_class, pred_idx, probs
return StubLearner()
# Prediction function
def predict_image(img):
if img is None:
return "Please upload an image", 0
model = load_model()
try:
# Process the image
pred_class, pred_idx, probs = model.predict(img)
confidence = float(probs[pred_idx]) * 100
result = "Jimis" if str(pred_class).lower() == "jimis" else "Not Jimis"
return result, round(confidence, 2)
except Exception as e:
print(f"Error during prediction: {e}")
import traceback
traceback.print_exc()
return f"Error processing image: {str(e)}", 0
# Example images for the demo
examples = [
# You can add example image paths here if you have them
]
# Create the Gradio interface
demo = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil", label="Upload an image"),
outputs=[
gr.Label(label="Prediction"),
gr.Number(label="Confidence (%)")
],
title="Jimis Classifier",
description="Upload an image to check if it contains Jimis",
examples=examples,
article="""
## How it works
This application uses a machine learning model trained to recognize Jimis in images.
The model was trained on a custom dataset of Jimis and non-Jimis images using the
fastai library and a ResNet architecture.
Simply upload an image, and the model will tell you whether it contains Jimis and
how confident it is about its prediction.
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch()