Tulitula commited on
Commit
cb67003
·
verified ·
1 Parent(s): 169b299

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -97
app.py CHANGED
@@ -2,59 +2,80 @@ import re
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
- from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq
6
 
7
  # Auto-detect CPU/GPU
8
  device = 0 if torch.cuda.is_available() else -1
9
 
10
- # 1) BLIP captioner
11
- processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
12
- model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large")
13
- caption_pipe = pipeline(
14
- "image-to-text",
15
- model=model,
16
- processor=processor,
17
- device=device
18
- )
19
-
20
- # 2) Flan-T5 for text-to-text
21
- FLAN = "google/flan-t5-large"
22
- category_pipe = pipeline(
23
- "text2text-generation",
24
- model=FLAN,
25
- tokenizer=FLAN,
26
- device=device,
27
- max_new_tokens=32,
28
- do_sample=True,
29
- temperature=1.0,
30
- )
31
- analysis_pipe = pipeline(
32
- "text2text-generation",
33
- model=FLAN,
34
- tokenizer=FLAN,
35
- device=device,
36
- max_new_tokens=256,
37
- do_sample=True,
38
- temperature=1.0,
39
- )
40
- suggestion_pipe = pipeline(
41
- "text2text-generation",
42
- model=FLAN,
43
- tokenizer=FLAN,
44
- device=device,
45
- max_new_tokens=256,
46
- do_sample=True,
47
- temperature=1.0,
48
- )
49
- # Expander when BLIP caption is too short
50
- expansion_pipe = pipeline(
51
- "text2text-generation",
52
- model=FLAN,
53
- tokenizer=FLAN,
54
- device=device,
55
- max_new_tokens=128,
56
- do_sample=False,
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # Example gallery helper returns 10 example ad URLs
60
  def get_recommendations():
@@ -71,53 +92,67 @@ def get_recommendations():
71
  "https://i.imgur.com/Xj92Cjv.jpeg",
72
  ]
73
 
74
- # Main processing function
75
  def process(image: Image):
76
- # 1) BLIP caption
77
- caption = caption_pipe(image, max_new_tokens=64, do_sample=False)[0]['generated_text'].strip()
78
-
79
- # 1a) Expand caption if too short
80
- if len(caption.split()) < 3:
81
- desc = expansion_pipe(f"Expand into a detailed description: {caption}")[0]['generated_text'].strip()
82
- else:
83
- desc = caption
84
-
85
- # 2) Ad category
86
- cat_prompt = (
87
- f"Description: {desc}\n\n"
88
- "Provide a concise category label for this ad (e.g. 'Food', 'Fitness'):"
89
- )
90
- category = category_pipe(cat_prompt)[0]['generated_text'].splitlines()[0].strip()
 
 
 
 
 
 
 
91
 
92
- # 3) Five-sentence analysis
93
- ana_prompt = (
94
- f"Description: {desc}\n\n"
95
- "Write exactly five sentences explaining what this ad communicates and its emotional impact."
96
- )
97
- raw_ana = analysis_pipe(ana_prompt)[0]['generated_text'].strip()
98
- sentences = re.split(r'(?<=[.!?])\s+', raw_ana)
99
- analysis = " ".join(sentences[:5])
100
-
101
- # 4) Five bullet-point suggestions
102
- sug_prompt = (
103
- f"Description: {desc}\n\n"
104
- "Suggest five distinct improvements for this ad. Each must start with '- ' and be one sentence."
105
- )
106
- raw_sug = suggestion_pipe(sug_prompt)[0]['generated_text'].strip()
107
- bullets = [l for l in raw_sug.splitlines() if l.startswith('-')]
108
- if len(bullets) < 5:
109
- extra_lines = [l for l in raw_sug.splitlines() if l.strip()]
110
- for line in extra_lines:
111
- if len(bullets) >= 5:
112
- break
113
- bullets.append(line if line.startswith('-') else '- ' + line)
114
- suggestions = '\n'.join(bullets[:5])
115
-
116
- return caption, category, analysis, suggestions, get_recommendations()
117
-
118
- # Gradio UI
 
 
 
 
 
 
 
119
  def main():
120
- with gr.Blocks() as demo:
121
  gr.Markdown("## 📢 Smart Ad Analyzer")
