text2sign / inference.py
xiaruize's picture
Initial commit: add model and code
5e9417b
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
# Add model code to path if needed
sys.path.append(os.path.join(os.path.dirname(__file__), "../text_to_sign"))
from pipeline import Text2SignPipeline
def generate_and_save(prompt, checkpoint_path, output_path, device="cuda"):
pipeline = Text2SignPipeline.from_pretrained(checkpoint_path, device=device)
with torch.no_grad():
video_frames = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5)[0]
# Save as filmstrip
fig, axes = plt.subplots(1, len(video_frames), figsize=(2*len(video_frames), 2))
for i, frame in enumerate(video_frames):
axes[i].imshow(frame)
axes[i].axis('off')
plt.tight_layout()
plt.savefig(output_path)
print(f"Saved filmstrip to {output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--prompt', type=str, required=True, help='Text prompt to generate sign language video')
parser.add_argument('--checkpoint', type=str, default='checkpoint_epoch_70.pt', help='Path to model checkpoint')
parser.add_argument('--output', type=str, default='generated_filmstrip.png', help='Output image path')
parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu')
args = parser.parse_args()
generate_and_save(args.prompt, args.checkpoint, args.output, args.device)