picpro-backend / services /generator.py
ManiRafy2's picture
Update services/generator.py
7ef34c5 verified
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import gc
import os
class ImageGenerator:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = None
# Switch to Z-Image-Turbo
self.model_id = "mrfakename/Z-Image-Turbo"
# Lazy loading to avoid memory spikes on startup
self._load_model()
def _load_model(self):
print(f"Loading Model: {self.model_id} on {self.device}...")
try:
# Z-Image recommends bfloat16
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
# Using standard DiffusionPipeline which should resolve to ZImagePipeline if supported
# or fallback to compatible pipeline
self.pipe = DiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=dtype,
use_safetensors=True,
token=os.environ.get("HF_TOKEN")
)
# Memory optimizations
if self.device == "cuda":
self.pipe.enable_model_cpu_offload()
print("Z-Image-Turbo Model Loaded Successfully.")
except Exception as e:
print(f"Error loading Z-Image model: {e}")
self.pipe = None
def generate(
self,
prompt: str,
negative_prompt: str = "",
width: int = 1024,
height: int = 1024,
steps: int = 9, # Recommended steps for Z-Image-Turbo
guidance_scale: float = 0.0, # Recommended guidance for Z-Image-Turbo
seed: int = -1
) -> Image.Image:
if self.pipe is None:
self._load_model()
if self.pipe is None:
raise RuntimeError("Failed to load model.")
try:
# Handle Seed
generator = None
if seed != -1:
generator = torch.Generator(device="cpu").manual_seed(int(seed))
print(f"Generating image for prompt: '{prompt}'...")
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=guidance_scale,
generator=generator
).images[0]
return image
except Exception as e:
print(f"Generation error: {e}")
raise e
finally:
# Cleanup memory if needed
if self.device == "cuda":
torch.cuda.empty_cache()
gc.collect()