OncoVision-X / src /models /malignancy_classifier.py
adityasync's picture
Clean OncoVision-X deployment with LFS
8960670
#!/usr/bin/env python3
"""Malignancy classifier using a 3D ResNet backbone."""
from pathlib import Path
import torch
import torch.nn as nn
class MalignancyClassifier(nn.Module):
"""3D ResNet-based malignancy classifier."""
def __init__(self, pretrained_path=None, use_torchvision_pretrained=True):
super().__init__()
try:
from torchvision.models.video import R3D_18_Weights, r3d_18
if use_torchvision_pretrained and pretrained_path is None:
self.backbone = r3d_18(weights=R3D_18_Weights.DEFAULT)
else:
self.backbone = r3d_18(weights=None)
except ImportError:
from torchvision.models.video import r3d_18
self.backbone = r3d_18(
pretrained=(use_torchvision_pretrained and pretrained_path is None)
)
if pretrained_path and Path(pretrained_path).exists():
try:
state = torch.load(pretrained_path, map_location='cpu', weights_only=False)
if 'state_dict' in state:
state = state['state_dict']
clean_state = {key.replace('module.', ''): value for key, value in state.items()}
self.backbone.load_state_dict(clean_state, strict=False)
print(f" Loaded MedicalNet weights from {pretrained_path}")
except Exception as exc:
print(f" Warning: could not load MedicalNet weights: {exc}")
self.backbone.fc = nn.Sequential(
nn.Linear(512, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, x):
"""Forward pass on (B, 1, 32, 32, 32) patches."""
x = x.repeat(1, 3, 1, 1, 1)
return self.backbone(x)