Spaces:
Running on Zero
Running on Zero
| import os | |
| import sys | |
| import platform | |
| import torch | |
| # -------------LOW VRAM TESTING ------------- | |
| # # only used for debugging, to emulate low-vram graphics cards: | |
| # | |
| #torch.cuda.set_per_process_memory_fraction(0.43) # Limit to 43% of my available VRAM, for testing. | |
| # And/or set maximum split size (in MB) | |
| #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.8' | |
| # -------------- INFO LOGGING ---------------- | |
| print( | |
| f"" | |
| f"[System Info] Python: {platform.python_version():<8} | " | |
| f"PyTorch: {torch.__version__:<8} | " | |
| f"CUDA: {'not available' if not torch.cuda.is_available() else torch.version.cuda}" | |
| ) | |
| import logging | |
| class TritonFilter(logging.Filter):# Custom filter to ignore Triton messages | |
| def filter(self, record): | |
| message = record.getMessage().lower() | |
| triton_phrases = [ | |
| "triton is not available", | |
| "matching triton is not available", | |
| "no module named 'triton'" | |
| ] | |
| return not any(phrase in message for phrase in triton_phrases) | |
| logger = logging.getLogger("trellis") | |
| logger.setLevel(logging.INFO) | |
| logger.propagate = False # Prevent messages from propagating to root | |
| handler = logging.StreamHandler() | |
| handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S')) | |
| #and to our own "trellis" logger: | |
| handler.addFilter(TritonFilter()) | |
| logger.addHandler(handler) | |
| # other scripts can now use this logger by doing 'logger = logging.getLogger("trellis")' | |
| # Configure root logger: | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%H:%M:%S', # Only show time | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| # Apply TritonFilter to all root handlers | |
| root_logger = logging.getLogger() | |
| for handler in root_logger.handlers: | |
| handler.addFilter(TritonFilter()) | |
| # -------------- CMD ARGS PARSE ----------------- | |
| # read command-line arguments, passed into this script when launching it: | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Trellis API server") | |
| parser.add_argument("--precision", | |
| choices=["full", "half", "float32", "float16"], | |
| default="full", | |
| help="Set the size of variables for pipeline, to save VRAM and gain performance") | |
| parser.add_argument("--ip", | |
| type=str, | |
| default="127.0.0.1", | |
| help="Specify the IP address on which the server will listen (default: 127.0.0.1)") | |
| parser.add_argument("--port", | |
| type=int, | |
| default=7960, | |
| help="Specify the port on which the server will listen (default: 7960)") | |
| cmd_args = parser.parse_args() | |
| # -------------- PIPELINE SETUP ---------------- | |
| var_cwd = os.getcwd() | |
| sys.path.append(var_cwd) | |
| print('') | |
| logger.info("Trellis API Server is starting up:") | |
| logger.info("Touching this window will pause it. If it happens, click inside it and press 'Enter' to unpause") | |
| print('') | |
| # Configure environment, BEFORE including trellis pipeline | |
| os.environ['ATTN_BACKEND'] = 'xformers' # or 'flash-attn' | |
| os.environ['SPCONV_ALGO'] = 'native' # or 'auto' | |
| # IMPORTING FROM state_manage AND INITIALIZE THE TRELLIS PIPELINE, | |
| # So, importing it only AFTER all of the above setup: | |
| from api_spz.core.state_manage import state | |
| state.initialize_pipeline(cmd_args.precision) | |
| # -------------- API SERVER SETUP AND LAUNCH ---------------- | |
| import uvicorn | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from api_spz.routes import generation | |
| from version import code_version | |
| async def lifespan(app: FastAPI): | |
| print('') | |
| logger.info(f"Trellis API version {code_version}") | |
| logger.info(f"Trellis API Server is active and listening on {cmd_args.ip}:{cmd_args.port}") | |
| print('') | |
| yield | |
| state.cleanup()#shutdown | |
| app = FastAPI(title="Trellis API", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Add the generation router | |
| app.include_router(generation.router) | |
| if __name__ == "__main__": | |
| uvicorn.run( app, | |
| host=cmd_args.ip, | |
| port=cmd_args.port, | |
| log_level="warning" ) |