fibrotest-base / app.py
miguelozaalon's picture
Update app.py
facf4b9 verified
import gradio as gr
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch
import torch.nn.functional as F
import os
# Load model and feature extractor
model_name = "miguelozaalon/fibrotest-base"
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN")
)
model = AutoModelForImageClassification.from_pretrained(
model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN")
)
label_map = {
0: "Normal",
1: "Lung Fibrosis",
}
model.config.id2label = label_map
model.config.label2id = {v: k for k, v in label_map.items()}
def predict_image(image):
# Preprocess image
inputs = feature_extractor(images=image, return_tensors="pt")
# Get model predictions
with torch.no_grad():
outputs = model(**inputs)
# Get probabilities using softmax
probs = F.softmax(outputs.logits, dim=-1)
# Get predicted class and confidence
predicted_class_idx = probs.argmax(-1).item()
confidence = probs[0][predicted_class_idx].item()
return {label_map[predicted_class_idx]: confidence}
# Create Gradio interface
title = "Lung Fibrosis Classification"
description = "This model classifies lung fibrosis images. Upload an image to predict."
examples = [["example_image.jpg"]]
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Label(),
title=title,
description=description,
examples=examples,
)
if __name__ == "__main__":
interface.launch()