sidmaz666 commited on
Commit
3b8610c
·
verified ·
1 Parent(s): 81e0ab5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -50
app.py CHANGED
@@ -32,7 +32,7 @@ logging.basicConfig(level=logging.INFO)
32
  logger = logging.getLogger(__name__)
33
 
34
  # --- Configuration ---
35
- MODEL_PATH = os.environ.get("OV_MODEL_PATH", "/app/models/base")
36
  CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
37
  OUTPUT_DIR = Path("/tmp/outputs")
38
  LORA_CACHE_DIR = Path("/app/models/loras")
@@ -162,28 +162,6 @@ AVAILABLE_LORAS = [
162
  ]
163
 
164
  # --- Helper Functions ---
165
- async def download_lora(lora_id: str) -> Path:
166
- """Download a LoRA from Hugging Face if not already cached."""
167
- lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
168
- if lora_path.exists():
169
- return lora_path
170
-
171
- logger.info(f"Downloading LoRA: {lora_id}")
172
- from huggingface_hub import hf_hub_download
173
- try:
174
- downloaded_path = hf_hub_download(
175
- repo_id=lora_id,
176
- filename="pytorch_lora_weights.safetensors",
177
- cache_dir=str(LORA_CACHE_DIR)
178
- )
179
- # Create a symlink to our expected path
180
- if not lora_path.exists():
181
- os.symlink(downloaded_path, lora_path)
182
- return lora_path
183
- except Exception as e:
184
- logger.error(f"Failed to download LoRA {lora_id}: {e}")
185
- raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
186
-
187
  def download_lora_sync(lora_id: str) -> Path:
188
  """Download a LoRA from Hugging Face if not already cached (synchronous)."""
189
  lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
@@ -206,6 +184,30 @@ def download_lora_sync(lora_id: str) -> Path:
206
  logger.error(f"Failed to download LoRA {lora_id}: {e}")
207
  raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def load_ov_pipeline():
210
  """Load the OpenVINO-optimized LCM pipeline from local path."""
211
  global ov_pipeline, tokenizer
@@ -305,30 +307,6 @@ def cleanup_temp_files(*paths: Path):
305
  except Exception as e:
306
  logger.warning(f"Failed to delete {path}: {e}")
307
 
308
- def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
309
- """Apply LoRAs to a PyTorch pipeline and return list of applied LoRAs."""
310
- lora_list = []
311
- if not lora_ids:
312
- return lora_list
313
-
314
- lora_ids_list = [lid.strip() for lid in lora_ids.split(",")]
315
- if lora_scales:
316
- scales_list = [float(s.strip()) for s in lora_scales.split(",")]
317
- else:
318
- scales_list = [1.0] * len(lora_ids_list)
319
-
320
- if len(lora_ids_list) != len(scales_list):
321
- raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
322
-
323
- # Download and load each LoRA synchronously
324
- for lora_id, scale in zip(lora_ids_list, scales_list):
325
- lora_path = download_lora_sync(lora_id)
326
- pipe.load_lora_weights(str(lora_path))
327
- pipe.fuse_lora(lora_scale=scale)
328
- lora_list.append({"id": lora_id, "scale": scale})
329
-
330
- return lora_list
331
-
332
  # --- API Endpoints ---
333
  @app.get("/", include_in_schema=False)
334
  async def root():
@@ -639,10 +617,10 @@ async def get_image(filename: str):
639
  summary="Health check endpoint"
640
  )
641
  async def health_check():
642
- base_model_exists = (Path("/app/models/base") / "openvino_model.xml").exists()
643
  return {
644
- "status": "healthy" if base_model_exists else "degraded",
645
- "base_model_loaded": base_model_exists,
646
  "available_loras": len(AVAILABLE_LORAS)
647
  }
648
 
 
32
  logger = logging.getLogger(__name__)
33
 
34
  # --- Configuration ---
35
+ MODEL_PATH = os.environ.get("OV_MODEL_PATH", "/app/models/LCM-dreamshaper-v7-openvino")
36
  CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
37
  OUTPUT_DIR = Path("/tmp/outputs")
38
  LORA_CACHE_DIR = Path("/app/models/loras")
 
162
  ]
163
 
164
  # --- Helper Functions ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def download_lora_sync(lora_id: str) -> Path:
166
  """Download a LoRA from Hugging Face if not already cached (synchronous)."""
167
  lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
 
184
  logger.error(f"Failed to download LoRA {lora_id}: {e}")
185
  raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
186
 
187
+ def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
188
+ """Apply LoRAs to a PyTorch pipeline and return list of applied LoRAs."""
189
+ lora_list = []
190
+ if not lora_ids:
191
+ return lora_list
192
+
193
+ lora_ids_list = [lid.strip() for lid in lora_ids.split(",")]
194
+ if lora_scales:
195
+ scales_list = [float(s.strip()) for s in lora_scales.split(",")]
196
+ else:
197
+ scales_list = [1.0] * len(lora_ids_list)
198
+
199
+ if len(lora_ids_list) != len(scales_list):
200
+ raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
201
+
202
+ # Download and load each LoRA synchronously
203
+ for lora_id, scale in zip(lora_ids_list, scales_list):
204
+ lora_path = download_lora_sync(lora_id)
205
+ pipe.load_lora_weights(str(lora_path))
206
+ pipe.fuse_lora(lora_scale=scale)
207
+ lora_list.append({"id": lora_id, "scale": scale})
208
+
209
+ return lora_list
210
+
211
  def load_ov_pipeline():
212
  """Load the OpenVINO-optimized LCM pipeline from local path."""
213
  global ov_pipeline, tokenizer
 
307
  except Exception as e:
308
  logger.warning(f"Failed to delete {path}: {e}")
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  # --- API Endpoints ---
311
  @app.get("/", include_in_schema=False)
312
  async def root():
 
617
  summary="Health check endpoint"
618
  )
619
  async def health_check():
620
+ model_index_exists = (Path(MODEL_PATH) / "model_index.json").exists()
621
  return {
622
+ "status": "healthy" if model_index_exists else "degraded",
623
+ "base_model_loaded": model_index_exists,
624
  "available_loras": len(AVAILABLE_LORAS)
625
  }
626