sidmaz666 commited on
Commit
49a8215
·
verified ·
1 Parent(s): 42ec191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -141
app.py CHANGED
@@ -16,15 +16,13 @@ from fastapi.responses import FileResponse
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel, Field
18
 
19
- # Diffusers imports
20
  from diffusers import (
21
- StableDiffusionPipeline,
22
  StableDiffusionImg2ImgPipeline,
23
  StableDiffusionControlNetPipeline,
24
  ControlNetModel,
25
  LCMScheduler,
26
  )
27
- from optimum.intel import OVStableDiffusionPipeline
28
  from transformers import CLIPTokenizer
29
 
30
  # Configure logging
@@ -32,7 +30,7 @@ logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
34
  # --- Configuration ---
35
- MODEL_PATH = os.environ.get("OV_MODEL_PATH", "/app/models/LCM-dreamshaper-v7-openvino")
36
  CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
37
  OUTPUT_DIR = Path("/tmp/outputs")
38
  LORA_CACHE_DIR = Path("/app/models/loras")
@@ -42,56 +40,47 @@ LORA_CACHE_DIR.mkdir(parents=True, exist_ok=True)
42
 
43
  # --- Pydantic Models for API Documentation ---
44
  class GenerationResponse(BaseModel):
45
- """Response after successful image generation."""
46
  status: Literal["success"] = "success"
47
  message: str = "Image generated successfully"
48
- image_base64: Optional[str] = Field(None, description="Base64 encoded image (optional)")
49
- image_url: Optional[str] = Field(None, description="Relative URL to download the image")
50
- seed: int = Field(..., description="The seed used for generation (for reproducibility)")
51
- parameters: dict = Field(..., description="All parameters used for this generation")
52
 
53
  class ErrorResponse(BaseModel):
54
- """Standard error response."""
55
  status: Literal["error"] = "error"
56
  message: str
57
  detail: Optional[str] = None
58
 
59
  class LoRAInfo(BaseModel):
60
- """Information about a LoRA."""
61
- id: str = Field(..., description="Hugging Face repository ID")
62
- name: str = Field(..., description="Display name")
63
- description: str = Field(..., description="What the LoRA does")
64
- suggested_strength: float = Field(..., description="Recommended strength")
65
- keywords: List[str] = Field(..., description="Suggested prompt keywords")
66
 
67
  # --- FastAPI App Setup ---
68
  app = FastAPI(
69
  title="LCM Dreamshaper v7 Image Generation API",
70
  description="""
71
- ## Fast, CPU-optimized image generation using LCM Dreamshaper v7 (OpenVINO)
72
 
73
  This API provides:
74
- - **`/generate`** - Text-to-image generation (OpenVINO optimized, ~20-30s)
75
- - **`/img2img`** - Image-to-image transformation (OpenVINO optimized)
76
- - **`/controlnet`** - Generate with structural guidance using ControlNet (PyTorch fallback)
77
  - **`/loras`** - List available style LoRAs
78
 
79
  ### Model Information
80
- - **Base Model**: `rupeshs/LCM-dreamshaper-v7-openvino` (1B parameters, 4-step LCM)
81
- - **Inference Engine**: OpenVINO for basic generation; PyTorch for LoRAs/ControlNet
82
- - **Average Generation Time**: 20-40 seconds on CPU (16GB RAM) for txt2img/img2img
83
- - **Recommended Steps**: 4 (optimized for LCM)
84
 
85
  ### Usage Notes
86
- - For LoRA requests, the system falls back to PyTorch (slightly slower but functional).
87
  - All image dimensions must be multiples of 8.
88
  - Seed is returned with every response; reuse it to reproduce the same image.
89
  """,
90
  version="1.0.0",
91
- contact={
92
- "name": "Your Name",
93
- "url": "https://huggingface.co/your-space",
94
- },
95
  )
96
 
