id-detector / custom_tools /liveness_detection_tool.py
reagvis's picture
Upload 14 files
45419f0 verified
from agentlego.tools import BaseTool
from PIL import Image
import torch
import tempfile
import cv2
import os
class LivenessDetectionTool(BaseTool):
default_desc = 'Detects liveness in an image using a DinoV2 image classification model.'
def __init__(self):
super().__init__()
# Move model loading inside the class initialization
from transformers import AutoImageProcessor, AutoModelForImageClassification
self.processor = AutoImageProcessor.from_pretrained("nguyenkhoa/dinov2_Liveness_detection_v2.2.3")
self.model = AutoModelForImageClassification.from_pretrained("nguyenkhoa/dinov2_Liveness_detection_v2.2.3")
def apply(self, image_path: str) -> str:
try:
# Load image
image = Image.open(image_path).convert("RGB")
# Preprocess and infer
inputs = self.processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get prediction
predicted_class_idx = torch.argmax(probs).item()
predicted_label = self.model.config.id2label[predicted_class_idx]
confidence = round(probs[predicted_class_idx].item(), 4)
# Format result
result = f"Liveness: {predicted_label} (Confidence: {confidence})"
return result
except Exception as e:
return f"Error during liveness detection: {str(e)}"