panda1835 commited on
Commit
4f13ac7
·
verified ·
1 Parent(s): 59d3daf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as T
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ # DINOv2
10
+
11
+ # Select checkpoint
12
+ dinov2_ckpt = ['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'][1]
13
+ dinov2 = torch.hub.load('facebookresearch/dinov2', dinov2_ckpt)
14
+
15
+ dinov2.to(device)
16
+ print()
17
+
18
+ transform_image = T.Compose([
19
+ T.Resize((224, 224)),
20
+ T.ToTensor(),
21
+ T.Normalize(mean=[0.485, 0.456, 0.406],
22
+ std=[0.229, 0.224, 0.225])
23
+ ])
24
+
25
+ def predict(image):
26
+ """
27
+ Predict the identity of an image.
28
+
29
+ Args:
30
+ image: A PIL Image object.
31
+
32
+ Returns:
33
+ A string representing the predicted identity of the image.
34
+ """
35
+
36
+ # Convert the image to a tensor.
37
+ transformed_img = transform_image(image)[:3].unsqueeze(0).to(device)
38
+
39
+ # Get the embedding of the image.
40
+ with torch.no_grad():
41
+ embedding = dinov2(transformed_img)
42
+ print(embedding.shape)
43
+ embedding = embedding[0].cpu().numpy().tolist()
44
+ print(embedding)
45
+ return {
46
+ "embedding": embedding
47
+ }
48
+
49
+ # Create a Gradio interface.
50
+ interface = gr.Interface(
51
+ fn=predict,
52
+ inputs=[gr.Image(type='pil')],
53
+ outputs=[gr.JSON()],
54
+ title="DINOv2 Image Retrieval",
55
+ )
56
+
57
+ # Start the Gradio server.
58
+ interface.launch()