O-TITANS-Gemma3 / otitans_surgery.py
paperscarecrow's picture
Upload 7 files
f4f6207 verified
import os
import sys
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
# 1. Map our absolute paths
PROJECT_DIR = os.path.dirname(os.path.abspath(__file__))
MODEL_PATH = "mlabonne/gemma-3-12b-it-abliterated"
# Ensure the script can find our custom core module
sys.path.append(PROJECT_DIR)
from otitans_core import OLoRALinear
def inject_orthogonal_memory(model, target_modules=["q_proj", "v_proj"], rank=8, alpha=16.0):
"""
Phase 4: The Surgeon's Scalpel.
Recursively hunts for target linear layers and replaces them with the OTITANS shield.
"""
injected_count = 0
# We convert named_modules() to a list so we can modify the model dictionary while iterating
for name, module in list(model.named_modules()):
# Check if the current module ends with any of our targets (e.g., 'layers.0.self_attn.q_proj')
if any(name.endswith(target) for target in target_modules):
# We only perform surgery on standard Linear layers to prevent breaking Layernorms
if isinstance(module, nn.Linear):
# Instantiate our custom Orthogonal wrapper around the existing frozen layer
wrapped_layer = OLoRALinear(base_layer=module, rank=rank, alpha=alpha)
# PyTorch Surgery: Find the parent module and physically overwrite the child attribute
parent_name = name.rsplit('.', 1)[0]
child_name = name.rsplit('.', 1)[-1]
# get_submodule is a clean PyTorch 1.9+ way to fetch nested parents
parent_module = model.get_submodule(parent_name)
setattr(parent_module, child_name, wrapped_layer)
injected_count += 1
return injected_count
def main():
print(f"[*] Opening the Forge. Loading base weights from:\n {MODEL_PATH}")
# Load the base model strictly in bfloat16 to preserve our 50GB VRAM headroom
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
device_map="cuda",
trust_remote_code=True
)
print("[*] Base architecture loaded. Freezing all 12 Billion primary parameters...")
# This is slightly redundant since OLoRALinear does this, but it is excellent practice
# to enforce a total freeze before making any structural changes.
for param in model.parameters():
param.requires_grad = False
print("[*] Commencing OTITANS surgery. Injecting Orthogonal LoRA into Attention vectors...")
# We target Query and Value projections.
# In recurrent memory theory, Queries ask for memory, Values store the memory.
targets = ["q_proj", "v_proj"]
replaced_layers = inject_orthogonal_memory(model, target_modules=targets, rank=8, alpha=16.0)
print(f"[*] Surgery complete. Successfully grafted {replaced_layers} OLoRALinear modules.")
# Verify exactly how much of the network is actually trainable now
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print("-" * 50)
print(f"[*] Total Parameters: {total_params:,}")
print(f"[*] Trainable OTITANS: {trainable_params:,}")
print(f"[*] Optimization Footprint: {(trainable_params / total_params) * 100:.4f}%")
print("-" * 50)
if __name__ == "__main__":
main()