Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from transformers import CLIPModel, CLIPProcessor, ViTFeatureExtractor, ViTForImageClassification | |
| _clip_model = None | |
| _clip_processor = None | |
| _age_model = None | |
| _age_extractor = None | |
| # nateraw/vit-age-classifier outputs probabilities over these age range buckets | |
| _AGE_LABELS = ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"] | |
| _AGE_MIDPOINTS = [1, 6, 14, 24, 34, 44, 54, 64, 74] | |
| def _get_clip(): | |
| global _clip_model, _clip_processor | |
| if _clip_model is None: | |
| _clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| _clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| return _clip_model, _clip_processor | |
| def _get_age_model(): | |
| global _age_model, _age_extractor | |
| if _age_model is None: | |
| _age_model = ViTForImageClassification.from_pretrained("nateraw/vit-age-classifier") | |
| _age_extractor = ViTFeatureExtractor.from_pretrained("nateraw/vit-age-classifier") | |
| return _age_model, _age_extractor | |
| def _embed(image_path): | |
| model, processor = _get_clip() | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| features = model.get_image_features(**inputs) | |
| features = features / features.norm(dim=-1, keepdim=True) | |
| return features[0].numpy() | |
| def _estimate_age(image_path): | |
| model, extractor = _get_age_model() | |
| image = Image.open(image_path).convert("RGB") | |
| inputs = extractor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1)[0].numpy() | |
| return int(round(float(np.dot(probs, _AGE_MIDPOINTS)))) | |
| def _to_pct(sim): | |
| # CLIP features are already L2-normalized so dot product == cosine similarity (-1 to 1) | |
| return round(max(0.0, min(1.0, (float(sim) + 1) / 2)) * 100, 1) | |
| def analyze(father_path, mother_path, child_path): | |
| child_emb = _embed(child_path) | |
| father_emb = _embed(father_path) | |
| mother_emb = _embed(mother_path) | |
| return { | |
| "age": _estimate_age(child_path), | |
| "father_score": _to_pct(np.dot(child_emb, father_emb)), | |
| "mother_score": _to_pct(np.dot(child_emb, mother_emb)), | |
| } | |