glaunet-screening / models.py
rallou's picture
Initial deploy GlauNet
6fcadef
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import models
from ultralytics import YOLO
from transformers import AutoProcessor, AutoModelForImageTextToText
import joblib
import os
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")
class AttentionGate(nn.Module):
def __init__(self, f_g, f_l, f_int):
super().__init__()
self.W_g = nn.Sequential(
nn.Conv2d(f_g, f_int, kernel_size=1),
nn.BatchNorm2d(f_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(f_l, f_int, kernel_size=1),
nn.BatchNorm2d(f_int)
)
self.psi = nn.Sequential(
nn.Conv2d(f_int, 1, kernel_size=1),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
if g1.shape != x1.shape:
g1 = F.interpolate(g1, size=x1.shape[2:],
mode='bilinear', align_corners=False)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
def conv_block(in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
class UNetEfficientNetB5(nn.Module):
def __init__(self, num_classes=3, pretrained=False):
super().__init__()
self.encoder = timm.create_model(
'efficientnet_b4',
pretrained=pretrained,
features_only=True,
out_indices=(0, 1, 2, 3, 4)
)
enc_channels = self.encoder.feature_info.channels()
self.center = conv_block(enc_channels[4], 256)
self.ag4 = AttentionGate(f_g=256, f_l=enc_channels[3], f_int=128)
self.ag3 = AttentionGate(f_g=enc_channels[3], f_l=enc_channels[2], f_int=64)
self.ag2 = AttentionGate(f_g=enc_channels[2], f_l=enc_channels[1], f_int=32)
self.ag1 = AttentionGate(f_g=enc_channels[1], f_l=enc_channels[0], f_int=16)
self.up4 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
self.dec4 = conv_block(256 + enc_channels[3], enc_channels[3])
self.up3 = nn.ConvTranspose2d(enc_channels[3], enc_channels[3], kernel_size=2, stride=2)
self.dec3 = conv_block(enc_channels[3] + enc_channels[2], enc_channels[2])
self.up2 = nn.ConvTranspose2d(enc_channels[2], enc_channels[2], kernel_size=2, stride=2)
self.dec2 = conv_block(enc_channels[2] + enc_channels[1], enc_channels[1])
self.up1 = nn.ConvTranspose2d(enc_channels[1], enc_channels[1], kernel_size=2, stride=2)
self.dec1 = conv_block(enc_channels[1] + enc_channels[0], enc_channels[0])
self.final_up = nn.ConvTranspose2d(enc_channels[0], 32, kernel_size=2, stride=2)
self.out = nn.Conv2d(32, num_classes, kernel_size=1)
self.ds3 = nn.Conv2d(enc_channels[2], num_classes, kernel_size=1)
self.ds2 = nn.Conv2d(enc_channels[1], num_classes, kernel_size=1)
def forward(self, x):
feats = self.encoder(x)
e0, e1, e2, e3, e4 = feats
c = self.center(e4)
d4 = self.up4(c)
e3_att = self.ag4(g=d4, x=e3)
if d4.shape != e3_att.shape:
d4 = F.interpolate(d4, size=e3_att.shape[2:], mode='bilinear', align_corners=False)
d4 = self.dec4(torch.cat([d4, e3_att], dim=1))
d3 = self.up3(d4)
e2_att = self.ag3(g=d3, x=e2)
if d3.shape != e2_att.shape:
d3 = F.interpolate(d3, size=e2_att.shape[2:], mode='bilinear', align_corners=False)
d3 = self.dec3(torch.cat([d3, e2_att], dim=1))
d2 = self.up2(d3)
e1_att = self.ag2(g=d2, x=e1)
if d2.shape != e1_att.shape:
d2 = F.interpolate(d2, size=e1_att.shape[2:], mode='bilinear', align_corners=False)
d2 = self.dec2(torch.cat([d2, e1_att], dim=1))
d1 = self.up1(d2)
e0_att = self.ag1(g=d1, x=e0)
if d1.shape != e0_att.shape:
d1 = F.interpolate(d1, size=e0_att.shape[2:], mode='bilinear', align_corners=False)
d1 = self.dec1(torch.cat([d1, e0_att], dim=1))
out = self.final_up(d1)
return self.out(out)
class EfficientNetB3Classifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = models.efficientnet_b3(weights=None)
in_features = self.backbone.classifier[1].in_features
self.backbone.classifier = nn.Sequential(
nn.Dropout(p=0.3, inplace=True),
nn.Linear(in_features, 1),
)
def forward(self, x):
return self.backbone(x)
def load_yolo(path: str):
model = YOLO(path)
model.to(DEVICE)
print("YOLO încărcat")
return model
def load_unet(path: str):
model = UNetEfficientNetB5(num_classes=3, pretrained=False).to(DEVICE)
state_dict = torch.load(path, map_location=DEVICE)
model.load_state_dict(state_dict, strict=True)
model.eval()
print("U-Net încărcat")
return model
def load_efficientnet(path: str):
model = EfficientNetB3Classifier().to(DEVICE)
state_dict = torch.load(path, map_location=DEVICE)
model.load_state_dict(state_dict, strict=False)
model.eval()
print("EfficientNet încărcat")
return model
def load_medgemma(model_id: str = "google/medgemma-1.5-4b-it"):
processor = AutoProcessor.from_pretrained(model_id, use_fast=False)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.eval()
print("MedGemma încărcat")
return model, processor
def load_fusion(pkl_path: str, thresh_path: str):
fusion_model = joblib.load(pkl_path)
threshold = joblib.load(thresh_path)
print(f"Fusion model încărcat (threshold={threshold:.4f})")
return fusion_model, threshold