Skin-AI / app.py
Eraly-ml's picture
Update app.py
003c4e4 verified
raw
history blame
2.79 kB
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import os
# Загрузка модели и меток классов
def load_model():
model_path = "skin_disease_model_jit.pt"
labels_path = "labels.txt"
if not os.path.exists(model_path):
raise FileNotFoundError(f"Модель не найдена: {model_path}")
if not os.path.exists(labels_path):
raise FileNotFoundError("Файл labels.txt не найден.")
model = torch.jit.load(model_path, map_location=torch.device('cpu'))
model.eval()
with open(labels_path, "r") as f:
labels = [line.strip() for line in f.readlines()]
return model, labels
model, labels = load_model()
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Функция предсказания
def predict(image):
image = image.convert("RGB")
image_tensor = preprocess(image).unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
scores = torch.nn.functional.softmax(output[0], dim=0)
return {label: float(score) for label, score in zip(labels, scores)}
# Создание интерфейса
title = "🩺 Классификация кожных заболеваний"
description = (
"Загрузите изображение кожи, чтобы получить предсказание.\n\n"
"⚠️ **Важно!** Данное приложение использует искусственный интеллект для анализа изображений, "
"но оно **не является медицинским инструментом**. Результаты предсказания могут быть неточными. "
"Для точной диагностики обратитесь к врачу-специалисту."
)
css = """
h1 { text-align: center; color: white; }
body { background-color: #131722; color: white; font-family: Arial, sans-serif; }
.gradio-container { max-width: 800px; margin: auto; padding-top: 20px; }
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="📷 Изображение", interactive=True)
predict_button = gr.Button("🔍 Анализировать", variant="primary")
with gr.Column():
result_label = gr.Label(num_top_classes=3, label="📊 Предсказания")
predict_button.click(fn=predict, inputs=image_input, outputs=result_label)
# Запуск
if __name__ == "__main__":
demo.launch()