import gradio as gr from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration from groq import Groq # ============================== # LOAD MODEL (ONCE) # ============================== processor = BlipProcessor.from_pretrained( "Salesforce/blip-image-captioning-base" ) model = BlipForConditionalGeneration.from_pretrained( "Salesforce/blip-image-captioning-base" ) # ============================== # CORE FUNCTION # ============================== def generate_caption(api_key, image, style): try: if not api_key or api_key.strip() == "": return "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" if image is None: return "❌ Please upload an image." if image.mode != "RGB": image = image.convert("RGB") # Image → basic caption inputs = processor(image, return_tensors="pt") output = model.generate(**inputs, max_new_tokens=30) basic_caption = processor.decode( output[0], skip_special_tokens=True ) # Groq refinement client = Groq(api_key=api_key) prompt = f""" Rewrite the following image caption in a {style.lower()} style. Keep it short (1–2 lines). Caption: "{basic_caption}" """ response = client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role": "user", "content": prompt}], temperature=0.7, ) refined = response.choices[0].message.content return ( f"🖼️ **Basic Caption:** {basic_caption}\n\n" f"✨ **AI Refined Caption ({style}):**\n{refined}" ) except Exception as e: return f"❌ Error:\n{str(e)}" # ============================== # UI # ============================== with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown(""" # 🖼️ Image Caption Generator Image → Caption → Groq AI ✨ Hugging Face Deployment """) api_key = gr.Textbox( label="🔑 Groq API Key", type="password", placeholder="Paste your Groq API key here" ) image = gr.Image(type="pil", label="📷 Upload Image") style = gr.Dropdown( ["Normal", "Creative", "Fun / Gen-Z"], value="Normal", label="🎨 Caption Style" ) btn = gr.Button("🚀 Generate Caption") output = gr.Markdown() btn.click( generate_caption, inputs=[api_key, image, style], outputs=output ) app.launch()