ConvNext-large-cpp / predictor.py
Yashodhar29's picture
Upload folder using huggingface_hub
f4a9a1c verified
import torch
import timm
from PIL import Image
from torchvision import transforms
class AIDetector:
def __init__(self, model_path, config_path):
with open(config_path, 'r') as f:
self.config = __import__('json').load(f)
# Load Architecture
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = timm.create_model(self.config['model_name'], pretrained=False, num_classes=2)
# Load Weights
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
# Setup Transforms
self.transform = transforms.Compose([
transforms.Resize((self.config['img_size'], self.config['img_size'])),
transforms.ToTensor(),
transforms.Normalize(mean=self.config['mean'], std=self.config['std']),
])
def predict(self, image_path):
img = Image.open(image_path).convert('RGB')
img_t = self.transform(img).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(img_t)
probs = torch.nn.functional.softmax(outputs, dim=1)
conf, pred = torch.max(probs, 1)
label = self.config['labels'][str(pred.item())]
return {"prediction": label, "confidence": conf.item()}
# Example usage:
# detector = AIDetector('convnext_ai_detector_final.pth', 'model_config.json')
# print(detector.predict('test_image.jpg'))