import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import numpy as np import gradio as gr from facenet_pytorch import InceptionResnetV1, MTCNN from datasets import load_dataset from tqdm import tqdm import wikipedia dataset = load_dataset("tonyassi/celebrity-1000", split="train") device = 'cuda' if torch.cuda.is_available() else 'cpu' transform = transforms.Compose([ transforms.Resize((160, 160)), transforms.ToTensor(), transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) ]) class HF_CelebDataset(Dataset): def __init__(self, hf_dataset, transform=None): self.dataset = hf_dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): img = self.dataset[idx]['image'] label = self.dataset[idx]['label'] if img.mode != 'RGB': img = img.convert('RGB') if self.transform: img = self.transform(img) return img, label subset_dataset = HF_CelebDataset(dataset, transform) loader = DataLoader(subset_dataset, batch_size=64, shuffle=False, num_workers=2) facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device) mtcnn = MTCNN(image_size=160, margin=20, device=device) embeddings, labels = [], [] with torch.no_grad(): for imgs, lbls in tqdm(loader): imgs = imgs.to(device) emb = facenet(imgs) emb = F.normalize(emb, p=2, dim=1) embeddings.append(emb.cpu()) labels.extend(lbls.numpy()) embeddings = torch.cat(embeddings) labels = np.array(labels) unique_labels = np.unique(labels) avg_embs, celeb_labels, celeb_names = [], [], [] for lbl in unique_labels: idxs = np.where(labels == lbl)[0] mean_emb = embeddings[idxs].mean(dim=0) mean_emb = F.normalize(mean_emb, p=2, dim=0) avg_embs.append(mean_emb) celeb_labels.append(lbl) celeb_names.append(dataset.features['label'].names[lbl]) celeb_embeddings = torch.stack(avg_embs).to(device) def find_most_similar(user_img, model, celeb_embeddings, celeb_names, top_k=1): face = mtcnn(user_img) if face is None: face = transform(user_img) img_tensor = face.unsqueeze(0).to(device) with torch.no_grad(): user_emb = model(img_tensor) user_emb = F.normalize(user_emb, p=2, dim=1) similarity = torch.matmul(user_emb, celeb_embeddings.T) topk_vals, topk_idx = torch.topk(similarity, top_k, dim=1) top_labels = [celeb_names[i] for i in topk_idx[0]] top_scores = [float(topk_vals[0][i]) for i in range(top_k)] return list(zip(top_labels, top_scores)) def get_wikipedia_info(name): try: summary = wikipedia.summary(name, sentences=3, auto_suggest=False, redirect=True) return summary except wikipedia.exceptions.DisambiguationError as e: return f"Multiple results found for {name}: {e.options[:3]}" except wikipedia.exceptions.PageError: return f"No information found for {name}." except Exception as e: return f"Error fetching info: {e}" def gradio_find(user_img): top_result = find_most_similar(user_img, facenet, celeb_embeddings, celeb_names, top_k=1)[0] name, score = top_result lbl = celeb_labels[celeb_names.index(name)] idxs = np.where(labels == lbl)[0] result_imgs = [] for i in idxs[:6]: img = dataset[int(i)]['image'].resize((128, 128)) caption = f"{name}" result_imgs.append((img, caption)) wiki_text = get_wikipedia_info(name) return result_imgs, wiki_text with gr.Blocks() as demo: gr.Markdown("## Which Celebrity Do You Look Like?") with gr.Row(): img_input = gr.Image(type="pil", label="Upload your face") gallery = gr.Gallery(label="Similar celebrity", columns=3) info_box = gr.Textbox(label="Wikipedia Info", lines=10) btn = gr.Button("Find Similar") btn.click(fn=gradio_find, inputs=img_input, outputs=[gallery, info_box]) gr.close_all() demo.launch(share=True, inbrowser=True, prevent_thread_lock=True) if __name__ == "__main__": demo.launch()