Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import gradio as gr | |
| from facenet_pytorch import InceptionResnetV1, MTCNN | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| import wikipedia | |
| dataset = load_dataset("tonyassi/celebrity-1000", split="train") | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| transform = transforms.Compose([ | |
| transforms.Resize((160, 160)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5]) | |
| ]) | |
| class HF_CelebDataset(Dataset): | |
| def __init__(self, hf_dataset, transform=None): | |
| self.dataset = hf_dataset | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| img = self.dataset[idx]['image'] | |
| label = self.dataset[idx]['label'] | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| if self.transform: | |
| img = self.transform(img) | |
| return img, label | |
| subset_dataset = HF_CelebDataset(dataset, transform) | |
| loader = DataLoader(subset_dataset, batch_size=64, shuffle=False, num_workers=2) | |
| facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device) | |
| mtcnn = MTCNN(image_size=160, margin=20, device=device) | |
| embeddings, labels = [], [] | |
| with torch.no_grad(): | |
| for imgs, lbls in tqdm(loader): | |
| imgs = imgs.to(device) | |
| emb = facenet(imgs) | |
| emb = F.normalize(emb, p=2, dim=1) | |
| embeddings.append(emb.cpu()) | |
| labels.extend(lbls.numpy()) | |
| embeddings = torch.cat(embeddings) | |
| labels = np.array(labels) | |
| unique_labels = np.unique(labels) | |
| avg_embs, celeb_labels, celeb_names = [], [], [] | |
| for lbl in unique_labels: | |
| idxs = np.where(labels == lbl)[0] | |
| mean_emb = embeddings[idxs].mean(dim=0) | |
| mean_emb = F.normalize(mean_emb, p=2, dim=0) | |
| avg_embs.append(mean_emb) | |
| celeb_labels.append(lbl) | |
| celeb_names.append(dataset.features['label'].names[lbl]) | |
| celeb_embeddings = torch.stack(avg_embs).to(device) | |
| def find_most_similar(user_img, model, celeb_embeddings, celeb_names, top_k=1): | |
| face = mtcnn(user_img) | |
| if face is None: | |
| face = transform(user_img) | |
| img_tensor = face.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| user_emb = model(img_tensor) | |
| user_emb = F.normalize(user_emb, p=2, dim=1) | |
| similarity = torch.matmul(user_emb, celeb_embeddings.T) | |
| topk_vals, topk_idx = torch.topk(similarity, top_k, dim=1) | |
| top_labels = [celeb_names[i] for i in topk_idx[0]] | |
| top_scores = [float(topk_vals[0][i]) for i in range(top_k)] | |
| return list(zip(top_labels, top_scores)) | |
| def get_wikipedia_info(name): | |
| try: | |
| summary = wikipedia.summary(name, sentences=3, auto_suggest=False, redirect=True) | |
| return summary | |
| except wikipedia.exceptions.DisambiguationError as e: | |
| return f"Multiple results found for {name}: {e.options[:3]}" | |
| except wikipedia.exceptions.PageError: | |
| return f"No information found for {name}." | |
| except Exception as e: | |
| return f"Error fetching info: {e}" | |
| def gradio_find(user_img): | |
| top_result = find_most_similar(user_img, facenet, celeb_embeddings, celeb_names, top_k=1)[0] | |
| name, score = top_result | |
| lbl = celeb_labels[celeb_names.index(name)] | |
| idxs = np.where(labels == lbl)[0] | |
| result_imgs = [] | |
| for i in idxs[:6]: | |
| img = dataset[int(i)]['image'].resize((128, 128)) | |
| caption = f"{name}" | |
| result_imgs.append((img, caption)) | |
| wiki_text = get_wikipedia_info(name) | |
| return result_imgs, wiki_text | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Which Celebrity Do You Look Like?") | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", label="Upload your face") | |
| gallery = gr.Gallery(label="Similar celebrity", columns=3) | |
| info_box = gr.Textbox(label="Wikipedia Info", lines=10) | |
| btn = gr.Button("Find Similar") | |
| btn.click(fn=gradio_find, inputs=img_input, outputs=[gallery, info_box]) | |
| gr.close_all() | |
| demo.launch(share=True, inbrowser=True, prevent_thread_lock=True) | |
| if __name__ == "__main__": | |
| demo.launch() |