DKatheesrupan commited on
Commit
8247df2
·
verified ·
1 Parent(s): 9e15079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -124
app.py CHANGED
@@ -1,125 +1,125 @@
1
- import os
2
- from pathlib import Path
3
-
4
- import gradio as gr
5
- from transformers import pipeline
6
-
7
-
8
- # ----------------------------
9
- # Paths
10
- # ----------------------------
11
-
12
- BASE_DIR = Path(__file__).resolve().parent
13
-
14
- # HIER ggf. den Modellordner anpassen
15
- MODEL_PATH = BASE_DIR.parent / "flower-vit"
16
-
17
- EXAMPLE_DIR = BASE_DIR / "example_images"
18
-
19
-
20
- # ----------------------------
21
- # Labels
22
- # ----------------------------
23
-
24
- CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"]
25
-
26
-
27
- # ----------------------------
28
- # Load models
29
- # ----------------------------
30
-
31
- print("Loading custom model...")
32
- vit_classifier = pipeline(
33
- "image-classification",
34
- model=str(MODEL_PATH)
35
- )
36
-
37
- print("Loading CLIP model...")
38
- clip_classifier = pipeline(
39
- task="zero-shot-image-classification",
40
- model="openai/clip-vit-base-patch32"
41
- )
42
-
43
-
44
- # ----------------------------
45
- # Helper functions
46
- # ----------------------------
47
-
48
- def normalize_custom_labels(results):
49
- id2label = {
50
- "LABEL_0": "cheetah",
51
- "LABEL_1": "leopard",
52
- "LABEL_2": "lion",
53
- "LABEL_3": "puma",
54
- "LABEL_4": "tiger",
55
- }
56
-
57
- output = {}
58
-
59
- for r in results:
60
- label = r["label"]
61
- score = float(r["score"])
62
-
63
- if label in id2label:
64
- label = id2label[label]
65
- else:
66
- label = label.lower()
67
-
68
- output[label] = score
69
-
70
- return output
71
-
72
-
73
- # ----------------------------
74
- # Main function
75
- # ----------------------------
76
-
77
- def classify_cat(image):
78
- # Custom Model
79
- vit_results = vit_classifier(image)
80
- vit_output = normalize_custom_labels(vit_results)
81
-
82
- # CLIP
83
- clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
84
- clip_results = clip_classifier(image, candidate_labels=clip_labels)
85
-
86
- clip_output = {}
87
- for r in clip_results:
88
- label = r["label"].replace("a photo of a ", "").lower()
89
- score = float(r["score"])
90
- clip_output[label] = score
91
-
92
- return vit_output, clip_output
93
-
94
-
95
- # ----------------------------
96
- # Example images
97
- # ----------------------------
98
-
99
- example_images = [
100
- [str(EXAMPLE_DIR / "Cheetah_032.jpg")],
101
- [str(EXAMPLE_DIR / "Leopard_001.jpg")],
102
- [str(EXAMPLE_DIR / "Lion_003.jpg")],
103
- [str(EXAMPLE_DIR / "Puma_001.jpg")],
104
- [str(EXAMPLE_DIR / "Tiger_001.jpg")]
105
- ]
106
-
107
-
108
- # ----------------------------
109
- # Interface
110
- # ----------------------------
111
-
112
- iface = gr.Interface(
113
- fn=classify_cat,
114
- inputs=gr.Image(type="filepath"),
115
- outputs=[
116
- gr.Label(label="Custom Model"),
117
- gr.Label(label="CLIP")
118
- ],
119
- title="Big Cat Classification",
120
- description="Compare Custom Model vs CLIP",
121
- examples=example_images
122
- )
123
-
124
- if __name__ == "__main__":
125
  iface.launch()
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ from transformers import pipeline
6
+
7
+
8
+ # ----------------------------
9
+ # Paths
10
+ # ----------------------------
11
+
12
+ BASE_DIR = Path(__file__).resolve().parent
13
+
14
+ # HIER ggf. den Modellordner anpassen
15
+ MODEL_PATH = BASE_DIR.parent / "cat-vit"
16
+
17
+ EXAMPLE_DIR = BASE_DIR / "example_images"
18
+
19
+
20
+ # ----------------------------
21
+ # Labels
22
+ # ----------------------------
23
+
24
+ CAT_LABELS = ["cheetah", "leopard", "lion", "puma", "tiger"]
25
+
26
+
27
+ # ----------------------------
28
+ # Load models
29
+ # ----------------------------
30
+
31
+ print("Loading custom model...")
32
+ vit_classifier = pipeline(
33
+ "image-classification",
34
+ model=str(MODEL_PATH)
35
+ )
36
+
37
+ print("Loading CLIP model...")
38
+ clip_classifier = pipeline(
39
+ task="zero-shot-image-classification",
40
+ model="openai/clip-vit-base-patch32"
41
+ )
42
+
43
+
44
+ # ----------------------------
45
+ # Helper functions
46
+ # ----------------------------
47
+
48
+ def normalize_custom_labels(results):
49
+ id2label = {
50
+ "LABEL_0": "cheetah",
51
+ "LABEL_1": "leopard",
52
+ "LABEL_2": "lion",
53
+ "LABEL_3": "puma",
54
+ "LABEL_4": "tiger",
55
+ }
56
+
57
+ output = {}
58
+
59
+ for r in results:
60
+ label = r["label"]
61
+ score = float(r["score"])
62
+
63
+ if label in id2label:
64
+ label = id2label[label]
65
+ else:
66
+ label = label.lower()
67
+
68
+ output[label] = score
69
+
70
+ return output
71
+
72
+
73
+ # ----------------------------
74
+ # Main function
75
+ # ----------------------------
76
+
77
+ def classify_cat(image):
78
+ # Custom Model
79
+ vit_results = vit_classifier(image)
80
+ vit_output = normalize_custom_labels(vit_results)
81
+
82
+ # CLIP
83
+ clip_labels = [f"a photo of a {label}" for label in CAT_LABELS]
84
+ clip_results = clip_classifier(image, candidate_labels=clip_labels)
85
+
86
+ clip_output = {}
87
+ for r in clip_results:
88
+ label = r["label"].replace("a photo of a ", "").lower()
89
+ score = float(r["score"])
90
+ clip_output[label] = score
91
+
92
+ return vit_output, clip_output
93
+
94
+
95
+ # ----------------------------
96
+ # Example images
97
+ # ----------------------------
98
+
99
+ example_images = [
100
+ [str(EXAMPLE_DIR / "Cheetah_032.jpg")],
101
+ [str(EXAMPLE_DIR / "Leopard_001.jpg")],
102
+ [str(EXAMPLE_DIR / "Lion_003.jpg")],
103
+ [str(EXAMPLE_DIR / "Puma_001.jpg")],
104
+ [str(EXAMPLE_DIR / "Tiger_001.jpg")]
105
+ ]
106
+
107
+
108
+ # ----------------------------
109
+ # Interface
110
+ # ----------------------------
111
+
112
+ iface = gr.Interface(
113
+ fn=classify_cat,
114
+ inputs=gr.Image(type="filepath"),
115
+ outputs=[
116
+ gr.Label(label="Custom Model"),
117
+ gr.Label(label="CLIP")
118
+ ],
119
+ title="Big Cat Classification",
120
+ description="Compare Custom Model vs CLIP",
121
+ examples=example_images
122
+ )
123
+
124
+ if __name__ == "__main__":
125
  iface.launch()