File size: 2,515 Bytes
ea4fe4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from diffusers import AutoPipelineForTextToImage
import os

class ModelHandler:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers"
        self.pipeline = None
        self.load_model()

    def load_model(self):
        """
        Loads the model pipeline. Uses float16 for GPU to save memory.
        """
        try:
            print(f"Loading model: {self.model_id} on {self.device}...")
            
            dtype = torch.float16 if self.device == "cuda" else torch.float32
            
            # AutoPipeline handles the architecture detection automatically
            self.pipeline = AutoPipelineForTextToImage.from_pretrained(
                self.model_id,
                torch_dtype=dtype,
                use_safetensors=True
            )
            
            if self.device == "cuda":
                self.pipeline.to("cuda")
                # Optional: Enable CPU offload if VRAM is limited (e.g. < 8GB)
                # self.pipeline.enable_model_cpu_offload() 
            
            print("Model loaded successfully.")
            
        except Exception as e:
            print(f"Error loading model: {e}")
            # Fallback or re-raise depending on deployment needs
            raise e

    def infer(self, prompt, negative_prompt, width, height, num_inference_steps, guidance_scale, seed, progress_callback=None):
        """
        Runs inference on the loaded pipeline.
        """
        if self.pipeline is None:
            self.load_model()

        generator = torch.Generator(device=self.device).manual_seed(int(seed))
        
        # Progress bar handling
        def callback_dynamic(step, timestep, latents):
            if progress_callback:
                progress_callback((step, num_inference_steps))

        # Depending on the specific diffusers version or pipeline type, 
        # callback usage might vary slightly, but this is standard for recent versions.
        image = self.pipeline(
            prompt=prompt,
            negative_prompt=negative_prompt,
            width=width,
            height=height,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator,
            # callback=callback_dynamic, # Optional: enable for granular progress updates
            # callback_steps=1
        ).images[0]
        
        return image