runway_reference / generate_image.py
herodevcode
Updated generate_image.py
a0199e2
import os
import base64
import time
import requests
from typing import List, Optional, Tuple
from runwayml import RunwayML
import mimetypes
from urllib.parse import urlparse
def encode_image_to_data_uri(image_path: str) -> str:
"""Convert a local image file to a data URI."""
mime_type, _ = mimetypes.guess_type(image_path)
if not mime_type or not mime_type.startswith('image/'):
raise ValueError(f"Unsupported image type for {image_path}")
with open(image_path, 'rb') as image_file:
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
return f"data:{mime_type};base64,{encoded_string}"
def save_generated_image(image_url: str, filename: str = None, batch_folder: str = None) -> str:
"""
Download and save the generated image to a timestamped batch folder.
Args:
image_url: URL of the generated image
filename: Optional filename (auto-generated if not provided)
batch_folder: Optional batch folder name (auto-generated with timestamp if not provided)
Returns:
Path to the saved image file
"""
if not batch_folder:
timestamp = time.strftime("%Y%m%d_%H%M%S")
batch_folder = f"batch_{timestamp}"
output_dir = os.path.join("output", batch_folder)
os.makedirs(output_dir, exist_ok=True)
if not filename:
timestamp = int(time.time())
filename = f"generated_{timestamp}.jpg"
if not os.path.splitext(filename)[1]:
filename += ".jpg"
output_path = os.path.join(output_dir, filename)
response = requests.get(image_url)
response.raise_for_status()
with open(output_path, 'wb') as f:
f.write(response.content)
return output_path
def generate_image_with_references(
prompt_text: str,
reference_image_paths: List[str],
ratio: str = "1920:1080",
model: str = "gen4_image",
seed: Optional[int] = None,
api_key: Optional[str] = None,
auto_tag_prompt: bool = True
) -> str:
"""
Generate an image using RunwayML API with reference images.
Args:
prompt_text: Description of the image to generate (max 1000 characters)
reference_image_paths: List of local image file paths to use as references
ratio: Output image resolution (default: "1920:1080")
model: Model to use (default: "gen4_image")
seed: Optional seed for reproducible results
api_key: Optional API key (uses RUNWAYML_API_SECRET env var if not provided)
auto_tag_prompt: Whether to automatically append tags to prompt (default: True)
Returns:
Task ID for the generation request
"""
client = RunwayML(api_key=api_key or os.environ.get("RUNWAYML_API_SECRET"))
if len(reference_image_paths) > 3:
raise ValueError("Maximum 3 reference images allowed")
if len(prompt_text) > 1000:
raise ValueError("Prompt text must be 1000 characters or less")
reference_images = []
tags = []
used_standard_tags = set()
for i, image_path in enumerate(reference_image_paths):
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
filename = os.path.splitext(os.path.basename(image_path))[0]
path_parts = image_path.split(os.sep)
tag = None
for part in path_parts:
if part == 'characters' and 'character' not in used_standard_tags:
tag = 'character'
used_standard_tags.add('character')
break
elif part == 'scenes' and 'scene' not in used_standard_tags:
tag = 'scene'
used_standard_tags.add('scene')
break
elif part == 'styles' and 'style' not in used_standard_tags:
tag = 'style'
used_standard_tags.add('style')
break
if not tag:
tag = f"ref_{filename}".replace('-', '_').replace(' ', '_')[:16]
tag = ''.join(c for c in tag if c.isalnum() or c == '_')
if not tag[0].isalpha():
tag = f"img_{tag}"
tag = tag[:16]
tags.append(tag)
data_uri = encode_image_to_data_uri(image_path)
reference_images.append({"uri": data_uri, "tag": tag})
final_prompt = prompt_text
if auto_tag_prompt and tags:
tag_mentions = " ".join([f"@{tag}" for tag in tags])
final_prompt = f"{prompt_text} using references: {tag_mentions}"
if len(final_prompt) > 1000:
tag_mentions = " ".join([f"@{tag}" for tag in tags])
final_prompt = f"{prompt_text} {tag_mentions}"
if len(final_prompt) > 1000:
available_chars = 1000 - len(tag_mentions) - 1
final_prompt = f"{prompt_text[:available_chars]} {tag_mentions}"
print(f"Using tags: {tags}")
if auto_tag_prompt:
print(f"Auto-tagged prompt: {final_prompt}")
else:
print(f"Manual tagging mode - use @{', @'.join(tags)} in your prompt")
print(f"Original prompt: {final_prompt}")
create_params = {
"model": model,
"prompt_text": final_prompt,
"ratio": ratio,
"reference_images": reference_images
}
if seed is not None:
create_params["seed"] = seed
task = client.text_to_image.create(**create_params)
return task.id
def check_task_status(task_id: str, api_key: Optional[str] = None):
"""
Check the status of a generation task.
Args:
task_id: The task ID returned from generate_image_with_references
api_key: Optional API key (uses RUNWAYML_API_SECRET env var if not provided)
Returns:
Task details including status and output URLs if completed
"""
client = RunwayML(api_key=api_key or os.environ.get("RUNWAYML_API_SECRET"))
return client.tasks.retrieve(id=task_id)
def generate_and_wait_for_result(
prompt_text: str,
reference_image_paths: List[str],
ratio: str = "1920:1080",
model: str = "gen4_image",
seed: Optional[int] = None,
api_key: Optional[str] = None,
filename: str = None,
batch_folder: str = None,
max_retries: int = 8,
wait_interval: int = 15,
auto_tag_prompt: bool = True
) -> Tuple[str, str]:
"""
Generate an image and wait for completion with automatic retries.
Args:
prompt_text: Description of the image to generate
reference_image_paths: List of local image file paths to use as references
ratio: Output image resolution
model: Model to use
seed: Optional seed for reproducible results
api_key: Optional API key
filename: Optional filename for saved image
max_retries: Maximum number of status checks (default: 8)
wait_interval: Seconds to wait between checks (default: 15)
auto_tag_prompt: Whether to automatically append tags to prompt
Returns:
Tuple of (task_id, saved_image_path)
"""
task_id = generate_image_with_references(
prompt_text=prompt_text,
reference_image_paths=reference_image_paths,
ratio=ratio,
model=model,
seed=seed,
api_key=api_key,
auto_tag_prompt=auto_tag_prompt
)
print(f"Image generation started. Task ID: {task_id}")
print(f"Checking status every {wait_interval} seconds (max {max_retries} attempts)...")
for attempt in range(max_retries):
print(f"Attempt {attempt + 1}/{max_retries} - Waiting {wait_interval} seconds...")
time.sleep(wait_interval)
try:
status = check_task_status(task_id, api_key)
print(f"Status: {status.status}")
if status.status == "SUCCEEDED":
if hasattr(status, 'output') and status.output:
image_url = status.output[0]
print(f"Generation completed! Image URL: {image_url}")
saved_path = save_generated_image(image_url, filename, batch_folder)
print(f"Image saved to: {saved_path}")
return task_id, saved_path
else:
print("Task succeeded but no output found")
return task_id, None
elif status.status == "FAILED":
print("Task failed")
return task_id, None
elif status.status in ["PENDING", "RUNNING"]:
print("Task still in progress...")
continue
except Exception as e:
print(f"Error checking status: {e}")
if attempt == max_retries - 1:
print("Max retries reached. Task may still be processing.")
return task_id, None
print(f"Timeout after {max_retries} attempts. Task may still be processing.")
print(f"You can manually check status later using task ID: {task_id}")
return task_id, None
def main():
print("\n=== Testing RunwayML with Reference Images ===")
reference_images = [
"assets/characters/japanese_guy.jpg",
"assets/scenes/f1-fields.jpg",
"assets/styles/f1-cockpit.jpg"
]
print("=== Manual Tagging Mode ===")
manual_prompt = "@character in a @scene with @style composition, cinematic lighting, high detail"
try:
task_id, saved_path = generate_and_wait_for_result(
prompt_text=manual_prompt,
reference_image_paths=reference_images,
ratio="1920:1080",
filename="f1_driver_manual_tags.jpg",
auto_tag_prompt=False
)
if saved_path:
print(f"Manual tagging success! Image saved to: {saved_path}")
else:
print(f"Manual tagging incomplete. Task ID: {task_id}")
except Exception as e:
print(f"Manual tagging error: {e}")
print("\n" + "="*50)
print("=== Auto Tagging Mode Example ===")
auto_prompt = "A Japanese F1 driver in a cockpit style setting on a racing field, cinematic lighting, high detail"
try:
task_id, saved_path = generate_and_wait_for_result(
prompt_text=auto_prompt,
reference_image_paths=reference_images,
ratio="1920:1080",
filename="f1_driver_auto_tags.jpg",
auto_tag_prompt=True
)
if saved_path:
print(f"Auto tagging success! Image saved to: {saved_path}")
else:
print(f"Auto tagging incomplete. Task ID: {task_id}")
except Exception as e:
print(f"Auto tagging error: {e}")
if __name__ == "__main__":
main()