anyonehomep1mane commited on
Commit
5aa6736
·
1 Parent(s): e344222

Code Changes

Browse files
app.py CHANGED
@@ -1,125 +1,24 @@
1
- import torch
2
- from transformers import AutoModel, AutoProcessor
3
- import gradio as gr
4
- from PIL import Image
5
- from gradio.themes import Soft
6
- from gradio.themes.utils import colors, fonts, sizes
7
 
8
- import warnings
9
- warnings.filterwarnings(action="ignore")
10
 
11
- colors.orange_red = colors.Color(
12
- name="orange_red",
13
- c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
14
- c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
15
- c800="#B33000", c900="#992900", c950="#802200",
16
- )
17
 
18
- class OrangeRedTheme(Soft):
19
- def __init__(self):
20
- super().__init__(
21
- primary_hue=colors.orange_red,
22
- secondary_hue=colors.orange_red,
23
- neutral_hue=colors.slate,
24
- text_size=sizes.text_lg,
25
- font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
26
- font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
27
- )
28
- super().set(
29
- body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
30
- button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
31
- button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
32
- button_primary_text_color="white",
33
- block_border_width="3px",
34
- block_shadow="*shadow_drop_lg",
35
- )
36
-
37
- orange_red_theme = OrangeRedTheme()
38
-
39
- MODEL_ID = "openai/clip-vit-base-patch32"
40
- model = AutoModel.from_pretrained(
41
- MODEL_ID,
42
- torch_dtype=torch.bfloat16,
43
- attn_implementation="sdpa"
44
- )
45
- processor = AutoProcessor.from_pretrained(MODEL_ID)
46
-
47
- device = "cuda" if torch.cuda.is_available() else "cpu"
48
- model = model.to(device)
49
-
50
- def postprocess_metaclip(probs, labels):
51
- return {labels[i]: probs[0][i].item() for i in range(len(labels))}
52
-
53
- def metaclip_detector(image, texts):
54
- inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
55
- with torch.no_grad():
56
- outputs = model(**inputs)
57
- probs = outputs.logits_per_image.softmax(dim=1)
58
- return probs
59
-
60
- def infer(image, candidate_labels):
61
- candidate_labels = [l.strip() for l in candidate_labels.split(",")]
62
- probs = metaclip_detector(image, candidate_labels)
63
- return postprocess_metaclip(probs, labels=candidate_labels)
64
-
65
- css_style = """
66
- #container {
67
- max-width: 1280px; /* wider layout */
68
- margin: auto;
69
- }
70
-
71
- @media (min-width: 1600px) {
72
- #container {
73
- max-width: 1440px;
74
- }
75
- }
76
-
77
- #title h1 {
78
- font-size: 2.4em !important;
79
- }
80
- """
81
-
82
- with gr.Blocks(title="AI Document Summarizer") as demo:
83
- with gr.Column(elem_id="container"):
84
-
85
- gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
86
- gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.")
87
-
88
- with gr.Row(equal_height=True):
89
- with gr.Column():
90
- image_input = gr.Image(type="pil", label="Upload Image", height=310)
91
- text_input = gr.Textbox(label="Input labels (comma separated)")
92
- run_button = gr.Button("Run", variant="primary")
93
- with gr.Column():
94
- metaclip_output = gr.Label(
95
- label="Open AI Zero-Shot Classification Output",
96
- num_top_classes=5
97
- )
98
-
99
- with gr.Row(equal_height=True):
100
- gr.Examples(
101
- examples=[
102
- ["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"],
103
- ["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"],
104
- ["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"]
105
- ],
106
- inputs=[image_input, text_input],
107
- outputs=[metaclip_output],
108
- fn=infer,
109
- )
110
-
111
- run_button.click(
112
- fn=infer,
113
- inputs=[image_input, text_input],
114
- outputs=[metaclip_output]
115
- )
116
-
117
- if __name__ == "__main__":
118
  demo.queue().launch(
119
- theme=orange_red_theme,
120
- css=css_style,
121
  show_error=True,
122
  server_name="0.0.0.0",
123
  server_port=7860,
124
  debug=True
125
- )
 
 
 
 
1
+ from utils.warnings import suppress_warnings
2
+ from core.model_loader import load_model
3
+ from ui.theme import OrangeRedTheme
4
+ from ui.styles import CSS_STYLE
5
+ from ui.layout import build_ui
 
6
 
7
+ def main():
8
+ suppress_warnings()
9
 
10
+ model, processor = load_model()
11
+ theme = OrangeRedTheme()
 
 
 
 
12
 
