Ariyan-Pro's picture
Enterprise Adversarial ML Governance Engine v5.0 LTS
f4bee9e
"""
Model utilities: loading, saving, evaluation, etc.
"""
import torch
import torch.nn as nn
import json
import yaml
from pathlib import Path
import numpy as np
from typing import Dict, Any, Optional
from datetime import datetime
def save_model(model: nn.Module, path: str, metadata: Optional[Dict] = None):
"""
Save model with metadata
Args:
model: PyTorch model
path: Path to save model
metadata: Additional metadata to save
"""
Path(path).parent.mkdir(parents=True, exist_ok=True)
# Save model state
torch.save({
'state_dict': model.state_dict(),
'model_class': model.__class__.__name__,
'metadata': metadata or {}
}, path)
# Save model card
model_card = {
'path': path,
'model_class': model.__class__.__name__,
'parameters': sum(p.numel() for p in model.parameters()),
'trainable_parameters': sum(p.numel() for p in model.parameters() if p.requires_grad),
'save_timestamp': str(datetime.now()),
**metadata
}
model_card_path = Path(path).with_suffix('.json')
with open(model_card_path, 'w') as f:
json.dump(model_card, f, indent=2)
def load_model(path: str, model_class: Optional[nn.Module] = None, device: str = 'cpu'):
"""
Load model with error handling
Args:
path: Path to saved model
model_class: Model class (if None, tries to import from saved metadata)
device: Device to load model on
Returns:
Loaded model and metadata
"""
if not Path(path).exists():
raise FileNotFoundError(f"Model file not found: {path}")
# FIX: Remove weights_only=True to handle numpy objects
checkpoint = torch.load(path, map_location=device) # No weights_only
if model_class is None:
# Try to import model class from base directory
import sys
sys.path.insert(0, 'models/base')
try:
module = __import__('mnist_cnn')
model_class = getattr(module, checkpoint['model_class'])
except ImportError:
raise ValueError(f"Could not import model class: {checkpoint['model_class']}")
model = model_class()
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
return model, checkpoint.get('metadata', {})
def evaluate_model(model: nn.Module, dataloader: torch.utils.data.DataLoader,
device: str = 'cpu') -> Dict[str, float]:
"""
Evaluate model accuracy
Args:
model: PyTorch model
dataloader: DataLoader for evaluation
device: Device for computation
Returns:
Dictionary of metrics
"""
model.eval()
correct = 0
total = 0
losses = []
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
# Calculate loss
loss = criterion(output, target)
losses.append(loss.item())
# Calculate accuracy
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
total += target.size(0)
accuracy = 100. * correct / total
avg_loss = np.mean(losses)
return {
'accuracy': accuracy,
'loss': avg_loss,
'correct': correct,
'total': total
}
def get_model_summary(model: nn.Module) -> str:
"""Generate a summary of model architecture"""
summary_lines = []
total_params = 0
trainable_params = 0
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.BatchNorm1d)):
num_params = sum(p.numel() for p in module.parameters())
total_params += num_params
trainable_params += sum(p.numel() for p in module.parameters() if p.requires_grad)
if isinstance(module, nn.Conv2d):
summary_lines.append(
f"{name}: Conv2d(in={module.in_channels}, out={module.out_channels}, "
f"kernel={module.kernel_size}, stride={module.stride})"
)
elif isinstance(module, nn.Linear):
summary_lines.append(
f"{name}: Linear(in={module.in_features}, out={module.out_features})"
)
summary = "\n".join(summary_lines)
summary += f"\n\nTotal parameters: {total_params:,}"
summary += f"\nTrainable parameters: {trainable_params:,}"
summary += f"\nNon-trainable parameters: {total_params - trainable_params:,}"
return summary
def update_registry(model_name: str, path: str, metadata: Dict[str, Any]):
"""Update model registry"""
registry_path = Path("models/registry.json")
if registry_path.exists():
with open(registry_path, 'r') as f:
try:
registry = json.load(f)
except json.JSONDecodeError:
registry = {}
else:
registry = {}
registry[model_name] = {
'path': path,
'input_size': '1x28x28',
'num_classes': 10,
'metadata': metadata,
'timestamp': str(datetime.now())
}
with open(registry_path, 'w') as f:
json.dump(registry, f, indent=2)
# Keep the datetime import at the end