Update app.py
Browse files
app.py
CHANGED
|
@@ -32,7 +32,7 @@ logging.basicConfig(level=logging.INFO)
|
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# --- Configuration ---
|
| 35 |
-
MODEL_PATH = os.environ.get("OV_MODEL_PATH", "/app/models/
|
| 36 |
CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
|
| 37 |
OUTPUT_DIR = Path("/tmp/outputs")
|
| 38 |
LORA_CACHE_DIR = Path("/app/models/loras")
|
|
@@ -162,28 +162,6 @@ AVAILABLE_LORAS = [
|
|
| 162 |
]
|
| 163 |
|
| 164 |
# --- Helper Functions ---
|
| 165 |
-
async def download_lora(lora_id: str) -> Path:
|
| 166 |
-
"""Download a LoRA from Hugging Face if not already cached."""
|
| 167 |
-
lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
|
| 168 |
-
if lora_path.exists():
|
| 169 |
-
return lora_path
|
| 170 |
-
|
| 171 |
-
logger.info(f"Downloading LoRA: {lora_id}")
|
| 172 |
-
from huggingface_hub import hf_hub_download
|
| 173 |
-
try:
|
| 174 |
-
downloaded_path = hf_hub_download(
|
| 175 |
-
repo_id=lora_id,
|
| 176 |
-
filename="pytorch_lora_weights.safetensors",
|
| 177 |
-
cache_dir=str(LORA_CACHE_DIR)
|
| 178 |
-
)
|
| 179 |
-
# Create a symlink to our expected path
|
| 180 |
-
if not lora_path.exists():
|
| 181 |
-
os.symlink(downloaded_path, lora_path)
|
| 182 |
-
return lora_path
|
| 183 |
-
except Exception as e:
|
| 184 |
-
logger.error(f"Failed to download LoRA {lora_id}: {e}")
|
| 185 |
-
raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
|
| 186 |
-
|
| 187 |
def download_lora_sync(lora_id: str) -> Path:
|
| 188 |
"""Download a LoRA from Hugging Face if not already cached (synchronous)."""
|
| 189 |
lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
|
|
@@ -206,6 +184,30 @@ def download_lora_sync(lora_id: str) -> Path:
|
|
| 206 |
logger.error(f"Failed to download LoRA {lora_id}: {e}")
|
| 207 |
raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
def load_ov_pipeline():
|
| 210 |
"""Load the OpenVINO-optimized LCM pipeline from local path."""
|
| 211 |
global ov_pipeline, tokenizer
|
|
@@ -305,30 +307,6 @@ def cleanup_temp_files(*paths: Path):
|
|
| 305 |
except Exception as e:
|
| 306 |
logger.warning(f"Failed to delete {path}: {e}")
|
| 307 |
|
| 308 |
-
def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
|
| 309 |
-
"""Apply LoRAs to a PyTorch pipeline and return list of applied LoRAs."""
|
| 310 |
-
lora_list = []
|
| 311 |
-
if not lora_ids:
|
| 312 |
-
return lora_list
|
| 313 |
-
|
| 314 |
-
lora_ids_list = [lid.strip() for lid in lora_ids.split(",")]
|
| 315 |
-
if lora_scales:
|
| 316 |
-
scales_list = [float(s.strip()) for s in lora_scales.split(",")]
|
| 317 |
-
else:
|
| 318 |
-
scales_list = [1.0] * len(lora_ids_list)
|
| 319 |
-
|
| 320 |
-
if len(lora_ids_list) != len(scales_list):
|
| 321 |
-
raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
|
| 322 |
-
|
| 323 |
-
# Download and load each LoRA synchronously
|
| 324 |
-
for lora_id, scale in zip(lora_ids_list, scales_list):
|
| 325 |
-
lora_path = download_lora_sync(lora_id)
|
| 326 |
-
pipe.load_lora_weights(str(lora_path))
|
| 327 |
-
pipe.fuse_lora(lora_scale=scale)
|
| 328 |
-
lora_list.append({"id": lora_id, "scale": scale})
|
| 329 |
-
|
| 330 |
-
return lora_list
|
| 331 |
-
|
| 332 |
# --- API Endpoints ---
|
| 333 |
@app.get("/", include_in_schema=False)
|
| 334 |
async def root():
|
|
@@ -639,10 +617,10 @@ async def get_image(filename: str):
|
|
| 639 |
summary="Health check endpoint"
|
| 640 |
)
|
| 641 |
async def health_check():
|
| 642 |
-
|
| 643 |
return {
|
| 644 |
-
"status": "healthy" if
|
| 645 |
-
"base_model_loaded":
|
| 646 |
"available_loras": len(AVAILABLE_LORAS)
|
| 647 |
}
|
| 648 |
|
|
|
|
| 32 |
logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# --- Configuration ---
|
| 35 |
+
MODEL_PATH = os.environ.get("OV_MODEL_PATH", "/app/models/LCM-dreamshaper-v7-openvino")
|
| 36 |
CONTROLNET_ID = "lllyasviel/sd-controlnet-canny"
|
| 37 |
OUTPUT_DIR = Path("/tmp/outputs")
|
| 38 |
LORA_CACHE_DIR = Path("/app/models/loras")
|
|
|
|
| 162 |
]
|
| 163 |
|
| 164 |
# --- Helper Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
def download_lora_sync(lora_id: str) -> Path:
|
| 166 |
"""Download a LoRA from Hugging Face if not already cached (synchronous)."""
|
| 167 |
lora_path = LORA_CACHE_DIR / f"{lora_id.replace('/', '_')}.safetensors"
|
|
|
|
| 184 |
logger.error(f"Failed to download LoRA {lora_id}: {e}")
|
| 185 |
raise HTTPException(status_code=400, detail=f"LoRA {lora_id} not found or invalid")
|
| 186 |
|
| 187 |
+
def apply_loras_to_pipe(pipe, lora_ids: str, lora_scales: Optional[str] = None) -> list:
|
| 188 |
+
"""Apply LoRAs to a PyTorch pipeline and return list of applied LoRAs."""
|
| 189 |
+
lora_list = []
|
| 190 |
+
if not lora_ids:
|
| 191 |
+
return lora_list
|
| 192 |
+
|
| 193 |
+
lora_ids_list = [lid.strip() for lid in lora_ids.split(",")]
|
| 194 |
+
if lora_scales:
|
| 195 |
+
scales_list = [float(s.strip()) for s in lora_scales.split(",")]
|
| 196 |
+
else:
|
| 197 |
+
scales_list = [1.0] * len(lora_ids_list)
|
| 198 |
+
|
| 199 |
+
if len(lora_ids_list) != len(scales_list):
|
| 200 |
+
raise HTTPException(status_code=400, detail="Number of LoRA IDs must match number of scales")
|
| 201 |
+
|
| 202 |
+
# Download and load each LoRA synchronously
|
| 203 |
+
for lora_id, scale in zip(lora_ids_list, scales_list):
|
| 204 |
+
lora_path = download_lora_sync(lora_id)
|
| 205 |
+
pipe.load_lora_weights(str(lora_path))
|
| 206 |
+
pipe.fuse_lora(lora_scale=scale)
|
| 207 |
+
lora_list.append({"id": lora_id, "scale": scale})
|
| 208 |
+
|
| 209 |
+
return lora_list
|
| 210 |
+
|
| 211 |
def load_ov_pipeline():
|
| 212 |
"""Load the OpenVINO-optimized LCM pipeline from local path."""
|
| 213 |
global ov_pipeline, tokenizer
|
|
|
|
| 307 |
except Exception as e:
|
| 308 |
logger.warning(f"Failed to delete {path}: {e}")
|
| 309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
# --- API Endpoints ---
|
| 311 |
@app.get("/", include_in_schema=False)
|
| 312 |
async def root():
|
|
|
|
| 617 |
summary="Health check endpoint"
|
| 618 |
)
|
| 619 |
async def health_check():
|
| 620 |
+
model_index_exists = (Path(MODEL_PATH) / "model_index.json").exists()
|
| 621 |
return {
|
| 622 |
+
"status": "healthy" if model_index_exists else "degraded",
|
| 623 |
+
"base_model_loaded": model_index_exists,
|
| 624 |
"available_loras": len(AVAILABLE_LORAS)
|
| 625 |
}
|
| 626 |
|