anyonehomep1mane commited on
Commit
aacb585
·
1 Parent(s): 1d7d4a2

Code Changes

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +61 -71
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ .vscode
3
+ venv
app.py CHANGED
@@ -1,77 +1,51 @@
1
  import torch
2
- from transformers import CLIPProcessor, CLIPModel
3
  import gradio as gr
4
  from PIL import Image
5
- import requests
6
- from typing import Iterable
7
-
8
  from gradio.themes import Soft
9
  from gradio.themes.utils import colors, fonts, sizes
10
 
11
  import warnings
12
  warnings.filterwarnings(action="ignore")
13
 
14
- from pathlib import Path
15
-
16
- BASE_DIR = Path(__file__).parent
17
- ASSETS_DIR = BASE_DIR / "images"
18
-
19
  colors.orange_red = colors.Color(
20
  name="orange_red",
21
- c50="#FFF0E5",
22
- c100="#FFE0CC",
23
- c200="#FFC299",
24
- c300="#FFA366",
25
- c400="#FF8533",
26
- c500="#FF4500",
27
- c600="#E63E00",
28
- c700="#CC3700",
29
- c800="#B33000",
30
- c900="#992900",
31
- c950="#802200",
32
  )
33
 
34
  class OrangeRedTheme(Soft):
35
- def __init__(
36
- self,
37
- *,
38
- primary_hue: colors.Color | str = colors.gray,
39
- secondary_hue: colors.Color | str = colors.orange_red,
40
- neutral_hue: colors.Color | str = colors.slate,
41
- text_size: sizes.Size | str = sizes.text_lg,
42
- font: fonts.Font | str | Iterable[fonts.Font | str] = (
43
- fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
44
- ),
45
- font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
46
- fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
47
- ),
48
- ):
49
  super().__init__(
50
- primary_hue=primary_hue,
51
- secondary_hue=secondary_hue,
52
- neutral_hue=neutral_hue,
53
- text_size=text_size,
54
- font=font,
55
- font_mono=font_mono,
56
  )
57
  super().set(
58
- background_fill_primary="*primary_50",
59
- background_fill_primary_dark="*primary_900",
60
  body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
61
- body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
62
- button_primary_text_color="white",
63
- button_primary_text_color_hover="white",
64
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
65
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
66
- block_title_text_weight="600",
 
67
  block_shadow="*shadow_drop_lg",
68
  )
69
 
70
  orange_red_theme = OrangeRedTheme()
71
 
72
  MODEL_ID = "openai/clip-vit-base-patch32"
73
- model = CLIPModel.from_pretrained(MODEL_ID)
74
- processor = CLIPProcessor.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
75
 
76
  def postprocess_metaclip(probs, labels):
77
  return {labels[i]: probs[0][i].item() for i in range(len(labels))}
@@ -88,48 +62,64 @@ def infer(image, candidate_labels):
88
  probs = metaclip_detector(image, candidate_labels)
89
  return postprocess_metaclip(probs, labels=candidate_labels)
90
 
91
- css = """
92
- #root, body, html {
93
- margin: 0;
94
- padding: 0;
95
- height: 100%;
96
  }
97
 
98
- .center-container {
99
- max-width: 1000px;
100
- margin: 0 auto !important;
101
- display: flex;
102
- flex-direction: column;
103
- align-items: center;
104
  }
105
 
106
- #main-title h1 {
107
- text-align: center !important;
108
- width: 100%;
109
  }
110
  """
111
 
112
- with gr.Blocks(css=css, theme=orange_red_theme) as demo:
113
- with gr.Column(elem_classes="center-container"):
114
 
115
- gr.Markdown("# **MetaCLIP 2 Zero-Shot Classification**", elem_id="main-title")
116
- gr.Markdown("This is the demo of MetaCLIP 2 for zero-shot classification.")
117
 
118
- with gr.Row():
119
  with gr.Column():
120
- image_input = gr.Image(type="filepath", label="Upload Image", height=310)
121
  text_input = gr.Textbox(label="Input labels (comma separated)")
122
  run_button = gr.Button("Run", variant="primary")
123
  with gr.Column():
124
  metaclip_output = gr.Label(
125
- label="MetaCLIP 2 Output",
126
- num_top_classes=3
127
  )
128
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  run_button.click(
130
  fn=infer,
131
  inputs=[image_input, text_input],
132
  outputs=[metaclip_output]
133
  )