13
+ demo = build_ui(model, processor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  demo.queue().launch(
15
+ theme=theme,
16
+ css=CSS_STYLE,
17
  show_error=True,
18
  server_name="0.0.0.0",
19
  server_port=7860,
20
  debug=True
21
+ )
22
+
23
+ if __name__ == "__main__":
24
+ main()
cat.jpg → assets/cat.jpg RENAMED
File without changes
fridge.jpg → assets/fridge.jpg RENAMED
File without changes
zebra.jpg → assets/zebra.jpg RENAMED
File without changes
config/__pycache__/settings.cpython-310.pyc ADDED
Binary file (361 Bytes). View file
 
config/settings.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ MODEL_ID = "openai/clip-vit-base-patch32"
4
+
5
+ TORCH_DTYPE = torch.bfloat16
6
+ ATTN_IMPLEMENTATION = "sdpa"
7
+
8
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
core/__pycache__/inference.cpython-310.pyc ADDED
Binary file (1.3 kB). View file
 
core/__pycache__/model_loader.cpython-310.pyc ADDED
Binary file (652 Bytes). View file
 
core/inference.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from config.settings import DEVICE
3
+
4
+ def post_processed_probs(probs, labels):
5
+ return {labels[i]: probs[0][i].item() for i in range(len(labels))}
6
+
7
+ def generate_ouput(model, processor, image, texts):
8
+ inputs = processor(
9
+ text=texts,
10
+ images=image,
11
+ return_tensors="pt",
12
+ padding=True
13
+ ).to(DEVICE)
14
+
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+ probs = outputs.logits_per_image.softmax(dim=1)
18
+
19
+ return probs
20
+
21
+ def infer(model, processor, image, candidate_labels):
22
+ labels = [l.strip() for l in candidate_labels.split(",")]
23
+ probs = generate_ouput(model, processor, image, labels)
24
+ return post_processed_probs(probs, labels)
core/model_loader.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoProcessor
3
+ from config.settings import MODEL_ID, TORCH_DTYPE, ATTN_IMPLEMENTATION, DEVICE
4
+
5
+ def load_model():
6
+ model = AutoModel.from_pretrained(
7
+ MODEL_ID,
8
+ torch_dtype=TORCH_DTYPE,
9
+ attn_implementation=ATTN_IMPLEMENTATION
10
+ )
11
+ model = model.to(DEVICE)
12
+ model.eval()
13
+
14
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
15
+ return model, processor
ui/__pycache__/layout.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
ui/__pycache__/styles.cpython-310.pyc ADDED
Binary file (367 Bytes). View file
 
ui/__pycache__/theme.cpython-310.pyc ADDED
Binary file (1.5 kB). View file
 
ui/layout.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.inference import infer
3
+
4
+ def build_ui(model, processor):
5
+ with gr.Blocks(title="AI Document Summarizer") as demo:
6
+ with gr.Column(elem_id="container"):
7
+ gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
8
+ gr.Markdown(
9
+ "This is the demo of model **openai/clip-vit-base-patch32** "
10
+ "for zero-shot image classification."
11
+ )
12
+
13
+ with gr.Row(equal_height=True):
14
+ with gr.Column():
15
+ image_input = gr.Image(type="pil", label="Upload Image", height=310)
16
+ text_input = gr.Textbox(label="Input labels (comma separated)")
17
+ run_button = gr.Button("Run", variant="primary")
18
+
19
+ with gr.Column():
20
+ output = gr.Label(
21
+ label="Open AI Zero-Shot Classification Output",
22
+ num_top_classes=5
23
+ )
24
+
25
+ with gr.Row(equal_height=True):
26
+ gr.Examples(
27
+ examples=[
28
+ ["./assets/zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"],
29
+ ["./assets/cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"],
30
+ ["./assets/fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"]
31
+ ],
32
+ inputs=[image_input, text_input],
33
+ outputs=[output],
34
+ fn=lambda img, txt: infer(model, processor, img, txt)
35
+ )
36
+
37
+ run_button.click(
38
+ fn=lambda img, txt: infer(model, processor, img, txt),
39
+ inputs=[image_input, text_input],
40
+ outputs=[output]
41
+ )
42
+
43
+ return demo
ui/styles.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CSS_STYLE = """
2
+ #container {
3
+ max-width: 1280px;
4
+ margin: auto;
5
+ }
6
+
7
+ @media (min-width: 1600px) {
8
+ #container {
9
+ max-width: 1440px;
10
+ }
11
+ }
12
+
13
+ #title h1 {
14
+ font-size: 2.4em !important;
15
+ }
16
+ """
ui/theme.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio.themes import Soft
2
+ from gradio.themes.utils import colors, fonts, sizes
3
+
4
+ colors.orange_red = colors.Color(
5
+ name="orange_red",
6
+ c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
7
+ c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
8
+ c800="#B33000", c900="#992900", c950="#802200",
9
+ )
10
+
11
+ class OrangeRedTheme(Soft):
12
+ def __init__(self):
13
+ super().__init__(
14
+ primary_hue=colors.orange_red,
15
+ secondary_hue=colors.orange_red,
16
+ neutral_hue=colors.slate,
17
+ text_size=sizes.text_lg,
18
+ font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
19
+ font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
20
+ )
21
+ super().set(
22
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
23
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
24
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
25
+ button_primary_text_color="white",
26
+ block_border_width="3px",
27
+ block_shadow="*shadow_drop_lg",
28
+ )
utils/__pycache__/warnings.cpython-310.pyc ADDED
Binary file (351 Bytes). View file
 
utils/warnings.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import warnings
2
+
3
+ def suppress_warnings():
4
+ warnings.filterwarnings(action="ignore")
version_one_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModel, AutoProcessor
3
+ import gradio as gr
4
+ from PIL import Image
5
+ from gradio.themes import Soft
6
+ from gradio.themes.utils import colors, fonts, sizes
7
+
8
+ import warnings
9
+ warnings.filterwarnings(action="ignore")
10
+
11
+ colors.orange_red = colors.Color(
12
+ name="orange_red",
13
+ c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
14
+ c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
15
+ c800="#B33000", c900="#992900", c950="#802200",
16
+ )
17
+
18
+ class OrangeRedTheme(Soft):
19
+ def __init__(self):
20
+ super().__init__(
21
+ primary_hue=colors.orange_red,
22
+ secondary_hue=colors.orange_red,
23
+ neutral_hue=colors.slate,
24
+ text_size=sizes.text_lg,
25
+ font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
26
+ font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
27
+ )
28
+ super().set(
29
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
30
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
31
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
32
+ button_primary_text_color="white",
33
+ block_border_width="3px",
34
+ block_shadow="*shadow_drop_lg",
35
+ )
36
+
37
+ orange_red_theme = OrangeRedTheme()
38
+
39
+ MODEL_ID = "openai/clip-vit-base-patch32"
40
+ model = AutoModel.from_pretrained(
41
+ MODEL_ID,
42
+ torch_dtype=torch.bfloat16,
43
+ attn_implementation="sdpa"
44
+ )
45
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ model = model.to(device)
49
+
50
+ def postprocess_metaclip(probs, labels):
51
+ return {labels[i]: probs[0][i].item() for i in range(len(labels))}
52
+
53
+ def metaclip_detector(image, texts):
54
+ inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+ probs = outputs.logits_per_image.softmax(dim=1)
58
+ return probs
59
+
60
+ def infer(image, candidate_labels):
61
+ candidate_labels = [l.strip() for l in candidate_labels.split(",")]
62
+ probs = metaclip_detector(image, candidate_labels)
63
+ return postprocess_metaclip(probs, labels=candidate_labels)
64
+
65
+ css_style = """
66
+ #container {
67
+ max-width: 1280px; /* wider layout */
68
+ margin: auto;
69
+ }
70
+
71
+ @media (min-width: 1600px) {
72
+ #container {
73
+ max-width: 1440px;
74
+ }
75
+ }
76
+
77
+ #title h1 {
78
+ font-size: 2.4em !important;
79
+ }
80
+ """
81
+
82
+ with gr.Blocks(title="AI Document Summarizer") as demo:
83
+ with gr.Column(elem_id="container"):
84
+
85
+ gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
86
+ gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.")
87
+
88
+ with gr.Row(equal_height=True):
89
+ with gr.Column():
90
+ image_input = gr.Image(type="pil", label="Upload Image", height=310)
91
+ text_input = gr.Textbox(label="Input labels (comma separated)")
92
+ run_button = gr.Button("Run", variant="primary")
93
+ with gr.Column():
94
+ metaclip_output = gr.Label(
95
+ label="Open AI Zero-Shot Classification Output",
96
+ num_top_classes=5
97
+ )
98
+
99
+ with gr.Row(equal_height=True):
100
+ gr.Examples(
101
+ examples=[
102
+ ["./zebra.jpg", "a photo of a zebra, a photo of a horse, a photo of a donkey"],
103
+ ["./cat.jpg", "a photo of a cat, a photo of two cats, a photo of three cats"],
104
+ ["./fridge.jpg", "a photo of a fridge, a photo of a cupboard, a photo of a wardrobe"]
105
+ ],
106
+ inputs=[image_input, text_input],
107
+ outputs=[metaclip_output],
108
+ fn=infer,
109
+ )
110
+
111
+ run_button.click(
112
+ fn=infer,
113
+ inputs=[image_input, text_input],
114
+ outputs=[metaclip_output]
115
+ )
116
+
117
+ if __name__ == "__main__":
118
+ demo.queue().launch(
119
+ theme=orange_red_theme,
120
+ css=css_style,
121
+ show_error=True,
122
+ server_name="0.0.0.0",
123
+ server_port=7860,
124
+ debug=True
125
+ )