Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| from gradio_pdf import PDF | |
| import numpy as np | |
| import random | |
| import torch | |
| import spaces | |
| import math | |
| import os | |
| import yaml | |
| import io | |
| import tempfile | |
| import shutil | |
| import uuid | |
| import time | |
| import json | |
| from typing import List, Tuple, Dict, Optional | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| from PIL import Image | |
| from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler | |
| from huggingface_hub import InferenceClient | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.pdfgen import canvas | |
| from reportlab.pdfbase import pdfmetrics | |
| from reportlab.lib.utils import ImageReader | |
| from PyPDF2 import PdfReader, PdfWriter | |
| # --- Style Presets Loading --- | |
| def load_style_presets(): | |
| """Load style presets from YAML file.""" | |
| try: | |
| with open('style_presets.yaml', 'r') as f: | |
| data = yaml.safe_load(f) | |
| # Filter only enabled presets | |
| presets = {k: v for k, v in data['presets'].items() if v.get('enabled', True)} | |
| return presets | |
| except Exception as e: | |
| print(f"Error loading style presets: {e}") | |
| return {"no_style": {"id": "no_style", "label": "No style (custom)", "prompt_prefix": "", "prompt_suffix": "", "negative_prompt": ""}} | |
| # Load presets at startup | |
| STYLE_PRESETS = load_style_presets() | |
| # --- Page Layouts Loading --- | |
| def load_page_layouts(): | |
| """Load page layouts from YAML file.""" | |
| try: | |
| with open('page_layouts.yaml', 'r') as f: | |
| data = yaml.safe_load(f) | |
| return data['layouts'] | |
| except Exception as e: | |
| print(f"Error loading page layouts: {e}") | |
| # Fallback to basic layouts | |
| return { | |
| 1: [{"id": "full_page", "label": "Full Page", "positions": [[0.05, 0.05, 0.9, 0.9]]}], | |
| 2: [{"id": "horizontal_split", "label": "Horizontal Split", "positions": [[0.05, 0.05, 0.425, 0.9], [0.525, 0.05, 0.425, 0.9]]}], | |
| 3: [{"id": "grid", "label": "Grid", "positions": [[0.05, 0.05, 0.283, 0.5], [0.358, 0.05, 0.283, 0.5], [0.666, 0.05, 0.283, 0.5]]}], | |
| 4: [{"id": "grid_2x2", "label": "2x2 Grid", "positions": [[0.05, 0.05, 0.425, 0.425], [0.525, 0.05, 0.425, 0.425], [0.05, 0.525, 0.425, 0.425], [0.525, 0.525, 0.425, 0.425]]}] | |
| } | |
| # Load layouts at startup | |
| PAGE_LAYOUTS = load_page_layouts() | |
| def get_layout_choices(num_images: int) -> List[Tuple[str, str]]: | |
| """Get available layout choices for a given number of images.""" | |
| key = f"{num_images}_image" if num_images == 1 else f"{num_images}_images" | |
| if key in PAGE_LAYOUTS: | |
| return [(layout["label"], layout["id"]) for layout in PAGE_LAYOUTS[key]] | |
| # Return empty list if no layouts found (shouldn't happen with our config) | |
| return [("Default", "default")] | |
| def get_random_style_preset(): | |
| """Get a random style preset (excluding 'no_style' and 'random').""" | |
| eligible_keys = [k for k in STYLE_PRESETS.keys() if k not in ['no_style', 'random']] | |
| if eligible_keys: | |
| return random.choice(eligible_keys) | |
| return 'no_style' | |
| def apply_style_preset(prompt, style_preset_key, custom_style_text=""): | |
| """ | |
| Apply style preset to the prompt. | |
| Args: | |
| prompt: The user's base prompt | |
| style_preset_key: The key of the selected style preset | |
| custom_style_text: Custom style text when 'no_style' is selected | |
| Returns: | |
| tuple: (styled_prompt, negative_prompt) | |
| """ | |
| if style_preset_key == 'no_style': | |
| # Use custom style text if provided | |
| if custom_style_text and custom_style_text.strip(): | |
| styled_prompt = f"{custom_style_text}, {prompt}" | |
| else: | |
| styled_prompt = prompt | |
| return styled_prompt, "" | |
| if style_preset_key == 'random': | |
| # Select a random style | |
| style_preset_key = get_random_style_preset() | |
| if style_preset_key in STYLE_PRESETS: | |
| preset = STYLE_PRESETS[style_preset_key] | |
| prefix = preset.get('prompt_prefix', '') | |
| suffix = preset.get('prompt_suffix', '') | |
| negative = preset.get('negative_prompt', '') | |
| # Build the styled prompt | |
| parts = [] | |
| if prefix: | |
| parts.append(prefix) | |
| parts.append(prompt) | |
| if suffix: | |
| parts.append(suffix) | |
| styled_prompt = ', '.join(parts) | |
| return styled_prompt, negative | |
| # Fallback to original prompt if preset not found | |
| return prompt, "" | |
| # --- Story Generation using Hugging Face InferenceClient --- | |
| def generate_story_scenes(story_prompt, num_scenes, style_context=""): | |
| """ | |
| Generates a sequence of scene descriptions with captions and dialogues. | |
| Args: | |
| story_prompt: The user's story prompt | |
| num_scenes: Number of scenes to generate | |
| style_context: Optional style context to consider | |
| Returns: | |
| List of dicts with 'caption' and 'dialogue' keys | |
| """ | |
| # Ensure HF_TOKEN is set | |
| api_key = os.environ.get("HF_TOKEN") | |
| if not api_key: | |
| print("HF_TOKEN not set, using fallback scene generation") | |
| # Simple fallback - just split the prompt into scenes | |
| fallback_scenes = [] | |
| for i in range(num_scenes): | |
| fallback_scenes.append({ | |
| "caption": f"{story_prompt} (scene {i+1} of {num_scenes})", | |
| "dialogue": "" | |
| }) | |
| return fallback_scenes | |
| # Initialize the client | |
| client = InferenceClient( | |
| provider="cerebras", | |
| api_key=api_key, | |
| ) | |
| # Create system prompt for story generation | |
| system_prompt = f"""You are a comic book story writer. Generate exactly {num_scenes} scenes for a comic page based on the user's story prompt. | |
| IMPORTANT INSTRUCTIONS: | |
| 1. Output ONLY a YAML list with exactly {num_scenes} items | |
| 2. Each item must have exactly two fields: | |
| - caption: A detailed visual description of the scene (describe characters, clothing, location, action, expressions) | |
| - dialogue: Natural language description of what the character says/exclaims/shouts (can be empty string if no dialogue) | |
| 3. For captions: Be very descriptive. Repeat character descriptions in each scene (appearance, clothes, etc.) | |
| 4. For dialogue: Write it as a natural language action that will be added to the scene description | |
| - Format: "The [character] says: [what they say]" or "The [character] exclaims: [what they exclaim]" | |
| - DO NOT include character names in the dialogue text itself | |
| - Use verbs like: says, exclaims, shouts, whispers, asks, replies, thinks | |
| 5. Keep continuity between scenes to tell a coherent story | |
| 6. Make each scene visually distinct but connected to the narrative | |
| Example output format: | |
| - caption: "A young woman with long red hair wearing a blue detective coat stands in a dark alley, holding a magnifying glass up to examine mysterious glowing footprints on the wet pavement" | |
| dialogue: "The detective exclaims: These tracks aren't human!" | |
| - caption: "The same red-haired woman in the blue coat backs away in shock as a massive shark fin emerges from a puddle in the alley, water splashing everywhere" | |
| dialogue: "The detective shouts: OH NO, SHARKS IN THE CITY!" | |
| - caption: "The red-haired detective in blue coat runs down the alley, looking back over her shoulder at the shark fin pursuing her through the puddles" | |
| dialogue: "The detective thinks to herself: I need to warn everyone!" | |
| Generate exactly {num_scenes} scenes. Output ONLY the YAML list, no other text.""" | |
| # Format the messages | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": f"Create {num_scenes} comic scenes for this story: {story_prompt}"} | |
| ] | |
| try: | |
| # Call the API | |
| completion = client.chat.completions.create( | |
| model="Qwen/Qwen3-235B-A22B-Instruct-2507", | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=2000, | |
| ) | |
| response = completion.choices[0].message.content | |
| # Parse the YAML response | |
| scenes = parse_yaml_scenes(response, num_scenes) | |
| return scenes | |
| except Exception as e: | |
| print(f"Error during story generation: {e}") | |
| # Fallback to simple scene splitting | |
| fallback_scenes = [] | |
| for i in range(num_scenes): | |
| fallback_scenes.append({ | |
| "caption": f"{story_prompt} (part {i+1} of {num_scenes})", | |
| "dialogue": "" | |
| }) | |
| return fallback_scenes | |
| def parse_yaml_scenes(yaml_text, expected_count): | |
| """ | |
| Parse YAML text to extract scene captions and dialogues. | |
| """ | |
| try: | |
| # Clean up the text - remove markdown code blocks if present | |
| yaml_text = yaml_text.strip() | |
| if yaml_text.startswith("```yaml"): | |
| yaml_text = yaml_text[7:] | |
| if yaml_text.startswith("```"): | |
| yaml_text = yaml_text[3:] | |
| if yaml_text.endswith("```"): | |
| yaml_text = yaml_text[:-3] | |
| # Parse YAML | |
| scenes = yaml.safe_load(yaml_text) | |
| if not isinstance(scenes, list): | |
| raise ValueError("Expected a list of scenes") | |
| # Validate and clean scenes | |
| valid_scenes = [] | |
| for scene in scenes: | |
| if isinstance(scene, dict) and 'caption' in scene: | |
| valid_scenes.append({ | |
| 'caption': str(scene.get('caption', '')), | |
| 'dialogue': str(scene.get('dialogue', '')) | |
| }) | |
| # Ensure we have the expected number of scenes | |
| while len(valid_scenes) < expected_count: | |
| valid_scenes.append({ | |
| 'caption': 'continuation of the story', | |
| 'dialogue': '' | |
| }) | |
| return valid_scenes[:expected_count] | |
| except Exception as e: | |
| print(f"Error parsing YAML scenes: {e}") | |
| # Return fallback scenes | |
| return [{'caption': 'scene description', 'dialogue': ''} for _ in range(expected_count)] | |
| def get_caption_language(prompt): | |
| """Detects if the prompt contains Chinese characters.""" | |
| ranges = [ | |
| ('\u4e00', '\u9fff'), # CJK Unified Ideographs | |
| ] | |
| for char in prompt: | |
| if any(start <= char <= end for start, end in ranges): | |
| return 'zh' | |
| return 'en' | |
| # --- Model Loading --- | |
| # Use the new lightning-fast model setup | |
| ckpt_id = "Qwen/Qwen-Image" | |
| # Scheduler configuration from the Qwen-Image-Lightning repository | |
| scheduler_config = { | |
| "base_image_seq_len": 256, | |
| "base_shift": math.log(3), | |
| "invert_sigmas": False, | |
| "max_image_seq_len": 8192, | |
| "max_shift": math.log(3), | |
| "num_train_timesteps": 1000, | |
| "shift": 1.0, | |
| "shift_terminal": None, | |
| "stochastic_sampling": False, | |
| "time_shift_type": "exponential", | |
| "use_beta_sigmas": False, | |
| "use_dynamic_shifting": True, | |
| "use_exponential_sigmas": False, | |
| "use_karras_sigmas": False, | |
| } | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config) | |
| pipe = DiffusionPipeline.from_pretrained( | |
| ckpt_id, scheduler=scheduler, torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| # Load LoRA weights for acceleration | |
| pipe.load_lora_weights( | |
| "lightx2v/Qwen-Image-Lightning", weight_name="Qwen-Image-Lightning-8steps-V1.1.safetensors" | |
| ) | |
| pipe.fuse_lora() | |
| #pipe.unload_lora_weights() | |
| #pipe.load_lora_weights("flymy-ai/qwen-image-realism-lora") | |
| #pipe.fuse_lora() | |
| #pipe.unload_lora_weights() | |
| # --- UI Constants and Helpers --- | |
| MAX_SEED = np.iinfo(np.int32).max | |
| def get_image_size_for_position(position_data, image_index, num_images): | |
| """Determines optimal image size based on its position in the layout. | |
| Args: | |
| position_data: Layout position data [x, y, width, height] in relative units | |
| image_index: Index of the current image (0-based) | |
| num_images: Total number of images in the layout | |
| Returns: | |
| tuple: (width, height) optimized for the position's aspect ratio, max 1024 in any dimension | |
| """ | |
| if not position_data: | |
| return 1024, 1024 # Default square | |
| x_rel, y_rel, w_rel, h_rel = position_data | |
| aspect_ratio = w_rel / h_rel if h_rel > 0 else 1.0 | |
| # Max dimension is 1024 | |
| max_dim = 1024 | |
| # Calculate dimensions maintaining aspect ratio with max of 1024 | |
| if aspect_ratio >= 1: # Wider than tall | |
| width = max_dim | |
| height = int(max_dim / aspect_ratio) | |
| # Ensure height is at least 256 for quality | |
| if height < 256: | |
| height = 256 | |
| width = int(256 * aspect_ratio) | |
| else: # Taller than wide | |
| height = max_dim | |
| width = int(max_dim * aspect_ratio) | |
| # Ensure width is at least 256 for quality | |
| if width < 256: | |
| width = 256 | |
| height = int(256 / aspect_ratio) | |
| # Round to nearest 64 for better compatibility | |
| width = (width // 64) * 64 | |
| height = (height // 64) * 64 | |
| # Ensure we don't exceed max_dim after rounding | |
| if width > max_dim: | |
| width = max_dim | |
| if height > max_dim: | |
| height = max_dim | |
| # Minimum size check | |
| width = max(width, 256) | |
| height = max(height, 256) | |
| return width, height | |
| def get_layout_position_for_image(layout_id, num_images, image_index): | |
| """Get the position data for a specific image in a layout. | |
| Args: | |
| layout_id: ID of the selected layout | |
| num_images: Total number of images | |
| image_index: Index of the current image (0-based) | |
| Returns: | |
| Position data [x, y, width, height] or None | |
| """ | |
| key = f"{num_images}_image" if num_images == 1 else f"{num_images}_images" | |
| layouts = PAGE_LAYOUTS.get(key, []) | |
| layout = next((l for l in layouts if l["id"] == layout_id), None) | |
| if layout and "positions" in layout: | |
| positions = layout["positions"] | |
| if image_index < len(positions): | |
| return positions[image_index] | |
| # Fallback positions for each number of images | |
| fallback_positions = { | |
| 1: [[0.05, 0.05, 0.9, 0.9]], | |
| 2: [[0.05, 0.05, 0.425, 0.9], [0.525, 0.05, 0.425, 0.9]], | |
| 3: [[0.05, 0.25, 0.283, 0.5], [0.358, 0.25, 0.283, 0.5], [0.666, 0.25, 0.283, 0.5]], | |
| 4: [[0.05, 0.05, 0.425, 0.425], [0.525, 0.05, 0.425, 0.425], | |
| [0.05, 0.525, 0.425, 0.425], [0.525, 0.525, 0.425, 0.425]], | |
| 5: [[0.05, 0.05, 0.9, 0.3], [0.05, 0.4, 0.283, 0.55], [0.358, 0.4, 0.283, 0.55], | |
| [0.666, 0.4, 0.283, 0.275], [0.666, 0.7, 0.283, 0.275]], | |
| 6: [[0.05, 0.05, 0.425, 0.283], [0.525, 0.05, 0.425, 0.283], | |
| [0.05, 0.358, 0.425, 0.283], [0.525, 0.358, 0.425, 0.283], | |
| [0.05, 0.666, 0.425, 0.283], [0.525, 0.666, 0.425, 0.283]] | |
| } | |
| positions = fallback_positions.get(num_images, fallback_positions[1]) | |
| if image_index < len(positions): | |
| return positions[image_index] | |
| return [0.05, 0.05, 0.9, 0.9] # Ultimate default | |
| # --- Session Management Functions --- | |
| class SessionManager: | |
| """Manages user session data and temporary file storage.""" | |
| def __init__(self, session_id: str = None): | |
| self.session_id = session_id or str(uuid.uuid4()) | |
| self.base_dir = Path(tempfile.gettempdir()) / "gradio_comic_sessions" | |
| self.session_dir = self.base_dir / self.session_id | |
| self.session_dir.mkdir(parents=True, exist_ok=True) | |
| self.metadata_file = self.session_dir / "metadata.json" | |
| self.pdf_path = self.session_dir / "comic.pdf" | |
| self.load_or_create_metadata() | |
| def load_or_create_metadata(self): | |
| """Load existing metadata or create new.""" | |
| if self.metadata_file.exists(): | |
| with open(self.metadata_file, 'r') as f: | |
| self.metadata = json.load(f) | |
| else: | |
| self.metadata = { | |
| "created_at": datetime.now().isoformat(), | |
| "pages": [], | |
| "total_pages": 0 | |
| } | |
| self.save_metadata() | |
| def save_metadata(self): | |
| """Save metadata to file.""" | |
| with open(self.metadata_file, 'w') as f: | |
| json.dump(self.metadata, f, indent=2) | |
| def add_page(self, images: List[Image.Image], layout_id: str, seeds: List[int]): | |
| """Add a new page to the session.""" | |
| page_num = self.metadata["total_pages"] + 1 | |
| page_dir = self.session_dir / f"page_{page_num}" | |
| page_dir.mkdir(exist_ok=True) | |
| # Save images | |
| image_paths = [] | |
| for i, img in enumerate(images): | |
| img_path = page_dir / f"image_{i+1}.jpg" | |
| img.save(img_path, 'JPEG', quality=95) | |
| image_paths.append(str(img_path)) | |
| # Update metadata | |
| self.metadata["pages"].append({ | |
| "page_num": page_num, | |
| "layout_id": layout_id, | |
| "num_images": len(images), | |
| "image_paths": image_paths, | |
| "seeds": seeds, | |
| "created_at": datetime.now().isoformat() | |
| }) | |
| self.metadata["total_pages"] = page_num | |
| self.save_metadata() | |
| return page_num | |
| def get_all_pages_images(self) -> List[Tuple[List[Image.Image], str, int]]: | |
| """Get all images from all pages.""" | |
| pages_data = [] | |
| for page in self.metadata["pages"]: | |
| images = [] | |
| for img_path in page["image_paths"]: | |
| if Path(img_path).exists(): | |
| images.append(Image.open(img_path)) | |
| if images: | |
| pages_data.append((images, page["layout_id"], page["num_images"])) | |
| return pages_data | |
| def cleanup_old_sessions(self, max_age_hours: int = 24): | |
| """Clean up sessions older than max_age_hours.""" | |
| if not self.base_dir.exists(): | |
| return | |
| cutoff_time = datetime.now() - timedelta(hours=max_age_hours) | |
| for session_dir in self.base_dir.iterdir(): | |
| if session_dir.is_dir(): | |
| metadata_file = session_dir / "metadata.json" | |
| if metadata_file.exists(): | |
| try: | |
| with open(metadata_file, 'r') as f: | |
| metadata = json.load(f) | |
| created_at = datetime.fromisoformat(metadata["created_at"]) | |
| if created_at < cutoff_time: | |
| shutil.rmtree(session_dir) | |
| print(f"Cleaned up old session: {session_dir.name}") | |
| except Exception as e: | |
| print(f"Error cleaning session {session_dir.name}: {e}") | |
| # --- PDF Generation Functions --- | |
| def create_single_page_pdf(images: List[Image.Image], layout_id: str, num_images: int) -> bytes: | |
| """ | |
| Create a single PDF page with images arranged according to the selected layout. | |
| Args: | |
| images: List of PIL images | |
| layout_id: ID of the selected layout | |
| num_images: Number of images to include | |
| Returns: | |
| PDF page as bytes | |
| """ | |
| # Create a bytes buffer for the PDF | |
| pdf_buffer = io.BytesIO() | |
| # Create canvas with A4 size | |
| pdf = canvas.Canvas(pdf_buffer, pagesize=A4) | |
| page_width, page_height = A4 | |
| # Get the layout configuration | |
| key = f"{num_images}_image" if num_images == 1 else f"{num_images}_images" | |
| layouts = PAGE_LAYOUTS.get(key, []) | |
| layout = next((l for l in layouts if l["id"] == layout_id), None) | |
| if not layout: | |
| # Fallback to default grid layout | |
| if num_images == 1: | |
| positions = [[0.05, 0.05, 0.9, 0.9]] | |
| elif num_images == 2: | |
| positions = [[0.05, 0.05, 0.425, 0.9], [0.525, 0.05, 0.425, 0.9]] | |
| elif num_images == 3: | |
| positions = [[0.05, 0.05, 0.283, 0.9], [0.358, 0.05, 0.283, 0.9], [0.666, 0.05, 0.283, 0.9]] | |
| elif num_images == 4: | |
| positions = [[0.05, 0.05, 0.425, 0.425], [0.525, 0.05, 0.425, 0.425], | |
| [0.05, 0.525, 0.425, 0.425], [0.525, 0.525, 0.425, 0.425]] | |
| elif num_images == 5: | |
| positions = [[0.05, 0.05, 0.9, 0.3], [0.05, 0.4, 0.283, 0.55], [0.358, 0.4, 0.283, 0.55], | |
| [0.666, 0.4, 0.283, 0.275], [0.666, 0.7, 0.283, 0.275]] | |
| elif num_images == 6: | |
| positions = [[0.05, 0.05, 0.425, 0.283], [0.525, 0.05, 0.425, 0.283], | |
| [0.05, 0.358, 0.425, 0.283], [0.525, 0.358, 0.425, 0.283], | |
| [0.05, 0.666, 0.425, 0.283], [0.525, 0.666, 0.425, 0.283]] | |
| else: | |
| # For more than 6, create a simple grid | |
| positions = [[0.05, 0.05, 0.9, 0.9]] | |
| else: | |
| positions = layout["positions"] | |
| # Draw each image according to the layout | |
| for i, (image, pos) in enumerate(zip(images[:num_images], positions)): | |
| if i >= len(images): | |
| break | |
| x_rel, y_rel, w_rel, h_rel = pos | |
| # Pack images more tightly - significantly reduce empty space | |
| # Minimal padding between panels (0.5% of page dimensions) | |
| padding = 0.005 | |
| # Scale up positions and sizes to fill more of the page | |
| # This brings everything closer to the edges and each other | |
| scale_factor = 1.15 # Increase overall scale by 15% | |
| # Calculate centered scaling to maintain layout proportions | |
| center_x = 0.5 | |
| center_y = 0.5 | |
| # Scale positions relative to center | |
| x_rel = center_x + (x_rel - center_x) * scale_factor | |
| y_rel = center_y + (y_rel - center_y) * scale_factor | |
| # Scale sizes | |
| w_rel = w_rel * scale_factor | |
| h_rel = h_rel * scale_factor | |
| # Apply bounds checking to prevent overflow | |
| if x_rel < padding: | |
| x_rel = padding | |
| if y_rel < padding: | |
| y_rel = padding | |
| if x_rel + w_rel > 1 - padding: | |
| w_rel = 1 - padding - x_rel | |
| if y_rel + h_rel > 1 - padding: | |
| h_rel = 1 - padding - y_rel | |
| # Convert relative positions to absolute positions | |
| # Note: In ReportLab, y=0 is at the bottom | |
| x = x_rel * page_width | |
| y = (1 - y_rel - h_rel) * page_height # Flip Y coordinate | |
| width = w_rel * page_width | |
| height = h_rel * page_height | |
| # Calculate image aspect ratio and layout aspect ratio | |
| img_aspect = image.width / image.height | |
| layout_aspect = width / height | |
| # Preserve aspect ratio while fitting in the allocated space | |
| if img_aspect > layout_aspect: | |
| # Image is wider than the layout space | |
| new_height = width / img_aspect | |
| y_offset = (height - new_height) / 2 | |
| actual_width = width | |
| actual_height = new_height | |
| actual_x = x | |
| actual_y = y + y_offset | |
| else: | |
| # Image is taller than the layout space | |
| new_width = height * img_aspect | |
| x_offset = (width - new_width) / 2 | |
| actual_width = new_width | |
| actual_height = height | |
| actual_x = x + x_offset | |
| actual_y = y | |
| # Convert PIL image to format suitable for ReportLab | |
| img_buffer = io.BytesIO() | |
| # Save with good quality | |
| image.save(img_buffer, format='JPEG', quality=95) | |
| img_buffer.seek(0) | |
| # Draw the image on the PDF preserving aspect ratio | |
| pdf.drawImage(ImageReader(img_buffer), actual_x, actual_y, | |
| width=actual_width, height=actual_height, | |
| preserveAspectRatio=True, mask='auto') | |
| # Save the PDF | |
| pdf.save() | |
| # Get the PDF bytes | |
| pdf_buffer.seek(0) | |
| pdf_bytes = pdf_buffer.read() | |
| return pdf_bytes | |
| def create_multi_page_pdf(session_manager: SessionManager) -> str: | |
| """ | |
| Create a multi-page PDF from all pages in the session. | |
| Args: | |
| session_manager: SessionManager instance with page data | |
| Returns: | |
| Path to the created PDF file | |
| """ | |
| pages_data = session_manager.get_all_pages_images() | |
| if not pages_data: | |
| return None | |
| # Create PDF writer | |
| pdf_writer = PdfWriter() | |
| # Create each page | |
| for images, layout_id, num_images in pages_data: | |
| page_pdf_bytes = create_single_page_pdf(images, layout_id, num_images) | |
| # Read the single page PDF | |
| page_pdf_reader = PdfReader(io.BytesIO(page_pdf_bytes)) | |
| # Add the page to the writer | |
| for page in page_pdf_reader.pages: | |
| pdf_writer.add_page(page) | |
| # Write to file | |
| pdf_path = session_manager.pdf_path | |
| with open(pdf_path, 'wb') as f: | |
| pdf_writer.write(f) | |
| return str(pdf_path) | |
| # --- Main Inference Function (with session support) --- | |
| # Increased duration for up to 6 images | |
| def infer_page( | |
| prompt, | |
| guidance_scale=1.0, | |
| num_inference_steps=8, | |
| style_preset="no_style", | |
| custom_style_text="", | |
| num_images=1, | |
| layout="default", | |
| session_state=None, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """ | |
| Generates images for a new page and adds them to the PDF. | |
| Args: | |
| prompt (str): The text prompt to generate images from. | |
| guidance_scale (float): Corresponds to `true_cfg_scale`. | |
| num_inference_steps (int): The number of denoising steps. | |
| style_preset (str): The key of the style preset to apply. | |
| custom_style_text (str): Custom style text when 'no_style' is selected. | |
| num_images (int): Number of images to generate (1-6). | |
| layout (str): The layout ID for arranging images in the PDF. | |
| session_state: Current session state dictionary. | |
| progress (gr.Progress): A Gradio Progress object to track generation. | |
| Returns: | |
| tuple: Updated session state, PDF path, and updated button label. | |
| """ | |
| # Initialize or retrieve session | |
| if session_state is None or "session_id" not in session_state: | |
| session_state = {"session_id": str(uuid.uuid4()), "page_count": 0} | |
| session_manager = SessionManager(session_state["session_id"]) | |
| # Clean up old sessions periodically | |
| if random.random() < 0.1: # 10% chance to cleanup on each request | |
| session_manager.cleanup_old_sessions() | |
| # Check page limit | |
| if session_manager.metadata["total_pages"] >= 128: | |
| return session_state, None, None, f"Page limit reached" | |
| generated_images = [] | |
| used_seeds = [] | |
| # Generate story scenes | |
| progress(0, f"Generating story with {num_images} scenes...") | |
| scenes = generate_story_scenes(prompt, int(num_images), style_preset) | |
| # Generate the requested number of images | |
| for i in range(int(num_images)): | |
| progress((i + 0.5) / num_images, f"Generating image {i+1} of {num_images} for page {session_manager.metadata['total_pages'] + 1}") | |
| current_seed = random.randint(0, MAX_SEED) # Always randomize seed | |
| # Get optimal aspect ratio based on position in layout | |
| position_data = get_layout_position_for_image(layout, int(num_images), i) | |
| # Use scene caption and dialogue for this image | |
| scene_prompt = scenes[i]['caption'] | |
| scene_dialogue = scenes[i]['dialogue'] | |
| # Generate single image with automatic aspect ratio | |
| image, used_seed = infer_single_auto( | |
| prompt=scene_prompt, | |
| seed=current_seed, | |
| randomize_seed=False, # We handle randomization here | |
| position_data=position_data, | |
| image_index=i, | |
| num_images=int(num_images), | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| dialogue=scene_dialogue, # Pass dialogue separately | |
| style_preset=style_preset, | |
| custom_style_text=custom_style_text, | |
| ) | |
| generated_images.append(image) | |
| used_seeds.append(used_seed) | |
| # Add page to session | |
| progress(0.8, "Adding page to document...") | |
| page_num = session_manager.add_page(generated_images, layout, used_seeds) | |
| # Create multi-page PDF | |
| progress(0.9, "Creating PDF...") | |
| pdf_path = create_multi_page_pdf(session_manager) | |
| progress(1.0, "Done!") | |
| # Update session state | |
| session_state["page_count"] = page_num | |
| # Next button label | |
| next_page_num = page_num + 1 | |
| button_label = f"Generate page {next_page_num}" if next_page_num <= 128 else "Page limit reached" | |
| return session_state, pdf_path, pdf_path, button_label | |
| # New inference function with automatic aspect ratio | |
| def infer_single_auto( | |
| prompt, | |
| seed=42, | |
| randomize_seed=False, | |
| position_data=None, | |
| image_index=0, | |
| num_images=1, | |
| guidance_scale=1.0, | |
| num_inference_steps=8, | |
| dialogue="", # New parameter for dialogue | |
| style_preset="no_style", | |
| custom_style_text="", | |
| ): | |
| """ | |
| Generates an image with automatically determined aspect ratio based on layout position. | |
| """ | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Automatically determine image size based on position | |
| width, height = get_image_size_for_position(position_data, image_index, num_images) | |
| # Set up the generator for reproducibility | |
| generator = torch.Generator(device="cuda").manual_seed(seed) | |
| print(f"Original prompt: '{prompt}'") | |
| print(f"Style preset: '{style_preset}'") | |
| print(f"Auto-selected size based on layout: {width}x{height}") | |
| # Apply style preset first | |
| styled_prompt, style_negative_prompt = apply_style_preset(prompt, style_preset, custom_style_text) | |
| # Add dialogue to the prompt if present | |
| if dialogue and dialogue.strip(): | |
| # Simply append the dialogue as it's already properly formatted from the LLM | |
| styled_prompt = f"{styled_prompt}. {dialogue.strip()}" | |
| # Use style negative prompt if available, otherwise default | |
| negative_prompt = style_negative_prompt if style_negative_prompt else " " | |
| print(f"Final Prompt: '{styled_prompt}'") | |
| print(f"Negative Prompt: '{negative_prompt}'") | |
| print(f"Seed: {seed}, Size: {width}x{height}, Steps: {num_inference_steps}, True CFG Scale: {guidance_scale}") | |
| # Generate the image | |
| image = pipe( | |
| prompt=styled_prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| true_cfg_scale=guidance_scale, # Use true_cfg_scale for this model | |
| ).images[0] | |
| # Convert to grayscale if using manga_no_color style | |
| if style_preset == "manga_no_color": | |
| # Convert to grayscale while preserving quality | |
| image = image.convert('L').convert('RGB') | |
| return image, seed | |
| # Keep the old infer function for backward compatibility (simplified) | |
| infer = infer_single_auto | |
| # --- Examples and UI Layout --- | |
| examples = [ | |
| "A capybara wearing a suit holding a sign that reads Hello World", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| #logo-title { | |
| text-align: center; | |
| } | |
| #logo-title img { | |
| width: 400px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| # Session state | |
| session_state = gr.State(value={"session_id": str(uuid.uuid4()), "page_count": 0}) | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(""" | |
| <div id="logo-title"> | |
| <img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/qwen_image_logo.png" alt="Qwen-Image Logo" width="400" style="display: block; margin: 0 auto;"> | |
| <h2 style="font-style: italic;color: #5b47d1;margin-top: -33px !important;margin-left: 133px;">AiComicFactory-GradioEdition</h2> | |
| </div> | |
| """) | |
| gr.Markdown("This demo uses [Qwen-Image-Lightning](https://huggingface.co/lightx2v/Qwen-Image-Lightning). Hugigng Face PRO users can perform more generations.") | |
| with gr.Row(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=False, | |
| placeholder="Enter your prompt", | |
| container=False, | |
| ) | |
| with gr.Column(scale=0): | |
| run_button = gr.Button("Generate page 1", variant="primary") | |
| reset_button = gr.Button("Reset", variant="secondary") | |
| # New row for Style Preset and Page Layout | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Number of images slider (affects layout choices) | |
| num_images_slider = gr.Slider( | |
| label="Images per page", | |
| minimum=1, | |
| maximum=6, | |
| step=1, | |
| value=1, | |
| info="Number of images to generate for the PDF (1-6)" | |
| ) | |
| with gr.Column(scale=2): | |
| layout_dropdown = gr.Dropdown( | |
| label="Page Layout", | |
| choices=[("Full Page", "full_page")], | |
| value="full_page", | |
| interactive=True, | |
| info="How images are arranged on the page" | |
| ) | |
| with gr.Column(scale=2): | |
| # Create dropdown choices from loaded presets | |
| style_choices = [(preset["label"], key) for key, preset in STYLE_PRESETS.items()] | |
| style_preset = gr.Dropdown( | |
| label="Style Preset", | |
| choices=style_choices, | |
| value="no_style", | |
| interactive=True | |
| ) | |
| with gr.Column(scale=2): | |
| custom_style_text = gr.Textbox( | |
| label="Custom Style Text", | |
| placeholder="Enter custom style (e.g., 'oil painting')", | |
| visible=False, | |
| lines=1 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| pdf_preview = PDF(label="PDF Preview", show_label=True, height=600, elem_id="pdf-preview") | |
| pdf_output = gr.File(label="Download PDF", show_label=True, elem_id="pdf-download") | |
| gr.Markdown("""**Note:** Your images and PDF are saved for up to 24 hours. | |
| You can continue adding pages (up to 128) by clicking the generate button.""") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale (True CFG Scale)", | |
| minimum=1.0, | |
| maximum=5.0, | |
| step=0.1, | |
| value=1.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=4, | |
| maximum=28, | |
| step=1, | |
| value=8, | |
| ) | |
| # Add interaction to show/hide custom style text field | |
| def toggle_custom_style(style_value): | |
| return gr.update(visible=(style_value == "no_style")) | |
| style_preset.change( | |
| fn=toggle_custom_style, | |
| inputs=[style_preset], | |
| outputs=[custom_style_text] | |
| ) | |
| # Update layout dropdown when number of images changes | |
| def update_layout_choices(num_images): | |
| choices = get_layout_choices(int(num_images)) | |
| return gr.update(choices=choices, value=choices[0][1] if choices else "default") | |
| num_images_slider.change( | |
| fn=update_layout_choices, | |
| inputs=[num_images_slider], | |
| outputs=[layout_dropdown] | |
| ) | |
| # Update examples to show some with different styles and image counts | |
| styled_examples = [ | |
| ["A capybara wearing a suit holding a sign that reads Hello World", "no_style", "", 1], | |
| ["sharks raining down on san francisco", "anime", "", 2], | |
| ["A beautiful landscape with mountains and a lake", "watercolor", "", 3], | |
| ["A knight fighting a dragon", "medieval", "", 4], | |
| ["Space battle with laser beams", "sci-fi", "", 5], | |
| ["Detective investigating a mystery", "noir", "", 6], | |
| ] | |
| gr.Examples( | |
| examples=styled_examples, | |
| inputs=[prompt, style_preset, custom_style_text, num_images_slider], | |
| outputs=None, # Don't show outputs for examples | |
| fn=None, | |
| cache_examples=False | |
| ) | |
| # Define the main generation event | |
| generation_event = gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer_page, | |
| inputs=[ | |
| prompt, | |
| guidance_scale, | |
| num_inference_steps, | |
| style_preset, | |
| custom_style_text, | |
| num_images_slider, | |
| layout_dropdown, | |
| session_state, | |
| ], | |
| outputs=[session_state, pdf_output, pdf_preview, run_button], | |
| ) | |
| # Reset button functionality | |
| def reset_session(): | |
| new_state = {"session_id": str(uuid.uuid4()), "page_count": 0} | |
| return new_state, None, None, "Generate page 1" | |
| # Connect the reset button | |
| reset_button.click( | |
| fn=reset_session, | |
| inputs=[], | |
| outputs=[session_state, pdf_output, pdf_preview, run_button] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(mcp_server=True) |