Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| from torchvision import models, transforms | |
| from torchvision.datasets import ImageFolder | |
| from torch.utils.data import DataLoader | |
| from transformers import ViTForImageClassification | |
| from torch import nn | |
| from torch.cuda.amp import autocast | |
| import os | |
| from contextlib import nullcontext | |
| # Global configuration | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Label mapping (HAM10K) | |
| label_mapping = { | |
| 0: "Меланоцитарный невус", | |
| 1: "Меланома", | |
| 2: "Базальноклеточная карцинома", | |
| 3: "Актинический кератоз", | |
| 4: "Доброкачественная кератома", | |
| 5: "Дерматофиброма", | |
| 6: "Сосудистые поражения" | |
| } | |
| # Paths and hyperparams | |
| CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./") | |
| SUBMISSIONS_PATH = os.getenv("SUBMISSIONS_PATH", "./submissions") | |
| FT_BATCH = 32 | |
| FT_EPOCHS = 1 # adjust as needed | |
| LR = 1e-4 | |
| os.makedirs(CHECKPOINTS_PATH, exist_ok=True) | |
| os.makedirs(SUBMISSIONS_PATH, exist_ok=True) | |
| # Model definitions | |
| def get_efficientnet(): | |
| model = models.efficientnet_v2_s(weights="IMAGENET1K_V1") | |
| model.classifier[1] = nn.Linear(1280, len(label_mapping)) | |
| return model.to(device) | |
| def get_deit(): | |
| model = ViTForImageClassification.from_pretrained( | |
| 'facebook/deit-base-patch16-224', | |
| num_labels=len(label_mapping), | |
| ignore_mismatched_sizes=True | |
| ) | |
| return model.to(device) | |
| # Transforms | |
| train_transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def transform_image(image): | |
| return train_transform(image).unsqueeze(0).to(device) | |
| # Model Handler | |
| class ModelHandler: | |
| def __init__(self): | |
| self.efficientnet = None | |
| self.deit = None | |
| self.models_loaded = False | |
| self.load_models() | |
| def load_models(self): | |
| try: | |
| self.efficientnet = get_efficientnet() | |
| eff_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth") | |
| self.efficientnet.load_state_dict(torch.load(eff_path, map_location=device)) | |
| self.efficientnet.eval() | |
| self.deit = get_deit() | |
| deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth") | |
| self.deit.load_state_dict(torch.load(deit_path, map_location=device)) | |
| self.deit.eval() | |
| self.models_loaded = True | |
| print("✅ Models loaded successfully") | |
| except Exception as e: | |
| print(f"❌ Error loading models: {e}") | |
| self.models_loaded = False | |
| def predict(self, image, use='efficientnet'): | |
| if not self.models_loaded: | |
| return {"error": "Модели не загружены"} | |
| inputs = transform_image(image) | |
| ctx = autocast() if device.type == 'cuda' else nullcontext() | |
| with ctx: | |
| if use == 'efficientnet': | |
| logits = self.efficientnet(inputs) | |
| elif use == 'deit': | |
| logits = self.deit(pixel_values=inputs).logits | |
| else: | |
| logits = (self.efficientnet(inputs) + self.deit(pixel_values=inputs).logits) / 2 | |
| probs = torch.nn.functional.softmax(logits, dim=1) | |
| return self._format_predictions(probs) | |
| def _format_predictions(self, probs): | |
| top5_probs, top5_inds = torch.topk(probs, 5) | |
| return {label_mapping[i.item()]: float(top5_probs[0][k].item()) | |
| for k, i in enumerate(top5_inds[0])} | |
| # Initialize handler | |
| model_handler = ModelHandler() | |
| def predict_efficientnet(image): | |
| return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'efficientnet') | |
| def predict_deit(image): | |
| return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'deit') | |
| def predict_ensemble(image): | |
| return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'ensemble') | |
| # Finetuning logic | |
| def finetune_models(): | |
| # Prepare dataset | |
| dataset = ImageFolder(SUBMISSIONS_PATH, transform=train_transform) | |
| loader = DataLoader(dataset, batch_size=8, shuffle=True) | |
| # Finetune EfficientNet | |
| eff = get_efficientnet() | |
| eff.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"), map_location=device)) | |
| eff.train() | |
| optimizer = torch.optim.Adam(eff.parameters(), lr=LR) | |
| criterion = nn.CrossEntropyLoss() | |
| for epoch in range(FT_EPOCHS): | |
| for imgs, lbls in loader: | |
| imgs, lbls = imgs.to(device), lbls.to(device) | |
| optimizer.zero_grad() | |
| outputs = eff(imgs) | |
| loss = criterion(outputs, lbls) | |
| loss.backward() | |
| optimizer.step() | |
| torch.save(eff.state_dict(), os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")) | |
| # Finetune DeiT | |
| dt = get_deit() | |
| dt.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "deit_best.pth"), map_location=device)) | |
| dt.train() | |
| optimizer = torch.optim.Adam(dt.parameters(), lr=LR) | |
| for epoch in range(FT_EPOCHS): | |
| for imgs, lbls in loader: | |
| imgs, lbls = imgs.to(device), lbls.to(device) | |
| optimizer.zero_grad() | |
| outputs = dt(pixel_values=imgs).logits | |
| loss = criterion(outputs, lbls) | |
| loss.backward() | |
| optimizer.step() | |
| torch.save(dt.state_dict(), os.path.join(CHECKPOINTS_PATH, "deit_best.pth")) | |
| # Reload into handler | |
| model_handler.load_models() | |
| print("🔄 Models fine-tuned and reloaded") | |
| def handle_submission(image, label): | |
| if image is None or label is None: | |
| return "⚠️ Загрузите изображение и выберите метку" | |
| # Save image under label folder | |
| lbl_dir = os.path.join(SUBMISSIONS_PATH, str(label)) | |
| os.makedirs(lbl_dir, exist_ok=True) | |
| idx = len([f for f in os.listdir(lbl_dir) if f.endswith(('.png','.jpg'))]) + 1 | |
| path = os.path.join(lbl_dir, f"{label}_{idx}.png") | |
| image.save(path) | |
| # Count total submissions | |
| total = sum(len(files) for _, _, files in os.walk(SUBMISSIONS_PATH)) | |
| rem = FT_BATCH - (total % FT_BATCH) | |
| if rem == FT_BATCH: | |
| rem = 0 # just reached batch multiple | |
| # Trigger finetune if batch complete | |
| if total % FT_BATCH == 0: | |
| finetune_models() | |
| # Clear submissions | |
| for root, _, files in os.walk(SUBMISSIONS_PATH): | |
| for f in files: | |
| os.remove(os.path.join(root, f)) | |
| return f"Осталось {rem} изображений до следующей тонкой настройки" | |
| # Create Gradio interface | |
| def create_interface(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Диагностика кожных поражений (HAM10K)") | |
| status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены" | |
| gr.Markdown(f"**Состояние моделей:** {status}") | |
| with gr.Tabs(): | |
| with gr.TabItem("EfficientNet"): | |
| img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") | |
| gr.Button("Предсказать").click(predict_efficientnet, inputs=img, outputs=out) | |
| with gr.TabItem("DeiT"): | |
| img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") | |
| gr.Button("Предсказать").click(predict_deit, inputs=img, outputs=out) | |
| with gr.TabItem("Ансамблевая модель"): | |
| img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") | |
| gr.Button("Предсказать").click(predict_ensemble, inputs=img, outputs=out) | |
| with gr.TabItem("Submit for Finetuning"): | |
| sub_img = gr.Image(type="pil", label="Изображение для тонкой настройки") | |
| sub_lbl = gr.Dropdown(choices=list(label_mapping.values()), label="Выберите метку") | |
| sub_btn = gr.Button("Отправить") | |
| sub_out = gr.Textbox(label="Статус") | |
| sub_btn.click(handle_submission, inputs=[sub_img, sub_lbl], outputs=sub_out) | |
| return demo | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| print("🚀 Запуск интерфейса...") | |
| interface.launch(server_port=7860) | |