134
 
135
- demo.launch()
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModel, AutoProcessor
3
  import gradio as gr
4
  from PIL import Image
 
 
 
5
  from gradio.themes import Soft
6
  from gradio.themes.utils import colors, fonts, sizes
7
 
8
  import warnings
9
  warnings.filterwarnings(action="ignore")
10
 
 
 
 
 
 
11
  colors.orange_red = colors.Color(
12
  name="orange_red",
13
+ c50="#FFF0E5", c100="#FFE0CC", c200="#FFC299", c300="#FFA366",
14
+ c400="#FF8533", c500="#FF4500", c600="#E63E00", c700="#CC3700",
15
+ c800="#B33000", c900="#992900", c950="#802200",
 
 
 
 
 
 
 
 
16
  )
17
 
18
  class OrangeRedTheme(Soft):
19
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  super().__init__(
21
+ primary_hue=colors.orange_red,
22
+ secondary_hue=colors.orange_red,
23
+ neutral_hue=colors.slate,
24
+ text_size=sizes.text_lg,
25
+ font=(fonts.GoogleFont("Outfit"), "Arial", "sans-serif"),
26
+ font_mono=(fonts.GoogleFont("IBM Plex Mono"), "monospace"),
27
  )
28
  super().set(
 
 
29
  body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
 
 
 
30
  button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
31
  button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
32
+ button_primary_text_color="white",
33
+ block_border_width="3px",
34
  block_shadow="*shadow_drop_lg",
35
  )
36
 
37
  orange_red_theme = OrangeRedTheme()
38
 
39
  MODEL_ID = "openai/clip-vit-base-patch32"
40
+ model = AutoModel.from_pretrained(
41
+ MODEL_ID,
42
+ torch_dtype=torch.bfloat16,
43
+ attn_implementation="sdpa"
44
+ )
45
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
46
+
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ model = model.to(device)
49
 
50
  def postprocess_metaclip(probs, labels):
51
  return {labels[i]: probs[0][i].item() for i in range(len(labels))}
 
62
  probs = metaclip_detector(image, candidate_labels)
63
  return postprocess_metaclip(probs, labels=candidate_labels)
64
 
65
+ css_style = """
66
+ #container {
67
+ max-width: 1280px; /* wider layout */
68
+ margin: auto;
 
69
  }
70
 
71
+ @media (min-width: 1600px) {
72
+ #container {
73
+ max-width: 1440px;
74
+ }
 
 
75
  }
76
 
77
+ #title h1 {
78
+ font-size: 2.4em !important;
 
79
  }
80
  """
81
 
82
+ with gr.Blocks(title="AI Document Summarizer") as demo:
83
+ with gr.Column(elem_id="container"):
84
 
85
+ gr.Markdown("# **Open AI Zero-Shot Classification**", elem_id="title")
86
+ gr.Markdown("This is the demo of model 'openai/clip-vit-base-patch32' for zero-shot classification.")
87
 
88
+ with gr.Row(equal_height=True):
89
  with gr.Column():
90
+ image_input = gr.Image(type="pil", label="Upload Image", height=310)
91
  text_input = gr.Textbox(label="Input labels (comma separated)")
92
  run_button = gr.Button("Run", variant="primary")
93
  with gr.Column():
94
  metaclip_output = gr.Label(
95
+ label="Open AI Zero-Shot Classification Output",
96
+ num_top_classes=5
97
  )
98
 
99
+ # with gr.Row(equal_height=True):
100
+ # gr.Examples(
101
+ # examples=[
102
+ # ["./baklava.jpg", "dessert on a plate, baklava"],
103
+ # ["./cat.jpg", "a cat, two cats, three cats"],
104
+ # ["./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"],
105
+ # ],
106
+ # inputs=[image_input, text_input],
107
+ # outputs=[metaclip_output],
108
+ # fn=infer,
109
+ # )
110
+
111
  run_button.click(
112
  fn=infer,
113
  inputs=[image_input, text_input],
114
  outputs=[metaclip_output]
115
  )
116
 
117
+ if __name__ == "__main__":
118
+ demo.queue().launch(
119
+ theme=orange_red_theme,
120
+ css=css_style,
121
+ show_error=True,
122
+ server_name="0.0.0.0",
123
+ server_port=7860,
124
+ debug=True
125
+ )