Update app.py
Browse files
app.py
CHANGED
|
@@ -16,15 +16,13 @@ from fastapi.responses import FileResponse
|
|
| 16 |
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
from pydantic import BaseModel, Field
|
| 18 |
|
| 19 |
-
# Diffusers imports
|
| 20 |
from diffusers import (
|
| 21 |
-
|
| 22 |
StableDiffusionImg2ImgPipeline,
|
| 23 |
StableDiffusionControlNetPipeline,
|
| 24 |
ControlNetModel,
|
| 25 |
LCMScheduler,
|
| 26 |
)
|
| 27 |
-
from optimum.intel import OVStableDiffusionPipeline
|
| 28 |
from transformers import CLIPTokenizer
|
| 29 |
|
| 30 |
# Configure logging
|
|
@@ -32,7 +30,7 @@ logging.basicConfig(level=logging.INFO)
|
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# --- Configuration ---
|
| 35 |
-
|
| 36 |
CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
|
| 37 |
OUTPUT_DIR = Path("/tmp/outputs")
|
| 38 |
LORA_CACHE_DIR = Path("/app/models/loras")
|
|
@@ -42,56 +40,47 @@ LORA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
| 42 |
|
| 43 |
# --- Pydantic Models for API Documentation ---
|
| 44 |
class GenerationResponse(BaseModel):
|
| 45 |
-
"""Response after successful image generation."""
|
| 46 |
status: Literal["success"] = "success"
|
| 47 |
message: str = "Image generated successfully"
|
| 48 |
-
image_base64: Optional[str] = Field(None
|
| 49 |
-
image_url: Optional[str] = Field(None
|
| 50 |
-
seed: int = Field(...
|
| 51 |
-
parameters: dict = Field(...
|
| 52 |
|
| 53 |
class ErrorResponse(BaseModel):
|
| 54 |
-
"""Standard error response."""
|
| 55 |
status: Literal["error"] = "error"
|
| 56 |
message: str
|
| 57 |
detail: Optional[str] = None
|
| 58 |
|
| 59 |
class LoRAInfo(BaseModel):
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
keywords: List[str] = Field(..., description="Suggested prompt keywords")
|
| 66 |
|
| 67 |
# --- FastAPI App Setup ---
|
| 68 |
app = FastAPI(
|
| 69 |
title="LCM Dreamshaper v7 Image Generation API",
|
| 70 |
description="""
|
| 71 |
-
## Fast, CPU-optimized image generation using LCM Dreamshaper v7
|
| 72 |
|
| 73 |
This API provides:
|
| 74 |
-
- **`/generate`** - Text-to-image generation (
|
| 75 |
-
- **`/img2img`** - Image-to-image transformation
|
| 76 |
-
- **`/controlnet`** - Generate with structural guidance using ControlNet
|
| 77 |
- **`/loras`** - List available style LoRAs
|
| 78 |
|
| 79 |
### Model Information
|
| 80 |
-
- **Base Model**: `
|
| 81 |
-
- **Inference Engine**:
|
| 82 |
-
- **Average Generation Time**: 20-40 seconds on CPU (16GB RAM)
|
| 83 |
-
- **Recommended Steps**: 4 (optimized for LCM)
|
| 84 |
|
| 85 |
### Usage Notes
|
| 86 |
-
- For LoRA requests, the system falls back to PyTorch (slightly slower but functional).
|
| 87 |
- All image dimensions must be multiples of 8.
|
| 88 |
- Seed is returned with every response; reuse it to reproduce the same image.
|
| 89 |
""",
|
| 90 |
version="1.0.0",
|
| 91 |
-
contact={
|
| 92 |
-
"name": "Your Name",
|
| 93 |
-
"url": "https://huggingface.co/your-space",
|
| 94 |
-
},
|
| 95 |
)
|
| 96 |
|
| 97 |
app.add_middleware(
|
|
@@ -102,13 +91,11 @@ app.add_middleware(
|
|
| 102 |
)
|
| 103 |
|
| 104 |
# --- Global Variables for Models ---
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
controlnet_pipeline = None # PyTorch ControlNet pipeline
|
| 109 |
tokenizer = None
|
| 110 |
|
| 111 |
-
# --- Available LoRAs (Pre-defined) ---
|
| 112 |
AVAILABLE_LORAS = [
|
| 113 |
{
|
| 114 |
"id": "prithiviraj1710/pixel-art",
|
|
@@ -163,7 +150,6 @@ AVAILABLE_LORAS = [
|
|
| 163 |
|
| 164 |
# --- Helper Functions ---
|
| 165 |
def download_lora_sync(lora_id: str) -> Path:
|
| 166 |
-
"""Download a LoRA from Hugging Face if not already cached (synchronous)."""
|
| 167 |
lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
|
| 168 |
if lora_path.exists():
|
| 169 |
return lora_path
|
|
@@ -176,7 +162,6 @@ def download_lora_sync(lora_id: str) -> Path:
|
|
| 176 |
filename="pytorch_lora_weights.safetensors",
|
| 177 |
cache_dir=str(LORA_CACHE_DIR)
|
| 178 |
)
|
| 179 |
-
# Create a symlink to our expected path for easy future access
|
| 180 |
if not lora_path.exists():
|
| 181 |
os.symlink(downloaded_path, lora_path)
|
| 182 |
return lora_path
|
|
@@ -185,7 +170,6 @@ def download_lora_sync(lora_id: str) -> Path:
|
|
| 185 |
raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
|
| 186 |
|
| 187 |
def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
|
| 188 |
-
"""Apply LoRAs to a PyTorch pipeline and return list of applied LoRAs."""
|
| 189 |
lora_list = []
|
| 190 |
if not lora_ids:
|
| 191 |
return lora_list
|
|
@@ -199,7 +183,6 @@ def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None)
|
|
| 199 |
if len(lora_ids_list) != len(scales_list):
|
| 200 |
raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
|
| 201 |
|
| 202 |
-
# Download and load each LoRA synchronously
|
| 203 |
for lora_id, scale in zip(lora_ids_list, scales_list):
|
| 204 |
lora_path = download_lora_sync(lora_id)
|
| 205 |
pipe.load_lora_weights(str(lora_path))
|
|
@@ -208,74 +191,58 @@ def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None)
|
|
| 208 |
|
| 209 |
return lora_list
|
| 210 |
|
| 211 |
-
def
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
MODEL_PATH,
|
| 218 |
-
ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "INFERENCE_NUM_THREADS": "4"},
|
| 219 |
-
compile=False
|
| 220 |
-
)
|
| 221 |
-
ov_pipeline.reshape(batch_size=1, height=512, width=512, num_images_per_prompt=1)
|
| 222 |
-
ov_pipeline.compile()
|
| 223 |
-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 224 |
-
logger.info("OpenVINO pipeline loaded and compiled.")
|
| 225 |
-
return ov_pipeline
|
| 226 |
-
|
| 227 |
-
def load_torch_pipelines():
|
| 228 |
-
"""Load PyTorch pipelines for LoRA support."""
|
| 229 |
-
global torch_txt2img, torch_img2img
|
| 230 |
-
if torch_txt2img is None:
|
| 231 |
-
logger.info("Loading PyTorch pipelines (for LoRA support)...")
|
| 232 |
-
# Use the original Dreamshaper v7 model
|
| 233 |
-
model_id = "Lykon/dreamshaper-7"
|
| 234 |
-
pipe = StableDiffusionPipeline.from_pretrained(
|
| 235 |
-
model_id,
|
| 236 |
torch_dtype=torch.float32,
|
| 237 |
safety_checker=None
|
| 238 |
)
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
safety_checker=None,
|
| 251 |
feature_extractor=None,
|
| 252 |
)
|
| 253 |
-
|
| 254 |
-
logger.info("
|
| 255 |
-
return
|
| 256 |
|
| 257 |
def load_controlnet_pipeline():
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
if controlnet_pipeline is None:
|
| 261 |
logger.info("Loading ControlNet pipeline...")
|
| 262 |
controlnet = ControlNetModel.from_pretrained(
|
| 263 |
CONTROLNET_ID,
|
| 264 |
torch_dtype=torch.float32
|
| 265 |
)
|
| 266 |
-
|
| 267 |
-
|
| 268 |
controlnet=controlnet,
|
| 269 |
torch_dtype=torch.float32,
|
| 270 |
safety_checker=None
|
| 271 |
)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
logger.info("ControlNet pipeline loaded.")
|
| 275 |
-
return
|
| 276 |
|
| 277 |
def apply_canny_edge(image: Image.Image, low_threshold: int = 100, high_threshold: int = 200) -> Image.Image:
|
| 278 |
-
"""Apply Canny edge detection to an image."""
|
| 279 |
image_np = np.array(image)
|
| 280 |
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
|
| 281 |
edges = cv2.Canny(image_np, low_threshold, high_threshold)
|
|
@@ -283,7 +250,6 @@ def apply_canny_edge(image: Image.Image, low_threshold: int = 100, high_threshol
|
|
| 283 |
return Image.fromarray(edges)
|
| 284 |
|
| 285 |
async def save_upload_file(upload_file: UploadFile) -> Path:
|
| 286 |
-
"""Save an uploaded file to a temporary location."""
|
| 287 |
temp_dir = Path("/tmp/uploads")
|
| 288 |
temp_dir.mkdir(exist_ok=True)
|
| 289 |
file_path = temp_dir / f"{uuid.uuid4()}_{upload_file.filename}"
|
|
@@ -293,13 +259,11 @@ async def save_upload_file(upload_file: UploadFile) -> Path:
|
|
| 293 |
return file_path
|
| 294 |
|
| 295 |
def image_to_base64(image: Image.Image, format: str = "PNG") -> str:
|
| 296 |
-
"""Convert a PIL Image to a base64 string."""
|
| 297 |
buffered = BytesIO()
|
| 298 |
image.save(buffered, format=format)
|
| 299 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 300 |
|
| 301 |
def cleanup_temp_files(*paths: Path):
|
| 302 |
-
"""Remove temporary files."""
|
| 303 |
for path in paths:
|
| 304 |
try:
|
| 305 |
if path.exists():
|
|
@@ -354,26 +318,11 @@ async def text_to_image(
|
|
| 354 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
| 355 |
lora_list = []
|
| 356 |
|
| 357 |
-
|
| 358 |
if lora_ids:
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
try:
|
| 363 |
-
image = pipe_txt2img(
|
| 364 |
-
prompt=prompt,
|
| 365 |
-
negative_prompt=negative_prompt,
|
| 366 |
-
num_inference_steps=steps,
|
| 367 |
-
guidance_scale=guidance_scale,
|
| 368 |
-
generator=generator,
|
| 369 |
-
height=height,
|
| 370 |
-
width=width
|
| 371 |
-
).images[0]
|
| 372 |
-
finally:
|
| 373 |
-
pipe_txt2img.unfuse_lora()
|
| 374 |
-
else:
|
| 375 |
-
# Use fast OpenVINO pipeline
|
| 376 |
-
pipe = load_ov_pipeline()
|
| 377 |
image = pipe(
|
| 378 |
prompt=prompt,
|
| 379 |
negative_prompt=negative_prompt,
|
|
@@ -383,6 +332,9 @@ async def text_to_image(
|
|
| 383 |
height=height,
|
| 384 |
width=width
|
| 385 |
).images[0]
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
output_filename = f"txt2img_{uuid.uuid4()}.png"
|
| 388 |
output_path = OUTPUT_DIR / output_filename
|
|
@@ -452,33 +404,24 @@ async def image_to_image(
|
|
| 452 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
| 453 |
lora_list = []
|
| 454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
try:
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
guidance_scale=guidance_scale,
|
| 466 |
-
generator=generator
|
| 467 |
-
).images[0]
|
| 468 |
-
pipe_img2img.unfuse_lora()
|
| 469 |
-
else:
|
| 470 |
-
pipe = load_ov_pipeline()
|
| 471 |
-
output_image = pipe(
|
| 472 |
-
prompt=prompt,
|
| 473 |
-
image=init_image,
|
| 474 |
-
strength=strength,
|
| 475 |
-
negative_prompt=negative_prompt,
|
| 476 |
-
num_inference_steps=steps,
|
| 477 |
-
guidance_scale=guidance_scale,
|
| 478 |
-
generator=generator
|
| 479 |
-
).images[0]
|
| 480 |
finally:
|
| 481 |
cleanup_temp_files(input_path)
|
|
|
|
|
|
|
| 482 |
|
| 483 |
output_filename = f"img2img_{uuid.uuid4()}.png"
|
| 484 |
output_path = OUTPUT_DIR / output_filename
|
|
@@ -617,20 +560,25 @@ async def get_image(filename: str):
|
|
| 617 |
summary="Health check endpoint"
|
| 618 |
)
|
| 619 |
async def health_check():
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
return {
|
| 622 |
-
"status":
|
| 623 |
-
"base_model_loaded":
|
| 624 |
"available_loras": len(AVAILABLE_LORAS)
|
| 625 |
}
|
| 626 |
|
| 627 |
@app.on_event("startup")
|
| 628 |
async def startup_event():
|
| 629 |
-
"
|
| 630 |
-
logger.info("Starting up, pre-loading OpenVINO model...")
|
| 631 |
try:
|
| 632 |
-
|
| 633 |
-
logger.info("
|
| 634 |
except Exception as e:
|
| 635 |
-
logger.error(f"Failed to load
|
| 636 |
-
# PyTorch pipelines are loaded on-demand to save memory
|
|
|
|
| 16 |
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
from pydantic import BaseModel, Field
|
| 18 |
|
|
|
|
| 19 |
from diffusers import (
|
| 20 |
+
DiffusionPipeline,
|
| 21 |
StableDiffusionImg2ImgPipeline,
|
| 22 |
StableDiffusionControlNetPipeline,
|
| 23 |
ControlNetModel,
|
| 24 |
LCMScheduler,
|
| 25 |
)
|
|
|
|
| 26 |
from transformers import CLIPTokenizer
|
| 27 |
|
| 28 |
# Configure logging
|
|
|
|
| 30 |
logger = logging.getLogger(__name__)
|
| 31 |
|
| 32 |
# --- Configuration ---
|
| 33 |
+
MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
|
| 34 |
CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
|
| 35 |
OUTPUT_DIR = Path("/tmp/outputs")
|
| 36 |
LORA_CACHE_DIR = Path("/app/models/loras")
|
|
|
|
| 40 |
|
| 41 |
# --- Pydantic Models for API Documentation ---
|
| 42 |
class GenerationResponse(BaseModel):
|
|
|
|
| 43 |
status: Literal["success"] = "success"
|
| 44 |
message: str = "Image generated successfully"
|
| 45 |
+
image_base64: Optional[str] = Field(None)
|
| 46 |
+
image_url: Optional[str] = Field(None)
|
| 47 |
+
seed: int = Field(...)
|
| 48 |
+
parameters: dict = Field(...)
|
| 49 |
|
| 50 |
class ErrorResponse(BaseModel):
|
|
|
|
| 51 |
status: Literal["error"] = "error"
|
| 52 |
message: str
|
| 53 |
detail: Optional[str] = None
|
| 54 |
|
| 55 |
class LoRAInfo(BaseModel):
|
| 56 |
+
id: str
|
| 57 |
+
name: str
|
| 58 |
+
description: str
|
| 59 |
+
suggested_strength: float
|
| 60 |
+
keywords: List[str]
|
|
|
|
| 61 |
|
| 62 |
# --- FastAPI App Setup ---
|
| 63 |
app = FastAPI(
|
| 64 |
title="LCM Dreamshaper v7 Image Generation API",
|
| 65 |
description="""
|
| 66 |
+
## Fast, CPU-optimized image generation using LCM Dreamshaper v7
|
| 67 |
|
| 68 |
This API provides:
|
| 69 |
+
- **`/generate`** - Text-to-image generation (~20-30s on CPU)
|
| 70 |
+
- **`/img2img`** - Image-to-image transformation
|
| 71 |
+
- **`/controlnet`** - Generate with structural guidance using ControlNet
|
| 72 |
- **`/loras`** - List available style LoRAs
|
| 73 |
|
| 74 |
### Model Information
|
| 75 |
+
- **Base Model**: `SimianLuo/LCM_Dreamshaper_v7` (1B parameters, 4-step LCM)
|
| 76 |
+
- **Inference Engine**: PyTorch with LCM Scheduler
|
| 77 |
+
- **Average Generation Time**: 20-40 seconds on CPU (16GB RAM)
|
|
|
|
| 78 |
|
| 79 |
### Usage Notes
|
|
|
|
| 80 |
- All image dimensions must be multiples of 8.
|
| 81 |
- Seed is returned with every response; reuse it to reproduce the same image.
|
| 82 |
""",
|
| 83 |
version="1.0.0",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
)
|
| 85 |
|
| 86 |
app.add_middleware(
|
|
|
|
| 91 |
)
|
| 92 |
|
| 93 |
# --- Global Variables for Models ---
|
| 94 |
+
txt2img_pipe = None
|
| 95 |
+
img2img_pipe = None
|
| 96 |
+
controlnet_pipe = None
|
|
|
|
| 97 |
tokenizer = None
|
| 98 |
|
|
|
|
| 99 |
AVAILABLE_LORAS = [
|
| 100 |
{
|
| 101 |
"id": "prithiviraj1710/pixel-art",
|
|
|
|
| 150 |
|
| 151 |
# --- Helper Functions ---
|
| 152 |
def download_lora_sync(lora_id: str) -> Path:
|
|
|
|
| 153 |
lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
|
| 154 |
if lora_path.exists():
|
| 155 |
return lora_path
|
|
|
|
| 162 |
filename="pytorch_lora_weights.safetensors",
|
| 163 |
cache_dir=str(LORA_CACHE_DIR)
|
| 164 |
)
|
|
|
|
| 165 |
if not lora_path.exists():
|
| 166 |
os.symlink(downloaded_path, lora_path)
|
| 167 |
return lora_path
|
|
|
|
| 170 |
raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
|
| 171 |
|
| 172 |
def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
|
|
|
|
| 173 |
lora_list = []
|
| 174 |
if not lora_ids:
|
| 175 |
return lora_list
|
|
|
|
| 183 |
if len(lora_ids_list) != len(scales_list):
|
| 184 |
raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
|
| 185 |
|
|
|
|
| 186 |
for lora_id, scale in zip(lora_ids_list, scales_list):
|
| 187 |
lora_path = download_lora_sync(lora_id)
|
| 188 |
pipe.load_lora_weights(str(lora_path))
|
|
|
|
| 191 |
|
| 192 |
return lora_list
|
| 193 |
|
| 194 |
+
def load_txt2img_pipeline():
|
| 195 |
+
global txt2img_pipe, tokenizer
|
| 196 |
+
if txt2img_pipe is None:
|
| 197 |
+
logger.info(f"Loading text-to-image pipeline from {MODEL_ID}...")
|
| 198 |
+
txt2img_pipe = DiffusionPipeline.from_pretrained(
|
| 199 |
+
MODEL_ID,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
torch_dtype=torch.float32,
|
| 201 |
safety_checker=None
|
| 202 |
)
|
| 203 |
+
txt2img_pipe.to("cpu")
|
| 204 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 205 |
+
logger.info("Text-to-image pipeline loaded.")
|
| 206 |
+
return txt2img_pipe
|
| 207 |
+
|
| 208 |
+
def load_img2img_pipeline():
|
| 209 |
+
global img2img_pipe
|
| 210 |
+
if img2img_pipe is None:
|
| 211 |
+
logger.info("Loading image-to-image pipeline...")
|
| 212 |
+
txt2img = load_txt2img_pipeline()
|
| 213 |
+
img2img_pipe = StableDiffusionImg2ImgPipeline(
|
| 214 |
+
vae=txt2img.vae,
|
| 215 |
+
text_encoder=txt2img.text_encoder,
|
| 216 |
+
tokenizer=txt2img.tokenizer,
|
| 217 |
+
unet=txt2img.unet,
|
| 218 |
+
scheduler=txt2img.scheduler,
|
| 219 |
safety_checker=None,
|
| 220 |
feature_extractor=None,
|
| 221 |
)
|
| 222 |
+
img2img_pipe.to("cpu")
|
| 223 |
+
logger.info("Image-to-image pipeline loaded.")
|
| 224 |
+
return img2img_pipe
|
| 225 |
|
| 226 |
def load_controlnet_pipeline():
|
| 227 |
+
global controlnet_pipe
|
| 228 |
+
if controlnet_pipe is None:
|
|
|
|
| 229 |
logger.info("Loading ControlNet pipeline...")
|
| 230 |
controlnet = ControlNetModel.from_pretrained(
|
| 231 |
CONTROLNET_ID,
|
| 232 |
torch_dtype=torch.float32
|
| 233 |
)
|
| 234 |
+
controlnet_pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 235 |
+
MODEL_ID,
|
| 236 |
controlnet=controlnet,
|
| 237 |
torch_dtype=torch.float32,
|
| 238 |
safety_checker=None
|
| 239 |
)
|
| 240 |
+
controlnet_pipe.scheduler = LCMScheduler.from_config(controlnet_pipe.scheduler.config)
|
| 241 |
+
controlnet_pipe.to("cpu")
|
| 242 |
logger.info("ControlNet pipeline loaded.")
|
| 243 |
+
return controlnet_pipe
|
| 244 |
|
| 245 |
def apply_canny_edge(image: Image.Image, low_threshold: int = 100, high_threshold: int = 200) -> Image.Image:
|
|
|
|
| 246 |
image_np = np.array(image)
|
| 247 |
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
|
| 248 |
edges = cv2.Canny(image_np, low_threshold, high_threshold)
|
|
|
|
| 250 |
return Image.fromarray(edges)
|
| 251 |
|
| 252 |
async def save_upload_file(upload_file: UploadFile) -> Path:
|
|
|
|
| 253 |
temp_dir = Path("/tmp/uploads")
|
| 254 |
temp_dir.mkdir(exist_ok=True)
|
| 255 |
file_path = temp_dir / f"{uuid.uuid4()}_{upload_file.filename}"
|
|
|
|
| 259 |
return file_path
|
| 260 |
|
| 261 |
def image_to_base64(image: Image.Image, format: str = "PNG") -> str:
|
|
|
|
| 262 |
buffered = BytesIO()
|
| 263 |
image.save(buffered, format=format)
|
| 264 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 265 |
|
| 266 |
def cleanup_temp_files(*paths: Path):
|
|
|
|
| 267 |
for path in paths:
|
| 268 |
try:
|
| 269 |
if path.exists():
|
|
|
|
| 318 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
| 319 |
lora_list = []
|
| 320 |
|
| 321 |
+
pipe = load_txt2img_pipeline()
|
| 322 |
if lora_ids:
|
| 323 |
+
lora_list = apply_loras_to_pipe(pipe, lora_ids, lora_scales)
|
| 324 |
+
|
| 325 |
+
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
image = pipe(
|
| 327 |
prompt=prompt,
|
| 328 |
negative_prompt=negative_prompt,
|
|
|
|
| 332 |
height=height,
|
| 333 |
width=width
|
| 334 |
).images[0]
|
| 335 |
+
finally:
|
| 336 |
+
if lora_ids:
|
| 337 |
+
pipe.unfuse_lora()
|
| 338 |
|
| 339 |
output_filename = f"txt2img_{uuid.uuid4()}.png"
|
| 340 |
output_path = OUTPUT_DIR / output_filename
|
|
|
|
| 404 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
| 405 |
lora_list = []
|
| 406 |
|
| 407 |
+
pipe = load_img2img_pipeline()
|
| 408 |
+
if lora_ids:
|
| 409 |
+
lora_list = apply_loras_to_pipe(pipe, lora_ids, lora_scales)
|
| 410 |
+
|
| 411 |
try:
|
| 412 |
+
output_image = pipe(
|
| 413 |
+
prompt=prompt,
|
| 414 |
+
image=init_image,
|
| 415 |
+
strength=strength,
|
| 416 |
+
negative_prompt=negative_prompt,
|
| 417 |
+
num_inference_steps=steps,
|
| 418 |
+
guidance_scale=guidance_scale,
|
| 419 |
+
generator=generator
|
| 420 |
+
).images[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
finally:
|
| 422 |
cleanup_temp_files(input_path)
|
| 423 |
+
if lora_ids:
|
| 424 |
+
pipe.unfuse_lora()
|
| 425 |
|
| 426 |
output_filename = f"img2img_{uuid.uuid4()}.png"
|
| 427 |
output_path = OUTPUT_DIR / output_filename
|
|
|
|
| 560 |
summary="Health check endpoint"
|
| 561 |
)
|
| 562 |
async def health_check():
|
| 563 |
+
try:
|
| 564 |
+
_ = load_txt2img_pipeline()
|
| 565 |
+
status = "healthy"
|
| 566 |
+
model_loaded = True
|
| 567 |
+
except:
|
| 568 |
+
status = "degraded"
|
| 569 |
+
model_loaded = False
|
| 570 |
+
|
| 571 |
return {
|
| 572 |
+
"status": status,
|
| 573 |
+
"base_model_loaded": model_loaded,
|
| 574 |
"available_loras": len(AVAILABLE_LORAS)
|
| 575 |
}
|
| 576 |
|
| 577 |
@app.on_event("startup")
|
| 578 |
async def startup_event():
|
| 579 |
+
logger.info("Starting up, pre-loading text-to-image model...")
|
|
|
|
| 580 |
try:
|
| 581 |
+
load_txt2img_pipeline()
|
| 582 |
+
logger.info("Text-to-image model is ready.")
|
| 583 |
except Exception as e:
|
| 584 |
+
logger.error(f"Failed to pre-load model: {e}")
|
|
|