| |
|
|
| import torch |
| import torch.nn as nn |
|
|
| def convert_ptv3_to_bi_ptv3(model, verbose=False): |
| """ |
| Convert a PTV3 model to Bi-PTV3. |
| This version is COMPATIBLE with your existing checkpoint that has LSR scales but NO ema_offset. |
| """ |
| from .binary_layers import BiLinearLSR |
|
|
| exclusion_keywords = ['embedding', 'stem', 'seg_head'] |
| |
| print("Starting the conversion from PTV3 to Bi-PTV3...") |
| |
| |
| def _recursive_replace(module, parent_name=""): |
| for child_name, child_module in module.named_children(): |
| full_name = f"{parent_name}.{child_name}" if parent_name else child_name |
| is_excluded = any(keyword in full_name for keyword in exclusion_keywords) |
|
|
| if isinstance(child_module, nn.Linear) and not is_excluded: |
| old_layer = child_module |
| new_layer = BiLinearLSR( |
| in_features=old_layer.in_features, |
| out_features=old_layer.out_features, |
| bias=(old_layer.bias is not None) |
| ) |
| new_layer.weight.data.copy_(old_layer.weight.data) |
| if old_layer.bias is not None: |
| new_layer.bias.data.copy_(old_layer.bias.data) |
| setattr(module, child_name, new_layer) |
| if verbose: |
| print(f"SUCCESS: Replaced '{full_name}' with BiLinearLSR.") |
| else: |
| if is_excluded and verbose: |
| print(f"INFO: Skipping module '{full_name}' due to exclusion rules.") |
| _recursive_replace(child_module, parent_name=full_name) |
|
|
| _recursive_replace(model) |
| |
| |
| |
| |
| |
| print("Conversion to Bi-PTV3 complete! (Compatible with checkpoint without ema_offset)") |
| return model |