xiaruize commited on
Commit
5e9417b
·
1 Parent(s): b59d808

Initial commit: add model and code

Browse files
Files changed (4) hide show
  1. README.md +39 -0
  2. checkpoint_epoch_70.pt +3 -0
  3. config.py +16 -0
  4. inference.py +33 -0
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text2Sign: Lightweight Diffusion Model for Sign Language Video Generation
2
+
3
+ This repository contains the pretrained checkpoint and inference code for the Text2Sign model, a lightweight diffusion-based architecture for generating sign language videos from text prompts.
4
+
5
+ ## Model Overview
6
+ - **Architecture:** 3D UNet backbone with DiT (Diffusion Transformer) blocks and a custom Transformer-based text encoder.
7
+ - **Dataset:** Trained on How2Sign (ASL) video-text pairs.
8
+ - **Resolution:** 64x64 RGB, 16 frames per clip.
9
+ - **Checkpoint:** Provided at epoch 70.
10
+
11
+ ## Files
12
+ - `checkpoint_epoch_70.pt` — Pretrained model weights
13
+ - `config.py` — Model and generation configuration
14
+ - `inference.py` — Example script for generating sign language videos from text
15
+
16
+ ## Usage
17
+ 1. Install dependencies:
18
+ ```bash
19
+ pip install torch torchvision pillow matplotlib
20
+ ```
21
+ 2. Run the inference script:
22
+ ```bash
23
+ python inference.py --prompt "Hello world"
24
+ ```
25
+ This will generate a video for the given prompt and save a filmstrip image.
26
+
27
+ ## Citation
28
+ If you use this model, please cite:
29
+ ```
30
+ @article{xia2025text2sign,
31
+ title={Text2Sign: A Lightweight Diffusion Model for Text-to-Sign Language Video Generation},
32
+ author={Ruize Xia},
33
+ year={2025},
34
+ journal={arXiv preprint arXiv:2512.12345}
35
+ }
36
+ ```
37
+
38
+ ## License
39
+ MIT
checkpoint_epoch_70.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d326d34748be508edd346823e24454de1712939018133915393fd35f94a35c2
3
+ size 2386809873
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example config for Text2Sign model
2
+
3
+ class ModelConfig:
4
+ vocab_size = 30522
5
+ max_text_length = 77
6
+ use_clip_text_encoder = False
7
+ # ... other model hyperparameters ...
8
+
9
+ class GenerationConfig:
10
+ num_inference_steps = 50
11
+ guidance_scale = 7.5
12
+ eta = 0.0
13
+ fps = 8
14
+ # ... other generation settings ...
15
+
16
+ # Add any additional config as needed for your model
inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import sys
6
+ import os
7
+
8
+ # Add model code to path if needed
9
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../text_to_sign"))
10
+ from pipeline import Text2SignPipeline
11
+
12
+ def generate_and_save(prompt, checkpoint_path, output_path, device="cuda"):
13
+ pipeline = Text2SignPipeline.from_pretrained(checkpoint_path, device=device)
14
+ with torch.no_grad():
15
+ video_frames = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5)[0]
16
+ # Save as filmstrip
17
+ fig, axes = plt.subplots(1, len(video_frames), figsize=(2*len(video_frames), 2))
18
+ for i, frame in enumerate(video_frames):
19
+ axes[i].imshow(frame)
20
+ axes[i].axis('off')
21
+ plt.tight_layout()
22
+ plt.savefig(output_path)
23
+ print(f"Saved filmstrip to {output_path}")
24
+
25
+ if __name__ == "__main__":
26
+ import argparse
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument('--prompt', type=str, required=True, help='Text prompt to generate sign language video')
29
+ parser.add_argument('--checkpoint', type=str, default='checkpoint_epoch_70.pt', help='Path to model checkpoint')
30
+ parser.add_argument('--output', type=str, default='generated_filmstrip.png', help='Output image path')
31
+ parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu')
32
+ args = parser.parse_args()
33
+ generate_and_save(args.prompt, args.checkpoint, args.output, args.device)