thanhcong2001 commited on
Commit
d464b17
·
verified ·
1 Parent(s): c6edf72

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -0
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import gradio as gr
8
+ from facenet_pytorch import InceptionResnetV1, MTCNN
9
+ from datasets import load_dataset
10
+ from tqdm import tqdm
11
+ import nest_asyncio
12
+ import wikipedia
13
+ dataset = load_dataset("tonyassi/celebrity-1000", split="train")
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
+
16
+ transform = transforms.Compose([
17
+ transforms.Resize((160, 160)),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
20
+ ])
21
+ class HF_CelebDataset(Dataset):
22
+ def __init__(self, hf_dataset, transform=None):
23
+ self.dataset = hf_dataset
24
+ self.transform = transform
25
+
26
+ def __len__(self):
27
+ return len(self.dataset)
28
+
29
+ def __getitem__(self, idx):
30
+ img = self.dataset[idx]['image']
31
+ label = self.dataset[idx]['label']
32
+ if img.mode != 'RGB':
33
+ img = img.convert('RGB')
34
+ if self.transform:
35
+ img = self.transform(img)
36
+ return img, label
37
+
38
+ subset_dataset = HF_CelebDataset(dataset, transform)
39
+ loader = DataLoader(subset_dataset, batch_size=64, shuffle=False, num_workers=2)
40
+ facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
41
+ mtcnn = MTCNN(image_size=160, margin=20, device=device)
42
+ embeddings, labels = [], []
43
+ with torch.no_grad():
44
+ for imgs, lbls in tqdm(loader):
45
+ imgs = imgs.to(device)
46
+ emb = facenet(imgs)
47
+ emb = F.normalize(emb, p=2, dim=1)
48
+ embeddings.append(emb.cpu())
49
+ labels.extend(lbls.numpy())
50
+
51
+ embeddings = torch.cat(embeddings)
52
+ labels = np.array(labels)
53
+ unique_labels = np.unique(labels)
54
+ avg_embs, celeb_labels, celeb_names = [], [], []
55
+ for lbl in unique_labels:
56
+ idxs = np.where(labels == lbl)[0]
57
+ mean_emb = embeddings[idxs].mean(dim=0)
58
+ mean_emb = F.normalize(mean_emb, p=2, dim=0)
59
+ avg_embs.append(mean_emb)
60
+ celeb_labels.append(lbl)
61
+ celeb_names.append(dataset.features['label'].names[lbl])
62
+
63
+ celeb_embeddings = torch.stack(avg_embs).to(device)
64
+ def find_most_similar(user_img, model, celeb_embeddings, celeb_names, top_k=1):
65
+ face = mtcnn(user_img)
66
+ if face is None:
67
+ face = transform(user_img)
68
+ img_tensor = face.unsqueeze(0).to(device)
69
+ with torch.no_grad():
70
+ user_emb = model(img_tensor)
71
+ user_emb = F.normalize(user_emb, p=2, dim=1)
72
+ similarity = torch.matmul(user_emb, celeb_embeddings.T)
73
+ topk_vals, topk_idx = torch.topk(similarity, top_k, dim=1)
74
+ top_labels = [celeb_names[i] for i in topk_idx[0]]
75
+ top_scores = [float(topk_vals[0][i]) for i in range(top_k)]
76
+ return list(zip(top_labels, top_scores))
77
+ def get_wikipedia_info(name):
78
+ try:
79
+ summary = wikipedia.summary(name, sentences=3, auto_suggest=False, redirect=True)
80
+ return summary
81
+ except wikipedia.exceptions.DisambiguationError as e:
82
+ return f"Multiple results found for {name}: {e.options[:3]}"
83
+ except wikipedia.exceptions.PageError:
84
+ return f"No information found for {name}."
85
+ except Exception as e:
86
+ return f"Error fetching info: {e}"
87
+ def gradio_find(user_img):
88
+ top_result = find_most_similar(user_img, facenet, celeb_embeddings, celeb_names, top_k=1)[0]
89
+ name, score = top_result
90
+ lbl = celeb_labels[celeb_names.index(name)]
91
+ idxs = np.where(labels == lbl)[0]
92
+
93
+ result_imgs = []
94
+ for i in idxs[:6]:
95
+ img = dataset[int(i)]['image'].resize((128, 128))
96
+ caption = f"{name} (sim={score:.3f})"
97
+ result_imgs.append((img, caption))
98
+ wiki_text = get_wikipedia_info(name)
99
+ return result_imgs, wiki_text
100
+ nest_asyncio.apply()
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("## Which Celebrity Do You Look Like?")
103
+ with gr.Row():
104
+ img_input = gr.Image(type="pil", label="Upload your face")
105
+ gallery = gr.Gallery(label="Top 3 similar celebrities", columns=3)
106
+ info_box = gr.Textbox(label="Wikipedia Info", lines=10)
107
+ btn = gr.Button("Find Similar")
108
+ btn.click(fn=gradio_find, inputs=img_input, outputs=[gallery, info_box])
109
+
110
+ gr.close_all()
111
+ demo.launch(share=True, inbrowser=True, prevent_thread_lock=True)