GRN / test_pipe.py
hanjian.thu123
[bufix] local tested
6a5e9aa
raw
history blame contribute delete
734 Bytes
import torch
from grn_pipeline import GRNPipeline
# 加载
pipe = GRNPipeline.from_pretrained(
model_path='/tmp/weights/9a8a674133266e996d8d56e784a10d67.pth',
vae_path='/tmp/weights/HBQ_tokenizer_64dim_M4.ckpt',
text_encoder_ckpt='/tmp/weights/umt5-xxl',
torch_dtype=torch.bfloat16
)
# 移动到设备
pipe = pipe.to('cuda')
# 生成图像
result = pipe(
prompt="A cute cat playing in the garden",
guidance_scale=3.0,
num_inference_steps=50,
width=1024,
height=1024,
content_type='image',
seed=42
)
image = result.images[0]
import pdb; pdb.set_trace()
# # 生成视频
# result = pipe(
# prompt="A dog chasing a butterfly",
# content_type='video'
# )
# video = result.videos[0]