|
|
import torch |
|
|
try: |
|
|
import fal_client |
|
|
except: |
|
|
fal_client = None |
|
|
|
|
|
from diffusers import AutoPipelineForText2Image, HunyuanVideoPipeline, DiffusionPipeline |
|
|
import json |
|
|
import diffusers |
|
|
from functools import partial |
|
|
import os |
|
|
|
|
|
os.environ['FAL_KEY'] = 'YOUR_API_KEY' |
|
|
|
|
|
def init_multiple_pipelines(pipe_name, pipe_init_kwargs, num_devices, device_id=None): |
|
|
pipelines_dict = [] |
|
|
|
|
|
if device_id is not None: |
|
|
assert num_devices == 1 |
|
|
|
|
|
for i in range(num_devices): |
|
|
actual_device_id = device_id if device_id is not None else i |
|
|
try: |
|
|
pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}') |
|
|
except Exception as e: |
|
|
|
|
|
config = json.load(open(os.path.join(pipe_name, 'model_index.json'))) |
|
|
class_name_str = config['_class_name'] |
|
|
pipeline_class = getattr(diffusers, class_name_str) |
|
|
pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}') |
|
|
|
|
|
|
|
|
|
|
|
pipelines_dict.append(pipeline) |
|
|
return pipelines_dict |
|
|
|
|
|
|
|
|
def init_pipeline_from_names(pipe_names, weight_dtype): |
|
|
pipelines_dict = {} |
|
|
for name in pipe_names: |
|
|
pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype) |
|
|
pipelines_dict[name] = pipeline |
|
|
return pipelines_dict |
|
|
|
|
|
|
|
|
def on_queue_update(update): |
|
|
if isinstance(update, fal_client.InProgress): |
|
|
for log in update.logs: |
|
|
print(log["message"]) |
|
|
|
|
|
def gen_with_api(pipe_names, generation_kwargs): |
|
|
result = fal_client.subscribe( |
|
|
pipe_names, |
|
|
arguments=generation_kwargs, |
|
|
with_logs=True, |
|
|
on_queue_update=on_queue_update, |
|
|
) |
|
|
return result |