Spaces:
Running
Running
| import os | |
| import base64 | |
| import time | |
| import requests | |
| from typing import List, Optional, Tuple | |
| from runwayml import RunwayML | |
| import mimetypes | |
| from urllib.parse import urlparse | |
| def encode_image_to_data_uri(image_path: str) -> str: | |
| """Convert a local image file to a data URI.""" | |
| mime_type, _ = mimetypes.guess_type(image_path) | |
| if not mime_type or not mime_type.startswith('image/'): | |
| raise ValueError(f"Unsupported image type for {image_path}") | |
| with open(image_path, 'rb') as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
| return f"data:{mime_type};base64,{encoded_string}" | |
| def save_generated_image(image_url: str, filename: str = None, batch_folder: str = None) -> str: | |
| """ | |
| Download and save the generated image to a timestamped batch folder. | |
| Args: | |
| image_url: URL of the generated image | |
| filename: Optional filename (auto-generated if not provided) | |
| batch_folder: Optional batch folder name (auto-generated with timestamp if not provided) | |
| Returns: | |
| Path to the saved image file | |
| """ | |
| if not batch_folder: | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| batch_folder = f"batch_{timestamp}" | |
| output_dir = os.path.join("output", batch_folder) | |
| os.makedirs(output_dir, exist_ok=True) | |
| if not filename: | |
| timestamp = int(time.time()) | |
| filename = f"generated_{timestamp}.jpg" | |
| if not os.path.splitext(filename)[1]: | |
| filename += ".jpg" | |
| output_path = os.path.join(output_dir, filename) | |
| response = requests.get(image_url) | |
| response.raise_for_status() | |
| with open(output_path, 'wb') as f: | |
| f.write(response.content) | |
| return output_path | |
| def generate_image_with_references( | |
| prompt_text: str, | |
| reference_image_paths: List[str], | |
| ratio: str = "1920:1080", | |
| model: str = "gen4_image", | |
| seed: Optional[int] = None, | |
| api_key: Optional[str] = None, | |
| auto_tag_prompt: bool = True | |
| ) -> str: | |
| """ | |
| Generate an image using RunwayML API with reference images. | |
| Args: | |
| prompt_text: Description of the image to generate (max 1000 characters) | |
| reference_image_paths: List of local image file paths to use as references | |
| ratio: Output image resolution (default: "1920:1080") | |
| model: Model to use (default: "gen4_image") | |
| seed: Optional seed for reproducible results | |
| api_key: Optional API key (uses RUNWAYML_API_SECRET env var if not provided) | |
| auto_tag_prompt: Whether to automatically append tags to prompt (default: True) | |
| Returns: | |
| Task ID for the generation request | |
| """ | |
| client = RunwayML(api_key=api_key or os.environ.get("RUNWAYML_API_SECRET")) | |
| if len(reference_image_paths) > 3: | |
| raise ValueError("Maximum 3 reference images allowed") | |
| if len(prompt_text) > 1000: | |
| raise ValueError("Prompt text must be 1000 characters or less") | |
| reference_images = [] | |
| tags = [] | |
| used_standard_tags = set() | |
| for i, image_path in enumerate(reference_image_paths): | |
| if not os.path.exists(image_path): | |
| raise FileNotFoundError(f"Image file not found: {image_path}") | |
| filename = os.path.splitext(os.path.basename(image_path))[0] | |
| path_parts = image_path.split(os.sep) | |
| tag = None | |
| for part in path_parts: | |
| if part == 'characters' and 'character' not in used_standard_tags: | |
| tag = 'character' | |
| used_standard_tags.add('character') | |
| break | |
| elif part == 'scenes' and 'scene' not in used_standard_tags: | |
| tag = 'scene' | |
| used_standard_tags.add('scene') | |
| break | |
| elif part == 'styles' and 'style' not in used_standard_tags: | |
| tag = 'style' | |
| used_standard_tags.add('style') | |
| break | |
| if not tag: | |
| tag = f"ref_{filename}".replace('-', '_').replace(' ', '_')[:16] | |
| tag = ''.join(c for c in tag if c.isalnum() or c == '_') | |
| if not tag[0].isalpha(): | |
| tag = f"img_{tag}" | |
| tag = tag[:16] | |
| tags.append(tag) | |
| data_uri = encode_image_to_data_uri(image_path) | |
| reference_images.append({"uri": data_uri, "tag": tag}) | |
| final_prompt = prompt_text | |
| if auto_tag_prompt and tags: | |
| tag_mentions = " ".join([f"@{tag}" for tag in tags]) | |
| final_prompt = f"{prompt_text} using references: {tag_mentions}" | |
| if len(final_prompt) > 1000: | |
| tag_mentions = " ".join([f"@{tag}" for tag in tags]) | |
| final_prompt = f"{prompt_text} {tag_mentions}" | |
| if len(final_prompt) > 1000: | |
| available_chars = 1000 - len(tag_mentions) - 1 | |
| final_prompt = f"{prompt_text[:available_chars]} {tag_mentions}" | |
| print(f"Using tags: {tags}") | |
| if auto_tag_prompt: | |
| print(f"Auto-tagged prompt: {final_prompt}") | |
| else: | |
| print(f"Manual tagging mode - use @{', @'.join(tags)} in your prompt") | |
| print(f"Original prompt: {final_prompt}") | |
| create_params = { | |
| "model": model, | |
| "prompt_text": final_prompt, | |
| "ratio": ratio, | |
| "reference_images": reference_images | |
| } | |
| if seed is not None: | |
| create_params["seed"] = seed | |
| task = client.text_to_image.create(**create_params) | |
| return task.id | |
| def check_task_status(task_id: str, api_key: Optional[str] = None): | |
| """ | |
| Check the status of a generation task. | |
| Args: | |
| task_id: The task ID returned from generate_image_with_references | |
| api_key: Optional API key (uses RUNWAYML_API_SECRET env var if not provided) | |
| Returns: | |
| Task details including status and output URLs if completed | |
| """ | |
| client = RunwayML(api_key=api_key or os.environ.get("RUNWAYML_API_SECRET")) | |
| return client.tasks.retrieve(id=task_id) | |
| def generate_and_wait_for_result( | |
| prompt_text: str, | |
| reference_image_paths: List[str], | |
| ratio: str = "1920:1080", | |
| model: str = "gen4_image", | |
| seed: Optional[int] = None, | |
| api_key: Optional[str] = None, | |
| filename: str = None, | |
| batch_folder: str = None, | |
| max_retries: int = 8, | |
| wait_interval: int = 15, | |
| auto_tag_prompt: bool = True | |
| ) -> Tuple[str, str]: | |
| """ | |
| Generate an image and wait for completion with automatic retries. | |
| Args: | |
| prompt_text: Description of the image to generate | |
| reference_image_paths: List of local image file paths to use as references | |
| ratio: Output image resolution | |
| model: Model to use | |
| seed: Optional seed for reproducible results | |
| api_key: Optional API key | |
| filename: Optional filename for saved image | |
| max_retries: Maximum number of status checks (default: 8) | |
| wait_interval: Seconds to wait between checks (default: 15) | |
| auto_tag_prompt: Whether to automatically append tags to prompt | |
| Returns: | |
| Tuple of (task_id, saved_image_path) | |
| """ | |
| task_id = generate_image_with_references( | |
| prompt_text=prompt_text, | |
| reference_image_paths=reference_image_paths, | |
| ratio=ratio, | |
| model=model, | |
| seed=seed, | |
| api_key=api_key, | |
| auto_tag_prompt=auto_tag_prompt | |
| ) | |
| print(f"Image generation started. Task ID: {task_id}") | |
| print(f"Checking status every {wait_interval} seconds (max {max_retries} attempts)...") | |
| for attempt in range(max_retries): | |
| print(f"Attempt {attempt + 1}/{max_retries} - Waiting {wait_interval} seconds...") | |
| time.sleep(wait_interval) | |
| try: | |
| status = check_task_status(task_id, api_key) | |
| print(f"Status: {status.status}") | |
| if status.status == "SUCCEEDED": | |
| if hasattr(status, 'output') and status.output: | |
| image_url = status.output[0] | |
| print(f"Generation completed! Image URL: {image_url}") | |
| saved_path = save_generated_image(image_url, filename, batch_folder) | |
| print(f"Image saved to: {saved_path}") | |
| return task_id, saved_path | |
| else: | |
| print("Task succeeded but no output found") | |
| return task_id, None | |
| elif status.status == "FAILED": | |
| print("Task failed") | |
| return task_id, None | |
| elif status.status in ["PENDING", "RUNNING"]: | |
| print("Task still in progress...") | |
| continue | |
| except Exception as e: | |
| print(f"Error checking status: {e}") | |
| if attempt == max_retries - 1: | |
| print("Max retries reached. Task may still be processing.") | |
| return task_id, None | |
| print(f"Timeout after {max_retries} attempts. Task may still be processing.") | |
| print(f"You can manually check status later using task ID: {task_id}") | |
| return task_id, None | |
| def main(): | |
| print("\n=== Testing RunwayML with Reference Images ===") | |
| reference_images = [ | |
| "assets/characters/japanese_guy.jpg", | |
| "assets/scenes/f1-fields.jpg", | |
| "assets/styles/f1-cockpit.jpg" | |
| ] | |
| print("=== Manual Tagging Mode ===") | |
| manual_prompt = "@character in a @scene with @style composition, cinematic lighting, high detail" | |
| try: | |
| task_id, saved_path = generate_and_wait_for_result( | |
| prompt_text=manual_prompt, | |
| reference_image_paths=reference_images, | |
| ratio="1920:1080", | |
| filename="f1_driver_manual_tags.jpg", | |
| auto_tag_prompt=False | |
| ) | |
| if saved_path: | |
| print(f"Manual tagging success! Image saved to: {saved_path}") | |
| else: | |
| print(f"Manual tagging incomplete. Task ID: {task_id}") | |
| except Exception as e: | |
| print(f"Manual tagging error: {e}") | |
| print("\n" + "="*50) | |
| print("=== Auto Tagging Mode Example ===") | |
| auto_prompt = "A Japanese F1 driver in a cockpit style setting on a racing field, cinematic lighting, high detail" | |
| try: | |
| task_id, saved_path = generate_and_wait_for_result( | |
| prompt_text=auto_prompt, | |
| reference_image_paths=reference_images, | |
| ratio="1920:1080", | |
| filename="f1_driver_auto_tags.jpg", | |
| auto_tag_prompt=True | |
| ) | |
| if saved_path: | |
| print(f"Auto tagging success! Image saved to: {saved_path}") | |
| else: | |
| print(f"Auto tagging incomplete. Task ID: {task_id}") | |
| except Exception as e: | |
| print(f"Auto tagging error: {e}") | |
| if __name__ == "__main__": | |
| main() |