Demo_v4 / app.py
Sagar31658's picture
Update app.py
08d7458 verified
import gradio as gr
import torch
import numpy as np
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
from PIL import Image
import requests
from io import BytesIO
# Load model and feature extractor
model_name = "Team-SknAI/SknAI-v4-10Labels" # Replace with your model name
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
def predict_image(image):
model.eval()
# Convert NumPy array to PIL image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Ensure the image is resized to match the expected input size
image = image.resize((224, 224))
inputs = feature_extractor(images=image, return_tensors="pt")
# Ensure input matches model precision
inputs = {k: v.to(model.dtype).to("cpu") for k, v in inputs.items()} # Force CPU to avoid device mismatch
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
class_name = model.config.id2label[predicted_class]
return class_name
app = gr.Interface(fn=predict_image, inputs="image", outputs="text")
app.launch(share=True)