Spaces:
Sleeping
Sleeping
| from transformers import ViTModel, ViTImageProcessor | |
| from PIL import Image, ImageOps | |
| import gradio as gr | |
| import torch | |
| from datasets import Dataset | |
| from torch.nn import CosineSimilarity | |
| image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
| image_encoder = ViTModel.from_pretrained("model/image_encoder/epoch_29").eval() | |
| scribble_encoder = ViTModel.from_pretrained("model/scibble_encoder/epoch_29").eval() | |
| candidates: Dataset = None | |
| cosinesimilarity = CosineSimilarity() | |
| def load_candidates(candidate_dir, progress=gr.Progress()): | |
| def preprocess(examples): | |
| images = [image for image in examples["image"]] | |
| examples["image_embedding"] = image_encoder(image_processor(images, return_tensors="pt")["pixel_values"])["pooler_output"] | |
| progress.update(len(images)) | |
| return examples | |
| dataset = [dict(image=Image.open(tempfile.name).convert("RGB").resize((224, 224))) for tempfile in progress.tqdm(candidate_dir)] | |
| dataset = Dataset.from_list(dataset) | |
| progress.tqdm(dataset) | |
| with torch.no_grad(): | |
| dataset = dataset.map(preprocess, batched=True, batch_size=1) | |
| return dataset | |
| def load_candidates_in_cache(candidate_files): | |
| global candidates | |
| candidates = load_candidates(candidate_files) | |
| return [f.name for f in candidate_files] | |
| def scribble_matching(input_img: Image): | |
| input_img = ImageOps.invert(input_img) | |
| scribble = input_img | |
| scribble_embedding = scribble_encoder(image_processor(scribble, return_tensors="pt")["pixel_values"])["pooler_output"].to("cpu") | |
| image_embeddings = torch.tensor(candidates["image_embedding"], dtype=torch.float32) | |
| sim = cosinesimilarity(scribble_embedding, image_embeddings) | |
| predicts = torch.topk(sim, k=15) | |
| output_imgs = candidates[predicts.indices.tolist()]["image"] | |
| labels = predicts.values.tolist() | |
| labels = [f"{label:.3f}" for label in labels] | |
| return list(zip([input_img] + output_imgs, ["preview"] + labels)) | |
| def main(): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| input_img = gr.Image(type="pil", label="scribble", height=512, width=512, source="canvas", tool="color-sketch", brush_radius=10) | |
| prediction_gallery = gr.Gallery(min_width=512, columns=4, show_label=True) | |
| with gr.Row(): | |
| candidate_dir = gr.File(file_count="directory", min_width=300, height=300) | |
| load_candidates_btn = gr.Button("Load", variant="secondary", size="sm") | |
| btn = gr.Button("Scribble Matching", variant="primary") | |
| load_candidates_btn.click(fn=load_candidates_in_cache, inputs=[candidate_dir], outputs=candidate_dir) | |
| btn.click(fn=scribble_matching, inputs=[input_img], outputs=[prediction_gallery]) | |
| demo.queue().launch() | |
| if __name__ == "__main__": | |
| main() |