import os from pathlib import Path import torch import torch.nn as nn from torchvision import transforms from transformers import ( CLIPVisionConfig, CLIPVisionModel, Swinv2Config, Swinv2ForImageClassification, ) def get_swinv2_transform(): return transforms.Compose( [ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(256), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ] ) def get_clip_transform(): return transforms.Compose( [ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), ), ] ) class SwinV2Classifier(nn.Module): HF_MODEL_ID = "microsoft/swinv2-small-patch4-window16-256" def __init__(self, ckpt_path: str, num_labels: int = 2): super().__init__() self.num_labels = num_labels config = Swinv2Config.from_pretrained( self.HF_MODEL_ID, num_labels=num_labels, ignore_mismatched_sizes=True ) self.model = Swinv2ForImageClassification(config) self._load_weights(ckpt_path) def _load_weights(self, ckpt_path: str): path = Path(ckpt_path) if not path.exists(): raise FileNotFoundError(f"Checkpoint não encontrado: {ckpt_path}") if path.suffix == ".safetensors": from safetensors.torch import load_file state_dict = load_file(str(path), device="cpu") else: state_dict = torch.load(str(path), map_location="cpu") self.model.load_state_dict(state_dict, strict=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(pixel_values=x).logits def predict_prob(self, x: torch.Tensor) -> float: with torch.no_grad(): logits = self.forward(x) return torch.softmax(logits, dim=1)[0, 1].item() class DF40CLIPModel(nn.Module): def __init__(self, num_labels=2): super().__init__() config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") self.backbone = CLIPVisionModel(config) self.head = nn.Linear(config.hidden_size, num_labels) def forward(self, pixel_values): outputs = self.backbone(pixel_values=pixel_values) return self.head(outputs.pooler_output)