Pytorch / app.py
thanhcong2001's picture
Update app.py
bd0098a verified
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import gradio as gr
from facenet_pytorch import InceptionResnetV1, MTCNN
from datasets import load_dataset
from tqdm import tqdm
import wikipedia
dataset = load_dataset("tonyassi/celebrity-1000", split="train")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose([
transforms.Resize((160, 160)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
])
class HF_CelebDataset(Dataset):
def __init__(self, hf_dataset, transform=None):
self.dataset = hf_dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img = self.dataset[idx]['image']
label = self.dataset[idx]['label']
if img.mode != 'RGB':
img = img.convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
subset_dataset = HF_CelebDataset(dataset, transform)
loader = DataLoader(subset_dataset, batch_size=64, shuffle=False, num_workers=2)
facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
mtcnn = MTCNN(image_size=160, margin=20, device=device)
embeddings, labels = [], []
with torch.no_grad():
for imgs, lbls in tqdm(loader):
imgs = imgs.to(device)
emb = facenet(imgs)
emb = F.normalize(emb, p=2, dim=1)
embeddings.append(emb.cpu())
labels.extend(lbls.numpy())
embeddings = torch.cat(embeddings)
labels = np.array(labels)
unique_labels = np.unique(labels)
avg_embs, celeb_labels, celeb_names = [], [], []
for lbl in unique_labels:
idxs = np.where(labels == lbl)[0]
mean_emb = embeddings[idxs].mean(dim=0)
mean_emb = F.normalize(mean_emb, p=2, dim=0)
avg_embs.append(mean_emb)
celeb_labels.append(lbl)
celeb_names.append(dataset.features['label'].names[lbl])
celeb_embeddings = torch.stack(avg_embs).to(device)
def find_most_similar(user_img, model, celeb_embeddings, celeb_names, top_k=1):
face = mtcnn(user_img)
if face is None:
face = transform(user_img)
img_tensor = face.unsqueeze(0).to(device)
with torch.no_grad():
user_emb = model(img_tensor)
user_emb = F.normalize(user_emb, p=2, dim=1)
similarity = torch.matmul(user_emb, celeb_embeddings.T)
topk_vals, topk_idx = torch.topk(similarity, top_k, dim=1)
top_labels = [celeb_names[i] for i in topk_idx[0]]
top_scores = [float(topk_vals[0][i]) for i in range(top_k)]
return list(zip(top_labels, top_scores))
def get_wikipedia_info(name):
try:
summary = wikipedia.summary(name, sentences=3, auto_suggest=False, redirect=True)
return summary
except wikipedia.exceptions.DisambiguationError as e:
return f"Multiple results found for {name}: {e.options[:3]}"
except wikipedia.exceptions.PageError:
return f"No information found for {name}."
except Exception as e:
return f"Error fetching info: {e}"
def gradio_find(user_img):
top_result = find_most_similar(user_img, facenet, celeb_embeddings, celeb_names, top_k=1)[0]
name, score = top_result
lbl = celeb_labels[celeb_names.index(name)]
idxs = np.where(labels == lbl)[0]
result_imgs = []
for i in idxs[:6]:
img = dataset[int(i)]['image'].resize((128, 128))
caption = f"{name}"
result_imgs.append((img, caption))
wiki_text = get_wikipedia_info(name)
return result_imgs, wiki_text
with gr.Blocks() as demo:
gr.Markdown("## Which Celebrity Do You Look Like?")
with gr.Row():
img_input = gr.Image(type="pil", label="Upload your face")
gallery = gr.Gallery(label="Similar celebrity", columns=3)
info_box = gr.Textbox(label="Wikipedia Info", lines=10)
btn = gr.Button("Find Similar")
btn.click(fn=gradio_find, inputs=img_input, outputs=[gallery, info_box])
gr.close_all()
demo.launch(share=True, inbrowser=True, prevent_thread_lock=True)
if __name__ == "__main__":
demo.launch()