| | import asyncio |
| | import gc |
| | import logging |
| | import os |
| | import random |
| | import threading |
| | from contextlib import asynccontextmanager |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, Optional, Type |
| |
|
| | import torch |
| | from fastapi import FastAPI, HTTPException, Request |
| | from fastapi.concurrency import run_in_threadpool |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import FileResponse |
| | from Pipelines import ModelPipelineInitializer |
| | from pydantic import BaseModel |
| |
|
| | from utils import RequestScopedPipeline, Utils |
| |
|
| |
|
| | @dataclass |
| | class ServerConfigModels: |
| | model: str = "stabilityai/stable-diffusion-3.5-medium" |
| | type_models: str = "t2im" |
| | constructor_pipeline: Optional[Type] = None |
| | custom_pipeline: Optional[Type] = None |
| | components: Optional[Dict[str, Any]] = None |
| | torch_dtype: Optional[torch.dtype] = None |
| | host: str = "0.0.0.0" |
| | port: int = 8500 |
| |
|
| |
|
| | server_config = ServerConfigModels() |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | logging.basicConfig(level=logging.INFO) |
| | app.state.logger = logging.getLogger("diffusers-server") |
| | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True" |
| | os.environ["CUDA_LAUNCH_BLOCKING"] = "0" |
| |
|
| | app.state.total_requests = 0 |
| | app.state.active_inferences = 0 |
| | app.state.metrics_lock = asyncio.Lock() |
| | app.state.metrics_task = None |
| |
|
| | app.state.utils_app = Utils( |
| | host=server_config.host, |
| | port=server_config.port, |
| | ) |
| |
|
| | async def metrics_loop(): |
| | try: |
| | while True: |
| | async with app.state.metrics_lock: |
| | total = app.state.total_requests |
| | active = app.state.active_inferences |
| | app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}") |
| | await asyncio.sleep(5) |
| | except asyncio.CancelledError: |
| | app.state.logger.info("Metrics loop cancelled") |
| | raise |
| |
|
| | app.state.metrics_task = asyncio.create_task(metrics_loop()) |
| |
|
| | try: |
| | yield |
| | finally: |
| | task = app.state.metrics_task |
| | if task: |
| | task.cancel() |
| | try: |
| | await task |
| | except asyncio.CancelledError: |
| | pass |
| |
|
| | try: |
| | stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None) |
| | if callable(stop_fn): |
| | await run_in_threadpool(stop_fn) |
| | except Exception as e: |
| | app.state.logger.warning(f"Error during pipeline shutdown: {e}") |
| |
|
| | app.state.logger.info("Lifespan shutdown complete") |
| |
|
| |
|
| | app = FastAPI(lifespan=lifespan) |
| |
|
| | logger = logging.getLogger("DiffusersServer.Pipelines") |
| |
|
| |
|
| | initializer = ModelPipelineInitializer( |
| | model=server_config.model, |
| | type_models=server_config.type_models, |
| | ) |
| | model_pipeline = initializer.initialize_pipeline() |
| | model_pipeline.start() |
| |
|
| | request_pipe = RequestScopedPipeline(model_pipeline.pipeline) |
| | pipeline_lock = threading.Lock() |
| |
|
| | logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})") |
| |
|
| | app.state.MODEL_INITIALIZER = initializer |
| | app.state.MODEL_PIPELINE = model_pipeline |
| | app.state.REQUEST_PIPE = request_pipe |
| | app.state.PIPELINE_LOCK = pipeline_lock |
| |
|
| |
|
| | class JSONBodyQueryAPI(BaseModel): |
| | model: str | None = None |
| | prompt: str |
| | negative_prompt: str | None = None |
| | num_inference_steps: int = 28 |
| | num_images_per_prompt: int = 1 |
| |
|
| |
|
| | @app.middleware("http") |
| | async def count_requests_middleware(request: Request, call_next): |
| | async with app.state.metrics_lock: |
| | app.state.total_requests += 1 |
| | response = await call_next(request) |
| | return response |
| |
|
| |
|
| | @app.get("/") |
| | async def root(): |
| | return {"message": "Welcome to the Diffusers Server"} |
| |
|
| |
|
| | @app.post("/api/diffusers/inference") |
| | async def api(json: JSONBodyQueryAPI): |
| | prompt = json.prompt |
| | negative_prompt = json.negative_prompt or "" |
| | num_steps = json.num_inference_steps |
| | num_images_per_prompt = json.num_images_per_prompt |
| |
|
| | wrapper = app.state.MODEL_PIPELINE |
| | initializer = app.state.MODEL_INITIALIZER |
| |
|
| | utils_app = app.state.utils_app |
| |
|
| | if not wrapper or not wrapper.pipeline: |
| | raise HTTPException(500, "Model not initialized correctly") |
| | if not prompt.strip(): |
| | raise HTTPException(400, "No prompt provided") |
| |
|
| | def make_generator(): |
| | g = torch.Generator(device=initializer.device) |
| | return g.manual_seed(random.randint(0, 10_000_000)) |
| |
|
| | req_pipe = app.state.REQUEST_PIPE |
| |
|
| | def infer(): |
| | gen = make_generator() |
| | return req_pipe.generate( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | generator=gen, |
| | num_inference_steps=num_steps, |
| | num_images_per_prompt=num_images_per_prompt, |
| | device=initializer.device, |
| | output_type="pil", |
| | ) |
| |
|
| | try: |
| | async with app.state.metrics_lock: |
| | app.state.active_inferences += 1 |
| |
|
| | output = await run_in_threadpool(infer) |
| |
|
| | async with app.state.metrics_lock: |
| | app.state.active_inferences = max(0, app.state.active_inferences - 1) |
| |
|
| | urls = [utils_app.save_image(img) for img in output.images] |
| | return {"response": urls} |
| |
|
| | except Exception as e: |
| | async with app.state.metrics_lock: |
| | app.state.active_inferences = max(0, app.state.active_inferences - 1) |
| | logger.error(f"Error during inference: {e}") |
| | raise HTTPException(500, f"Error in processing: {e}") |
| |
|
| | finally: |
| | if torch.cuda.is_available(): |
| | torch.cuda.synchronize() |
| | torch.cuda.empty_cache() |
| | torch.cuda.reset_peak_memory_stats() |
| | torch.cuda.ipc_collect() |
| | gc.collect() |
| |
|
| |
|
| | @app.get("/images/{filename}") |
| | async def serve_image(filename: str): |
| | utils_app = app.state.utils_app |
| | file_path = os.path.join(utils_app.image_dir, filename) |
| | if not os.path.isfile(file_path): |
| | raise HTTPException(status_code=404, detail="Image not found") |
| | return FileResponse(file_path, media_type="image/png") |
| |
|
| |
|
| | @app.get("/api/status") |
| | async def get_status(): |
| | memory_info = {} |
| | if torch.cuda.is_available(): |
| | memory_allocated = torch.cuda.memory_allocated() / 1024**3 |
| | memory_reserved = torch.cuda.memory_reserved() / 1024**3 |
| | memory_info = { |
| | "memory_allocated_gb": round(memory_allocated, 2), |
| | "memory_reserved_gb": round(memory_reserved, 2), |
| | "device": torch.cuda.get_device_name(0), |
| | } |
| |
|
| | return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info} |
| |
|
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | uvicorn.run(app, host=server_config.host, port=server_config.port) |
| |
|