|
|
import torch |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import sys |
|
|
import os |
|
|
|
|
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../text_to_sign")) |
|
|
from pipeline import Text2SignPipeline |
|
|
|
|
|
def generate_and_save(prompt, checkpoint_path, output_path, device="cuda"): |
|
|
pipeline = Text2SignPipeline.from_pretrained(checkpoint_path, device=device) |
|
|
with torch.no_grad(): |
|
|
video_frames = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5)[0] |
|
|
|
|
|
fig, axes = plt.subplots(1, len(video_frames), figsize=(2*len(video_frames), 2)) |
|
|
for i, frame in enumerate(video_frames): |
|
|
axes[i].imshow(frame) |
|
|
axes[i].axis('off') |
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path) |
|
|
print(f"Saved filmstrip to {output_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--prompt', type=str, required=True, help='Text prompt to generate sign language video') |
|
|
parser.add_argument('--checkpoint', type=str, default='checkpoint_epoch_70.pt', help='Path to model checkpoint') |
|
|
parser.add_argument('--output', type=str, default='generated_filmstrip.png', help='Output image path') |
|
|
parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu') |
|
|
args = parser.parse_args() |
|
|
generate_and_save(args.prompt, args.checkpoint, args.output, args.device) |
|
|
|