Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import timm | |
| from torchvision import models | |
| from ultralytics import YOLO | |
| from transformers import AutoProcessor, AutoModelForImageTextToText | |
| import joblib | |
| import os | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device: {DEVICE}") | |
| class AttentionGate(nn.Module): | |
| def __init__(self, f_g, f_l, f_int): | |
| super().__init__() | |
| self.W_g = nn.Sequential( | |
| nn.Conv2d(f_g, f_int, kernel_size=1), | |
| nn.BatchNorm2d(f_int) | |
| ) | |
| self.W_x = nn.Sequential( | |
| nn.Conv2d(f_l, f_int, kernel_size=1), | |
| nn.BatchNorm2d(f_int) | |
| ) | |
| self.psi = nn.Sequential( | |
| nn.Conv2d(f_int, 1, kernel_size=1), | |
| nn.BatchNorm2d(1), | |
| nn.Sigmoid() | |
| ) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, g, x): | |
| g1 = self.W_g(g) | |
| x1 = self.W_x(x) | |
| if g1.shape != x1.shape: | |
| g1 = F.interpolate(g1, size=x1.shape[2:], | |
| mode='bilinear', align_corners=False) | |
| psi = self.relu(g1 + x1) | |
| psi = self.psi(psi) | |
| return x * psi | |
| def conv_block(in_ch, out_ch): | |
| return nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True) | |
| ) | |
| class UNetEfficientNetB5(nn.Module): | |
| def __init__(self, num_classes=3, pretrained=False): | |
| super().__init__() | |
| self.encoder = timm.create_model( | |
| 'efficientnet_b4', | |
| pretrained=pretrained, | |
| features_only=True, | |
| out_indices=(0, 1, 2, 3, 4) | |
| ) | |
| enc_channels = self.encoder.feature_info.channels() | |
| self.center = conv_block(enc_channels[4], 256) | |
| self.ag4 = AttentionGate(f_g=256, f_l=enc_channels[3], f_int=128) | |
| self.ag3 = AttentionGate(f_g=enc_channels[3], f_l=enc_channels[2], f_int=64) | |
| self.ag2 = AttentionGate(f_g=enc_channels[2], f_l=enc_channels[1], f_int=32) | |
| self.ag1 = AttentionGate(f_g=enc_channels[1], f_l=enc_channels[0], f_int=16) | |
| self.up4 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2) | |
| self.dec4 = conv_block(256 + enc_channels[3], enc_channels[3]) | |
| self.up3 = nn.ConvTranspose2d(enc_channels[3], enc_channels[3], kernel_size=2, stride=2) | |
| self.dec3 = conv_block(enc_channels[3] + enc_channels[2], enc_channels[2]) | |
| self.up2 = nn.ConvTranspose2d(enc_channels[2], enc_channels[2], kernel_size=2, stride=2) | |
| self.dec2 = conv_block(enc_channels[2] + enc_channels[1], enc_channels[1]) | |
| self.up1 = nn.ConvTranspose2d(enc_channels[1], enc_channels[1], kernel_size=2, stride=2) | |
| self.dec1 = conv_block(enc_channels[1] + enc_channels[0], enc_channels[0]) | |
| self.final_up = nn.ConvTranspose2d(enc_channels[0], 32, kernel_size=2, stride=2) | |
| self.out = nn.Conv2d(32, num_classes, kernel_size=1) | |
| self.ds3 = nn.Conv2d(enc_channels[2], num_classes, kernel_size=1) | |
| self.ds2 = nn.Conv2d(enc_channels[1], num_classes, kernel_size=1) | |
| def forward(self, x): | |
| feats = self.encoder(x) | |
| e0, e1, e2, e3, e4 = feats | |
| c = self.center(e4) | |
| d4 = self.up4(c) | |
| e3_att = self.ag4(g=d4, x=e3) | |
| if d4.shape != e3_att.shape: | |
| d4 = F.interpolate(d4, size=e3_att.shape[2:], mode='bilinear', align_corners=False) | |
| d4 = self.dec4(torch.cat([d4, e3_att], dim=1)) | |
| d3 = self.up3(d4) | |
| e2_att = self.ag3(g=d3, x=e2) | |
| if d3.shape != e2_att.shape: | |
| d3 = F.interpolate(d3, size=e2_att.shape[2:], mode='bilinear', align_corners=False) | |
| d3 = self.dec3(torch.cat([d3, e2_att], dim=1)) | |
| d2 = self.up2(d3) | |
| e1_att = self.ag2(g=d2, x=e1) | |
| if d2.shape != e1_att.shape: | |
| d2 = F.interpolate(d2, size=e1_att.shape[2:], mode='bilinear', align_corners=False) | |
| d2 = self.dec2(torch.cat([d2, e1_att], dim=1)) | |
| d1 = self.up1(d2) | |
| e0_att = self.ag1(g=d1, x=e0) | |
| if d1.shape != e0_att.shape: | |
| d1 = F.interpolate(d1, size=e0_att.shape[2:], mode='bilinear', align_corners=False) | |
| d1 = self.dec1(torch.cat([d1, e0_att], dim=1)) | |
| out = self.final_up(d1) | |
| return self.out(out) | |
| class EfficientNetB3Classifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.backbone = models.efficientnet_b3(weights=None) | |
| in_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(p=0.3, inplace=True), | |
| nn.Linear(in_features, 1), | |
| ) | |
| def forward(self, x): | |
| return self.backbone(x) | |
| def load_yolo(path: str): | |
| model = YOLO(path) | |
| model.to(DEVICE) | |
| print("YOLO încărcat") | |
| return model | |
| def load_unet(path: str): | |
| model = UNetEfficientNetB5(num_classes=3, pretrained=False).to(DEVICE) | |
| state_dict = torch.load(path, map_location=DEVICE) | |
| model.load_state_dict(state_dict, strict=True) | |
| model.eval() | |
| print("U-Net încărcat") | |
| return model | |
| def load_efficientnet(path: str): | |
| model = EfficientNetB3Classifier().to(DEVICE) | |
| state_dict = torch.load(path, map_location=DEVICE) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| print("EfficientNet încărcat") | |
| return model | |
| def load_medgemma(model_id: str = "google/medgemma-1.5-4b-it"): | |
| processor = AutoProcessor.from_pretrained(model_id, use_fast=False) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| ) | |
| model.eval() | |
| print("MedGemma încărcat") | |
| return model, processor | |
| def load_fusion(pkl_path: str, thresh_path: str): | |
| fusion_model = joblib.load(pkl_path) | |
| threshold = joblib.load(thresh_path) | |
| print(f"Fusion model încărcat (threshold={threshold:.4f})") | |
| return fusion_model, threshold |