deepfake_model / model.py
Simma7's picture
Create model.py
5ad0794 verified
import torch
import timm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_model(model_path, model_type="vit"):
if model_type == "clip":
model = timm.create_model("vit_base_patch16_clip_224.openai", pretrained=False, num_classes=1)
else:
model = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=1)
checkpoint = torch.load(model_path, map_location="cpu")
if isinstance(checkpoint, dict):
state_dict = checkpoint.get("model_state_dict", checkpoint.get("state_dict", checkpoint))
else:
state_dict = checkpoint
model.load_state_dict(state_dict, strict=False)
model.to(DEVICE)
model.eval()
return model