import torch import torch.nn as nn import torch.nn.functional as F import timm from torchvision import models from ultralytics import YOLO from transformers import AutoProcessor, AutoModelForImageTextToText import joblib import os DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {DEVICE}") class AttentionGate(nn.Module): def __init__(self, f_g, f_l, f_int): super().__init__() self.W_g = nn.Sequential( nn.Conv2d(f_g, f_int, kernel_size=1), nn.BatchNorm2d(f_int) ) self.W_x = nn.Sequential( nn.Conv2d(f_l, f_int, kernel_size=1), nn.BatchNorm2d(f_int) ) self.psi = nn.Sequential( nn.Conv2d(f_int, 1, kernel_size=1), nn.BatchNorm2d(1), nn.Sigmoid() ) self.relu = nn.ReLU(inplace=True) def forward(self, g, x): g1 = self.W_g(g) x1 = self.W_x(x) if g1.shape != x1.shape: g1 = F.interpolate(g1, size=x1.shape[2:], mode='bilinear', align_corners=False) psi = self.relu(g1 + x1) psi = self.psi(psi) return x * psi def conv_block(in_ch, out_ch): return nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) class UNetEfficientNetB5(nn.Module): def __init__(self, num_classes=3, pretrained=False): super().__init__() self.encoder = timm.create_model( 'efficientnet_b4', pretrained=pretrained, features_only=True, out_indices=(0, 1, 2, 3, 4) ) enc_channels = self.encoder.feature_info.channels() self.center = conv_block(enc_channels[4], 256) self.ag4 = AttentionGate(f_g=256, f_l=enc_channels[3], f_int=128) self.ag3 = AttentionGate(f_g=enc_channels[3], f_l=enc_channels[2], f_int=64) self.ag2 = AttentionGate(f_g=enc_channels[2], f_l=enc_channels[1], f_int=32) self.ag1 = AttentionGate(f_g=enc_channels[1], f_l=enc_channels[0], f_int=16) self.up4 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2) self.dec4 = conv_block(256 + enc_channels[3], enc_channels[3]) self.up3 = nn.ConvTranspose2d(enc_channels[3], enc_channels[3], kernel_size=2, stride=2) self.dec3 = conv_block(enc_channels[3] + enc_channels[2], enc_channels[2]) self.up2 = nn.ConvTranspose2d(enc_channels[2], enc_channels[2], kernel_size=2, stride=2) self.dec2 = conv_block(enc_channels[2] + enc_channels[1], enc_channels[1]) self.up1 = nn.ConvTranspose2d(enc_channels[1], enc_channels[1], kernel_size=2, stride=2) self.dec1 = conv_block(enc_channels[1] + enc_channels[0], enc_channels[0]) self.final_up = nn.ConvTranspose2d(enc_channels[0], 32, kernel_size=2, stride=2) self.out = nn.Conv2d(32, num_classes, kernel_size=1) self.ds3 = nn.Conv2d(enc_channels[2], num_classes, kernel_size=1) self.ds2 = nn.Conv2d(enc_channels[1], num_classes, kernel_size=1) def forward(self, x): feats = self.encoder(x) e0, e1, e2, e3, e4 = feats c = self.center(e4) d4 = self.up4(c) e3_att = self.ag4(g=d4, x=e3) if d4.shape != e3_att.shape: d4 = F.interpolate(d4, size=e3_att.shape[2:], mode='bilinear', align_corners=False) d4 = self.dec4(torch.cat([d4, e3_att], dim=1)) d3 = self.up3(d4) e2_att = self.ag3(g=d3, x=e2) if d3.shape != e2_att.shape: d3 = F.interpolate(d3, size=e2_att.shape[2:], mode='bilinear', align_corners=False) d3 = self.dec3(torch.cat([d3, e2_att], dim=1)) d2 = self.up2(d3) e1_att = self.ag2(g=d2, x=e1) if d2.shape != e1_att.shape: d2 = F.interpolate(d2, size=e1_att.shape[2:], mode='bilinear', align_corners=False) d2 = self.dec2(torch.cat([d2, e1_att], dim=1)) d1 = self.up1(d2) e0_att = self.ag1(g=d1, x=e0) if d1.shape != e0_att.shape: d1 = F.interpolate(d1, size=e0_att.shape[2:], mode='bilinear', align_corners=False) d1 = self.dec1(torch.cat([d1, e0_att], dim=1)) out = self.final_up(d1) return self.out(out) class EfficientNetB3Classifier(nn.Module): def __init__(self): super().__init__() self.backbone = models.efficientnet_b3(weights=None) in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features, 1), ) def forward(self, x): return self.backbone(x) def load_yolo(path: str): model = YOLO(path) model.to(DEVICE) print("YOLO încărcat") return model def load_unet(path: str): model = UNetEfficientNetB5(num_classes=3, pretrained=False).to(DEVICE) state_dict = torch.load(path, map_location=DEVICE) model.load_state_dict(state_dict, strict=True) model.eval() print("U-Net încărcat") return model def load_efficientnet(path: str): model = EfficientNetB3Classifier().to(DEVICE) state_dict = torch.load(path, map_location=DEVICE) model.load_state_dict(state_dict, strict=False) model.eval() print("EfficientNet încărcat") return model def load_medgemma(model_id: str = "google/medgemma-1.5-4b-it"): processor = AutoProcessor.from_pretrained(model_id, use_fast=False) model = AutoModelForImageTextToText.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", ) model.eval() print("MedGemma încărcat") return model, processor def load_fusion(pkl_path: str, thresh_path: str): fusion_model = joblib.load(pkl_path) threshold = joblib.load(thresh_path) print(f"Fusion model încărcat (threshold={threshold:.4f})") return fusion_model, threshold