Skin-AI / app.py
Eraly-ml's picture
Update app.py
cf0a5d8 verified
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
from typing import Tuple, List, Dict
def load_model() -> Tuple[torch.nn.Module, List[str]]:
"""
Loads the model and class labels.
Returns:
model: The loaded PyTorch model.
labels: List of class labels.
"""
model_path = "skinconvnext_scripted.pt"
labels_path = "labels.txt"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
if not os.path.exists(labels_path):
raise FileNotFoundError("File labels.txt not found.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load(model_path, map_location=device)
model.eval()
with open(labels_path, "r") as f:
labels = [line.strip() for line in f.readlines()]
return model, labels
model, labels = load_model()
# Define image preprocessing steps
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def predict(image: Image.Image) -> Dict[str, float]:
"""
Makes a prediction for the given image.
Args:
image (PIL.Image): The input image.
Returns:
Dict[str, float]: A dictionary where keys are class names, and values are probabilities.
"""
try:
image = image.convert("RGB")
image_tensor = preprocess(image).unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
model.to(device)
with torch.no_grad():
output = model(image_tensor)
scores = torch.nn.functional.softmax(output[0], dim=0)
predictions = {label: float(score) for label, score in zip(labels, scores)}
sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
return sorted_predictions
except Exception as e:
return {"error": str(e)}
title = "🔥 Skin-AI"
description = (
"🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
"This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n"
"### 🚀 How to Use:\n\n"
"1️⃣ Upload or take a photo of the affected skin area.\n\n"
"2️⃣ Click the 'Submit' button.\n\n"
"3️⃣ The app will return the probabilities for possible skin conditions.\n\n"
"⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n"
"### 🛠 Technologies Used:\n"
"- PyTorch (Lightning)\n"
"- Gradio\n"
"- Hugging Face Spaces\n\n"
"🔗 Source Code: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI )"
)
# Adding example images
examples = [
["example1.jpg"],
["example2.jpg"]
]
def update_submit_state(image):
return gr.update(interactive=image is not None)
with gr.Blocks() as interface:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
submit_button = gr.Button("Submit", interactive=False)
output_label = gr.Label(num_top_classes=3, label="Prediction")
image_input.change(fn=update_submit_state, inputs=image_input, outputs=submit_button)
submit_button.click(fn=predict, inputs=image_input, outputs=output_label)
gr.Examples(examples, inputs=image_input)
# Для Docker/HF Spaces — указываем хост и порт явно
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860)