Tulitula commited on
Commit
f029782
·
verified ·
1 Parent(s): 24c1dc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -9,11 +9,11 @@ from transformers import (
9
  AutoTokenizer,
10
  AutoModelForSeq2SeqLM,
11
  )
12
- from difflib import SequenceMatcher
13
 
 
14
  DEVICE = 0 if torch.cuda.is_available() else -1
15
 
16
- # Load BLIP
17
  processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
18
  blip_model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large")
19
  caption_pipe = pipeline(
@@ -24,7 +24,7 @@ caption_pipe = pipeline(
24
  device=DEVICE,
25
  )
26
 
27
- # Load Flan-T5
28
  FLAN_MODEL = "google/flan-t5-large"
29
  flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL)
30
  flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL)
@@ -38,6 +38,7 @@ category_pipe = pipeline(
38
  do_sample=True,
39
  temperature=1.0,
40
  )
 
41
  analysis_pipe = pipeline(
42
  "text2text-generation",
43
  model=flan_model,
@@ -47,6 +48,8 @@ analysis_pipe = pipeline(
47
  do_sample=True,
48
  temperature=1.0,
49
  )
 
 
50
  suggestion_pipe = pipeline(
51
  "text2text-generation",
52
  model=flan_model,
@@ -54,8 +57,9 @@ suggestion_pipe = pipeline(
54
  device=DEVICE,
55
  max_new_tokens=256,
56
  do_sample=True,
57
- temperature=1.05, # Slightly more conservative than before
58
  )
 
59
  expansion_pipe = pipeline(
60
  "text2text-generation",
61
  model=flan_model,
@@ -66,6 +70,7 @@ expansion_pipe = pipeline(
66
  )
67
 
68
  def get_recommendations():
 
69
  return [
70
  "https://i.imgur.com/InC88PP.jpeg",
71
  "https://i.imgur.com/7BHfv4T.png",
@@ -79,53 +84,68 @@ def get_recommendations():
79
  "https://i.imgur.com/Xj92Cjv.jpeg",
80
  ]
81
 
82
- def unique_suggestions(suggestions):
83
- """Strictly remove near-duplicates, keep order, ignore case/punct."""
84
- seen = []
85
- for s in suggestions:
86
- norm = re.sub(r'[^a-z0-9 ]', '', s.lower())
87
- if all(SequenceMatcher(None, norm, re.sub(r'[^a-z0-9 ]', '', x.lower())).ratio() < 0.91 for x in seen):
88
- seen.append(s)
89
- return seen
90
-
91
  def process(image: Image):
92
  if image is None:
93
  return "", "", "", get_recommendations()
94
 
 
95
  caption_res = caption_pipe(image, max_new_tokens=64)
96
  raw_caption = caption_res[0]["generated_text"].strip()
97
- desc = raw_caption
98
- if len(desc.split()) < 3:
99
- exp = expansion_pipe(f"Expand into a detailed description: {desc}")
 
100
  desc = exp[0]["generated_text"].strip()
 
 
101
 
102
  # 2. Category
103
- cat_prompt = f"Description: {desc}\n\nProvide a concise category label for this ad (e.g. 'Food', 'Fitness'):"
 
 
 
104
  cat_out = category_pipe(cat_prompt)[0]["generated_text"].splitlines()[0].strip()
105
 
106
  # 3. Five-sentence analysis
107
- ana_prompt = f"Description: {desc}\n\nWrite exactly five sentences explaining what this ad communicates and its emotional impact."
 
 
 
108
  ana_raw = analysis_pipe(ana_prompt)[0]["generated_text"].strip()
109
  sentences = re.split(r'(?<=[.!?])\s+', ana_raw)
110
  analysis = " ".join(sentences[:5])
111
 
112
- # 4. Five improvement suggestions (not forcing uniqueness in prompt)
113
- sug_prompt = f"Description: {desc}\n\nSuggest five ways this ad could be improved. Each suggestion should be one sentence and start with '- '."
 
 
 
 
114
  sug_raw = suggestion_pipe(sug_prompt)[0]["generated_text"].strip()
115
- all_sugs = [line for line in sug_raw.splitlines() if line.strip().startswith("-")]
116
- unique_sugs = unique_suggestions(all_sugs)
117
-
118
- # Default suggestions if model outputs < 5 after filtering
 
 
 
 
 
 
 
 
119
  defaults = [
120
  "- Make the main headline more eye-catching.",
121
  "- Add a clear and visible call-to-action button.",
122
  "- Use contrasting colors for better readability.",
123
  "- Highlight the unique selling point of the product.",
124
- "- Simplify the design to reduce clutter.",
125
  ]
126
  for d in defaults:
127
- if len(unique_sugs) < 5 and d not in unique_sugs:
 
128
  unique_sugs.append(d)
 
129
  suggestions = "\n".join(unique_sugs[:5])
130
 
131
  return cat_out, analysis, suggestions, get_recommendations()
 
9
  AutoTokenizer,
10
  AutoModelForSeq2SeqLM,
11
  )
 
12
 
13
+ # Auto-detect CPU/GPU
14
  DEVICE = 0 if torch.cuda.is_available() else -1
15
 
16
+ # Load BLIP captioning model
17
  processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
18
  blip_model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large")
19
  caption_pipe = pipeline(
 
24
  device=DEVICE,
25
  )
26
 
27
+ # Load Flan-T5 for text-to-text
28
  FLAN_MODEL = "google/flan-t5-large"
29
  flan_tokenizer = AutoTokenizer.from_pretrained(FLAN_MODEL)
30
  flan_model = AutoModelForSeq2SeqLM.from_pretrained(FLAN_MODEL)
 
38
  do_sample=True,
39
  temperature=1.0,
40
  )
41
+
42
  analysis_pipe = pipeline(
43
  "text2text-generation",
44
  model=flan_model,
 
48
  do_sample=True,
49
  temperature=1.0,
50
  )
51
+
52
+ # Set higher temperature for more variety
53
  suggestion_pipe = pipeline(
54
  "text2text-generation",
55
  model=flan_model,
 
57
  device=DEVICE,
58
  max_new_tokens=256,
59
  do_sample=True,
60
+ temperature=1.2,
61
  )
62
+
63
  expansion_pipe = pipeline(
64
  "text2text-generation",
65
  model=flan_model,
 
70
  )
71
 
72
  def get_recommendations():
73
+ # Returns list of 10 example ad image URLs
74
  return [
75
  "https://i.imgur.com/InC88PP.jpeg",
76
  "https://i.imgur.com/7BHfv4T.png",
 
84
  "https://i.imgur.com/Xj92Cjv.jpeg",
85
  ]
86
 
 
 
 
 
 
 
 
 
 
87
  def process(image: Image):
88
  if image is None:
89
  return "", "", "", get_recommendations()
90
 
91
+ # 1. BLIP caption
92
  caption_res = caption_pipe(image, max_new_tokens=64)
93
  raw_caption = caption_res[0]["generated_text"].strip()
94
+
95
+ # 1a. Expand caption if too short
96
+ if len(raw_caption.split()) < 3:
97
+ exp = expansion_pipe(f"Expand into a detailed description: {raw_caption}")
98
  desc = exp[0]["generated_text"].strip()
99
+ else:
100
+ desc = raw_caption
101
 
102
  # 2. Category
103
+ cat_prompt = (
104
+ f"Description: {desc}\n\n"
105
+ "Provide a concise category label for this ad (e.g. 'Food', 'Fitness'):"
106
+ )
107
  cat_out = category_pipe(cat_prompt)[0]["generated_text"].splitlines()[0].strip()
108
 
109
  # 3. Five-sentence analysis
110
+ ana_prompt = (
111
+ f"Description: {desc}\n\n"
112
+ "Write exactly five sentences explaining what this ad communicates and its emotional impact."
113
+ )
114
  ana_raw = analysis_pipe(ana_prompt)[0]["generated_text"].strip()
115
  sentences = re.split(r'(?<=[.!?])\s+', ana_raw)
116
  analysis = " ".join(sentences[:5])
117
 
118
+ # 4. Five bullet-point suggestions (uniqueness enforced)
119
+ sug_prompt = (
120
+ f"Description: {desc}\n\n"
121
+ "Suggest five ways this ad could be improved. Each suggestion must be about a different aspect, such as visuals, message, call-to-action, color, clarity, layout, or audience targeting. "
122
+ "Each suggestion must start with '- ' and be one full sentence. Make sure each is different from the others."
123
+ )
124
  sug_raw = suggestion_pipe(sug_prompt)[0]["generated_text"].strip()
125
+ all_sugs = [line.strip() for line in sug_raw.splitlines() if line.strip().startswith("-")]
126
+ unique_sugs = []
127
+ seen = set()
128
+ for line in all_sugs:
129
+ line_clean = line.lower().strip().rstrip(".")
130
+ if line_clean not in seen and len(line_clean) > 4:
131
+ unique_sugs.append(line)
132
+ seen.add(line_clean)
133
+ if len(unique_sugs) == 5:
134
+ break
135
+
136
+ # Add non-repetitive defaults if needed
137
  defaults = [
138
  "- Make the main headline more eye-catching.",
139
  "- Add a clear and visible call-to-action button.",
140
  "- Use contrasting colors for better readability.",
141
  "- Highlight the unique selling point of the product.",
142
+ "- Simplify the design to reduce clutter."
143
  ]
144
  for d in defaults:
145
+ d_clean = d.lower().strip().rstrip(".")
146
+ if len(unique_sugs) < 5 and d_clean not in seen:
147
  unique_sugs.append(d)
148
+ seen.add(d_clean)
149
  suggestions = "\n".join(unique_sugs[:5])
150
 
151
  return cat_out, analysis, suggestions, get_recommendations()