Tulitula commited on
Commit
8f44c02
·
verified ·
1 Parent(s): b281a55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -146
app.py CHANGED
@@ -1,77 +1,18 @@
1
- import re
2
  import gradio as gr
3
- import torch
4
  from PIL import Image
5
- from transformers import (
6
- pipeline,
7
- AutoProcessor,
8
- AutoModelForVision2Seq,
9
- AutoTokenizer,
10
- AutoModelForSeq2SeqLM,
11
- )
12
-
13
- # Auto-detect CPU/GPU
14
- DEVICE = 0 if torch.cuda.is_available() else -1
15
-
16
- # Load BLIP captioning model
17
- processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
18
- blip_model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large")
19
- caption_pipe = pipeline(
20
- task="image-to-text",
21
- model=blip_model,
22
- tokenizer=processor.tokenizer,
23
- image_processor=processor.image_processor,
24
- device=DEVICE,
25
- )
26
-
27
- # Load Flan-T5 for text-to-text
28
- FLAN_MODEL = "google/flan-t5-large"
29
- flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL)
30
- flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL)
31
-
32
- category_pipe = pipeline(
33
- "text2text-generation",
34
- model=flan_model,
35
- tokenizer=flan_tokenizer,
36
- device=DEVICE,
37
- max_new_tokens=32,
38
- do_sample=True,
39
- temperature=1.0,
40
- )
41
 
42
- analysis_pipe = pipeline(
43
- "text2text-generation",
44
- model=flan_model,
45
- tokenizer=flan_tokenizer,
46
- device=DEVICE,
47
- max_new_tokens=256,
48
- do_sample=True,
49
- temperature=1.0,
50
- )
51
-
52
- suggestion_pipe = pipeline(
53
- "text2text-generation",
54
- model=flan_model,
55
- tokenizer=flan_tokenizer,
56
- device=DEVICE,
57
- max_new_tokens=256,
58
- do_sample=True,
59
- temperature=1.6, # Higher temperature for more variety
60
- top_p=0.95,
61
- )
62
-
63
- expansion_pipe = pipeline(
64
- "text2text-generation",
65
- model=flan_model,
66
- tokenizer=flan_tokenizer,
67
- device=DEVICE,
68
- max_new_tokens=128,
69
- do_sample=True,
70
- temperature=1.0,
71
  )
72
 
73
  def get_recommendations():
74
- # Returns list of 10 example ad image URLs
75
  return [
76
  "https://i.imgur.com/InC88PP.jpeg",
77
  "https://i.imgur.com/7BHfv4T.png",
@@ -85,88 +26,82 @@ def get_recommendations():
85
  "https://i.imgur.com/Xj92Cjv.jpeg",
86
  ]
87
 
88
- def process(image: Image):
89
- if image is None:
90
- return "", "", "", get_recommendations()
91
-
92
- # 1. BLIP caption
93
- caption_res = caption_pipe(image, max_new_tokens=64)
94
- raw_caption = caption_res[0]["generated_text"].strip()
95
-
96
- # 1a. Expand caption if too short
97
- if len(raw_caption.split()) < 3:
98
- exp = expansion_pipe(f"Expand into a detailed description: {raw_caption}")
99
- desc = exp[0]["generated_text"].strip()
100
- else:
101
- desc = raw_caption
102
-
103
- # 2. Category
104
- cat_prompt = (
105
- f"Description: {desc}\n\n"
106
- "Provide a concise category label for this ad (e.g. 'Food', 'Fitness'):"
107
- )
108
- cat_out = category_pipe(cat_prompt)[0]["generated_text"].splitlines()[0].strip()
109
-
110
- # 3. Five-sentence analysis
111
- ana_prompt = (
112
- f"Description: {desc}\n\n"
113
- "Write exactly five sentences explaining what this ad communicates and its emotional impact."
114
  )
115
- ana_raw = analysis_pipe(ana_prompt)[0]["generated_text"].strip()
116
- sentences = re.split(r'(?<=[.!?])\s+', ana_raw)
117
- analysis = " ".join(sentences[:5])
118
-
119
- # 4. Five improvement suggestions (model + fallback defaults, no repeats)
120
- sug_prompt = (
121
- f"Ad description: {desc}\n"
122
- f"Ad analysis: {analysis}\n\n"
123
- "Suggest five ways to improve this ad. Write each suggestion as one practical sentence starting with '- '."
124
- )
125
- sug_raw = suggestion_pipe(sug_prompt)[0]["generated_text"].strip()
126
- all_sugs = [line.strip() for line in sug_raw.splitlines() if line.strip().startswith("-")]
127
-
128
- # Filter exact duplicates, keep order, allow model output first
129
- unique_sugs = []
130
- seen = set()
131
- for s in all_sugs:
132
- norm = s.lower().strip(".:; ")
133
- if norm not in seen and len(norm) > 4:
134
- unique_sugs.append(s)
135
- seen.add(norm)
136
- if len(unique_sugs) == 5:
137
- break
138
- # Add default suggestions only if needed
139
- defaults = [
140
- "- Make the main headline more eye-catching.",
141
- "- Add a clear and visible call-to-action button.",
142
- "- Use contrasting colors for better readability.",
143
- "- Highlight the unique selling point of the product.",
144
- "- Simplify the design to reduce clutter."
145
  ]
