File size: 4,435 Bytes
0d775d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import os
import sys
import platform
import torch

# -------------LOW VRAM TESTING -------------


# # only used for debugging, to emulate low-vram graphics cards:
#
#torch.cuda.set_per_process_memory_fraction(0.43)  # Limit to 43% of my available VRAM, for testing.
# And/or set maximum split size (in MB)
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,garbage_collection_threshold:0.8'


# -------------- INFO LOGGING ----------------

print(
    f""
    f"[System Info] Python: {platform.python_version():<8} | "
    f"PyTorch: {torch.__version__:<8} | "
    f"CUDA: {'not available' if not torch.cuda.is_available() else torch.version.cuda}"
)

import logging

class TritonFilter(logging.Filter):# Custom filter to ignore Triton messages
    def filter(self, record):
        message = record.getMessage().lower()
        triton_phrases = [
            "triton is not available",
            "matching triton is not available",
            "no module named 'triton'"
        ]
        return not any(phrase in message for phrase in triton_phrases)
    
logger = logging.getLogger("trellis")
logger.setLevel(logging.INFO)
logger.propagate = False  # Prevent messages from propagating to root
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s',  datefmt='%H:%M:%S'))
#and to our own "trellis" logger:
handler.addFilter(TritonFilter())
logger.addHandler(handler)
# other scripts can now use this logger by doing 'logger = logging.getLogger("trellis")'

# Configure root logger:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%H:%M:%S', # Only show time
    handlers=[logging.StreamHandler()]
)
# Apply TritonFilter to all root handlers
root_logger = logging.getLogger()
for handler in root_logger.handlers:
    handler.addFilter(TritonFilter())



# -------------- CMD ARGS PARSE -----------------

# read command-line arguments, passed into this script when launching it:
import argparse
parser = argparse.ArgumentParser(description="Trellis API server")

parser.add_argument("--precision", 
                    choices=["full", "half", "float32", "float16"], 
                    default="full",
                    help="Set the size of variables for pipeline, to save VRAM and gain performance")

parser.add_argument("--ip", 
                    type=str, 
                    default="127.0.0.1", 
                    help="Specify the IP address on which the server will listen (default: 127.0.0.1)")

parser.add_argument("--port", 
                    type=int, 
                    default=7960, 
                    help="Specify the port on which the server will listen (default: 7960)")

cmd_args = parser.parse_args()


# -------------- PIPELINE SETUP ----------------

var_cwd = os.getcwd()
sys.path.append(var_cwd)

print('')
logger.info("Trellis API Server is starting up:")
logger.info("Touching this window will pause it.  If it happens, click inside it and press 'Enter' to unpause")
print('')

# Configure environment, BEFORE including trellis pipeline
os.environ['ATTN_BACKEND'] = 'xformers'    # or 'flash-attn'
os.environ['SPCONV_ALGO'] = 'native'       # or 'auto'

# IMPORTING FROM state_manage AND INITIALIZE THE TRELLIS PIPELINE,
# So, importing it only AFTER all of the above setup:
from api_spz.core.state_manage import state 
state.initialize_pipeline(cmd_args.precision)


# -------------- API SERVER SETUP AND LAUNCH ----------------

import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from api_spz.routes import generation
from version import code_version

@asynccontextmanager
async def lifespan(app: FastAPI):
    print('')
    logger.info(f"Trellis API version {code_version}")
    logger.info(f"Trellis API Server is active and listening on {cmd_args.ip}:{cmd_args.port}")
    print('')
    yield
    state.cleanup()#shutdown

app = FastAPI(title="Trellis API", lifespan=lifespan)



app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Add the generation router
app.include_router(generation.router)

if __name__ == "__main__":
    uvicorn.run( app,  
                 host=cmd_args.ip,
                 port=cmd_args.port,  
                 log_level="warning" )