File size: 3,125 Bytes
ac2243f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 | import logging
import os
from dataclasses import dataclass, field
from typing import List
import torch
from pydantic import BaseModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
logger = logging.getLogger(__name__)
class TextToImageInput(BaseModel):
model: str
prompt: str
size: str | None = None
n: int | None = None
@dataclass
class PresetModels:
SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
SD3_5: List[str] = field(
default_factory=lambda: [
"stabilityai/stable-diffusion-3.5-large",
"stabilityai/stable-diffusion-3.5-large-turbo",
"stabilityai/stable-diffusion-3.5-medium",
]
)
class TextToImagePipelineSD3:
def __init__(self, model_path: str | None = None):
self.model_path = model_path or os.getenv("MODEL_PATH")
self.pipeline: StableDiffusion3Pipeline | None = None
self.device: str | None = None
def start(self):
if torch.cuda.is_available():
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
logger.info("Loading CUDA")
self.device = "cuda"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.float16,
).to(device=self.device)
elif torch.backends.mps.is_available():
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
logger.info("Loading MPS for Mac M Series")
self.device = "mps"
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(device=self.device)
else:
raise Exception("No CUDA or MPS device available")
class ModelPipelineInitializer:
def __init__(self, model: str = "", type_models: str = "t2im"):
self.model = model
self.type_models = type_models
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "mps"
self.model_type = None
def initialize_pipeline(self):
if not self.model:
raise ValueError("Model name not provided")
# Check if model exists in PresetModels
preset_models = PresetModels()
# Determine which model type we're dealing with
if self.model in preset_models.SD3:
self.model_type = "SD3"
elif self.model in preset_models.SD3_5:
self.model_type = "SD3_5"
# Create appropriate pipeline based on model type and type_models
if self.type_models == "t2im":
if self.model_type in ["SD3", "SD3_5"]:
self.pipeline = TextToImagePipelineSD3(self.model)
else:
raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
elif self.type_models == "t2v":
raise ValueError(f"Unsupported type_models: {self.type_models}")
return self.pipeline
|