Spaces:
Sleeping
Sleeping
| from transformers import Blip2Processor, Blip2ForConditionalGeneration, AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| import gradio as gr | |
| import torch | |
| import os | |
| # Check CUDA availability | |
| if not torch.cuda.is_available(): | |
| os.environ["BITSANDBYTES_NOWELCOME"] = "1" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| os.environ["LIBRARY_PATH"] = "/usr/local/cuda/lib64/stubs:$LIBRARY_PATH" | |
| # Initialize environment | |
| os.makedirs("./offload", exist_ok=True) | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # Memory optimization | |
| torch.cuda.empty_cache() | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # Load BLIP-2 | |
| blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
| blip_model = Blip2ForConditionalGeneration.from_pretrained( | |
| "Salesforce/blip2-opt-2.7b", | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ).eval() | |
| # Load Phi-3 | |
| phi_model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| trust_remote_code=True, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| load_in_4bit=torch.cuda.is_available(), # Only use 4bit if CUDA available | |
| token=HF_TOKEN | |
| ).eval() | |
| phi_tokenizer = AutoTokenizer.from_pretrained( | |
| "microsoft/Phi-3-mini-4k-instruct", | |
| token=HF_TOKEN | |
| ) | |
| def analyze_image(image): | |
| inputs = blip_processor(image, return_tensors="pt").to(blip_model.device) | |
| generated_ids = blip_model.generate(**inputs, max_length=50) | |
| return blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| def generate_meme_caption(image_desc, user_prompt): | |
| messages = [ | |
| {"role": "system", "content": "You are a meme expert. Create funny captions in format: TOP TEXT | BOTTOM TEXT"}, | |
| {"role": "user", "content": f"Image context: {image_desc}\nUser input: {user_prompt}\nGenerate 3 meme captions (max 10 words each):"} | |
| ] | |
| inputs = phi_tokenizer.apply_chat_template( | |
| messages, | |
| return_tensors="pt", | |
| add_generation_prompt=True | |
| ).to(phi_model.device) | |
| outputs = phi_model.generate( | |
| inputs, | |
| max_new_tokens=200, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| return phi_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def create_meme(image, top_text, bottom_text): | |
| img = image.copy() | |
| draw = ImageDraw.Draw(img) | |
| # Use available font (works in Colab/Spaces) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", size=min(img.size)//12) | |
| except: | |
| font = ImageFont.load_default() | |
| # Top text | |
| draw.text( | |
| (img.width/2, 10), | |
| top_text, | |
| font=font, | |
| fill="white", | |
| anchor="mt", | |
| stroke_width=2, | |
| stroke_fill="black" | |
| ) | |
| # Bottom text | |
| draw.text( | |
| (img.width/2, img.height-10), | |
| bottom_text, | |
| font=font, | |
| fill="white", | |
| anchor="mb", | |
| stroke_width=2, | |
| stroke_fill="black" | |
| ) | |
| return img | |
| def process_meme(image, user_prompt): | |
| image_desc = analyze_image(image) | |
| raw_output = generate_meme_caption(image_desc, user_prompt) | |
| captions = [] | |
| for line in raw_output.split("\n"): | |
| if "|" in line: | |
| parts = line.split("|", 1) | |
| if len(parts) == 2: | |
| captions.append((parts[0].strip(), parts[1].strip())) | |
| memes = [create_meme(image, top, bottom) for top, bottom in captions[:3]] | |
| return memes | |
| with gr.Blocks(title="AI Meme Generator") as demo: | |
| gr.Markdown("# 🚀 AI Meme Generator") | |
| with gr.Row(): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| text_input = gr.Textbox(label="Meme Theme/Prompt") | |
| submit_btn = gr.Button("Generate Memes!") | |
| gallery = gr.Gallery(label="Generated Memes", columns=3) | |
| submit_btn.click( | |
| fn=process_meme, | |
| inputs=[image_input, text_input], | |
| outputs=gallery | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |