anyonehomep1mane commited on
Commit
1d7d4a2
·
0 Parent(s):

Initial Changes

Browse files
Files changed (4) hide show
  1. .gitattributes +37 -0
  2. README.md +14 -0
  3. app.py +135 -0
  4. requirements.txt +8 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ baklava.jpg filter=lfs diff=lfs merge=lfs -text
37
+ cat.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Open AI Zero Shot Image Classification
3
+ emoji: 👁
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 6.5.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: Image classification tasks in a zero-shot manner
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPProcessor, CLIPModel
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import requests
6
+ from typing import Iterable
7
+
8
+ from gradio.themes import Soft
9
+ from gradio.themes.utils import colors, fonts, sizes
10
+
11
+ import warnings
12
+ warnings.filterwarnings(action="ignore")
13
+
14
+ from pathlib import Path
15
+
16
+ BASE_DIR = Path(__file__).parent
17
+ ASSETS_DIR = BASE_DIR / "images"
18
+
19
+ colors.orange_red = colors.Color(
20
+ name="orange_red",
21
+ c50="#FFF0E5",
22
+ c100="#FFE0CC",
23
+ c200="#FFC299",
24
+ c300="#FFA366",
25
+ c400="#FF8533",
26
+ c500="#FF4500",
27
+ c600="#E63E00",
28
+ c700="#CC3700",
29
+ c800="#B33000",
30
+ c900="#992900",
31
+ c950="#802200",
32
+ )
33
+
34
+ class OrangeRedTheme(Soft):
35
+ def __init__(
36
+ self,
37
+ *,
38
+ primary_hue: colors.Color | str = colors.gray,
39
+ secondary_hue: colors.Color | str = colors.orange_red,
40
+ neutral_hue: colors.Color | str = colors.slate,
41
+ text_size: sizes.Size | str = sizes.text_lg,
42
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
43
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
44
+ ),
45
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
46
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
47
+ ),
48
+ ):
49
+ super().__init__(
50
+ primary_hue=primary_hue,
51
+ secondary_hue=secondary_hue,
52
+ neutral_hue=neutral_hue,
53
+ text_size=text_size,
54
+ font=font,
55
+ font_mono=font_mono,
56
+ )
57
+ super().set(
58
+ background_fill_primary="*primary_50",
59
+ background_fill_primary_dark="*primary_900",
60
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
61
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
62
+ button_primary_text_color="white",
63
+ button_primary_text_color_hover="white",
64
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
65
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
66
+ block_title_text_weight="600",
67
+ block_shadow="*shadow_drop_lg",
68
+ )
69
+
70
+ orange_red_theme = OrangeRedTheme()
71
+
72
+ MODEL_ID = "openai/clip-vit-base-patch32"
73
+ model = CLIPModel.from_pretrained(MODEL_ID)
74
+ processor = CLIPProcessor.from_pretrained(MODEL_ID)
75
+
76
+ def postprocess_metaclip(probs, labels):
77
+ return {labels[i]: probs[0][i].item() for i in range(len(labels))}
78
+
79
+ def metaclip_detector(image, texts):
80
+ inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
81
+ with torch.no_grad():
82
+ outputs = model(**inputs)
83
+ probs = outputs.logits_per_image.softmax(dim=1)
84
+ return probs
85
+
86
+ def infer(image, candidate_labels):
87
+ candidate_labels = [l.strip() for l in candidate_labels.split(",")]
88
+ probs = metaclip_detector(image, candidate_labels)
89
+ return postprocess_metaclip(probs, labels=candidate_labels)
90
+
91
+ css = """
92
+ #root, body, html {
93
+ margin: 0;
94
+ padding: 0;
95
+ height: 100%;
96
+ }
97
+
98
+ .center-container {
99
+ max-width: 1000px;
100
+ margin: 0 auto !important;
101
+ display: flex;
102
+ flex-direction: column;
103
+ align-items: center;
104
+ }
105
+
106
+ #main-title h1 {
107
+ text-align: center !important;
108
+ width: 100%;
109
+ }
110
+ """
111
+
112
+ with gr.Blocks(css=css, theme=orange_red_theme) as demo:
113
+ with gr.Column(elem_classes="center-container"):
114
+
115
+ gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**", elem_id="main-title")
116
+ gr.Markdown("This is the demo of MetaCLIP 2 for zero-shot classification.")
117
+
118
+ with gr.Row():
119
+ with gr.Column():
120
+ image_input = gr.Image(type="filepath", label="Upload Image", height=310)
121
+ text_input = gr.Textbox(label="Input labels (comma separated)")
122
+ run_button = gr.Button("Run", variant="primary")
123
+ with gr.Column():
124
+ metaclip_output = gr.Label(
125
+ label="MetaCLIP 2 Output",
126
+ num_top_classes=3
127
+ )
128
+
129
+ run_button.click(
130
+ fn=infer,
131
+ inputs=[image_input, text_input],
132
+ outputs=[metaclip_output]
133
+ )
134
+
135
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ huggingface-hub
2
+ sentencepiece
3
+ transformers
4
+ accelerate
5
+ protobuf
6
+ pillow
7
+ gradio
8
+ torch