| |
| """ |
| Gradio Demo: Meme Generator Pipeline - 2 Mode Version |
| Mode 1: Example Gallery - Select from 293 pre-loaded Chinese memes |
| Mode 2: Custom Upload - Upload your own image |
| """ |
|
|
| import gradio as gr |
| import os |
| import re |
| import csv |
| from PIL import Image, ImageDraw, ImageFont |
| import io |
| import base64 |
| import replicate |
| from pathlib import Path |
| import threading |
| from huggingface_hub import HfApi, hf_hub_download |
| from config import ( |
| SYSTEM_PROMPT, |
| CHARACTER_BATCHES, |
| LLAVA_MODEL, |
| LLAVA_MAX_TOKENS, |
| LLAVA_TEMPERATURE, |
| LLAVA_TOP_P, |
| FLUX_MODEL, |
| FLUX_INFERENCE_STEPS, |
| FLUX_GUIDANCE_SCALE, |
| FLUX_TEXT_FREE_INSTRUCTION, |
| MAX_QUEUE_SIZE, |
| MAX_CONCURRENT_THREADS, |
| INITIAL_GENERATION_COUNT, |
| MAX_GENERATIONS_PER_HOUR, |
| MAX_GENERATIONS_PER_IP_PER_HOUR, |
| COOLDOWN_SECONDS, |
| FONT_PATHS, |
| CAPTION_FONT_SIZE_RATIO |
| ) |
|
|
| |
| REPLICATE_API_TOKEN = os.environ.get("REPLICATE_API_TOKEN", "") |
| IMAGE_DIR = Path("image") |
| CSV_FILE = Path("labeled_data_clean.csv") |
|
|
| |
| COUNTER_REPO = "YZhao09/meme-counter" |
| COUNTER_FILENAME = "counter.txt" |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
|
|
|
|
| |
| def load_gallery_examples(): |
| """Load all example memes from CSV""" |
| examples = [] |
| try: |
| with open(CSV_FILE, 'r', encoding='utf-8') as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| img_path = IMAGE_DIR / row['filename'] |
| if img_path.exists(): |
| examples.append({ |
| 'filename': row['filename'], |
| 'image': str(img_path), |
| 'description': row['content'], |
| 'emotion': row.get('emotion', ''), |
| 'intensity': row.get('intensity', '') |
| }) |
| except Exception as e: |
| print(f"Error loading gallery: {e}") |
| return examples |
|
|
| GALLERY_EXAMPLES = load_gallery_examples() |
| print(f"✅ Loaded {len(GALLERY_EXAMPLES)} gallery examples") |
|
|
| |
|
|
| def load_generation_count_from_dataset(): |
| """Load counter from HF dataset repo (FREE persistent storage!)""" |
| try: |
| if not HF_TOKEN: |
| print("⚠️ No HF_TOKEN found - counter will reset on Space restart") |
| print(" To enable persistence: Add HF_TOKEN to Space secrets") |
| return INITIAL_GENERATION_COUNT |
| |
| |
| filepath = hf_hub_download( |
| repo_id=COUNTER_REPO, |
| filename=COUNTER_FILENAME, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| force_download=True |
| ) |
| |
| with open(filepath, 'r') as f: |
| count = int(f.read().strip()) |
| print(f"✅ Loaded generation count from dataset repo: {count}") |
| return count |
| |
| except Exception as e: |
| print(f"⚠️ Could not load counter from dataset (may not exist yet): {e}") |
| print(f" Using initial count: {INITIAL_GENERATION_COUNT}") |
| return INITIAL_GENERATION_COUNT |
|
|
| def save_generation_count_to_dataset(count): |
| """Save counter to HF dataset repo""" |
| try: |
| if not HF_TOKEN: |
| return |
| |
| |
| temp_file = "/tmp/counter.txt" |
| with open(temp_file, 'w') as f: |
| f.write(str(count)) |
| |
| |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=temp_file, |
| path_in_repo=COUNTER_FILENAME, |
| repo_id=COUNTER_REPO, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| commit_message=f"Update counter to {count}" |
| ) |
| |
| except Exception as e: |
| print(f"⚠️ Could not save counter to dataset: {e}") |
|
|
| |
| _generation_counter = 0 |
|
|
| |
| _total_generations = load_generation_count_from_dataset() |
| _generation_lock = threading.Lock() |
|
|
| |
| from collections import deque, defaultdict |
| import time |
|
|
| _hourly_generations = deque() |
| _last_generation_time = {} |
| _ip_generations = defaultdict(deque) |
| _rate_limit_lock = threading.Lock() |
|
|
| def check_rate_limits(request: gr.Request = None) -> tuple[bool, str]: |
| """ |
| Check if generation is allowed based on rate limits |
| Returns: (allowed: bool, message: str) |
| """ |
| global _hourly_generations, _last_generation_time, _ip_generations |
| |
| with _rate_limit_lock: |
| current_time = time.time() |
| |
| |
| ip_address = "unknown" |
| if request is not None: |
| ip_address = request.client.host if hasattr(request, 'client') else "unknown" |
| |
| |
| |
| while _hourly_generations and current_time - _hourly_generations[0] > 3600: |
| _hourly_generations.popleft() |
| |
| if len(_hourly_generations) >= MAX_GENERATIONS_PER_HOUR: |
| return False, f"⚠️ Global hourly limit reached ({MAX_GENERATIONS_PER_HOUR} generations/hour). Please try again later." |
| |
| |
| if ip_address != "unknown": |
| |
| while _ip_generations[ip_address] and current_time - _ip_generations[ip_address][0] > 3600: |
| _ip_generations[ip_address].popleft() |
| |
| if len(_ip_generations[ip_address]) >= MAX_GENERATIONS_PER_IP_PER_HOUR: |
| return False, f"⚠️ You've reached your limit of {MAX_GENERATIONS_PER_IP_PER_HOUR} generations per hour. Please try again later." |
| |
| |
| session_id = ip_address |
| if session_id in _last_generation_time: |
| time_since_last = current_time - _last_generation_time[session_id] |
| if time_since_last < COOLDOWN_SECONDS: |
| wait_time = int(COOLDOWN_SECONDS - time_since_last) |
| return False, f"⏳ Please wait {wait_time} seconds before next generation." |
| |
| |
| _hourly_generations.append(current_time) |
| _last_generation_time[session_id] = current_time |
| if ip_address != "unknown": |
| _ip_generations[ip_address].append(current_time) |
| |
| return True, "OK" |
|
|
| def increment_generation_count(): |
| """Thread-safe increment of generation counter with dataset persistence""" |
| global _total_generations |
| with _generation_lock: |
| _total_generations += 1 |
| |
| save_generation_count_to_dataset(_total_generations) |
| return _total_generations |
|
|
| def get_generation_count(): |
| """Get current generation count as string""" |
| with _generation_lock: |
| return str(_total_generations) |
|
|
| def refresh_counter(): |
| """Refresh counter from dataset - returns string value""" |
| global _total_generations |
| with _generation_lock: |
| _total_generations = load_generation_count_from_dataset() |
| return str(_total_generations) |
|
|
| |
|
|
| def get_character_batch(index: int = None) -> str: |
| """Get character suggestions based on rotation index""" |
| global _generation_counter |
| if index is None: |
| index = _generation_counter |
| _generation_counter += 1 |
| |
| batch_index = index % len(CHARACTER_BATCHES) |
| return CHARACTER_BATCHES[batch_index] |
|
|
| def call_llava_replicate(image: Image.Image, description: str, sample_index: int = None) -> dict: |
| """Call LLaVA via Replicate API with character batch rotation""" |
| if not REPLICATE_API_TOKEN: |
| return {"error": "Please set REPLICATE_API_TOKEN environment variable"} |
| |
| try: |
| |
| character_suggestions = get_character_batch(sample_index) |
| |
| |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| img_str = base64.b64encode(buffered.getvalue()).decode() |
| image_uri = f"data:image/png;base64,{img_str}" |
| |
| |
| customized_prompt = SYSTEM_PROMPT.replace("[CHARACTER_SUGGESTIONS]", character_suggestions) |
| |
| |
| user_prompt = f"Description: {description}\nOriginal Emotion: unknown\nOriginal Intensity: unknown" |
| full_prompt = customized_prompt + "\n\n" + user_prompt |
| |
| |
| output = replicate.run( |
| LLAVA_MODEL, |
| input={ |
| "image": image_uri, |
| "prompt": full_prompt, |
| "max_tokens": LLAVA_MAX_TOKENS, |
| "temperature": LLAVA_TEMPERATURE, |
| "top_p": LLAVA_TOP_P, |
| } |
| ) |
| |
| |
| translation = "".join(output) |
| |
| |
| image_gen_instructions = extract_image_gen_instructions(translation) |
| us_meme_caption = extract_us_meme_caption(translation) |
| |
| return { |
| "translation": translation, |
| "image_generation_instructions": image_gen_instructions, |
| "us_meme_caption": us_meme_caption |
| } |
| |
| except Exception as e: |
| return {"error": f"Exception: {str(e)}"} |
|
|
|
|
| def extract_image_gen_instructions(translation: str) -> str: |
| """Extract image generation instructions from LLaVA output""" |
| lines = [l.rstrip() for l in translation.splitlines()] |
| instruction_lines = [] |
| capture = False |
| for line in lines: |
| if re.match(r'^\s*(?:\d+\.\s*)?IMAGE GENERATION INSTRUCTIONS:', line, re.I): |
| capture = True |
| continue |
| if capture: |
| if re.match(r'^\s*(?:\d+\.\s*)?(US MEME CAPTION:|US MEME CAPTIONS:|US MEME:|NOTE:|CULTURAL CONTEXT:)', line, re.I): |
| break |
| if line.strip(): |
| instruction_lines.append(line.strip()) |
| return "\n".join(instruction_lines).strip() if instruction_lines else "" |
|
|
|
|
| def extract_us_meme_caption(translation: str) -> str: |
| """Extract US meme caption from LLaVA output""" |
| lines = [l.rstrip() for l in translation.splitlines()] |
| caption_lines = [] |
| capture = False |
| for line in lines: |
| if re.match(r'^\s*(?:\d+\.\s*)?US MEME CAPTIONS?:', line, re.I): |
| capture = True |
| continue |
| if capture: |
| if re.match(r'^\s*\d+\.\s+', line): |
| break |
| if re.match(r'^\s*(?:CONSTRAINTS?:|NOTE:|CULTURAL CONTEXT:)', line, re.I): |
| break |
| if line.strip(): |
| caption_lines.append(line.strip()) |
| break |
| return " ".join(caption_lines).strip() if caption_lines else "" |
|
|
|
|
| def call_flux_replicate(prompt: str) -> Image.Image: |
| """Call FLUX via Replicate API""" |
| if not REPLICATE_API_TOKEN: |
| return None |
| |
| try: |
| |
| if FLUX_TEXT_FREE_INSTRUCTION not in prompt: |
| prompt = prompt.rstrip(".") + ". " + FLUX_TEXT_FREE_INSTRUCTION |
| |
| |
| output = replicate.run( |
| FLUX_MODEL, |
| input={ |
| "prompt": prompt, |
| "num_inference_steps": FLUX_INFERENCE_STEPS, |
| "guidance_scale": FLUX_GUIDANCE_SCALE, |
| } |
| ) |
| |
| |
| if isinstance(output, list) and len(output) > 0: |
| output = output[0] |
| |
| |
| if hasattr(output, 'read'): |
| image = Image.open(output) |
| else: |
| import requests |
| response = requests.get(str(output)) |
| image = Image.open(io.BytesIO(response.content)) |
| |
| return image |
| |
| except Exception as e: |
| print(f"FLUX error: {e}") |
| return None |
|
|
|
|
| def add_caption_to_image(image: Image.Image, caption: str) -> Image.Image: |
| """Add caption to image using PIL""" |
| try: |
| |
| img = image.copy() |
| draw = ImageDraw.Draw(img) |
| |
| |
| width, height = img.size |
| |
| |
| font_size = int(height * CAPTION_FONT_SIZE_RATIO) |
| |
| |
| font = None |
| |
| for font_path in FONT_PATHS: |
| try: |
| font = ImageFont.truetype(font_path, font_size) |
| break |
| except: |
| continue |
| |
| |
| if font is None: |
| font = ImageFont.load_default() |
| |
| |
| words = caption.split() |
| lines = [] |
| current_line = [] |
| |
| for word in words: |
| test_line = ' '.join(current_line + [word]) |
| bbox = draw.textbbox((0, 0), test_line, font=font) |
| if bbox[2] - bbox[0] < width * 0.9: |
| current_line.append(word) |
| else: |
| if current_line: |
| lines.append(' '.join(current_line)) |
| current_line = [word] |
| |
| if current_line: |
| lines.append(' '.join(current_line)) |
| |
| |
| total_text_height = 0 |
| for line in lines: |
| bbox = draw.textbbox((0, 0), line, font=font) |
| total_text_height += (bbox[3] - bbox[1]) + 5 |
| |
| |
| y_text = int(height * 0.95) - total_text_height |
| for line in lines: |
| bbox = draw.textbbox((0, 0), line, font=font) |
| text_width = bbox[2] - bbox[0] |
| x_text = (width - text_width) // 2 |
| |
| |
| for adj_x in range(-2, 3): |
| for adj_y in range(-2, 3): |
| draw.text((x_text + adj_x, y_text + adj_y), line, font=font, fill='black') |
| |
| |
| draw.text((x_text, y_text), line, font=font, fill='white') |
| y_text += bbox[3] - bbox[1] + 5 |
| |
| return img |
| |
| except Exception as e: |
| print(f"Caption error: {e}") |
| return image |
|
|
|
|
| |
|
|
| def generate_meme(input_image, description, sample_index: int = None, request: gr.Request = None): |
| """Main pipeline: Chinese meme to US meme with character batch rotation""" |
| |
| |
| allowed, rate_limit_msg = check_rate_limits(request) |
| if not allowed: |
| current_count = get_generation_count() |
| return None, None, rate_limit_msg, None, None, None, current_count |
| |
| if input_image is None: |
| current_count = get_generation_count() |
| return None, None, "Please provide an input image", None, None, None, current_count |
| |
| if not description or not description.strip(): |
| current_count = get_generation_count() |
| return None, None, "Please provide a description", None, None, None, current_count |
| |
| if not REPLICATE_API_TOKEN: |
| current_count = get_generation_count() |
| return None, None, "Error: REPLICATE_API_TOKEN not set. Please configure it in environment variables or Hugging Face Spaces secrets.", None, None, None, current_count |
| |
| try: |
| |
| count = increment_generation_count() |
| |
| |
| character_batch = get_character_batch(sample_index) |
| batch_index = (sample_index if sample_index is not None else _generation_counter - 1) % len(CHARACTER_BATCHES) |
| |
| |
| if isinstance(input_image, str): |
| image = Image.open(input_image).convert('RGB') |
| else: |
| image = input_image.convert('RGB') |
| |
| yield None, None, "Step 1/3: Analyzing with LLaVA...", None, None, None, count |
| |
| |
| llava_result = call_llava_replicate(image, description, sample_index) |
| |
| if "error" in llava_result: |
| yield None, None, f"Error in Step 1: {llava_result['error']}", None, None, None, count |
| return |
| |
| translation = llava_result.get("translation", "") |
| image_instructions = llava_result.get("image_generation_instructions", "") |
| us_caption = llava_result.get("us_meme_caption", "") |
| |
| if not image_instructions: |
| yield None, None, "Error: Could not extract image instructions", translation, None, None, count |
| return |
| |
| yield None, None, "Step 2/3: Generating new meme image with FLUX...", translation, None, None, count |
| |
| |
| generated_image = call_flux_replicate(image_instructions) |
| |
| if generated_image is None: |
| yield None, None, "Error in Step 2: Failed to generate image", translation, None, None, count |
| return |
| |
| yield None, generated_image, "Step 3/3: Adding caption...", translation, image_instructions, us_caption, count |
| |
| |
| final_meme = add_caption_to_image(generated_image, us_caption) |
| |
| yield final_meme, generated_image, "Complete! Your US meme is ready!", translation, image_instructions, us_caption, count |
| |
| except Exception as e: |
| yield None, None, f"Error: {str(e)}", None, None, None, get_generation_count() |
|
|
|
|
| def select_gallery_example(evt: gr.SelectData): |
| """Handle gallery selection by index""" |
| idx = evt.index |
| if 0 <= idx < len(GALLERY_EXAMPLES): |
| example = GALLERY_EXAMPLES[idx] |
| return example['image'], example['description'] |
| return None, "" |
|
|
|
|
| |
|
|
| def create_demo(): |
| """Create the Gradio demo interface""" |
| |
| |
| custom_css = """ |
| .counter-display label { |
| color: #000000 !important; |
| font-weight: bold !important; |
| } |
| .counter-display input { |
| color: #000000 !important; |
| font-size: 1.2em !important; |
| font-weight: bold !important; |
| text-align: center !important; |
| } |
| /* Show loading state when counter is empty */ |
| .counter-display input:placeholder-shown { |
| font-style: italic; |
| opacity: 0.6; |
| } |
| """ |
| |
| with gr.Blocks(title="Chinese to US Meme Generator", css=custom_css) as demo: |
| |
| with gr.Row(): |
| gr.Markdown(""" |
| # MemeXGen |
| |
| Cross-Cultural Meme Transcreation with Vision-Language Models |
| |
| 📄 **Read our paper: [Beyond Translation: Cross-Cultural Meme Transcreation with Vision-Language Models](https://arxiv.org/pdf/2602.02510)** |
| """) |
| with gr.Column(scale=0, min_width=150): |
| generation_counter = gr.Textbox( |
| label="🔥 Generations", |
| value="Loading...", |
| interactive=False, |
| container=True, |
| elem_classes="counter-display", |
| show_label=True |
| ) |
| |
| gr.Markdown(""" |
| **Pipeline**: LLaVA-13B (analyze) → FLUX.1-schnell (generate) → Caption overlay |
| |
| --- |
| ### Tips: |
| - **Gallery Mode**: Browse and select from real Chinese memes with authentic descriptions |
| - **Custom Mode**: Upload your own images for processing |
| - **Models**: Using LLaVA-13B for cultural analysis, FLUX.1-schnell for image generation |
| --- |
| """) |
| |
| with gr.Tabs() as tabs: |
| |
| with gr.Tab("Example Gallery"): |
| gr.Markdown(""" |
| ### Select from 293 real memes |
| Click any image below to select it, then click "Transcreate Meme". |
| """) |
| |
| |
| gallery = gr.Gallery( |
| value=[ex['image'] for ex in GALLERY_EXAMPLES], |
| label="Chinese Meme Examples", |
| columns=6, |
| rows=3, |
| height=500, |
| object_fit="contain", |
| show_label=False, |
| interactive=True |
| ) |
| |
| |
| gallery_input = gr.Image(visible=False, type="pil") |
| gallery_description = gr.Textbox(visible=False) |
| |
| gallery_btn = gr.Button("Transcreate Meme", variant="primary", size="lg") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| gallery_output = gr.Image(label="Final Meme", height=400) |
| with gr.Column(scale=1): |
| gallery_base = gr.Image(label="Generated Base Image (before caption)", height=400) |
| |
| gallery_status = gr.Textbox(label="Status", lines=2) |
| |
| with gr.Accordion("Analysis Details", open=False): |
| gallery_translation = gr.Textbox( |
| label="LLaVA Analysis", |
| lines=15, |
| max_lines=30 |
| ) |
| gallery_instructions = gr.Textbox( |
| label="Image Generation Instructions", |
| lines=5 |
| ) |
| gallery_caption = gr.Textbox( |
| label="Extracted Caption", |
| lines=2 |
| ) |
| |
| |
| gallery.select( |
| fn=select_gallery_example, |
| outputs=[gallery_input, gallery_description] |
| ) |
| |
| gallery_btn.click( |
| fn=generate_meme, |
| inputs=[gallery_input, gallery_description], |
| outputs=[ |
| gallery_output, |
| gallery_base, |
| gallery_status, |
| gallery_translation, |
| gallery_instructions, |
| gallery_caption, |
| generation_counter |
| ] |
| ) |
| |
| |
| with gr.Tab("Custom Upload"): |
| gr.Markdown(""" |
| ### Upload your own meme |
| Upload any image and provide a description in Chinese or English. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| custom_input = gr.Image( |
| label="Upload Your Meme", |
| type="pil", |
| height=300 |
| ) |
| with gr.Column(scale=1): |
| custom_description = gr.Textbox( |
| label="Description (Chinese or English)", |
| lines=5, |
| placeholder="Describe the meme's content and emotion..." |
| ) |
| |
| custom_btn = gr.Button("Transcreate Meme", variant="primary", size="lg") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| custom_output = gr.Image(label="Final Meme", height=400) |
| with gr.Column(scale=1): |
| custom_base = gr.Image(label="Generated Base Image (before caption)", height=400) |
| |
| custom_status = gr.Textbox(label="Status", lines=2) |
| |
| with gr.Accordion("Analysis Details", open=False): |
| custom_translation = gr.Textbox( |
| label="LLaVA Analysis", |
| lines=15, |
| max_lines=30 |
| ) |
| custom_instructions = gr.Textbox( |
| label="Image Generation Instructions", |
| lines=5 |
| ) |
| custom_caption = gr.Textbox( |
| label="Extracted Caption", |
| lines=2 |
| ) |
| |
| |
| custom_btn.click( |
| fn=generate_meme, |
| inputs=[custom_input, custom_description], |
| outputs=[ |
| custom_output, |
| custom_base, |
| custom_status, |
| custom_translation, |
| custom_instructions, |
| custom_caption, |
| generation_counter |
| ] |
| ) |
| |
| |
| demo.load( |
| fn=refresh_counter, |
| outputs=generation_counter, |
| queue=False |
| ) |
| |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = create_demo() |
| demo.queue(max_size=MAX_QUEUE_SIZE) |
| demo.launch(share=True, server_name="0.0.0.0", max_threads=MAX_CONCURRENT_THREADS) |