disaster-app / src /model.py
yuto090612's picture
Upload src/model.py with huggingface_hub
edebaa5 verified
"""DINOv2 multi-head model for disaster building damage assessment.
Ported from hisaichi research code — simplified to inference-only,
original_only ROI mode (no two_stream, no CoVT, no cascade).
"""
import logging
from pathlib import Path
import timm
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
from .config import InferenceConfig
logger = logging.getLogger(__name__)
class DINOv2MultiHeadModel(nn.Module):
"""Multi-head model with DINOv2 backbone and optional auxiliary heads.
Architecture:
backbone (DINOv2 ViT-L/14) -> feature_transform -> head_full (6-class)
Optional auxiliary heads: head_damage (2), head_disaster_type (2), head_severity (3)
"""
def __init__(self, config: InferenceConfig) -> None:
super().__init__()
self.config = config
# Backbone
self.backbone = timm.create_model(
config.model_name,
pretrained=True,
num_classes=0,
img_size=config.image_size,
)
# Feature transform
self.feature_transform = nn.Sequential(
nn.Linear(config.hidden_dim, config.hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
)
# Main classification head
self.head_full = nn.Linear(config.hidden_dim, config.num_classes)
# Auxiliary heads (optional)
if config.use_auxiliary_heads and not config.ce_only:
self.head_damage = nn.Linear(config.hidden_dim, 2)
self.head_disaster_type = nn.Linear(config.hidden_dim, 2)
self.head_severity = nn.Linear(config.hidden_dim, 3)
else:
self.head_damage = None
self.head_disaster_type = None
self.head_severity = None
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass through backbone and classification heads.
Parameters
----------
x : torch.Tensor
Input images of shape ``(B, 3, H, W)``.
Returns
-------
dict[str, torch.Tensor]
Dictionary with key ``"full"`` (always present) and optionally
``"damage"``, ``"disaster_type"``, ``"severity"``.
"""
feat = self.backbone(x)
transformed = self.feature_transform(feat)
outputs: dict[str, torch.Tensor] = {
"full": self.head_full(transformed),
}
if self.head_damage is not None:
outputs["damage"] = self.head_damage(transformed)
outputs["disaster_type"] = self.head_disaster_type(transformed)
outputs["severity"] = self.head_severity(transformed)
return outputs
def build_model(config: InferenceConfig, device: str) -> nn.Module:
"""Create DINOv2 model with LoRA adapters applied.
Parameters
----------
config : InferenceConfig
Model configuration.
device : str
Target device (e.g. ``"cuda"`` or ``"cpu"``).
Returns
-------
nn.Module
Model with LoRA adapters, moved to ``device``.
"""
model = DINOv2MultiHeadModel(config).to(device)
modules_to_save = [
"feature_transform",
"head_full",
"head_damage",
"head_disaster_type",
"head_severity",
]
lora_config = LoraConfig(
r=config.lora_rank,
lora_alpha=config.lora_alpha,
target_modules=["qkv"],
lora_dropout=config.lora_dropout,
bias="none",
modules_to_save=modules_to_save,
)
model = get_peft_model(model, lora_config)
logger.info(
"Built DINOv2 model with LoRA (r=%d, alpha=%d) on %s",
config.lora_rank,
config.lora_alpha,
device,
)
return model
def load_checkpoint(model: nn.Module, checkpoint_path: Path, device: str) -> nn.Module:
"""Load trained weights from a checkpoint file.
Parameters
----------
model : nn.Module
Model with LoRA adapters (from :func:`build_model`).
checkpoint_path : Path
Path to ``best_model.pth``.
device : str
Target device for weight mapping.
Returns
-------
nn.Module
Model in eval mode with loaded weights.
Raises
------
FileNotFoundError
If ``checkpoint_path`` does not exist.
"""
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict)
model.eval()
logger.info("Loaded checkpoint from %s", checkpoint_path)
return model