Xernive's picture
Fix Hunyuan3D error handling + enhanced logging
0e805d4
"""FLUX.1-dev 2D image generation."""
# CRITICAL: Import spaces BEFORE torch/CUDA packages
import spaces
import torch
from pathlib import Path
from diffusers import DiffusionPipeline
from core.config import FLUX_MODELS, QualityPreset
from utils.memory import MemoryManager
class FluxGenerator:
"""Generates 2D images using FLUX.1-dev."""
def __init__(self):
self.memory_manager = MemoryManager()
def _load_model(self, model_id: str) -> DiffusionPipeline:
"""Load FLUX model (no caching to prevent OOM)."""
print(f"[FLUX] Loading model: {model_id}")
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_safetensors=True,
low_cpu_mem_usage=True
)
# Load to GPU (L4 has 24GB VRAM)
pipe = pipe.to("cuda", dtype=torch.bfloat16)
# Enable memory optimizations
pipe.enable_attention_slicing()
pipe.enable_vae_slicing()
# Enable xformers if available
try:
pipe.enable_xformers_memory_efficient_attention()
print("[FLUX] xformers enabled")
except Exception:
print("[FLUX] xformers not available")
return pipe
def _enhance_prompt_for_3d(self, prompt: str) -> str:
"""Enhance prompt for better 3D conversion."""
enhancements = [
"high detailed 3D model reference",
"complete object visible",
"white background",
"professional quality render",
"single centered object",
"game asset style",
"perfect for 3D reconstruction",
"clear silhouette",
"front facing view",
"studio lighting",
"clean edges",
"PBR ready",
]
enhanced = f"{prompt}, {', '.join(enhancements)}"
return enhanced[:500] # Limit length
@spaces.GPU(duration=35)
def generate(
self,
prompt: str,
preset: QualityPreset,
output_dir: Path
) -> Path:
"""Generate 2D image from text prompt."""
try:
print(f"[FLUX] Generating image: {preset.name} quality")
# Load model
pipe = self._load_model(FLUX_MODELS["dev"])
# Enhance prompt
enhanced_prompt = self._enhance_prompt_for_3d(prompt)
# Generate image
image = pipe(
prompt=enhanced_prompt,
height=960,
width=1440,
num_inference_steps=preset.flux_steps,
guidance_scale=preset.flux_guidance
).images[0]
# Save image
output_dir.mkdir(exist_ok=True, parents=True)
import time
output_path = output_dir / f"flux_{int(time.time())}.png"
image.save(output_path)
print(f"[FLUX] Image saved: {output_path}")
# Cleanup
self.memory_manager.cleanup_model(pipe)
return output_path
except Exception as e:
print(f"[FLUX] Error: {e}")
raise