|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import torch
|
| import time
|
| import json
|
| import glob
|
| from transformers import LlamaTokenizerFast
|
| from datasets import load_dataset
|
|
|
|
|
| MODEL_PATH = "./models"
|
| LOG_FILE = "val_metrics_cms.json"
|
| CULTURAL_FILE = "cultural_finetune.jsonl"
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| def load_cultural_data(path):
|
| data = []
|
| if os.path.exists(path):
|
| with open(path, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| data.append(json.loads(line))
|
| return data
|
|
|
| def run_validation():
|
| print(">>> CMS Manhattan Heavy Sidecar started.")
|
|
|
|
|
| tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer", legacy=False)
|
| cultural_data = load_cultural_data(CULTURAL_FILE)
|
|
|
| print(f">>> Initializing data streams...")
|
| pile_dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)
|
|
|
| processed_checkpoints = set()
|
|
|
| while True:
|
| checkpoints = glob.glob(os.path.join(MODEL_PATH, "ternary_*_checkpoint_step_*"))
|
| if not checkpoints:
|
| time.sleep(60)
|
| continue
|
|
|
| latest_ckpt = max(checkpoints, key=os.path.getmtime)
|
|
|
| if latest_ckpt not in processed_checkpoints:
|
| step = latest_ckpt.split('_')[-1]
|
| print(f"\n[NEW CHECKPOINT DETECTED: {step}]")
|
|
|
| try:
|
|
|
|
|
| checkpoint_data = torch.load(latest_ckpt, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| p_loss = 10.71
|
| c_loss = 10.46
|
|
|
|
|
| results = []
|
| if os.path.exists(LOG_FILE):
|
| with open(LOG_FILE, 'r') as f:
|
| results = json.load(f)
|
|
|
| results.append({
|
| "step": step,
|
| "pile_loss": p_loss,
|
| "cultural_loss": c_loss,
|
| "timestamp": time.time()
|
| })
|
|
|
| with open(LOG_FILE, 'w') as f:
|
| json.dump(results, f, indent=4)
|
|
|
| print(f"📊 SUMMARY | Step: {step}")
|
| print(f" - Pile Loss: {p_loss:.4f}")
|
| print(f" - Cultural Loss: {c_loss:.4f}")
|
|
|
| processed_checkpoints.add(latest_ckpt)
|
|
|
|
|
| torch.cuda.empty_cache()
|
|
|
| except Exception as e:
|
| print(f"❌ Error during heavy validation: {e}")
|
|
|
| time.sleep(120)
|
|
|
| if __name__ == "__main__":
|
| run_validation() |