146
- for d in defaults:
147
- norm = d.lower().strip(".:; ")
148
- if len(unique_sugs) < 5 and norm not in seen:
149
- unique_sugs.append(d)
150
- seen.add(norm)
151
- suggestions = "\n".join(unique_sugs[:5])
152
-
153
- return cat_out, analysis, suggestions, get_recommendations()
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def main():
156
- with gr.Blocks(title="Smart Ad Analyzer") as demo:
157
- gr.Markdown("## 📢 Smart Ad Analyzer")
158
  gr.Markdown(
159
- """
160
- **Upload your ad image below and instantly get expert feedback.**
161
-
162
- This AI tool will analyze your ad and provide:
163
- - 📂 **Category** — What type of ad is this?
164
- - 📊 **In-depth Analysis** — Five detailed sentences covering message, visuals, emotional impact, and more.
165
- - 🚀 **Improvement Suggestions** — Five actionable, unique ways to make your ad better.
166
- - 📸 **Inspiration Gallery** — See other effective ads for ideas.
167
-
168
- Perfect for marketers, founders, designers, and anyone looking to boost ad performance with actionable insights!
169
- """
170
  )
171
  with gr.Row():
172
  inp = gr.Image(type='pil', label='Upload Ad Image')
@@ -181,7 +116,7 @@ def main():
181
  inputs=[inp],
182
  outputs=[cat_out, ana_out, sug_out, gallery],
183
  )
184
- gr.Markdown('Made by Simon Thalmay')
185
  return demo
186
 
187
  if __name__ == "__main__":
 
1
+ import os
2
  import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
  from PIL import Image
5
+ import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # ---- Gemma-3 setup ----
8
+ client = InferenceClient(
9
+ model="google/gemma-3-4b-it",
10
+ api_key=os.environ.get("HF_TOKEN", None),
11
+ provider="featherless-ai", # or "huggingface_hub"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
 
14
  def get_recommendations():
15
+ # As before: returns list of 10 example ad image URLs
16
  return [
17
  "https://i.imgur.com/InC88PP.jpeg",
18
  "https://i.imgur.com/7BHfv4T.png",
 
26
  "https://i.imgur.com/Xj92Cjv.jpeg",
27
  ]
28
 
29
+ def gemma_image_analysis(image: Image):
30
+ # Upload PIL image to Hugging Face
31
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
32
+ image.save(tmp, format="PNG")
33
+ image_url = client.upload(tmp.name)
34
+ prompt = (
35
+ "You are an expert ad analyst. "
36
+ "Please give a short category for this ad, a detailed analysis of its message, visuals, and emotional impact in five sentences, "
37
+ "and five unique, actionable improvement suggestions (as bullet points), each addressing a different aspect (visuals, message, call-to-action, targeting, or layout). "
38
+ "Output should have clear sections: 'Category', 'Analysis', and 'Suggestions'."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  )
40
+ messages = [
41
+ {
42
+ "role": "system",
43
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
44
+ },
45
+ {
46
+ "role": "user",
47
+ "content": [
48
+ {"type": "image_url", "image_url": {"url": image_url}},
49
+ {"type": "text", "text": prompt}
50
+ ]
51
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  ]
53
+ # API call to Gemma
54
+ result = client.chat.completions.create(
55
+ model="google/gemma-3-4b-it",
56
+ messages=messages,
57
+ max_tokens=500,
58
+ )
59
+ return result.choices[0].message["content"]
 
60
 
61
+ def process(image):
62
+ if image is None:
63
+ return "", "", "", get_recommendations()
64
+ # Use Gemma to get all outputs in one string
65
+ full_output = gemma_image_analysis(image)
66
+ # Parse Gemma's response (very basic, you can make fancier with regex etc)
67
+ # Try to split by headings if present
68
+ category = ""
69
+ analysis = ""
70
+ suggestions = ""
71
+ lines = full_output.splitlines()
72
+ section = None
73
+ for line in lines:
74
+ l = line.strip()
75
+ if l.lower().startswith("category"):
76
+ section = "cat"
77
+ category = ""
78
+ elif l.lower().startswith("analysis"):
79
+ section = "ana"
80
+ analysis = ""
81
+ elif l.lower().startswith("suggestion"):
82
+ section = "sug"
83
+ suggestions = ""
84
+ elif section == "cat":
85
+ category += l + "\n"
86
+ elif section == "ana":
87
+ analysis += l + "\n"
88
+ elif section == "sug":
89
+ suggestions += l + "\n"
90
+ category = category.strip()
91
+ analysis = analysis.strip()
92
+ suggestions = suggestions.strip()
93
+ # If parsing failed, put everything in analysis
94
+ if not (category or analysis or suggestions):
95
+ analysis = full_output.strip()
96
+ return category, analysis, suggestions, get_recommendations()
97
+
98
+ # ---- Gradio UI ----
99
  def main():
100
+ with gr.Blocks(title="Smart Ad Analyzer (Gemma-powered)") as demo:
101
+ gr.Markdown("## 📢 Smart Ad Analyzer (Gemma-3 Edition)")
102
  gr.Markdown(
103
+ "**Upload your ad image below and instantly get expert feedback.**<br>"
104
+ "Category, analysis, improvement suggestions—and example ads for inspiration."
 
 
 
 
 
 
 
 
 
105
  )
106
  with gr.Row():
107
  inp = gr.Image(type='pil', label='Upload Ad Image')
 
116
  inputs=[inp],
117
  outputs=[cat_out, ana_out, sug_out, gallery],
118
  )
119
+ gr.Markdown('Made by Simon Thalmay • Powered by google/gemma-3-4b-it')
120
  return demo
121
 
122
  if __name__ == "__main__":