Tulitula commited on
Commit
7cef81c
·
verified ·
1 Parent(s): 8eb7728

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -30
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
5
 
6
- # Initialize BLIP for image captioning (slow mode avoids torchvision dependency)
7
  blip_processor = BlipProcessor.from_pretrained(
8
  "Salesforce/blip-image-captioning-base",
9
  use_fast=False
@@ -12,20 +12,24 @@ blip_model = BlipForConditionalGeneration.from_pretrained(
12
  "Salesforce/blip-image-captioning-base"
13
  )
14
 
15
- # Flan-T5-small pipelines (temperature=1 for diversity, max_new_tokens increased for depth)
16
- gen_pipeline = lambda model_name, tokens: pipeline(
17
- "text2text-generation",
18
- model=model_name,
19
- tokenizer=model_name,
20
- max_new_tokens=tokens,
21
- do_sample=True,
22
- temperature=1.0
23
- )
24
- category_generator = gen_pipeline("google/flan-t5-small", 100)
25
- analysis_generator = gen_pipeline("google/flan-t5-small", 500)
26
- suggestion_generator = gen_pipeline("google/flan-t5-small", 500)
 
 
 
 
 
27
 
28
- # Example URLs for gallery
29
  def get_recommendations():
30
  return [
31
  "https://i.imgur.com/InC88PP.jpeg",
@@ -40,26 +44,23 @@ def get_recommendations():
40
  "https://i.imgur.com/Xj92Cjv.jpeg",
41
  ]
42
 
43
- # Step 1: BLIP caption from image
44
  def generate_caption(image):
45
  inputs = blip_processor(images=image, return_tensors="pt")
46
  outputs = blip_model.generate(**inputs)
47
  return blip_processor.decode(outputs[0], skip_special_tokens=True)
48
 
49
- # Step 2: Flan interprets caption into a category label
50
  def generate_category(caption):
51
- prompt = (
52
- f"Caption: {caption}\n"
53
- "Provide a concise category label for this ad (e.g. 'Food Ad', 'Fitness Promotion')."
54
- )
55
  raw = category_generator(prompt)[0]["generated_text"].strip()
56
  return raw.splitlines()[0]
57
 
58
- # Step 3: Flan produces a five-sentence analysis of the caption
59
  def generate_analysis(caption):
60
  prompt = (
61
  f"Caption: {caption}\n"
62
- "Write exactly five sentences explaining what the ad conveys, its core message, and its emotional impact."
63
  )
64
  raw = analysis_generator(prompt)[0]["generated_text"].strip()
65
  sentences = re.split(r'(?<=[.!?])\s+', raw)
@@ -69,18 +70,18 @@ def generate_analysis(caption):
69
  def generate_suggestions(caption):
70
  prompt = (
71
  f"Caption: {caption}\n"
72
- "Suggest five distinct improvements for this ad as a bullet list. "
73
- "Each line must start with '- ' and describe one actionable change."
74
  )
75
  raw = suggestion_generator(prompt)[0]["generated_text"].strip()
76
  lines = [line for line in raw.splitlines() if line.strip().startswith('- ')]
77
- # ensure exactly five bullets
78
  if len(lines) < 5:
79
- fallback = [line for line in raw.splitlines() if line.strip()]
80
- lines = ['- ' + fallback[i] if not fallback[i].startswith('- ') else fallback[i] for i in range(min(5, len(fallback)))]
81
  return "\n".join(lines[:5])
82
 
83
- # Full workflow
 
84
  def process(image):
85
  caption = generate_caption(image)
86
  category = generate_category(caption)
@@ -89,11 +90,11 @@ def process(image):
89
  recs = get_recommendations()
90
  return category, analysis, suggestions, recs
91
 
92
- # Gradio UI
93
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
94
  gr.Markdown("## 📢 Smart Ad Analyzer")
95
  gr.Markdown(
96
- "Upload an image ad to see: a category, five-sentence analysis, five bullet-point improvements, and example ads."
97
  )
98
 
99
  with gr.Row():
 
3
  from PIL import Image
4
  from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
5
 
6
+ # Load BLIP for image captioning (slow processor, no torchvision dependency)
7
  blip_processor = BlipProcessor.from_pretrained(
8
  "Salesforce/blip-image-captioning-base",
9
  use_fast=False
 
12
  "Salesforce/blip-image-captioning-base"
13
  )
14
 
15
+ # Helper to create Flan-T5 pipelines (temperature=1.0 for diversity)
16
+ def make_pipeline(model_name, max_tokens):
17
+ return pipeline(
18
+ "text2text-generation",
19
+ model=model_name,
20
+ tokenizer=model_name,
21
+ max_new_tokens=max_tokens,
22
+ do_sample=True,
23
+ temperature=1.0
24
+ )
25
+
26
+ # Pipelines: category, analysis, suggestions
27
+ category_generator = make_pipeline("google/flan-t5-small", 100)
28
+ analysis_generator = make_pipeline("google/flan-t5-small", 500)
29
+ suggestion_generator = make_pipeline("google/flan-t5-small", 500)
30
+
31
+ # Example ads URLs for gallery
32
 
 
33
  def get_recommendations():
34
  return [
35
  "https://i.imgur.com/InC88PP.jpeg",
 
44
  "https://i.imgur.com/Xj92Cjv.jpeg",
45
  ]
46
 
47
+ # Step 1: BLIP generates a caption from the image
48
  def generate_caption(image):
49
  inputs = blip_processor(images=image, return_tensors="pt")
50
  outputs = blip_model.generate(**inputs)
51
  return blip_processor.decode(outputs[0], skip_special_tokens=True)
52
 
53
+ # Step 2: Flan interprets caption into a concise category label
54
  def generate_category(caption):
55
+ prompt = f"Caption: {caption}\nProvide a concise category label for this ad."
 
 
 
56
  raw = category_generator(prompt)[0]["generated_text"].strip()
57
  return raw.splitlines()[0]
58
 
59
+ # Step 3: Flan writes exactly five sentences of analysis
60
  def generate_analysis(caption):
61
  prompt = (
62
  f"Caption: {caption}\n"
63
+ "In exactly five sentences, explain what this ad communicates and its emotional impact."
64
  )
65
  raw = analysis_generator(prompt)[0]["generated_text"].strip()
66
  sentences = re.split(r'(?<=[.!?])\s+', raw)
 
70
  def generate_suggestions(caption):
71
  prompt = (
72
  f"Caption: {caption}\n"
73
+ "Suggest five distinct improvements as bullet points. "
74
+ "Each line must start with '- '."
75
  )
76
  raw = suggestion_generator(prompt)[0]["generated_text"].strip()
77
  lines = [line for line in raw.splitlines() if line.strip().startswith('- ')]
 
78
  if len(lines) < 5:
79
+ all_lines = [l.strip() for l in raw.splitlines() if l.strip()]
80
+ lines = [('- ' + all_lines[i]) if not all_lines[i].startswith('- ') else all_lines[i] for i in range(min(5, len(all_lines)))]
81
  return "\n".join(lines[:5])
82
 
83
+ # Combine steps into one process
84
+
85
  def process(image):
86
  caption = generate_caption(image)
87
  category = generate_category(caption)
 
90
  recs = get_recommendations()
91
  return category, analysis, suggestions, recs
92
 
93
+ # Gradio UI layout
94
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
95
  gr.Markdown("## 📢 Smart Ad Analyzer")
96
  gr.Markdown(
97
+ "Upload an image ad to see an Ad Category, a five-sentence Analysis, five bullet-point Suggestions, and Example Ads."
98
  )
99
 
100
  with gr.Row():