YZhao09's picture
Update app.py
62ec117 verified
#!/usr/bin/env python3
"""
Gradio Demo: Meme Generator Pipeline - 2 Mode Version
Mode 1: Example Gallery - Select from 293 pre-loaded Chinese memes
Mode 2: Custom Upload - Upload your own image
"""
import gradio as gr
import os
import re
import csv
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import replicate
from pathlib import Path
import threading
from huggingface_hub import HfApi, hf_hub_download
from config import (
SYSTEM_PROMPT,
CHARACTER_BATCHES,
LLAVA_MODEL,
LLAVA_MAX_TOKENS,
LLAVA_TEMPERATURE,
LLAVA_TOP_P,
FLUX_MODEL,
FLUX_INFERENCE_STEPS,
FLUX_GUIDANCE_SCALE,
FLUX_TEXT_FREE_INSTRUCTION,
MAX_QUEUE_SIZE,
MAX_CONCURRENT_THREADS,
INITIAL_GENERATION_COUNT,
MAX_GENERATIONS_PER_HOUR,
MAX_GENERATIONS_PER_IP_PER_HOUR,
COOLDOWN_SECONDS,
FONT_PATHS,
CAPTION_FONT_SIZE_RATIO
)
# ==================== Configuration ====================
REPLICATE_API_TOKEN = os.environ.get("REPLICATE_API_TOKEN", "")
IMAGE_DIR = Path("image")
CSV_FILE = Path("labeled_data_clean.csv")
# Dataset persistence configuration (FREE alternative to paid storage)
COUNTER_REPO = "YZhao09/meme-counter" # Create this dataset repo in your HF account
COUNTER_FILENAME = "counter.txt"
HF_TOKEN = os.environ.get("HF_TOKEN", "") # Add your HF write token to Space secrets
# ==================== Load Gallery Data ====================
def load_gallery_examples():
"""Load all example memes from CSV"""
examples = []
try:
with open(CSV_FILE, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
img_path = IMAGE_DIR / row['filename']
if img_path.exists():
examples.append({
'filename': row['filename'],
'image': str(img_path),
'description': row['content'],
'emotion': row.get('emotion', ''),
'intensity': row.get('intensity', '')
})
except Exception as e:
print(f"Error loading gallery: {e}")
return examples
GALLERY_EXAMPLES = load_gallery_examples()
print(f"✅ Loaded {len(GALLERY_EXAMPLES)} gallery examples")
# ==================== Dataset Persistence Functions ====================
def load_generation_count_from_dataset():
"""Load counter from HF dataset repo (FREE persistent storage!)"""
try:
if not HF_TOKEN:
print("⚠️ No HF_TOKEN found - counter will reset on Space restart")
print(" To enable persistence: Add HF_TOKEN to Space secrets")
return INITIAL_GENERATION_COUNT
# Try to download counter file from dataset repo
filepath = hf_hub_download(
repo_id=COUNTER_REPO,
filename=COUNTER_FILENAME,
repo_type="dataset",
token=HF_TOKEN,
force_download=True # Always get fresh copy, bypass cache
)
with open(filepath, 'r') as f:
count = int(f.read().strip())
print(f"✅ Loaded generation count from dataset repo: {count}")
return count
except Exception as e:
print(f"⚠️ Could not load counter from dataset (may not exist yet): {e}")
print(f" Using initial count: {INITIAL_GENERATION_COUNT}")
return INITIAL_GENERATION_COUNT
def save_generation_count_to_dataset(count):
"""Save counter to HF dataset repo"""
try:
if not HF_TOKEN:
return # Silently skip if no token
# Create temp file
temp_file = "/tmp/counter.txt"
with open(temp_file, 'w') as f:
f.write(str(count))
# Upload to dataset repo
api = HfApi()
api.upload_file(
path_or_fileobj=temp_file,
path_in_repo=COUNTER_FILENAME,
repo_id=COUNTER_REPO,
repo_type="dataset",
token=HF_TOKEN,
commit_message=f"Update counter to {count}"
)
except Exception as e:
print(f"⚠️ Could not save counter to dataset: {e}")
# Global counter for character batch rotation
_generation_counter = 0
# Global counter for total generations with dataset persistence
_total_generations = load_generation_count_from_dataset()
_generation_lock = threading.Lock()
# Rate limiting tracking
from collections import deque, defaultdict
import time
_hourly_generations = deque() # Track timestamps of generations in last hour
_last_generation_time = {} # Track last generation time per session
_ip_generations = defaultdict(deque) # Track generations per IP address
_rate_limit_lock = threading.Lock()
def check_rate_limits(request: gr.Request = None) -> tuple[bool, str]:
"""
Check if generation is allowed based on rate limits
Returns: (allowed: bool, message: str)
"""
global _hourly_generations, _last_generation_time, _ip_generations
with _rate_limit_lock:
current_time = time.time()
# Get IP address from request
ip_address = "unknown"
if request is not None:
ip_address = request.client.host if hasattr(request, 'client') else "unknown"
# Check global hourly limit
# Remove generations older than 1 hour
while _hourly_generations and current_time - _hourly_generations[0] > 3600:
_hourly_generations.popleft()
if len(_hourly_generations) >= MAX_GENERATIONS_PER_HOUR:
return False, f"⚠️ Global hourly limit reached ({MAX_GENERATIONS_PER_HOUR} generations/hour). Please try again later."
# Check per-IP hourly limit
if ip_address != "unknown":
# Remove old generations for this IP
while _ip_generations[ip_address] and current_time - _ip_generations[ip_address][0] > 3600:
_ip_generations[ip_address].popleft()
if len(_ip_generations[ip_address]) >= MAX_GENERATIONS_PER_IP_PER_HOUR:
return False, f"⚠️ You've reached your limit of {MAX_GENERATIONS_PER_IP_PER_HOUR} generations per hour. Please try again later."
# Check cooldown period (use IP as session ID)
session_id = ip_address
if session_id in _last_generation_time:
time_since_last = current_time - _last_generation_time[session_id]
if time_since_last < COOLDOWN_SECONDS:
wait_time = int(COOLDOWN_SECONDS - time_since_last)
return False, f"⏳ Please wait {wait_time} seconds before next generation."
# All checks passed - record this generation
_hourly_generations.append(current_time)
_last_generation_time[session_id] = current_time
if ip_address != "unknown":
_ip_generations[ip_address].append(current_time)
return True, "OK"
def increment_generation_count():
"""Thread-safe increment of generation counter with dataset persistence"""
global _total_generations
with _generation_lock:
_total_generations += 1
# Save to dataset repo for FREE persistence across restarts
save_generation_count_to_dataset(_total_generations)
return _total_generations
def get_generation_count():
"""Get current generation count as string"""
with _generation_lock:
return str(_total_generations)
def refresh_counter():
"""Refresh counter from dataset - returns string value"""
global _total_generations
with _generation_lock:
_total_generations = load_generation_count_from_dataset()
return str(_total_generations)
# ==================== Helper Functions ====================
def get_character_batch(index: int = None) -> str:
"""Get character suggestions based on rotation index"""
global _generation_counter
if index is None:
index = _generation_counter
_generation_counter += 1
batch_index = index % len(CHARACTER_BATCHES)
return CHARACTER_BATCHES[batch_index]
def call_llava_replicate(image: Image.Image, description: str, sample_index: int = None) -> dict:
"""Call LLaVA via Replicate API with character batch rotation"""
if not REPLICATE_API_TOKEN:
return {"error": "Please set REPLICATE_API_TOKEN environment variable"}
try:
# Get character suggestions for this generation (rotating batches)
character_suggestions = get_character_batch(sample_index)
# Convert image to base64 data URI
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
image_uri = f"data:image/png;base64,{img_str}"
# Replace [CHARACTER_SUGGESTIONS] in system prompt with actual batch
customized_prompt = SYSTEM_PROMPT.replace("[CHARACTER_SUGGESTIONS]", character_suggestions)
# Prepare prompt
user_prompt = f"Description: {description}\nOriginal Emotion: unknown\nOriginal Intensity: unknown"
full_prompt = customized_prompt + "\n\n" + user_prompt
# Call LLaVA on Replicate
output = replicate.run(
LLAVA_MODEL,
input={
"image": image_uri,
"prompt": full_prompt,
"max_tokens": LLAVA_MAX_TOKENS,
"temperature": LLAVA_TEMPERATURE,
"top_p": LLAVA_TOP_P,
}
)
# Output is a generator, join to string
translation = "".join(output)
# Extract sections
image_gen_instructions = extract_image_gen_instructions(translation)
us_meme_caption = extract_us_meme_caption(translation)
return {
"translation": translation,
"image_generation_instructions": image_gen_instructions,
"us_meme_caption": us_meme_caption
}
except Exception as e:
return {"error": f"Exception: {str(e)}"}
def extract_image_gen_instructions(translation: str) -> str:
"""Extract image generation instructions from LLaVA output"""
lines = [l.rstrip() for l in translation.splitlines()]
instruction_lines = []
capture = False
for line in lines:
if re.match(r'^\s*(?:\d+\.\s*)?IMAGE GENERATION INSTRUCTIONS:', line, re.I):
capture = True
continue
if capture:
if re.match(r'^\s*(?:\d+\.\s*)?(US MEME CAPTION:|US MEME CAPTIONS:|US MEME:|NOTE:|CULTURAL CONTEXT:)', line, re.I):
break
if line.strip():
instruction_lines.append(line.strip())
return "\n".join(instruction_lines).strip() if instruction_lines else ""
def extract_us_meme_caption(translation: str) -> str:
"""Extract US meme caption from LLaVA output"""
lines = [l.rstrip() for l in translation.splitlines()]
caption_lines = []
capture = False
for line in lines:
if re.match(r'^\s*(?:\d+\.\s*)?US MEME CAPTIONS?:', line, re.I):
capture = True
continue
if capture:
if re.match(r'^\s*\d+\.\s+', line):
break
if re.match(r'^\s*(?:CONSTRAINTS?:|NOTE:|CULTURAL CONTEXT:)', line, re.I):
break
if line.strip():
caption_lines.append(line.strip())
break
return " ".join(caption_lines).strip() if caption_lines else ""
def call_flux_replicate(prompt: str) -> Image.Image:
"""Call FLUX via Replicate API"""
if not REPLICATE_API_TOKEN:
return None
try:
# Add critical instruction
if FLUX_TEXT_FREE_INSTRUCTION not in prompt:
prompt = prompt.rstrip(".") + ". " + FLUX_TEXT_FREE_INSTRUCTION
# Call FLUX on Replicate
output = replicate.run(
FLUX_MODEL,
input={
"prompt": prompt,
"num_inference_steps": FLUX_INFERENCE_STEPS,
"guidance_scale": FLUX_GUIDANCE_SCALE,
}
)
# Output is a URL or file-like
if isinstance(output, list) and len(output) > 0:
output = output[0]
# Download image
if hasattr(output, 'read'):
image = Image.open(output)
else:
import requests
response = requests.get(str(output))
image = Image.open(io.BytesIO(response.content))
return image
except Exception as e:
print(f"FLUX error: {e}")
return None
def add_caption_to_image(image: Image.Image, caption: str) -> Image.Image:
"""Add caption to image using PIL"""
try:
# Create a copy
img = image.copy()
draw = ImageDraw.Draw(img)
# Image dimensions
width, height = img.size
# Try to load a nice font with larger size for readability
font_size = int(height * CAPTION_FONT_SIZE_RATIO)
# Try multiple font paths (macOS and Linux)
font = None
for font_path in FONT_PATHS:
try:
font = ImageFont.truetype(font_path, font_size)
break
except:
continue
# Fallback to default
if font is None:
font = ImageFont.load_default()
# Wrap text
words = caption.split()
lines = []
current_line = []
for word in words:
test_line = ' '.join(current_line + [word])
bbox = draw.textbbox((0, 0), test_line, font=font)
if bbox[2] - bbox[0] < width * 0.9:
current_line.append(word)
else:
if current_line:
lines.append(' '.join(current_line))
current_line = [word]
if current_line:
lines.append(' '.join(current_line))
# Calculate total height of all lines
total_text_height = 0
for line in lines:
bbox = draw.textbbox((0, 0), line, font=font)
total_text_height += (bbox[3] - bbox[1]) + 5
# Draw text at bottom (start from bottom and work up)
y_text = int(height * 0.95) - total_text_height
for line in lines:
bbox = draw.textbbox((0, 0), line, font=font)
text_width = bbox[2] - bbox[0]
x_text = (width - text_width) // 2
# Draw outline
for adj_x in range(-2, 3):
for adj_y in range(-2, 3):
draw.text((x_text + adj_x, y_text + adj_y), line, font=font, fill='black')
# Draw text
draw.text((x_text, y_text), line, font=font, fill='white')
y_text += bbox[3] - bbox[1] + 5
return img
except Exception as e:
print(f"Caption error: {e}")
return image
# ==================== Main Pipeline ====================
def generate_meme(input_image, description, sample_index: int = None, request: gr.Request = None):
"""Main pipeline: Chinese meme to US meme with character batch rotation"""
# Check rate limits first (with IP tracking)
allowed, rate_limit_msg = check_rate_limits(request)
if not allowed:
current_count = get_generation_count()
return None, None, rate_limit_msg, None, None, None, current_count
if input_image is None:
current_count = get_generation_count()
return None, None, "Please provide an input image", None, None, None, current_count
if not description or not description.strip():
current_count = get_generation_count()
return None, None, "Please provide a description", None, None, None, current_count
if not REPLICATE_API_TOKEN:
current_count = get_generation_count()
return None, None, "Error: REPLICATE_API_TOKEN not set. Please configure it in environment variables or Hugging Face Spaces secrets.", None, None, None, current_count
try:
# Increment counter at start
count = increment_generation_count()
# Get character batch for this generation
character_batch = get_character_batch(sample_index)
batch_index = (sample_index if sample_index is not None else _generation_counter - 1) % len(CHARACTER_BATCHES)
# Open image
if isinstance(input_image, str):
image = Image.open(input_image).convert('RGB')
else:
image = input_image.convert('RGB')
yield None, None, "Step 1/3: Analyzing with LLaVA...", None, None, None, count
# Step 1: Call LLaVA with character rotation
llava_result = call_llava_replicate(image, description, sample_index)
if "error" in llava_result:
yield None, None, f"Error in Step 1: {llava_result['error']}", None, None, None, count
return
translation = llava_result.get("translation", "")
image_instructions = llava_result.get("image_generation_instructions", "")
us_caption = llava_result.get("us_meme_caption", "")
if not image_instructions:
yield None, None, "Error: Could not extract image instructions", translation, None, None, count
return
yield None, None, "Step 2/3: Generating new meme image with FLUX...", translation, None, None, count
# Step 2: Call FLUX
generated_image = call_flux_replicate(image_instructions)
if generated_image is None:
yield None, None, "Error in Step 2: Failed to generate image", translation, None, None, count
return
yield None, generated_image, "Step 3/3: Adding caption...", translation, image_instructions, us_caption, count
# Step 3: Add caption
final_meme = add_caption_to_image(generated_image, us_caption)
yield final_meme, generated_image, "Complete! Your US meme is ready!", translation, image_instructions, us_caption, count
except Exception as e:
yield None, None, f"Error: {str(e)}", None, None, None, get_generation_count()
def select_gallery_example(evt: gr.SelectData):
"""Handle gallery selection by index"""
idx = evt.index
if 0 <= idx < len(GALLERY_EXAMPLES):
example = GALLERY_EXAMPLES[idx]
return example['image'], example['description']
return None, ""
# ==================== Gradio Interface ====================
def create_demo():
"""Create the Gradio demo interface"""
# Custom CSS for better styling
custom_css = """
.counter-display label {
color: #000000 !important;
font-weight: bold !important;
}
.counter-display input {
color: #000000 !important;
font-size: 1.2em !important;
font-weight: bold !important;
text-align: center !important;
}
/* Show loading state when counter is empty */
.counter-display input:placeholder-shown {
font-style: italic;
opacity: 0.6;
}
"""
with gr.Blocks(title="Chinese to US Meme Generator", css=custom_css) as demo:
# Add generation counter in top right corner
with gr.Row():
gr.Markdown("""
# MemeXGen
Cross-Cultural Meme Transcreation with Vision-Language Models
📄 **Read our paper: [Beyond Translation: Cross-Cultural Meme Transcreation with Vision-Language Models](https://arxiv.org/pdf/2602.02510)**
""")
with gr.Column(scale=0, min_width=150):
generation_counter = gr.Textbox(
label="🔥 Generations",
value="Loading...",
interactive=False,
container=True,
elem_classes="counter-display",
show_label=True
)
gr.Markdown("""
**Pipeline**: LLaVA-13B (analyze) → FLUX.1-schnell (generate) → Caption overlay
---
### Tips:
- **Gallery Mode**: Browse and select from real Chinese memes with authentic descriptions
- **Custom Mode**: Upload your own images for processing
- **Models**: Using LLaVA-13B for cultural analysis, FLUX.1-schnell for image generation
---
""")
with gr.Tabs() as tabs:
# ==================== Mode 1: Gallery ====================
with gr.Tab("Example Gallery"):
gr.Markdown("""
### Select from 293 real memes
Click any image below to select it, then click "Transcreate Meme".
""")
# Gallery view for visual selection
gallery = gr.Gallery(
value=[ex['image'] for ex in GALLERY_EXAMPLES],
label="Chinese Meme Examples",
columns=6,
rows=3,
height=500,
object_fit="contain",
show_label=False,
interactive=True
)
# Hidden components for selected image and description
gallery_input = gr.Image(visible=False, type="pil")
gallery_description = gr.Textbox(visible=False)
gallery_btn = gr.Button("Transcreate Meme", variant="primary", size="lg")
with gr.Row():
with gr.Column(scale=1):
gallery_output = gr.Image(label="Final Meme", height=400)
with gr.Column(scale=1):
gallery_base = gr.Image(label="Generated Base Image (before caption)", height=400)
gallery_status = gr.Textbox(label="Status", lines=2)
with gr.Accordion("Analysis Details", open=False):
gallery_translation = gr.Textbox(
label="LLaVA Analysis",
lines=15,
max_lines=30
)
gallery_instructions = gr.Textbox(
label="Image Generation Instructions",
lines=5
)
gallery_caption = gr.Textbox(
label="Extracted Caption",
lines=2
)
# Wire up gallery selection
gallery.select(
fn=select_gallery_example,
outputs=[gallery_input, gallery_description]
)
gallery_btn.click(
fn=generate_meme,
inputs=[gallery_input, gallery_description],
outputs=[
gallery_output,
gallery_base,
gallery_status,
gallery_translation,
gallery_instructions,
gallery_caption,
generation_counter
]
)
# ==================== Mode 2: Custom Upload ====================
with gr.Tab("Custom Upload"):
gr.Markdown("""
### Upload your own meme
Upload any image and provide a description in Chinese or English.
""")
with gr.Row():
with gr.Column(scale=1):
custom_input = gr.Image(
label="Upload Your Meme",
type="pil",
height=300
)
with gr.Column(scale=1):
custom_description = gr.Textbox(
label="Description (Chinese or English)",
lines=5,
placeholder="Describe the meme's content and emotion..."
)
custom_btn = gr.Button("Transcreate Meme", variant="primary", size="lg")
with gr.Row():
with gr.Column(scale=1):
custom_output = gr.Image(label="Final Meme", height=400)
with gr.Column(scale=1):
custom_base = gr.Image(label="Generated Base Image (before caption)", height=400)
custom_status = gr.Textbox(label="Status", lines=2)
with gr.Accordion("Analysis Details", open=False):
custom_translation = gr.Textbox(
label="LLaVA Analysis",
lines=15,
max_lines=30
)
custom_instructions = gr.Textbox(
label="Image Generation Instructions",
lines=5
)
custom_caption = gr.Textbox(
label="Extracted Caption",
lines=2
)
# Wire up custom upload
custom_btn.click(
fn=generate_meme,
inputs=[custom_input, custom_description],
outputs=[
custom_output,
custom_base,
custom_status,
custom_translation,
custom_instructions,
custom_caption,
generation_counter
]
)
# Auto-refresh counter on page load
demo.load(
fn=refresh_counter,
outputs=generation_counter,
queue=False
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=MAX_QUEUE_SIZE)
demo.launch(share=True, server_name="0.0.0.0", max_threads=MAX_CONCURRENT_THREADS)