File size: 1,986 Bytes
9b57ce7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
try:
import fal_client
except:
fal_client = None
from diffusers import AutoPipelineForText2Image, HunyuanVideoPipeline, DiffusionPipeline
import json
import diffusers
from functools import partial
import os
# export FAL_KEY="YOUR_API_KEY"
os.environ['FAL_KEY'] = 'YOUR_API_KEY'
def init_multiple_pipelines(pipe_name, pipe_init_kwargs, num_devices, device_id=None):
pipelines_dict = []
if device_id is not None:
assert num_devices == 1
for i in range(num_devices):
actual_device_id = device_id if device_id is not None else i
try:
pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
except Exception as e:
# try:
config = json.load(open(os.path.join(pipe_name, 'model_index.json')))
class_name_str = config['_class_name']
pipeline_class = getattr(diffusers, class_name_str)
pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
# except Exception as ew:
# print(e)
# pipeline = DiffusionPipeline.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
pipelines_dict.append(pipeline)
return pipelines_dict
def init_pipeline_from_names(pipe_names, weight_dtype):
pipelines_dict = {}
for name in pipe_names:
pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype)
pipelines_dict[name] = pipeline
return pipelines_dict
def on_queue_update(update):
if isinstance(update, fal_client.InProgress):
for log in update.logs:
print(log["message"])
def gen_with_api(pipe_names, generation_kwargs):
result = fal_client.subscribe(
pipe_names,
arguments=generation_kwargs,
with_logs=True,
on_queue_update=on_queue_update,
)
return result |