ulichovick commited on
Commit
4cd7b8d
·
verified ·
1 Parent(s): ce78316

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -0
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import SurfinBird
2
+ from torchvision import transforms, io
3
+ import csv
4
+ import gradio as gr
5
+
6
+ with open("birds.csv", "r") as r:
7
+ birds = list(csv.reader(r, delimiter=","))
8
+
9
+ config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525, "labels": birds}
10
+
11
+ model = SurfinBird(config=config)
12
+ model = SurfinBird.from_pretrained("ulichovick/birdnet")
13
+
14
+ usr_img_transform = transforms.Compose([
15
+ transforms.Resize(size=(224, 224)),
16
+ ])
17
+
18
+ target_image = io.read_image(str(image_path)).type(torch.float32)
19
+ target_image = target_image / 255.
20
+ target_image = usr_img_transform(target_image)
21
+
22
+ model.eval()
23
+ with torch.inference_mode():
24
+ target_image = target_image.unsqueeze(dim=0)
25
+ target_image_pred = model(target_image)
26
+ target_image_pred_label = torch.argmax(target_image_pred, dim=1)
27
+
28
+ #print(config["labels"][target_image_pred_label])
29
+
30
+ gr.Interface(
31
+ transcribe,
32
+ inputs=gr.Image(label="gimme da bird"),
33
+ outputs=["textbox"],
34
+ title=titulo,
35
+ description=desc,
36
+ ).launch()