| import gradio as gr |
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| import os |
|
|
| |
| def load_model(): |
| model_path = "skin_disease_model_jit.pt" |
| labels_path = "labels.txt" |
|
|
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Модель не найдена: {model_path}") |
| if not os.path.exists(labels_path): |
| raise FileNotFoundError("Файл labels.txt не найден.") |
|
|
| model = torch.jit.load(model_path, map_location=torch.device('cpu')) |
| model.eval() |
|
|
| with open(labels_path, "r") as f: |
| labels = [line.strip() for line in f.readlines()] |
|
|
| return model, labels |
|
|
| model, labels = load_model() |
|
|
| preprocess = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| |
| def predict(image): |
| image = image.convert("RGB") |
| image_tensor = preprocess(image).unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(image_tensor) |
| scores = torch.nn.functional.softmax(output[0], dim=0) |
|
|
| return {label: float(score) for label, score in zip(labels, scores)} |
|
|
| |
| title = "🩺 Классификация кожных заболеваний" |
| description = ( |
| "Загрузите изображение кожи, чтобы получить предсказание.\n\n" |
| "⚠️ **Важно!** Данное приложение использует искусственный интеллект для анализа изображений, " |
| "но оно **не является медицинским инструментом**. Результаты предсказания могут быть неточными. " |
| "Для точной диагностики обратитесь к врачу-специалисту." |
| ) |
|
|
| css = """ |
| h1 { text-align: center; color: white; } |
| body { background-color: #131722; color: white; font-family: Arial, sans-serif; } |
| .gradio-container { max-width: 800px; margin: auto; padding-top: 20px; } |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| gr.Markdown(f"# {title}") |
| gr.Markdown(description) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| image_input = gr.Image(type="pil", label="📷 Изображение", interactive=True) |
| predict_button = gr.Button("🔍 Анализировать", variant="primary") |
| with gr.Column(): |
| result_label = gr.Label(num_top_classes=3, label="📊 Предсказания") |
|
|
| predict_button.click(fn=predict, inputs=image_input, outputs=result_label) |
|
|
| |
| if __name__ == "__main__": |
| demo.launch() |
|
|