Tulitula commited on
Commit
8eb7728
·
verified ·
1 Parent(s): f002c57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
5
 
6
- # Load BLIP for image captioning (slow processor avoids torchvision dependency)
7
  blip_processor = BlipProcessor.from_pretrained(
8
  "Salesforce/blip-image-captioning-base",
9
  use_fast=False
@@ -12,33 +12,18 @@ blip_model = BlipForConditionalGeneration.from_pretrained(
12
  "Salesforce/blip-image-captioning-base"
13
  )
14
 
15
- # Hugging Face pipelines (all using Flan-T5-small for speed, temperature=1.0)
16
- category_generator = pipeline(
17
  "text2text-generation",
18
- model="google/flan-t5-small",
19
- tokenizer="google/flan-t5-small",
20
- max_new_tokens=50,
21
- do_sample=True,
22
- temperature=1.0
23
- )
24
-
25
- analysis_generator = pipeline(
26
- "text2text-generation",
27
- model="google/flan-t5-small",
28
- tokenizer="google/flan-t5-small",
29
- max_new_tokens=500,
30
- do_sample=True,
31
- temperature=1.0
32
- )
33
-
34
- suggestion_generator = pipeline(
35
- "text2text-generation",
36
- model="google/flan-t5-small",
37
- tokenizer="google/flan-t5-small",
38
- max_new_tokens=500,
39
  do_sample=True,
40
  temperature=1.0
41
  )
 
 
 
42
 
43
  # Example URLs for gallery
44
  def get_recommendations():
@@ -55,42 +40,47 @@ def get_recommendations():
55
  "https://i.imgur.com/Xj92Cjv.jpeg",
56
  ]
57
 
58
- # Generate BLIP caption from image
59
  def generate_caption(image):
60
  inputs = blip_processor(images=image, return_tensors="pt")
61
  outputs = blip_model.generate(**inputs)
62
  return blip_processor.decode(outputs[0], skip_special_tokens=True)
63
 
64
- # Generate concise category via Flan
65
  def generate_category(caption):
66
  prompt = (
67
  f"Caption: {caption}\n"
68
- "Provide a concise category label for this ad."
69
  )
70
  raw = category_generator(prompt)[0]["generated_text"].strip()
71
  return raw.splitlines()[0]
72
 
73
- # Produce 5-sentence analysis via Flan
74
  def generate_analysis(caption):
75
  prompt = (
76
  f"Caption: {caption}\n"
77
- "Write exactly five sentences explaining what this ad conveys and its key message."
78
  )
79
  raw = analysis_generator(prompt)[0]["generated_text"].strip()
80
  sentences = re.split(r'(?<=[.!?])\s+', raw)
81
  return " ".join(sentences[:5])
82
 
83
- # Suggest 5 bullet-point improvements via Flan
84
  def generate_suggestions(caption):
85
  prompt = (
86
  f"Caption: {caption}\n"
87
- "Suggest five distinct improvement points for this ad, formatted as a bullet list starting each line with '- '."
 
88
  )
89
  raw = suggestion_generator(prompt)[0]["generated_text"].strip()
90
- lines = [l for l in raw.splitlines() if l.strip().startswith('-')]
91
- return "\n".join(lines[:5]) if lines else "\n".join(raw.splitlines()[:5])
92
-
93
- # Full pipeline combining all steps
 
 
 
 
94
  def process(image):
95
  caption = generate_caption(image)
96
  category = generate_category(caption)
@@ -99,11 +89,11 @@ def process(image):
99
  recs = get_recommendations()
100
  return category, analysis, suggestions, recs
101
 
102
- # UI Layout using Gradio
103
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
104
  gr.Markdown("## 📢 Smart Ad Analyzer")
105
  gr.Markdown(
106
- "Upload an image ad to see: an Ad Category label, a five-sentence analysis, five bullet-point improvements, and example ads."
107
  )
108
 
109
  with gr.Row():
 
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
5
 
6
+ # Initialize BLIP for image captioning (slow mode avoids torchvision dependency)
7
  blip_processor = BlipProcessor.from_pretrained(
8
  "Salesforce/blip-image-captioning-base",
9
  use_fast=False
 
12
  "Salesforce/blip-image-captioning-base"
13
  )
14
 
15
+ # Flan-T5-small pipelines (temperature=1 for diversity, max_new_tokens increased for depth)
16
+ gen_pipeline = lambda model_name, tokens: pipeline(
17
  "text2text-generation",
18
+ model=model_name,
19
+ tokenizer=model_name,
20
+ max_new_tokens=tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  do_sample=True,
22
  temperature=1.0
23
  )
24
+ category_generator = gen_pipeline("google/flan-t5-small", 100)
25
+ analysis_generator = gen_pipeline("google/flan-t5-small", 500)
26
+ suggestion_generator = gen_pipeline("google/flan-t5-small", 500)
27
 
28
  # Example URLs for gallery
29
  def get_recommendations():
 
40
  "https://i.imgur.com/Xj92Cjv.jpeg",
41
  ]
42
 
43
+ # Step 1: BLIP caption from image
44
  def generate_caption(image):
45
  inputs = blip_processor(images=image, return_tensors="pt")
46
  outputs = blip_model.generate(**inputs)
47
  return blip_processor.decode(outputs[0], skip_special_tokens=True)
48
 
49
+ # Step 2: Flan interprets caption into a category label
50
  def generate_category(caption):
51
  prompt = (
52
  f"Caption: {caption}\n"
53
+ "Provide a concise category label for this ad (e.g. 'Food Ad', 'Fitness Promotion')."
54
  )
55
  raw = category_generator(prompt)[0]["generated_text"].strip()
56
  return raw.splitlines()[0]
57
 
58
+ # Step 3: Flan produces a five-sentence analysis of the caption
59
  def generate_analysis(caption):
60
  prompt = (
61
  f"Caption: {caption}\n"
62
+ "Write exactly five sentences explaining what the ad conveys, its core message, and its emotional impact."
63
  )
64
  raw = analysis_generator(prompt)[0]["generated_text"].strip()
65
  sentences = re.split(r'(?<=[.!?])\s+', raw)
66
  return " ".join(sentences[:5])
67
 
68
+ # Step 4: Flan suggests five bullet-point improvements
69
  def generate_suggestions(caption):
70
  prompt = (
71
  f"Caption: {caption}\n"
72
+ "Suggest five distinct improvements for this ad as a bullet list. "
73
+ "Each line must start with '- ' and describe one actionable change."
74
  )
75
  raw = suggestion_generator(prompt)[0]["generated_text"].strip()
76
+ lines = [line for line in raw.splitlines() if line.strip().startswith('- ')]
77
+ # ensure exactly five bullets
78
+ if len(lines) < 5:
79
+ fallback = [line for line in raw.splitlines() if line.strip()]
80
+ lines = ['- ' + fallback[i] if not fallback[i].startswith('- ') else fallback[i] for i in range(min(5, len(fallback)))]
81
+ return "\n".join(lines[:5])
82
+
83
+ # Full workflow
84
  def process(image):
85
  caption = generate_caption(image)
86
  category = generate_category(caption)
 
89
  recs = get_recommendations()
90
  return category, analysis, suggestions, recs
91
 
92
+ # Gradio UI
93
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
94
  gr.Markdown("## 📢 Smart Ad Analyzer")
95
  gr.Markdown(
96
+ "Upload an image ad to see: a category, five-sentence analysis, five bullet-point improvements, and example ads."
97
  )
98
 
99
  with gr.Row():