97
  app.add_middleware(
@@ -102,13 +91,11 @@ app.add_middleware(
102
  )
103
 
104
  # --- Global Variables for Models ---
105
- ov_pipeline = None # OpenVINO pipeline (fast, no LoRA)
106
- torch_txt2img = None # PyTorch txt2img pipeline (for LoRA)
107
- torch_img2img = None # PyTorch img2img pipeline (for LoRA)
108
- controlnet_pipeline = None # PyTorch ControlNet pipeline
109
  tokenizer = None
110
 
111
- # --- Available LoRAs (Pre-defined) ---
112
  AVAILABLE_LORAS = [
113
  {
114
  "id": "prithiviraj1710/pixel-art",
@@ -163,7 +150,6 @@ AVAILABLE_LORAS = [
163
 
164
  # --- Helper Functions ---
165
  def download_lora_sync(lora_id: str) -> Path:
166
- """Download a LoRA from Hugging Face if not already cached (synchronous)."""
167
  lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
168
  if lora_path.exists():
169
  return lora_path
@@ -176,7 +162,6 @@ def download_lora_sync(lora_id: str) -> Path:
176
  filename="pytorch_lora_weights.safetensors",
177
  cache_dir=str(LORA_CACHE_DIR)
178
  )
179
- # Create a symlink to our expected path for easy future access
180
  if not lora_path.exists():
181
  os.symlink(downloaded_path, lora_path)
182
  return lora_path
@@ -185,7 +170,6 @@ def download_lora_sync(lora_id: str) -> Path:
185
  raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
186
 
187
  def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
188
- """Apply LoRAs to a PyTorch pipeline and return list of applied LoRAs."""
189
  lora_list = []
190
  if not lora_ids:
191
  return lora_list
@@ -199,7 +183,6 @@ def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None)
199
  if len(lora_ids_list) != len(scales_list):
200
  raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
201
 
202
- # Download and load each LoRA synchronously
203
  for lora_id, scale in zip(lora_ids_list, scales_list):
204
  lora_path = download_lora_sync(lora_id)
205
  pipe.load_lora_weights(str(lora_path))
@@ -208,74 +191,58 @@ def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None)
208
 
209
  return lora_list
210
 
211
- def load_ov_pipeline():
212
- """Load the OpenVINO-optimized LCM pipeline from local path."""
213
- global ov_pipeline, tokenizer
214
- if ov_pipeline is None:
215
- logger.info(f"Loading OpenVINO pipeline from {MODEL_PATH}...")
216
- ov_pipeline = OVStableDiffusionPipeline.from_pretrained(
217
- MODEL_PATH,
218
- ov_config={"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1", "INFERENCE_NUM_THREADS": "4"},
219
- compile=False
220
- )
221
- ov_pipeline.reshape(batch_size=1, height=512, width=512, num_images_per_prompt=1)
222
- ov_pipeline.compile()
223
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
224
- logger.info("OpenVINO pipeline loaded and compiled.")
225
- return ov_pipeline
226
-
227
- def load_torch_pipelines():
228
- """Load PyTorch pipelines for LoRA support."""
229
- global torch_txt2img, torch_img2img
230
- if torch_txt2img is None:
231
- logger.info("Loading PyTorch pipelines (for LoRA support)...")
232
- # Use the original Dreamshaper v7 model
233
- model_id = "Lykon/dreamshaper-7"
234
- pipe = StableDiffusionPipeline.from_pretrained(
235
- model_id,
236
  torch_dtype=torch.float32,
237
  safety_checker=None
238
  )
239
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
240
- pipe.to("cpu")
241
- torch_txt2img = pipe
242
-
243
- # Create img2img pipeline from same components
244
- torch_img2img = StableDiffusionImg2ImgPipeline(
245
- vae=pipe.vae,
246
- text_encoder=pipe.text_encoder,
247
- tokenizer=pipe.tokenizer,
248
- unet=pipe.unet,
249
- scheduler=pipe.scheduler,
 
 
 
 
 
250
  safety_checker=None,
251
  feature_extractor=None,
252
  )
253
- torch_img2img.to("cpu")
254
- logger.info("PyTorch pipelines loaded.")
255
- return torch_txt2img, torch_img2img
256
 
257
  def load_controlnet_pipeline():
258
- """Load the ControlNet pipeline (PyTorch)."""
259
- global controlnet_pipeline
260
- if controlnet_pipeline is None:
261
  logger.info("Loading ControlNet pipeline...")
262
  controlnet = ControlNetModel.from_pretrained(
263
  CONTROLNET_ID,
264
  torch_dtype=torch.float32
265
  )
266
- controlnet_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
267
- "Lykon/dreamshaper-7",
268
  controlnet=controlnet,
269
  torch_dtype=torch.float32,
270
  safety_checker=None
271
  )