122
  gr.Markdown(
123
  "Upload an image ad to get:\n"
@@ -137,7 +172,7 @@ def main():
137
  sug_out = gr.Textbox(label='Improvement Suggestions', lines=5, interactive=False)
138
  btn = gr.Button('Analyze Ad', size='sm', variant='primary')
139
 
140
- gallery = gr.Gallery(label='Example Ads')
141
 
142
  btn.click(
143
  fn=process,
@@ -147,8 +182,8 @@ def main():
147
 
148
  gr.Markdown('Made by Simon Thalmay')
149
 
150
- demo.launch()
151
 
152
  if __name__ == '__main__':
153
- main()
154
-
 
2
  import gradio as gr
3
  import torch
4
  from PIL import Image
5
+ from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq, AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
  # Auto-detect CPU/GPU
8
  device = 0 if torch.cuda.is_available() else -1
9
 
10
+ # 1) BLIP captioner - Fixed tokenizer usage
11
+ try:
12
+ processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
13
+ model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large")
14
+
15
+ caption_pipe = pipeline(
16
+ "image-to-text",
17
+ model=model,
18
+ feature_extractor=processor.feature_extractor,
19
+ tokenizer=processor.tokenizer,
20
+ device=device
21
+ )
22
+ print("✅ BLIP model loaded successfully")
23
+ except Exception as e:
24
+ print(f"❌ Error loading BLIP model: {e}")
25
+ raise
26
+
27
+ # 2) Flan-T5 for text-to-text - Fixed tokenizer initialization
28
+ FLAN_MODEL = "google/flan-t5-large"
29
+ try:
30
+ # Load tokenizer and model separately for better control
31
+ flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL)
32
+ flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL)
33
+
34
+ # Create pipelines with explicit tokenizer
35
+ category_pipe = pipeline(
36
+ "text2text-generation",
37
+ model=flan_model,
38
+ tokenizer=flan_tokenizer,
39
+ device=device,
40
+ max_new_tokens=32,
41
+ do_sample=True,
42
+ temperature=1.0,
43
+ )
44
+
45
+ analysis_pipe = pipeline(
46
+ "text2text-generation",
47
+ model=flan_model,
48
+ tokenizer=flan_tokenizer,
49
+ device=device,
50
+ max_new_tokens=256,
51
+ do_sample=True,
52
+ temperature=1.0,
53
+ )
54
+
55
+ suggestion_pipe = pipeline(
56
+ "text2text-generation",
57
+ model=flan_model,
58
+ tokenizer=flan_tokenizer,
59
+ device=device,
60
+ max_new_tokens=256,
61
+ do_sample=True,
62
+ temperature=1.0,
63
+ )
64
+
65
+ # Expander when BLIP caption is too short
66
+ expansion_pipe = pipeline(
67
+ "text2text-generation",
68
+ model=flan_model,
69
+ tokenizer=flan_tokenizer,
70
+ device=device,
71
+ max_new_tokens=128,
72
+ do_sample=False,
73
+ )
74
+
75
+ print("✅ Flan-T5 model loaded successfully")
76
+ except Exception as e:
77
+ print(f"❌ Error loading Flan-T5 model: {e}")
78
+ raise
79
 
80
  # Example gallery helper returns 10 example ad URLs
81
  def get_recommendations():
 
92
  "https://i.imgur.com/Xj92Cjv.jpeg",
93
  ]
94
 
95
+ # Main processing function with error handling
96
  def process(image: Image):
97
+ try:
98
+ if image is None:
99
+ return "Please upload an image", "", "", "", get_recommendations()
100
+
101
+ # 1) BLIP caption
102
+ caption_result = caption_pipe(image, max_new_tokens=64, do_sample=False)
103
+ caption = caption_result[0]['generated_text'].strip()
104
+
105
+ # 1a) Expand caption if too short
106
+ if len(caption.split()) < 3:
107
+ desc_result = expansion_pipe(f"Expand into a detailed description: {caption}")
108
+ desc = desc_result[0]['generated_text'].strip()
109
+ else:
110
+ desc = caption
111
+
112
+ # 2) Ad category
113
+ cat_prompt = (
114
+ f"Description: {desc}\n\n"
115
+ "Provide a concise category label for this ad (e.g. 'Food', 'Fitness'):"
116
+ )
117
+ category_result = category_pipe(cat_prompt)
118
+ category = category_result[0]['generated_text'].splitlines()[0].strip()
119
 
120
+ # 3) Five-sentence analysis
121
+ ana_prompt = (
122
+ f"Description: {desc}\n\n"
123
+ "Write exactly five sentences explaining what this ad communicates and its emotional impact."
124
+ )
125
+ raw_ana_result = analysis_pipe(ana_prompt)
126
+ raw_ana = raw_ana_result[0]['generated_text'].strip()
127
+ sentences = re.split(r'(?<=[.!?])\s+', raw_ana)
128
+ analysis = " ".join(sentences[:5])
129
+
130
+ # 4) Five bullet-point suggestions
131
+ sug_prompt = (
132
+ f"Description: {desc}\n\n"
133
+ "Suggest five distinct improvements for this ad. Each must start with '- ' and be one sentence."
134
+ )
135
+ raw_sug_result = suggestion_pipe(sug_prompt)
136
+ raw_sug = raw_sug_result[0]['generated_text'].strip()
137
+ bullets = [l for l in raw_sug.splitlines() if l.startswith('-')]
138
+ if len(bullets) < 5:
139
+ extra = [l for l in raw_sug.splitlines() if l.strip()]
140
+ for line in extra:
141
+ if len(bullets) >= 5:
142
+ break
143
+ bullets.append(line if line.startswith('-') else '- ' + line)
144
+ suggestions = '\n'.join(bullets[:5])
145
+
146
+ return caption, category, analysis, suggestions, get_recommendations()
147
+
148
+ except Exception as e:
149
+ error_msg = f"Error processing image: {str(e)}"
150
+ print(error_msg)
151
+ return error_msg, "", "", "", get_recommendations()
152
+
153
+ # Gradio UI definition
154
  def main():
155
+ with gr.Blocks(title="Smart Ad Analyzer") as demo:
156
  gr.Markdown("## 📢 Smart Ad Analyzer")
157
  gr.Markdown(
158
  "Upload an image ad to get:\n"
 
172
  sug_out = gr.Textbox(label='Improvement Suggestions', lines=5, interactive=False)
173
  btn = gr.Button('Analyze Ad', size='sm', variant='primary')
174
 
175
+ gallery = gr.Gallery(label='Example Ads', value=get_recommendations())
176
 
177
  btn.click(
178
  fn=process,
 
182
 
183
  gr.Markdown('Made by Simon Thalmay')
184
 
185
+ return demo
186
 
187
  if __name__ == '__main__':
188
+ demo = main()
189
+ demo.launch()