ParentCloseTesting / src /backends /hf_backend.py
Prince-1's picture
Updated the Project
e793c54 verified
Raw
History Blame Contribute Delete
2.3 kB
import numpy as np
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, ViTFeatureExtractor, ViTForImageClassification
_clip_model = None
_clip_processor = None
_age_model = None
_age_extractor = None
# nateraw/vit-age-classifier outputs probabilities over these age range buckets
_AGE_LABELS = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"]
_AGE_MIDPOINTS = [1, 6, 14, 24, 34, 44, 54, 64, 74]
def _get_clip():
global _clip_model, _clip_processor
if _clip_model is None:
_clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
_clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
return _clip_model, _clip_processor
def _get_age_model():
global _age_model, _age_extractor
if _age_model is None:
_age_model = ViTForImageClassification.from_pretrained("nateraw/vit-age-classifier")
_age_extractor = ViTFeatureExtractor.from_pretrained("nateraw/vit-age-classifier")
return _age_model, _age_extractor
def _embed(image_path):
model, processor = _get_clip()
image = Image.open(image_path).convert("RGB")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
features = model.get_image_features(**inputs)
features = features / features.norm(dim=-1, keepdim=True)
return features[0].numpy()
def _estimate_age(image_path):
model, extractor = _get_age_model()
image = Image.open(image_path).convert("RGB")
inputs = extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)[0].numpy()
return int(round(float(np.dot(probs, _AGE_MIDPOINTS))))
def _to_pct(sim):
# CLIP features are already L2-normalized so dot product == cosine similarity (-1 to 1)
return round(max(0.0, min(1.0, (float(sim) + 1) / 2)) * 100, 1)
def analyze(father_path, mother_path, child_path):
child_emb = _embed(child_path)
father_emb = _embed(father_path)
mother_emb = _embed(mother_path)
return {
"age": _estimate_age(child_path),
"father_score": _to_pct(np.dot(child_emb, father_emb)),
"mother_score": _to_pct(np.dot(child_emb, mother_emb)),
}