import numpy as np import torch from PIL import Image from transformers import CLIPModel, CLIPProcessor, ViTFeatureExtractor, ViTForImageClassification _clip_model = None _clip_processor = None _age_model = None _age_extractor = None # nateraw/vit-age-classifier outputs probabilities over these age range buckets _AGE_LABELS = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] _AGE_MIDPOINTS = [1, 6, 14, 24, 34, 44, 54, 64, 74] def _get_clip(): global _clip_model, _clip_processor if _clip_model is None: _clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") _clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") return _clip_model, _clip_processor def _get_age_model(): global _age_model, _age_extractor if _age_model is None: _age_model = ViTForImageClassification.from_pretrained("nateraw/vit-age-classifier") _age_extractor = ViTFeatureExtractor.from_pretrained("nateraw/vit-age-classifier") return _age_model, _age_extractor def _embed(image_path): model, processor = _get_clip() image = Image.open(image_path).convert("RGB") inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): features = model.get_image_features(**inputs) features = features / features.norm(dim=-1, keepdim=True) return features[0].numpy() def _estimate_age(image_path): model, extractor = _get_age_model() image = Image.open(image_path).convert("RGB") inputs = extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=-1)[0].numpy() return int(round(float(np.dot(probs, _AGE_MIDPOINTS)))) def _to_pct(sim): # CLIP features are already L2-normalized so dot product == cosine similarity (-1 to 1) return round(max(0.0, min(1.0, (float(sim) + 1) / 2)) * 100, 1) def analyze(father_path, mother_path, child_path): child_emb = _embed(child_path) father_emb = _embed(father_path) mother_emb = _embed(mother_path) return { "age": _estimate_age(child_path), "father_score": _to_pct(np.dot(child_emb, father_emb)), "mother_score": _to_pct(np.dot(child_emb, mother_emb)), }