Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import pipeline, AutoProcessor, AutoModelForVision2Seq | |
| # Use HF_TOKEN from environment for private models if needed (can add below if your Gemma is gated) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Auto-detect device | |
| DEVICE = 0 if torch.cuda.is_available() else -1 | |
| # Load BLIP for captioning | |
| processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
| blip_model = AutoModelForVision2Seq.from_pretrained("Salesforce/blip-image-captioning-large") | |
| caption_pipe = pipeline( | |
| "image-to-text", | |
| model=blip_model, | |
| tokenizer=processor.tokenizer, | |
| image_processor=processor.image_processor, | |
| device=DEVICE, | |
| ) | |
| # Load Gemma for text generation (pick your Gemma checkpoint here) | |
| gemma_pipe = pipeline( | |
| "text-generation", | |
| model="google/gemma-2b-it", # Change this to any working Gemma instruct model! | |
| device=DEVICE, | |
| # token=HF_TOKEN # Uncomment if your Gemma model requires a token | |
| ) | |
| def get_recommendations(): | |
| return [ | |
| "https://i.imgur.com/InC88PP.jpeg", | |
| "https://i.imgur.com/7BHfv4T.png", | |
| "https://i.imgur.com/wp3Wzc4.jpeg", | |
| "https://i.imgur.com/5e2xOA4.jpeg", | |
| "https://i.imgur.com/txjRk98.jpeg", | |
| "https://i.imgur.com/rQ4AYl0.jpeg", | |
| "https://i.imgur.com/bDzwD04.jpeg", | |
| "https://i.imgur.com/fLMngXI.jpeg", | |
| "https://i.imgur.com/nYEJzxt.png", | |
| "https://i.imgur.com/Xj92Cjv.jpeg", | |
| ] | |
| def clean_output(text): | |
| # Remove prompt echoes if any (Gemma sometimes echoes) | |
| if "Description:" in text: | |
| text = text.split("Description:", 1)[-1] | |
| if "Category:" in text: | |
| text = text.split("Category:", 1)[-1] | |
| return text.strip() | |
| def process(image: Image): | |
| if image is None: | |
| return "", "", "", get_recommendations() | |
| # 1. BLIP captioning | |
| caption_res = caption_pipe(image, max_new_tokens=64) | |
| desc = caption_res[0]["generated_text"].strip() | |
| # 2. Gemma: Category | |
| cat_prompt = f"Classify the following ad in one or two words. Description: {desc}" | |
| cat_out = gemma_pipe(cat_prompt, max_new_tokens=16)[0]['generated_text'].strip() | |
| cat_out = clean_output(cat_out) | |
| # 3. Gemma: Analysis (5 sentences) | |
| ana_prompt = ( | |
| f"Describe in exactly five sentences what this ad communicates and its emotional impact. Description: {desc}" | |
| ) | |
| ana_out = gemma_pipe(ana_prompt, max_new_tokens=120)[0]['generated_text'].strip() | |
| ana_out = clean_output(ana_out) | |
| # 4. Gemma: Suggestions (5 bullets) | |
| sug_prompt = ( | |
| f"Suggest five practical improvements for this ad. Each suggestion must be unique, address a different aspect (message, visuals, call to action, targeting, or layout), start with '- ', and be one sentence. Description: {desc}" | |
| ) | |
| sug_out = gemma_pipe(sug_prompt, max_new_tokens=120)[0]['generated_text'].strip() | |
| sug_out = clean_output(sug_out) | |
| # Keep only lines that start with '-' | |
| sug_lines = [line for line in sug_out.splitlines() if line.strip().startswith('-')] | |
| suggestions = "\n".join(sug_lines[:5]) if sug_lines else sug_out | |
| return cat_out, ana_out, suggestions, get_recommendations() | |
| def main(): | |
| with gr.Blocks(title="Smart Ad Analyzer (BLIP + Gemma)") as demo: | |
| gr.Markdown("## 📢 Smart Ad Analyzer (BLIP + Gemma)") | |
| gr.Markdown( | |
| """ | |
| Upload your ad image below and instantly get expert feedback. | |
| Category, analysis, improvement suggestions—and example ads for inspiration. | |
| """ | |
| ) | |
| with gr.Row(): | |
| inp = gr.Image(type='pil', label='Upload Ad Image') | |
| with gr.Column(): | |
| cat_out = gr.Textbox(label='Ad Category', interactive=False) | |
| ana_out = gr.Textbox(label='Ad Analysis', lines=5, interactive=False) | |
| sug_out = gr.Textbox(label='Improvement Suggestions', lines=5, interactive=False) | |
| btn = gr.Button('Analyze Ad', variant='primary') | |
| gallery = gr.Gallery(label='Example Ads') | |
| btn.click( | |
| fn=process, | |
| inputs=[inp], | |
| outputs=[cat_out, ana_out, sug_out, gallery], | |
| ) | |
| gr.Markdown('Made by Simon Thalmay') | |
| return demo | |
| if __name__ == "__main__": | |
| demo = main() | |
| demo.launch() | |