HPSv3 / generate /utils /utils.py
sdsdgwe's picture
update
9b57ce7
import torch
try:
import fal_client
except:
fal_client = None
from diffusers import AutoPipelineForText2Image, HunyuanVideoPipeline, DiffusionPipeline
import json
import diffusers
from functools import partial
import os
# export FAL_KEY="YOUR_API_KEY"
os.environ['FAL_KEY'] = 'YOUR_API_KEY'
def init_multiple_pipelines(pipe_name, pipe_init_kwargs, num_devices, device_id=None):
pipelines_dict = []
if device_id is not None:
assert num_devices == 1
for i in range(num_devices):
actual_device_id = device_id if device_id is not None else i
try:
pipeline = AutoPipelineForText2Image.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
except Exception as e:
# try:
config = json.load(open(os.path.join(pipe_name, 'model_index.json')))
class_name_str = config['_class_name']
pipeline_class = getattr(diffusers, class_name_str)
pipeline = pipeline_class.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
# except Exception as ew:
# print(e)
# pipeline = DiffusionPipeline.from_pretrained(pipe_name, **pipe_init_kwargs).to(f'cuda:{actual_device_id}')
pipelines_dict.append(pipeline)
return pipelines_dict
def init_pipeline_from_names(pipe_names, weight_dtype):
pipelines_dict = {}
for name in pipe_names:
pipeline = AutoPipelineForText2Image.from_pretrained(name, torch_dtype=weight_dtype)
pipelines_dict[name] = pipeline
return pipelines_dict
def on_queue_update(update):
if isinstance(update, fal_client.InProgress):
for log in update.logs:
print(log["message"])
def gen_with_api(pipe_names, generation_kwargs):
result = fal_client.subscribe(
pipe_names,
arguments=generation_kwargs,
with_logs=True,
on_queue_update=on_queue_update,
)
return result