PRIMA-demo / scripts /update_headers.py
mwmathis's picture
Deploy PRIMA Gradio app to Space
1800173 verified
#!/usr/bin/env python3
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""
import os
import sys
from pathlib import Path
# Define the standard header for the project
STANDARD_HEADER = '''"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation
Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""'''
# Old headers that should be replaced
OLD_HEADERS = [
'''"""
"""'''
]
def should_skip_file(file_path):
"""
Determine if a file should be skipped for header addition.
Args:
file_path: Path to check
Returns:
True if the file should be skipped, False otherwise
"""
skip_dirs = {'.git', '__pycache__', '.pytest_cache', 'venv', 'env', '.tox', 'build', 'dist', '.eggs'}
# Skip if in excluded directory
for part in file_path.parts:
if part in skip_dirs:
return True
# Skip __init__.py files that are typically minimal
if file_path.name == '__init__.py':
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Skip if __init__.py is very short (likely just imports)
if len(content.strip()) < 50:
return True
except Exception:
pass
return False
def has_header(content):
"""
Check if content already has the standard header or a valid variant.
Args:
content: File content to check
Returns:
True if the file has the standard header or acceptable variant, False otherwise
"""
# Check for exact match
if STANDARD_HEADER.strip() in content:
return True
# Check for header with additional content (like in sort.py)
# Header should contain the key elements
lines = content.split('\n')
if len(lines) < 3:
return False
# Check if it starts with a docstring
if not lines[0].strip().startswith('"""'):
return False
# Check for key header components in the first 15 lines
header_section = '\n'.join(lines[:15])
required_elements = [
'FMPose3D: monocular 3D Pose Estimation via Flow Matching',
'Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis',
'Licensed under Apache 2.0'
]
return all(elem in header_section for elem in required_elements)
def needs_header_update(content):
"""
Check if content has an old header that needs updating.
Args:
content: File content to check
Returns:
Old header if found, None otherwise
"""
for old_header in OLD_HEADERS:
if old_header.strip() in content:
return old_header
return None
def add_or_update_header(file_path, check_only=False):
"""
Add or update the header in a single file.
Args:
file_path: Path to the file to update
check_only: If True, only check without modifying
Returns:
Tuple of (status, message) where status is 'ok', 'updated', 'added', or 'error'
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
# Check if file already has the correct header
if has_header(content):
return ('ok', 'Already has correct header')
# Check if file has an old header that needs replacing
old_header = needs_header_update(content)
if old_header:
if not check_only:
new_content = content.replace(old_header, STANDARD_HEADER)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(new_content)
return ('updated', 'Replaced old header with standard header')
# File has no header, add one
# Skip adding header to files that start with shebang or are very short
lines = content.split('\n')
if content.strip() and len(content.strip()) > 10:
if not check_only:
# Handle special cases for header placement
new_lines = []
insert_index = 0
# If file starts with shebang, keep it at the top
if lines[0].startswith('#!'):
new_lines.append(lines[0])
insert_index = 1
# Check for 'from __future__' imports which must be very early
# Find the first non-comment, non-shebang, non-empty line
future_import_index = None
for i in range(insert_index, min(len(lines), 10)):
line = lines[i].strip()
if line.startswith('from __future__'):
future_import_index = i
break
elif line and not line.startswith('#'):
# Found a non-comment line that isn't a future import
break
if future_import_index is not None:
# If there's a from __future__ import, add header AFTER it
new_lines.extend(lines[insert_index:future_import_index+1])
new_lines.append(STANDARD_HEADER)
new_lines.append('')
new_lines.extend(lines[future_import_index+1:])
else:
# Otherwise, add header at the beginning (after shebang if present)
new_lines.append(STANDARD_HEADER)
new_lines.append('')
new_lines.extend(lines[insert_index:])
new_content = '\n'.join(new_lines)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(new_content)
return ('added', 'Added standard header')
return ('ok', 'Skipped (file too short or empty)')
except Exception as e:
return ('error', f"Error processing file: {e}")
def find_and_process_headers(root_dir, check_only=False):
"""
Find and process all Python files.
Args:
root_dir: Root directory to search from
check_only: If True, only check without modifying files
Returns:
Dictionary with statistics about processed files
"""
root_path = Path(root_dir)
stats = {
'ok': [],
'updated': [],
'added': [],
'error': []
}
# Find all Python files
for py_file in root_path.rglob('*.py'):
# Skip files that should not be processed
if should_skip_file(py_file):
continue
status, message = add_or_update_header(py_file, check_only)
stats[status].append((py_file, message))
if status in ['updated', 'added']:
rel_path = py_file.relative_to(root_path)
print(f"{'[CHECK]' if check_only else '✓'} {rel_path}: {message}")
elif status == 'error':
rel_path = py_file.relative_to(root_path)
print(f"✗ {rel_path}: {message}")
return stats
def main():
"""Main function to run the header update script."""
check_only = '--check' in sys.argv
if len(sys.argv) > 1 and not sys.argv[1].startswith('--'):
root_dir = Path(sys.argv[1])
else:
root_dir = Path(os.getcwd())
mode = "Checking" if check_only else "Processing"
print(f"{mode} files for headers in: {root_dir}")
print("-" * 60)
stats = find_and_process_headers(root_dir, check_only)
print("-" * 60)
# Print summary
total_changes = len(stats['updated']) + len(stats['added'])
if check_only:
if total_changes > 0:
print(f"\n⚠ Found {total_changes} file(s) needing header updates:")
for file_path, msg in stats['updated']:
print(f" - {file_path.relative_to(root_dir)}: {msg}")
for file_path, msg in stats['added']:
print(f" - {file_path.relative_to(root_dir)}: {msg}")
return 1
else:
print("\n✓ All Python files have correct headers!")
return 0
else:
if total_changes > 0:
print(f"\n✓ Successfully processed {total_changes} file(s):")
if stats['updated']:
print(f" - Updated: {len(stats['updated'])} file(s)")
if stats['added']:
print(f" - Added headers: {len(stats['added'])} file(s)")
else:
print("\n✓ No files needed header updates.")
if stats['error']:
print(f"\n✗ Errors: {len(stats['error'])} file(s)")
for file_path, msg in stats['error']:
print(f" - {file_path.relative_to(root_dir)}: {msg}")
return 1
return 0
if __name__ == '__main__':
sys.exit(main())