File size: 3,797 Bytes
3609855 745d07d 69e2119 745d07d 72c0616 5abc557 71667c6 5abc557 71667c6 5abc557 1180bf9 62716fe 5abc557 62716fe 4b90365 62716fe 5abc557 71667c6 72c0616 5abc557 62716fe 71667c6 5abc557 62716fe 745d07d 62716fe 745d07d 71667c6 745d07d 5abc557 745d07d 69e2119 5abc557 71667c6 5abc557 71667c6 5abc557 71667c6 5abc557 72c0616 5abc557 71667c6 5abc557 72c0616 5abc557 71667c6 5abc557 72c0616 bffd9d1 5abc557 71667c6 5abc557 cf0a5d8 5abc557 4552899 71667c6 5abc557 4b90365 5abc557 72c0616 cf0a5d8 5abc557 7b665a0 5abc557 cf0a5d8 72c0616 cf0a5d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
from typing import Tuple, List, Dict
def load_model() -> Tuple[torch.nn.Module, List[str]]:
"""
Loads the model and class labels.
Returns:
model: The loaded PyTorch model.
labels: List of class labels.
"""
model_path = "skinconvnext_scripted.pt"
labels_path = "labels.txt"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
if not os.path.exists(labels_path):
raise FileNotFoundError("File labels.txt not found.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load(model_path, map_location=device)
model.eval()
with open(labels_path, "r") as f:
labels = [line.strip() for line in f.readlines()]
return model, labels
model, labels = load_model()
# Define image preprocessing steps
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def predict(image: Image.Image) -> Dict[str, float]:
"""
Makes a prediction for the given image.
Args:
image (PIL.Image): The input image.
Returns:
Dict[str, float]: A dictionary where keys are class names, and values are probabilities.
"""
try:
image = image.convert("RGB")
image_tensor = preprocess(image).unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
model.to(device)
with torch.no_grad():
output = model(image_tensor)
scores = torch.nn.functional.softmax(output[0], dim=0)
predictions = {label: float(score) for label, score in zip(labels, scores)}
sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
return sorted_predictions
except Exception as e:
return {"error": str(e)}
title = "🔥 Skin-AI"
description = (
"🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
"This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n"
"### 🚀 How to Use:\n\n"
"1️⃣ Upload or take a photo of the affected skin area.\n\n"
"2️⃣ Click the 'Submit' button.\n\n"
"3️⃣ The app will return the probabilities for possible skin conditions.\n\n"
"⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n"
"### 🛠 Technologies Used:\n"
"- PyTorch (Lightning)\n"
"- Gradio\n"
"- Hugging Face Spaces\n\n"
"🔗 Source Code: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI )"
)
# Adding example images
examples = [
["example1.jpg"],
["example2.jpg"]
]
def update_submit_state(image):
return gr.update(interactive=image is not None)
with gr.Blocks() as interface:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Image")
submit_button = gr.Button("Submit", interactive=False)
output_label = gr.Label(num_top_classes=3, label="Prediction")
image_input.change(fn=update_submit_state, inputs=image_input, outputs=submit_button)
submit_button.click(fn=predict, inputs=image_input, outputs=output_label)
gr.Examples(examples, inputs=image_input)
# Для Docker/HF Spaces — указываем хост и порт явно
if __name__ == "__main__":
interface.launch(server_name="0.0.0.0", server_port=7860) |