Update src/pipeline.py
Browse files- src/pipeline.py +18 -5
src/pipeline.py
CHANGED
|
@@ -1252,10 +1252,9 @@ def load_pipeline() -> StableDiffusionXLPipeline:
|
|
| 1252 |
torch_dtype=torch.float16,
|
| 1253 |
local_files_only=True,
|
| 1254 |
)
|
| 1255 |
-
pipeline.scheduler = UniPCMultistepScheduler.from_config(
|
| 1256 |
-
pipeline.scheduler.config,)
|
| 1257 |
pipeline.to("cuda")
|
| 1258 |
-
|
| 1259 |
config = CompilationConfig.Default()
|
| 1260 |
try:
|
| 1261 |
import xformers
|
|
@@ -1271,7 +1270,7 @@ def load_pipeline() -> StableDiffusionXLPipeline:
|
|
| 1271 |
|
| 1272 |
pipeline = compile(pipeline, config)
|
| 1273 |
for _ in range(4):
|
| 1274 |
-
pipeline(prompt="", num_inference_steps=
|
| 1275 |
|
| 1276 |
return pipeline
|
| 1277 |
|
|
@@ -1285,5 +1284,19 @@ def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> I
|
|
| 1285 |
width=request.width,
|
| 1286 |
height=request.height,
|
| 1287 |
generator=generator,
|
| 1288 |
-
num_inference_steps=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1289 |
).images[0]
|
|
|
|
| 1252 |
torch_dtype=torch.float16,
|
| 1253 |
local_files_only=True,
|
| 1254 |
)
|
| 1255 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config('./src',)
|
|
|
|
| 1256 |
pipeline.to("cuda")
|
| 1257 |
+
pipeline.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16).to('cuda')
|
| 1258 |
config = CompilationConfig.Default()
|
| 1259 |
try:
|
| 1260 |
import xformers
|
|
|
|
| 1270 |
|
| 1271 |
pipeline = compile(pipeline, config)
|
| 1272 |
for _ in range(4):
|
| 1273 |
+
pipeline(prompt="kamala harris", num_inference_steps=20)
|
| 1274 |
|
| 1275 |
return pipeline
|
| 1276 |
|
|
|
|
| 1284 |
width=request.width,
|
| 1285 |
height=request.height,
|
| 1286 |
generator=generator,
|
| 1287 |
+
num_inference_steps=2,
|
| 1288 |
+
).images[0]
|
| 1289 |
+
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
|
| 1293 |
+
generator = Generator(pipeline.device).manual_seed(request.seed) if request.seed else None
|
| 1294 |
+
|
| 1295 |
+
return pipeline(
|
| 1296 |
+
prompt=request.prompt,
|
| 1297 |
+
negative_prompt=request.negative_prompt,
|
| 1298 |
+
width=request.width,
|
| 1299 |
+
height=request.height,
|
| 1300 |
+
generator=generator,
|
| 1301 |
+
num_inference_steps=2,
|
| 1302 |
).images[0]
|