Add-feedback / app.py
Tulitula's picture
Update app.py
93e5528 verified
import os
import gradio as gr
import torch
from PIL import Image
from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq
# Use HF_TOKEN from environment for private models if needed (can add below if your Gemma is gated)
HF_TOKEN = os.environ.get("HF_TOKEN")
# Auto-detect device
DEVICE = 0 if torch.cuda.is_available() else -1
# Load BLIP for captioning
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large")
caption_pipe = pipeline(
"image-to-text",
model=blip_model,
tokenizer=processor.tokenizer,
image_processor=processor.image_processor,
device=DEVICE,
)
# Load Gemma for text generation (pick your Gemma checkpoint here)
gemma_pipe = pipeline(
"text-generation",
model="google/gemma-2b-it", # Change this to any working Gemma instruct model!
device=DEVICE,
# token=HF_TOKEN # Uncomment if your Gemma model requires a token
)
def get_recommendations():
return [
"https://i.imgur.com/InC88PP.jpeg",
"https://i.imgur.com/7BHfv4T.png",
"https://i.imgur.com/wp3Wzc4.jpeg",
"https://i.imgur.com/5e2xOA4.jpeg",
"https://i.imgur.com/txjRk98.jpeg",
"https://i.imgur.com/rQ4AYl0.jpeg",
"https://i.imgur.com/bDzwD04.jpeg",
"https://i.imgur.com/fLMngXI.jpeg",
"https://i.imgur.com/nYEJzxt.png",
"https://i.imgur.com/Xj92Cjv.jpeg",
]
def clean_output(text):
# Remove prompt echoes if any (Gemma sometimes echoes)
if "Description:" in text:
text = text.split("Description:", 1)[-1]
if "Category:" in text:
text = text.split("Category:", 1)[-1]
return text.strip()
def process(image: Image):
if image is None:
return "", "", "", get_recommendations()
# 1. BLIP captioning
caption_res = caption_pipe(image, max_new_tokens=64)
desc = caption_res[0]["generated_text"].strip()
# 2. Gemma: Category
cat_prompt = f"Classify the following ad in one or two words. Description: {desc}"
cat_out = gemma_pipe(cat_prompt, max_new_tokens=16)[0]['generated_text'].strip()
cat_out = clean_output(cat_out)
# 3. Gemma: Analysis (5 sentences)
ana_prompt = (
f"Describe in exactly five sentences what this ad communicates and its emotional impact. Description: {desc}"
)
ana_out = gemma_pipe(ana_prompt, max_new_tokens=120)[0]['generated_text'].strip()
ana_out = clean_output(ana_out)
# 4. Gemma: Suggestions (5 bullets)
sug_prompt = (
f"Suggest five practical improvements for this ad. Each suggestion must be unique, address a different aspect (message, visuals, call to action, targeting, or layout), start with '- ', and be one sentence. Description: {desc}"
)
sug_out = gemma_pipe(sug_prompt, max_new_tokens=120)[0]['generated_text'].strip()
sug_out = clean_output(sug_out)
# Keep only lines that start with '-'
sug_lines = [line for line in sug_out.splitlines() if line.strip().startswith('-')]
suggestions = "\n".join(sug_lines[:5]) if sug_lines else sug_out
return cat_out, ana_out, suggestions, get_recommendations()
def main():
with gr.Blocks(title="Smart Ad Analyzer (BLIP + Gemma)") as demo:
gr.Markdown("## 📢 Smart Ad Analyzer (BLIP + Gemma)")
gr.Markdown(
"""
Upload your ad image below and instantly get expert feedback.
Category, analysis, improvement suggestions—and example ads for inspiration.
"""
)
with gr.Row():
inp = gr.Image(type='pil', label='Upload Ad Image')
with gr.Column():
cat_out = gr.Textbox(label='Ad Category', interactive=False)
ana_out = gr.Textbox(label='Ad Analysis', lines=5, interactive=False)
sug_out = gr.Textbox(label='Improvement Suggestions', lines=5, interactive=False)
btn = gr.Button('Analyze Ad', variant='primary')
gallery = gr.Gallery(label='Example Ads')
btn.click(
fn=process,
inputs=[inp],
outputs=[cat_out, ana_out, sug_out, gallery],
)
gr.Markdown('Made by Simon Thalmay')
return demo
if __name__ == "__main__":
demo = main()
demo.launch()