272
- controlnet_pipeline.scheduler = LCMScheduler.from_config(controlnet_pipeline.scheduler.config)
273
- controlnet_pipeline.to("cpu")
274
  logger.info("ControlNet pipeline loaded.")
275
- return controlnet_pipeline
276
 
277
  def apply_canny_edge(image: Image.Image, low_threshold: int = 100, high_threshold: int = 200) -> Image.Image:
278
- """Apply Canny edge detection to an image."""
279
  image_np = np.array(image)
280
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
281
  edges = cv2.Canny(image_np, low_threshold, high_threshold)
@@ -283,7 +250,6 @@ def apply_canny_edge(image: Image.Image, low_threshold: int = 100, high_threshol
283
  return Image.fromarray(edges)
284
 
285
  async def save_upload_file(upload_file: UploadFile) -> Path:
286
- """Save an uploaded file to a temporary location."""
287
  temp_dir = Path("/tmp/uploads")
288
  temp_dir.mkdir(exist_ok=True)
289
  file_path = temp_dir / f"{uuid.uuid4()}_{upload_file.filename}"
@@ -293,13 +259,11 @@ async def save_upload_file(upload_file: UploadFile) -> Path:
293
  return file_path
294
 
295
  def image_to_base64(image: Image.Image, format: str = "PNG") -> str:
296
- """Convert a PIL Image to a base64 string."""
297
  buffered = BytesIO()
298
  image.save(buffered, format=format)
299
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
300
 
301
  def cleanup_temp_files(*paths: Path):
302
- """Remove temporary files."""
303
  for path in paths:
304
  try:
305
  if path.exists():
@@ -354,26 +318,11 @@ async def text_to_image(
354
  generator = torch.Generator(device="cpu").manual_seed(seed)
355
  lora_list = []
356
 
357
- # Choose pipeline based on LoRA presence
358
  if lora_ids:
359
- # Use PyTorch pipeline with LoRA support
360
- pipe_txt2img, _ = load_torch_pipelines()
361
- lora_list = apply_loras_to_pipe(pipe_txt2img, lora_ids, lora_scales)
362
- try:
363
- image = pipe_txt2img(
364
- prompt=prompt,
365
- negative_prompt=negative_prompt,
366
- num_inference_steps=steps,
367
- guidance_scale=guidance_scale,
368
- generator=generator,
369
- height=height,
370
- width=width
371
- ).images[0]
372
- finally:
373
- pipe_txt2img.unfuse_lora()
374
- else:
375
- # Use fast OpenVINO pipeline
376
- pipe = load_ov_pipeline()
377
  image = pipe(
378
  prompt=prompt,
379
  negative_prompt=negative_prompt,
@@ -383,6 +332,9 @@ async def text_to_image(
383
  height=height,
384
  width=width
385
  ).images[0]
 
 
 
386
 
387
  output_filename = f"txt2img_{uuid.uuid4()}.png"
388
  output_path = OUTPUT_DIR / output_filename
@@ -452,33 +404,24 @@ async def image_to_image(
452
  generator = torch.Generator(device="cpu").manual_seed(seed)
453
  lora_list = []
454
 
 
 
 
 
455
  try:
456
- if lora_ids:
457
- _, pipe_img2img = load_torch_pipelines()
458
- lora_list = apply_loras_to_pipe(pipe_img2img, lora_ids, lora_scales)
459
- output_image = pipe_img2img(
460
- prompt=prompt,
461
- image=init_image,
462
- strength=strength,
463
- negative_prompt=negative_prompt,
464
- num_inference_steps=steps,
465
- guidance_scale=guidance_scale,
466
- generator=generator
467
- ).images[0]
468
- pipe_img2img.unfuse_lora()
469
- else:
470
- pipe = load_ov_pipeline()
471
- output_image = pipe(
472
- prompt=prompt,
473
- image=init_image,
474
- strength=strength,
475
- negative_prompt=negative_prompt,
476
- num_inference_steps=steps,
477
- guidance_scale=guidance_scale,
478
- generator=generator
479
- ).images[0]
480
  finally:
481
  cleanup_temp_files(input_path)
 
 
482
 
483
  output_filename = f"img2img_{uuid.uuid4()}.png"
484
  output_path = OUTPUT_DIR / output_filename
@@ -617,20 +560,25 @@ async def get_image(filename: str):
617
  summary="Health check endpoint"
618
  )
619
  async def health_check():
620
- model_index_exists = (Path(MODEL_PATH) / "model_index.json").exists()
 
 
 
 
 
 
 
621
  return {
622
- "status": "healthy" if model_index_exists else "degraded",
623
- "base_model_loaded": model_index_exists,
624
  "available_loras": len(AVAILABLE_LORAS)
625
  }
626
 
627
  @app.on_event("startup")
628
  async def startup_event():
629
- """Pre-load OpenVINO model to reduce first-request latency."""
630
- logger.info("Starting up, pre-loading OpenVINO model...")
631
  try:
632
- load_ov_pipeline()
633
- logger.info("OpenVINO pipeline ready.")
634
  except Exception as e:
635
- logger.error(f"Failed to load OpenVINO pipeline: {e}")
636
- # PyTorch pipelines are loaded on-demand to save memory
 
16
  from fastapi.middleware.cors import CORSMiddleware
17
  from pydantic import BaseModel, Field
18
 
 
19
  from diffusers import (
20
+ DiffusionPipeline,
21
  StableDiffusionImg2ImgPipeline,
22
  StableDiffusionControlNetPipeline,
23
  ControlNetModel,
24
  LCMScheduler,
25
  )
 
26
  from transformers import CLIPTokenizer
27
 
28
  # Configure logging
 
30
  logger = logging.getLogger(__name__)
31
 
32
  # --- Configuration ---
33
+ MODEL_ID = "SimianLuo/LCM_Dreamshaper_v7"
34
  CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
35
  OUTPUT_DIR = Path("/tmp/outputs")
36
  LORA_CACHE_DIR = Path("/app/models/loras")
 
40
 
41
  # --- Pydantic Models for API Documentation ---
42
  class GenerationResponse(BaseModel):
 
43
  status: Literal["success"] = "success"
44
  message: str = "Image generated successfully"
45
+ image_base64: Optional[str] = Field(None)
46
+ image_url: Optional[str] = Field(None)
47
+ seed: int = Field(...)
48
+ parameters: dict = Field(...)
49
 
50
  class ErrorResponse(BaseModel):
 
51
  status: Literal["error"] = "error"
52
  message: str
53
  detail: Optional[str] = None
54
 
55
  class LoRAInfo(BaseModel):
56
+ id: str
57
+ name: str
58
+ description: str
59
+ suggested_strength: float
60
+ keywords: List[str]
 
61
 
62
  # --- FastAPI App Setup ---
63
  app = FastAPI(
64
  title="LCM Dreamshaper v7 Image Generation API",
65
  description="""
66
+ ## Fast, CPU-optimized image generation using LCM Dreamshaper v7
67
 
68
  This API provides:
69
+ - **`/generate`** - Text-to-image generation (~20-30s on CPU)
70
+ - **`/img2img`** - Image-to-image transformation
71
+ - **`/controlnet`** - Generate with structural guidance using ControlNet
72
  - **`/loras`** - List available style LoRAs
73
 
74
  ### Model Information
75
+ - **Base Model**: `SimianLuo/LCM_Dreamshaper_v7` (1B parameters, 4-step LCM)
76
+ - **Inference Engine**: PyTorch with LCM Scheduler
77
+ - **Average Generation Time**: 20-40 seconds on CPU (16GB RAM)
 
78
 
79
  ### Usage Notes
 
80
  - All image dimensions must be multiples of 8.
81
  - Seed is returned with every response; reuse it to reproduce the same image.
82
  """,
83
  version="1.0.0",
 
 
 
 
84
  )
85
 
86
  app.add_middleware(
 
91
  )
92
 
93
  # --- Global Variables for Models ---
94
+ txt2img_pipe = None
95
+ img2img_pipe = None
96
+ controlnet_pipe = None
 
97
  tokenizer = None
98
 
 
99
  AVAILABLE_LORAS = [
100
  {
101
  "id": "prithiviraj1710/pixel-art",
 
150
 
151
  # --- Helper Functions ---
152
  def download_lora_sync(lora_id: str) -> Path:
 
153
  lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
154
  if lora_path.exists():
155
  return lora_path
 
162
  filename="pytorch_lora_weights.safetensors",
163
  cache_dir=str(LORA_CACHE_DIR)
164
  )
 
165
  if not lora_path.exists():
166
  os.symlink(downloaded_path, lora_path)
167
  return lora_path
 
170
  raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
171
 
172
  def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
 
173
  lora_list = []
174
  if not lora_ids:
175
  return lora_list
 
183
  if len(lora_ids_list) != len(scales_list):
184
  raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
185
 
 
186
  for lora_id, scale in zip(lora_ids_list, scales_list):
187
  lora_path = download_lora_sync(lora_id)
188
  pipe.load_lora_weights(str(lora_path))
 
191
 
192
  return lora_list
193
 
194
+ def load_txt2img_pipeline():
195
+ global txt2img_pipe, tokenizer
196
+ if txt2img_pipe is None:
197
+ logger.info(f"Loading text-to-image pipeline from {MODEL_ID}...")
198
+ txt2img_pipe = DiffusionPipeline.from_pretrained(
199
+ MODEL_ID,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  torch_dtype=torch.float32,
201
  safety_checker=None
202
  )
203
+ txt2img_pipe.to("cpu")
204
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
205
+ logger.info("Text-to-image pipeline loaded.")
206
+ return txt2img_pipe
207
+
208
+ def load_img2img_pipeline():
209
+ global img2img_pipe
210
+ if img2img_pipe is None:
211
+ logger.info("Loading image-to-image pipeline...")
212
+ txt2img = load_txt2img_pipeline()
213
+ img2img_pipe = StableDiffusionImg2ImgPipeline(
214
+ vae=txt2img.vae,
215
+ text_encoder=txt2img.text_encoder,
216
+ tokenizer=txt2img.tokenizer,
217
+ unet=txt2img.unet,
218
+ scheduler=txt2img.scheduler,
219
  safety_checker=None,
220
  feature_extractor=None,
221
  )
222
+ img2img_pipe.to("cpu")
223
+ logger.info("Image-to-image pipeline loaded.")
224
+ return img2img_pipe
225
 
226
  def load_controlnet_pipeline():
227
+ global controlnet_pipe
228
+ if controlnet_pipe is None:
 
229
  logger.info("Loading ControlNet pipeline...")
230
  controlnet = ControlNetModel.from_pretrained(
231
  CONTROLNET_ID,
232
  torch_dtype=torch.float32
233
  )
234
+ controlnet_pipe = StableDiffusionControlNetPipeline.from_pretrained(
235
+ MODEL_ID,
236
  controlnet=controlnet,
237
  torch_dtype=torch.float32,
238
  safety_checker=None
239
  )
240
+ controlnet_pipe.scheduler = LCMScheduler.from_config(controlnet_pipe.scheduler.config)
241
+ controlnet_pipe.to("cpu")
242
  logger.info("ControlNet pipeline loaded.")
243
+ return controlnet_pipe
244
 
245
  def apply_canny_edge(image: Image.Image, low_threshold: int = 100, high_threshold: int = 200) -> Image.Image:
 
246
  image_np = np.array(image)
247
  image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2GRAY)
248
  edges = cv2.Canny(image_np, low_threshold, high_threshold)
 
250
  return Image.fromarray(edges)
251
 
252
  async def save_upload_file(upload_file: UploadFile) -> Path:
 
253
  temp_dir = Path("/tmp/uploads")
254
  temp_dir.mkdir(exist_ok=True)
255
  file_path = temp_dir / f"{uuid.uuid4()}_{upload_file.filename}"
 
259
  return file_path
260
 
261
  def image_to_base64(image: Image.Image, format: str = "PNG") -> str:
 
262
  buffered = BytesIO()
263
  image.save(buffered, format=format)
264
  return base64.b64encode(buffered.getvalue()).decode("utf-8")
265
 
266
  def cleanup_temp_files(*paths: Path):
 
267
  for path in paths:
268
  try:
269
  if path.exists():
 
318
  generator = torch.Generator(device="cpu").manual_seed(seed)
319
  lora_list = []
320
 
321
+ pipe = load_txt2img_pipeline()
322
  if lora_ids:
323
+ lora_list = apply_loras_to_pipe(pipe, lora_ids, lora_scales)
324
+
325
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  image = pipe(
327
  prompt=prompt,
328
  negative_prompt=negative_prompt,
 
332
  height=height,
333
  width=width
334
  ).images[0]
335
+ finally:
336
+ if lora_ids:
337
+ pipe.unfuse_lora()
338
 
339
  output_filename = f"txt2img_{uuid.uuid4()}.png"
340
  output_path = OUTPUT_DIR / output_filename
 
404
  generator = torch.Generator(device="cpu").manual_seed(seed)
405
  lora_list = []
406
 
407
+ pipe = load_img2img_pipeline()
408
+ if lora_ids:
409
+ lora_list = apply_loras_to_pipe(pipe, lora_ids, lora_scales)
410
+
411
  try:
412
+ output_image = pipe(
413
+ prompt=prompt,
414
+ image=init_image,
415
+ strength=strength,
416
+ negative_prompt=negative_prompt,
417
+ num_inference_steps=steps,
418
+ guidance_scale=guidance_scale,
419
+ generator=generator
420
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  finally:
422
  cleanup_temp_files(input_path)
423
+ if lora_ids:
424
+ pipe.unfuse_lora()
425
 
426
  output_filename = f"img2img_{uuid.uuid4()}.png"
427
  output_path = OUTPUT_DIR / output_filename
 
560
  summary="Health check endpoint"
561
  )
562
  async def health_check():
563
+ try:
564
+ _ = load_txt2img_pipeline()
565
+ status = "healthy"
566
+ model_loaded = True
567
+ except:
568
+ status = "degraded"
569
+ model_loaded = False
570
+
571
  return {
572
+ "status": status,
573
+ "base_model_loaded": model_loaded,
574
  "available_loras": len(AVAILABLE_LORAS)
575
  }
576
 
577
  @app.on_event("startup")
578
  async def startup_event():
579
+ logger.info("Starting up, pre-loading text-to-image model...")
 
580
  try:
581
+ load_txt2img_pipeline()
582
+ logger.info("Text-to-image model is ready.")
583
  except Exception as e:
584
+ logger.error(f"Failed to pre-load model: {e}")