ByteDream / infer.py
Enzo8930302's picture
Upload infer.py with huggingface_hub
a44493b verified
"""
Byte Dream - Command Line Inference Tool
Generate images from text prompts using the command line
"""
import argparse
from pathlib import Path
import torch
def main():
parser = argparse.ArgumentParser(
description="Byte Dream - AI Image Generation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage
python infer.py --prompt "A beautiful sunset over mountains"
# With custom parameters
python infer.py --prompt "Cyberpunk city" --negative "blurry" --steps 75 --guidance 8.0
# Specify output and size
python infer.py --prompt "Fantasy landscape" --output fantasy.png --width 768 --height 768
# With seed for reproducibility
python infer.py --prompt "Dragon" --seed 42 --output dragon.png
"""
)
parser.add_argument(
"--prompt", "-p",
type=str,
required=True,
help="Text prompt describing the desired image"
)
parser.add_argument(
"--negative", "-n",
type=str,
default="",
help="Negative prompt - things to avoid in the image"
)
parser.add_argument(
"--output", "-o",
type=str,
default="output.png",
help="Output image filename (default: output.png)"
)
parser.add_argument(
"--width", "-W",
type=int,
default=512,
help="Image width in pixels (default: 512)"
)
parser.add_argument(
"--height", "-H",
type=int,
default=512,
help="Image height in pixels (default: 512)"
)
parser.add_argument(
"--steps", "-s",
type=int,
default=50,
help="Number of inference steps (default: 50)"
)
parser.add_argument(
"--guidance", "-g",
type=float,
default=7.5,
help="Guidance scale - how closely to follow prompt (default: 7.5)"
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility (default: random)"
)
parser.add_argument(
"--model", "-m",
type=str,
default=None,
help="Path to model directory or Hugging Face repo ID (default: uses config)"
)
parser.add_argument(
"--hf_repo",
type=str,
default=None,
help="Hugging Face repository ID to load model from (e.g., username/repo)"
)
parser.add_argument(
"--config", "-c",
type=str,
default="config.yaml",
help="Path to config file (default: config.yaml)"
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="Device to run on: cpu or cuda (default: cpu)"
)
args = parser.parse_args()
# Import generator
from bytedream.generator import ByteDreamGenerator
# Initialize generator
print("="*60)
print("Byte Dream - AI Image Generator")
print("="*60)
# Determine if loading from HF or local
if args.hf_repo:
print(f"Loading model from Hugging Face: {args.hf_repo}")
generator = ByteDreamGenerator(
hf_repo_id=args.hf_repo,
config_path=args.config,
device=args.device,
)
else:
generator = ByteDreamGenerator(
model_path=args.model,
config_path=args.config,
device=args.device,
)
# Print model info
info = generator.get_model_info()
print(f"\nModel: {info['name']} v{info['version']}")
print(f"Device: {info['device']}")
print(f"Parameters: {info['unet_parameters']}")
print("="*60)
# Generate image
image = generator.generate(
prompt=args.prompt,
negative_prompt=args.negative if args.negative else None,
width=args.width,
height=args.height,
num_inference_steps=args.steps,
guidance_scale=args.guidance,
seed=args.seed,
)
# Save image
output_path = Path(args.output)
image.save(output_path)
print(f"\n✓ Image saved to: {output_path.absolute()}")
print("="*60)
if __name__ == "__main__":
main()