|
|
|
|
|
"""
|
|
|
Script to download the TSN model checkpoint for GenVidBench
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import requests
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
def download_file(url, filename):
|
|
|
"""Download a file with progress bar"""
|
|
|
response = requests.get(url, stream=True)
|
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
|
|
|
|
with open(filename, 'wb') as f, tqdm(
|
|
|
desc=filename,
|
|
|
total=total_size,
|
|
|
unit='iB',
|
|
|
unit_scale=True,
|
|
|
unit_divisor=1024,
|
|
|
) as pbar:
|
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
|
size = f.write(chunk)
|
|
|
pbar.update(size)
|
|
|
|
|
|
def main():
|
|
|
"""Download the TSN model checkpoint"""
|
|
|
|
|
|
os.makedirs('checkpoints', exist_ok=True)
|
|
|
|
|
|
|
|
|
checkpoint_url = "https://download.openmmlab.com/mmaction/v1.0/recognition/tsn/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth"
|
|
|
checkpoint_path = "checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth"
|
|
|
|
|
|
if os.path.exists(checkpoint_path):
|
|
|
print(f"Checkpoint already exists: {checkpoint_path}")
|
|
|
return
|
|
|
|
|
|
print(f"Downloading TSN model checkpoint...")
|
|
|
print(f"URL: {checkpoint_url}")
|
|
|
print(f"Destination: {checkpoint_path}")
|
|
|
|
|
|
try:
|
|
|
download_file(checkpoint_url, checkpoint_path)
|
|
|
print(f"β
Successfully downloaded checkpoint to {checkpoint_path}")
|
|
|
except Exception as e:
|
|
|
print(f"β Error downloading checkpoint: {e}")
|
|
|
print("Please download the checkpoint manually and place it in the checkpoints/ directory")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|