File size: 9,364 Bytes
873f551
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#!/usr/bin/env python3
"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation

Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""

import os
import sys
from pathlib import Path

# Define the standard header for the project
STANDARD_HEADER = '''"""
PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation

Official implementation of the paper:
"PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation"
by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis
Licensed under a modified MIT license
"""'''

# Old headers that should be replaced
OLD_HEADERS = [
    '''"""

"""'''
]


def should_skip_file(file_path):
    """
    Determine if a file should be skipped for header addition.
    
    Args:
        file_path: Path to check
        
    Returns:
        True if the file should be skipped, False otherwise
    """
    skip_dirs = {'.git', '__pycache__', '.pytest_cache', 'venv', 'env', '.tox', 'build', 'dist', '.eggs'}
    
    # Skip if in excluded directory
    for part in file_path.parts:
        if part in skip_dirs:
            return True
    
    # Skip __init__.py files that are typically minimal
    if file_path.name == '__init__.py':
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                content = f.read()
            # Skip if __init__.py is very short (likely just imports)
            if len(content.strip()) < 50:
                return True
        except Exception:
            pass
    
    return False


def has_header(content):
    """
    Check if content already has the standard header or a valid variant.
    
    Args:
        content: File content to check
        
    Returns:
        True if the file has the standard header or acceptable variant, False otherwise
    """
    # Check for exact match
    if STANDARD_HEADER.strip() in content:
        return True
    
    # Check for header with additional content (like in sort.py)
    # Header should contain the key elements
    lines = content.split('\n')
    if len(lines) < 3:
        return False
    
    # Check if it starts with a docstring
    if not lines[0].strip().startswith('"""'):
        return False
    
    # Check for key header components in the first 15 lines
    header_section = '\n'.join(lines[:15])
    required_elements = [
        'FMPose3D: monocular 3D Pose Estimation via Flow Matching',
        'Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis',
        'Licensed under Apache 2.0'
    ]
    
    return all(elem in header_section for elem in required_elements)


def needs_header_update(content):
    """
    Check if content has an old header that needs updating.
    
    Args:
        content: File content to check
        
    Returns:
        Old header if found, None otherwise
    """
    for old_header in OLD_HEADERS:
        if old_header.strip() in content:
            return old_header
    return None


def add_or_update_header(file_path, check_only=False):
    """
    Add or update the header in a single file.
    
    Args:
        file_path: Path to the file to update
        check_only: If True, only check without modifying
        
    Returns:
        Tuple of (status, message) where status is 'ok', 'updated', 'added', or 'error'
    """
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Check if file already has the correct header
        if has_header(content):
            return ('ok', 'Already has correct header')
        
        # Check if file has an old header that needs replacing
        old_header = needs_header_update(content)
        if old_header:
            if not check_only:
                new_content = content.replace(old_header, STANDARD_HEADER)
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(new_content)
            return ('updated', 'Replaced old header with standard header')
        
        # File has no header, add one
        # Skip adding header to files that start with shebang or are very short
        lines = content.split('\n')
        if content.strip() and len(content.strip()) > 10:
            if not check_only:
                # Handle special cases for header placement
                new_lines = []
                insert_index = 0
                
                # If file starts with shebang, keep it at the top
                if lines[0].startswith('#!'):
                    new_lines.append(lines[0])
                    insert_index = 1
                
                # Check for 'from __future__' imports which must be very early
                # Find the first non-comment, non-shebang, non-empty line
                future_import_index = None
                for i in range(insert_index, min(len(lines), 10)):
                    line = lines[i].strip()
                    if line.startswith('from __future__'):
                        future_import_index = i
                        break
                    elif line and not line.startswith('#'):
                        # Found a non-comment line that isn't a future import
                        break
                
                if future_import_index is not None:
                    # If there's a from __future__ import, add header AFTER it
                    new_lines.extend(lines[insert_index:future_import_index+1])
                    new_lines.append(STANDARD_HEADER)
                    new_lines.append('')
                    new_lines.extend(lines[future_import_index+1:])
                else:
                    # Otherwise, add header at the beginning (after shebang if present)
                    new_lines.append(STANDARD_HEADER)
                    new_lines.append('')
                    new_lines.extend(lines[insert_index:])
                
                new_content = '\n'.join(new_lines)
                
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(new_content)
            return ('added', 'Added standard header')
        
        return ('ok', 'Skipped (file too short or empty)')
        
    except Exception as e:
        return ('error', f"Error processing file: {e}")


def find_and_process_headers(root_dir, check_only=False):
    """
    Find and process all Python files.
    
    Args:
        root_dir: Root directory to search from
        check_only: If True, only check without modifying files
        
    Returns:
        Dictionary with statistics about processed files
    """
    root_path = Path(root_dir)
    stats = {
        'ok': [],
        'updated': [],
        'added': [],
        'error': []
    }
    
    # Find all Python files
    for py_file in root_path.rglob('*.py'):
        # Skip files that should not be processed
        if should_skip_file(py_file):
            continue
            
        status, message = add_or_update_header(py_file, check_only)
        stats[status].append((py_file, message))
        
        if status in ['updated', 'added']:
            rel_path = py_file.relative_to(root_path)
            print(f"{'[CHECK]' if check_only else '✓'} {rel_path}: {message}")
        elif status == 'error':
            rel_path = py_file.relative_to(root_path)
            print(f"✗ {rel_path}: {message}")
    
    return stats


def main():
    """Main function to run the header update script."""
    check_only = '--check' in sys.argv
    
    if len(sys.argv) > 1 and not sys.argv[1].startswith('--'):
        root_dir = Path(sys.argv[1])
    else:
        root_dir = Path(os.getcwd())
    
    mode = "Checking" if check_only else "Processing"
    print(f"{mode} files for headers in: {root_dir}")
    print("-" * 60)
    
    stats = find_and_process_headers(root_dir, check_only)
    
    print("-" * 60)
    
    # Print summary
    total_changes = len(stats['updated']) + len(stats['added'])
    
    if check_only:
        if total_changes > 0:
            print(f"\n⚠ Found {total_changes} file(s) needing header updates:")
            for file_path, msg in stats['updated']:
                print(f"  - {file_path.relative_to(root_dir)}: {msg}")
            for file_path, msg in stats['added']:
                print(f"  - {file_path.relative_to(root_dir)}: {msg}")
            return 1
        else:
            print("\n✓ All Python files have correct headers!")
            return 0
    else:
        if total_changes > 0:
            print(f"\n✓ Successfully processed {total_changes} file(s):")
            if stats['updated']:
                print(f"  - Updated: {len(stats['updated'])} file(s)")
            if stats['added']:
                print(f"  - Added headers: {len(stats['added'])} file(s)")
        else:
            print("\n✓ No files needed header updates.")
        
        if stats['error']:
            print(f"\n✗ Errors: {len(stats['error'])} file(s)")
            for file_path, msg in stats['error']:
                print(f"  - {file_path.relative_to(root_dir)}: {msg}")
            return 1
        
        return 0


if __name__ == '__main__':
    sys.exit(main())