File size: 3,304 Bytes
0e805d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""FLUX.1-dev 2D image generation."""

# CRITICAL: Import spaces BEFORE torch/CUDA packages
import spaces

import torch
from pathlib import Path
from diffusers import DiffusionPipeline

from core.config import FLUX_MODELS, QualityPreset
from utils.memory import MemoryManager


class FluxGenerator:
    """Generates 2D images using FLUX.1-dev."""
    
    def __init__(self):
        self.memory_manager = MemoryManager()
    
    def _load_model(self, model_id: str) -> DiffusionPipeline:
        """Load FLUX model (no caching to prevent OOM)."""
        print(f"[FLUX] Loading model: {model_id}")
        
        pipe = DiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            use_safetensors=True,
            low_cpu_mem_usage=True
        )
        
        # Load to GPU (L4 has 24GB VRAM)
        pipe = pipe.to("cuda", dtype=torch.bfloat16)
        
        # Enable memory optimizations
        pipe.enable_attention_slicing()
        pipe.enable_vae_slicing()
        
        # Enable xformers if available
        try:
            pipe.enable_xformers_memory_efficient_attention()
            print("[FLUX] xformers enabled")
        except Exception:
            print("[FLUX] xformers not available")
        
        return pipe
    
    def _enhance_prompt_for_3d(self, prompt: str) -> str:
        """Enhance prompt for better 3D conversion."""
        enhancements = [
            "high detailed 3D model reference",
            "complete object visible",
            "white background",
            "professional quality render",
            "single centered object",
            "game asset style",
            "perfect for 3D reconstruction",
            "clear silhouette",
            "front facing view",
            "studio lighting",
            "clean edges",
            "PBR ready",
        ]
        
        enhanced = f"{prompt}, {', '.join(enhancements)}"
        return enhanced[:500]  # Limit length
    
    @spaces.GPU(duration=35)
    def generate(
        self,
        prompt: str,
        preset: QualityPreset,
        output_dir: Path
    ) -> Path:
        """Generate 2D image from text prompt."""
        try:
            print(f"[FLUX] Generating image: {preset.name} quality")
            
            # Load model
            pipe = self._load_model(FLUX_MODELS["dev"])
            
            # Enhance prompt
            enhanced_prompt = self._enhance_prompt_for_3d(prompt)
            
            # Generate image
            image = pipe(
                prompt=enhanced_prompt,
                height=960,
                width=1440,
                num_inference_steps=preset.flux_steps,
                guidance_scale=preset.flux_guidance
            ).images[0]
            
            # Save image
            output_dir.mkdir(exist_ok=True, parents=True)
            import time
            output_path = output_dir / f"flux_{int(time.time())}.png"
            image.save(output_path)
            
            print(f"[FLUX] Image saved: {output_path}")
            
            # Cleanup
            self.memory_manager.cleanup_model(pipe)
            
            return output_path
            
        except Exception as e:
            print(f"[FLUX] Error: {e}")
            raise