File size: 3,797 Bytes
3609855
745d07d
 
 
 
69e2119
745d07d
72c0616
5abc557
71667c6
5abc557
71667c6
 
5abc557
1180bf9
62716fe
5abc557
62716fe
4b90365
62716fe
5abc557
 
71667c6
72c0616
5abc557
 
62716fe
71667c6
5abc557
62716fe
745d07d
62716fe
745d07d
71667c6
745d07d
 
 
5abc557
 
745d07d
 
69e2119
5abc557
71667c6
5abc557
71667c6
5abc557
 
71667c6
5abc557
72c0616
5abc557
71667c6
 
 
 
5abc557
72c0616
5abc557
 
 
 
 
71667c6
5abc557
72c0616
 
bffd9d1
5abc557
 
 
71667c6
 
 
 
 
 
 
 
5abc557
 
cf0a5d8
5abc557
4552899
71667c6
5abc557
 
 
 
4b90365
5abc557
 
72c0616
cf0a5d8
5abc557
 
 
 
 
 
 
 
 
 
7b665a0
5abc557
 
cf0a5d8
 
72c0616
cf0a5d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
from typing import Tuple, List, Dict

def load_model() -> Tuple[torch.nn.Module, List[str]]:
    """
    Loads the model and class labels.
    Returns:
        model: The loaded PyTorch model.
        labels: List of class labels.
    """
    model_path = "skinconvnext_scripted.pt"
    labels_path = "labels.txt"
    
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model not found: {model_path}")
    if not os.path.exists(labels_path):
        raise FileNotFoundError("File labels.txt not found.")
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.jit.load(model_path, map_location=device)
    model.eval()
    
    with open(labels_path, "r") as f:
        labels = [line.strip() for line in f.readlines()]
    
    return model, labels

model, labels = load_model()

# Define image preprocessing steps
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

def predict(image: Image.Image) -> Dict[str, float]:
    """
    Makes a prediction for the given image.
    Args:
        image (PIL.Image): The input image.
    
    Returns:
        Dict[str, float]: A dictionary where keys are class names, and values are probabilities.
    """
    try:
        image = image.convert("RGB")
        image_tensor = preprocess(image).unsqueeze(0)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        image_tensor = image_tensor.to(device)
        model.to(device)
        
        with torch.no_grad():
            output = model(image_tensor)
            scores = torch.nn.functional.softmax(output[0], dim=0)
        
        predictions = {label: float(score) for label, score in zip(labels, scores)}
        sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))
        
        return sorted_predictions
    except Exception as e:
        return {"error": str(e)}

title = "🔥 Skin-AI"
description = (
    "🔬 **Skin-AI — AI-Powered Skin Disease Classification**\n\n"
    "This project utilizes a deep learning model to classify skin diseases based on an uploaded image.\n\n"
    "### 🚀 How to Use:\n\n"
    "1️⃣ Upload or take a photo of the affected skin area.\n\n"
    "2️⃣ Click the 'Submit' button.\n\n"
    "3️⃣ The app will return the probabilities for possible skin conditions.\n\n"
    "⚠️ **Important!** The results are for informational purposes only and do not constitute a medical diagnosis.\n\n"
    "### 🛠 Technologies Used:\n"
    "- PyTorch (Lightning)\n"
    "- Gradio\n"
    "- Hugging Face Spaces\n\n"
    "🔗 Source Code: [Hugging Face](https://huggingface.co/Eraly-ml/Skin-AI  )"
)

# Adding example images
examples = [
    ["example1.jpg"],
    ["example2.jpg"]
]

def update_submit_state(image):
    return gr.update(interactive=image is not None)

with gr.Blocks() as interface:
    gr.Markdown(title)
    gr.Markdown(description)
    
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Image")
    
    submit_button = gr.Button("Submit", interactive=False)
    output_label = gr.Label(num_top_classes=3, label="Prediction")
    
    image_input.change(fn=update_submit_state, inputs=image_input, outputs=submit_button)
    submit_button.click(fn=predict, inputs=image_input, outputs=output_label)
    
    gr.Examples(examples, inputs=image_input)

# Для Docker/HF Spaces — указываем хост и порт явно
if __name__ == "__main__":
    interface.launch(server_name="0.0.0.0", server_port=7860)