kgrabko commited on
Commit
939c615
·
verified ·
1 Parent(s): f13eb01

Upload val_sidecar_cms_70b.py

Browse files
Files changed (1) hide show
  1. val_sidecar_cms_70b.py +107 -0
val_sidecar_cms_70b.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==============================================================================
2
+ # COPYRIGHT (C) 2025 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
3
+ # PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
4
+ #
5
+ # This software is licensed under the Commercial License Agreement V.1.2.
6
+ # Any use, modification, or distribution of this code requires compliance with
7
+ # the terms found in the LICENSE.md file in the root directory.
8
+ #
9
+ # NO PATENTING RIGHTS: Users are strictly prohibited from filing patent claims
10
+ # based on the BRE or SWA architectures disclosed herein.
11
+ # Contact: grabko@cmsmanhattan.com | +1 (516) 777-0945
12
+ # ==============================================================================
13
+
14
+ # Optimized for Heavy Ternary Models (70B/140B) on ROCm
15
+
16
+ import os
17
+ import torch
18
+ import time
19
+ import json
20
+ import glob
21
+ from transformers import LlamaTokenizerFast
22
+ from datasets import load_dataset
23
+
24
+ # Конфигурация
25
+ MODEL_PATH = "./models"
26
+ LOG_FILE = "val_metrics_cms.json"
27
+ CULTURAL_FILE = "cultural_finetune.jsonl"
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ def load_cultural_data(path):
31
+ data = []
32
+ if os.path.exists(path):
33
+ with open(path, 'r', encoding='utf-8') as f:
34
+ for line in f:
35
+ data.append(json.loads(line))
36
+ return data
37
+
38
+ def run_validation():
39
+ print(">>> CMS Manhattan Heavy Sidecar started.")
40
+
41
+ # Используем Llama-3 токенайзер как стандарт для 70B+
42
+ tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer", legacy=False)
43
+ cultural_data = load_cultural_data(CULTURAL_FILE)
44
+
45
+ print(f">>> Initializing data streams...")
46
+ pile_dataset = load_dataset("monology/pile-uncopyrighted", split="train", streaming=True)
47
+
48
+ processed_checkpoints = set()
49
+
50
+ while True:
51
+ checkpoints = glob.glob(os.path.join(MODEL_PATH, "ternary_*_checkpoint_step_*"))
52
+ if not checkpoints:
53
+ time.sleep(60)
54
+ continue
55
+
56
+ latest_ckpt = max(checkpoints, key=os.path.getmtime)
57
+
58
+ if latest_ckpt not in processed_checkpoints:
59
+ step = latest_ckpt.split('_')[-1]
60
+ print(f"\n[NEW CHECKPOINT DETECTED: {step}]")
61
+
62
+ try:
63
+ # ВАЖНО: Для 70B+ используем загрузку весов с маппингом на CPU перед GPU
64
+ # чтобы избежать пикового потребления VRAM
65
+ checkpoint_data = torch.load(latest_ckpt, map_location='cpu')
66
+
67
+ # Здесь должна быть инициализация вашей архитектуры 140B
68
+ # model = JiRackTernary140B(config).to(DEVICE)
69
+ # model.load_state_dict(checkpoint_data['model_state_dict'])
70
+
71
+ # Имитация расчета (подставьте реальный вызов модели)
72
+ # В 140B версии мы используем только 10 сэмплов для скорости
73
+ p_loss = 10.71 # Пример (заменить на model.forward)
74
+ c_loss = 10.46 # Пример (заменить на model.forward)
75
+
76
+ # Сохранение результатов
77
+ results = []
78
+ if os.path.exists(LOG_FILE):
79
+ with open(LOG_FILE, 'r') as f:
80
+ results = json.load(f)
81
+
82
+ results.append({
83
+ "step": step,
84
+ "pile_loss": p_loss,
85
+ "cultural_loss": c_loss,
86
+ "timestamp": time.time()
87
+ })
88
+
89
+ with open(LOG_FILE, 'w') as f:
90
+ json.dump(results, f, indent=4)
91
+
92
+ print(f"📊 SUMMARY | Step: {step}")
93
+ print(f" - Pile Loss: {p_loss:.4f}")
94
+ print(f" - Cultural Loss: {c_loss:.4f}")
95
+
96
+ processed_checkpoints.add(latest_ckpt)
97
+
98
+ # Очистка кэша после каждого тяжелого прогона
99
+ torch.cuda.empty_cache()
100
+
101
+ except Exception as e:
102
+ print(f"❌ Error during heavy validation: {e}")
103
+
104
+ time.sleep(120)
105
+
106
+ if __name__ == "__main__":
107
+ run_validation()