| """
|
| Hugging Face API Client for Byte Dream
|
| Use Byte Dream models directly from Hugging Face Hub
|
| """
|
|
|
| import torch
|
| import requests
|
| import base64
|
| from io import BytesIO
|
| from PIL import Image
|
| from typing import Optional, List, Union
|
| import time
|
|
|
|
|
| class HuggingFaceAPI:
|
| """
|
| Client for Hugging Face Inference API
|
| Allows using Byte Dream models without downloading them
|
| """
|
|
|
| def __init__(
|
| self,
|
| repo_id: str,
|
| token: Optional[str] = None,
|
| use_gpu: bool = False,
|
| ):
|
| """
|
| Initialize Hugging Face API client
|
|
|
| Args:
|
| repo_id: Repository ID (e.g., "username/ByteDream")
|
| token: Hugging Face API token (optional but recommended)
|
| use_gpu: Request GPU inference (if available)
|
| """
|
| self.repo_id = repo_id
|
| self.token = token
|
| self.use_gpu = use_gpu
|
|
|
|
|
| self.inference_api_url = f"https://api-inference.huggingface.co/models/{repo_id}"
|
| self.headers = {}
|
|
|
| if token:
|
| self.headers["Authorization"] = f"Bearer {token}"
|
|
|
| print(f"✓ Hugging Face API initialized for: {repo_id}")
|
|
|
| def query(
|
| self,
|
| prompt: str,
|
| negative_prompt: str = "",
|
| width: int = 512,
|
| height: int = 512,
|
| num_inference_steps: int = 50,
|
| guidance_scale: float = 7.5,
|
| seed: Optional[int] = None,
|
| ) -> Image.Image:
|
| """
|
| Query the model using Inference API
|
|
|
| Args:
|
| prompt: Text prompt
|
| negative_prompt: Negative prompt
|
| width: Image width
|
| height: Image height
|
| num_inference_steps: Number of denoising steps
|
| guidance_scale: Guidance scale
|
| seed: Random seed
|
|
|
| Returns:
|
| Generated PIL Image
|
| """
|
| payload = {
|
| "inputs": prompt,
|
| "parameters": {
|
| "negative_prompt": negative_prompt,
|
| "width": width,
|
| "height": height,
|
| "num_inference_steps": num_inference_steps,
|
| "guidance_scale": guidance_scale,
|
| }
|
| }
|
|
|
| if seed is not None:
|
| payload["parameters"]["seed"] = seed
|
|
|
|
|
| response = requests.post(
|
| self.inference_api_url,
|
| headers=self.headers,
|
| json=payload,
|
| )
|
|
|
|
|
| if response.status_code == 503:
|
|
|
| print("Model is loading on HF servers. Waiting...")
|
| time.sleep(5)
|
| return self.query(prompt, negative_prompt, width, height,
|
| num_inference_steps, guidance_scale, seed)
|
|
|
| response.raise_for_status()
|
|
|
|
|
| image_bytes = response.content
|
| image = Image.open(BytesIO(image_bytes))
|
|
|
| return image
|
|
|
| def query_batch(
|
| self,
|
| prompts: List[str],
|
| negative_prompt: str = "",
|
| width: int = 512,
|
| height: int = 512,
|
| num_inference_steps: int = 50,
|
| guidance_scale: float = 7.5,
|
| seeds: Optional[List[int]] = None,
|
| ) -> List[Image.Image]:
|
| """
|
| Generate multiple images
|
|
|
| Args:
|
| prompts: List of prompts
|
| negative_prompt: Negative prompt
|
| width: Image width
|
| height: Image height
|
| num_inference_steps: Number of steps
|
| guidance_scale: Guidance scale
|
| seeds: List of seeds
|
|
|
| Returns:
|
| List of PIL Images
|
| """
|
| images = []
|
|
|
| for i, prompt in enumerate(prompts):
|
| seed = seeds[i] if seeds and i < len(seeds) else None
|
|
|
| print(f"Generating image {i+1}/{len(prompts)}...")
|
| image = self.query(
|
| prompt=prompt,
|
| negative_prompt=negative_prompt,
|
| width=width,
|
| height=height,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| seed=seed,
|
| )
|
|
|
| images.append(image)
|
|
|
| return images
|
|
|
|
|
| class ByteDreamHFClient:
|
| """
|
| High-level client for Byte Dream on Hugging Face
|
| Supports both local inference and API usage
|
| """
|
|
|
| def __init__(
|
| self,
|
| repo_id: str,
|
| token: Optional[str] = None,
|
| use_api: bool = False,
|
| device: str = "cpu",
|
| ):
|
| """
|
| Initialize Byte Dream HF client
|
|
|
| Args:
|
| repo_id: Repository ID on Hugging Face
|
| token: HF API token
|
| use_api: Use Inference API instead of local inference
|
| device: Device for local inference
|
| """
|
| self.repo_id = repo_id
|
| self.token = token
|
| self.use_api = use_api
|
| self.device = device
|
|
|
| if use_api:
|
| self.api_client = HuggingFaceAPI(repo_id, token)
|
| print("✓ Using Hugging Face Inference API")
|
| else:
|
|
|
| from bytedream.generator import ByteDreamGenerator
|
| self.generator = ByteDreamGenerator(
|
| hf_repo_id=repo_id,
|
| config_path="config.yaml",
|
| device=device,
|
| )
|
| print("✓ Model loaded locally from Hugging Face")
|
|
|
| def generate(
|
| self,
|
| prompt: str,
|
| negative_prompt: str = "",
|
| width: int = 512,
|
| height: int = 512,
|
| num_inference_steps: int = 50,
|
| guidance_scale: float = 7.5,
|
| seed: Optional[int] = None,
|
| ) -> Image.Image:
|
| """
|
| Generate image from prompt
|
|
|
| Args:
|
| prompt: Text description
|
| negative_prompt: Things to avoid
|
| width: Image width
|
| height: Image height
|
| num_inference_steps: Number of steps
|
| guidance_scale: Guidance scale
|
| seed: Random seed
|
|
|
| Returns:
|
| Generated PIL Image
|
| """
|
| if self.use_api:
|
| return self.api_client.query(
|
| prompt=prompt,
|
| negative_prompt=negative_prompt,
|
| width=width,
|
| height=height,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| seed=seed,
|
| )
|
| else:
|
| return self.generator.generate(
|
| prompt=prompt,
|
| negative_prompt=negative_prompt if negative_prompt else None,
|
| width=width,
|
| height=height,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| seed=seed,
|
| )
|
|
|
| def generate_batch(
|
| self,
|
| prompts: List[str],
|
| negative_prompt: str = "",
|
| width: int = 512,
|
| height: int = 512,
|
| num_inference_steps: int = 50,
|
| guidance_scale: float = 7.5,
|
| seeds: Optional[List[int]] = None,
|
| ) -> List[Image.Image]:
|
| """
|
| Generate multiple images
|
|
|
| Args:
|
| prompts: List of text descriptions
|
| negative_prompt: Things to avoid
|
| width: Image width
|
| height: Image height
|
| num_inference_steps: Number of steps
|
| guidance_scale: Guidance scale
|
| seeds: List of random seeds
|
|
|
| Returns:
|
| List of PIL Images
|
| """
|
| if self.use_api:
|
| return self.api_client.query_batch(
|
| prompts=prompts,
|
| negative_prompt=negative_prompt,
|
| width=width,
|
| height=height,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| seeds=seeds,
|
| )
|
| else:
|
| return self.generator.generate_batch(
|
| prompts=prompts,
|
| negative_prompt=negative_prompt if negative_prompt else None,
|
| width=width,
|
| height=height,
|
| num_inference_steps=num_inference_steps,
|
| guidance_scale=guidance_scale,
|
| seeds=seeds,
|
| )
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| print("=" * 60)
|
| print("Example 1: Using Hugging Face Inference API")
|
| print("=" * 60)
|
|
|
|
|
|
|
|
|
| try:
|
| client = ByteDreamHFClient(
|
| repo_id="Enzo8930302/ByteDream",
|
|
|
| use_api=True,
|
| )
|
|
|
| image = client.generate(
|
| prompt="A beautiful sunset over mountains, digital art",
|
| negative_prompt="ugly, blurry, low quality",
|
| width=512,
|
| height=512,
|
| num_inference_steps=50,
|
| guidance_scale=7.5,
|
| seed=42,
|
| )
|
|
|
| image.save("output_api.png")
|
| print("✓ Image saved to output_api.png")
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
| print("Make sure the model exists on Hugging Face")
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print("Example 2: Download and run locally on CPU")
|
| print("=" * 60)
|
|
|
| try:
|
| client_local = ByteDreamHFClient(
|
| repo_id="Enzo8930302/ByteDream",
|
| use_api=False,
|
| device="cpu",
|
| )
|
|
|
| image_local = client_local.generate(
|
| prompt="A futuristic city at night, cyberpunk style",
|
| width=512,
|
| height=512,
|
| num_inference_steps=30,
|
| )
|
|
|
| image_local.save("output_local.png")
|
| print("✓ Image saved to output_local.png")
|
|
|
| except Exception as e:
|
| print(f"Error: {e}")
|
|
|