|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import torch
|
| import time
|
| import json
|
| import glob
|
| import matplotlib.pyplot as plt
|
| from transformers import AutoTokenizer
|
| from datasets import load_dataset
|
|
|
|
|
|
|
| from JiRackTernaryPyTorch_70b import JiRackTernary70B as JiRackTernary7B
|
| from JiRackTernaryPyTorch_70b import JiRackTernaryConfig
|
|
|
|
|
| BASE_CHECKPOINT_DIR = "checkpoints_jirack_7b"
|
| LOG_FILE = "val_metrics_7b_cms.json"
|
| PLOT_FILE = "training_progress_7b.png"
|
| CULTURAL_FILE = "cultural_finetune.jsonl"
|
| VAL_SAMPLES_SLIM = 100
|
| MAX_LENGTH = 2048
|
|
|
| def load_cultural_questions(file_path):
|
| questions = []
|
| if os.path.exists(file_path):
|
| with open(file_path, 'r', encoding='utf-8') as f:
|
| for line in f:
|
| questions.append(json.loads(line))
|
| return questions
|
|
|
| def update_plot(history):
|
| if len(history) < 2: return
|
| steps = [int(h['step']) for h in history]
|
| slim_losses = [h['slim_loss'] for h in history]
|
| cultural_losses = [h['cultural_loss'] for h in history]
|
|
|
| plt.figure(figsize=(10, 6))
|
| plt.plot(steps, slim_losses, 'b-o', label='SlimPajama Loss')
|
| plt.plot(steps, cultural_losses, 'r-s', label='Cultural Loss')
|
| plt.title('JiRackTernary 7B: Training Dynamics')
|
| plt.xlabel('Steps')
|
| plt.ylabel('Loss')
|
| plt.yscale('log')
|
| plt.legend()
|
| plt.grid(alpha=0.3)
|
| plt.savefig(PLOT_FILE)
|
| plt.close()
|
|
|
| def main():
|
| print(">>> Starting Validator for 7B Model...")
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
|
| if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
|
|
|
| cultural_questions = load_cultural_questions(CULTURAL_FILE)
|
| val_stream = load_dataset("cerebras/SlimPajama-627B", split="train", streaming=True)
|
|
|
| processed_checkpoints = set()
|
| device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
| while True:
|
| ckpt_files = glob.glob(os.path.join(BASE_CHECKPOINT_DIR, "*.pt"))
|
| if not ckpt_files:
|
| time.sleep(30)
|
| continue
|
|
|
| last_ckpt = max(ckpt_files, key=os.path.getmtime)
|
|
|
| if last_ckpt not in processed_checkpoints:
|
| print(f"\n[VALIDATING 7B: {last_ckpt}]")
|
| try:
|
|
|
| config = JiRackTernaryConfig(
|
| num_hidden_layers=32,
|
| hidden_size=4096,
|
| num_attention_heads=32
|
| )
|
| model = JiRackTernary7B(config)
|
|
|
| checkpoint = torch.load(last_ckpt, map_location='cpu')
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
| model.to(device)
|
| model.eval()
|
|
|
|
|
| s_loss = 0
|
| for i, ex in enumerate(val_stream):
|
| if i >= VAL_SAMPLES_SLIM: break
|
| tok = tokenizer(ex["text"], truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(device)
|
| with torch.no_grad():
|
| s_loss += model(tok.input_ids, labels=tok.input_ids).loss.mean().item()
|
| avg_slim = s_loss / VAL_SAMPLES_SLIM
|
|
|
|
|
| c_loss = 0
|
| for ex in cultural_questions:
|
| tok = tokenizer(ex["text"], truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(device)
|
| with torch.no_grad():
|
| c_loss += model(tok.input_ids, labels=tok.input_ids).loss.mean().item()
|
| avg_cultural = c_loss / len(cultural_questions) if cultural_questions else 0
|
|
|
|
|
| step_val = checkpoint.get('step', last_ckpt.split('_')[-1].replace('.pt', ''))
|
| history = []
|
| if os.path.exists(LOG_FILE):
|
| with open(LOG_FILE, 'r') as f: history = json.load(f)
|
|
|
| history.append({
|
| "step": step_val,
|
| "slim_loss": avg_slim,
|
| "cultural_loss": avg_cultural,
|
| "timestamp": time.time()
|
| })
|
|
|
| with open(LOG_FILE, 'w') as f: json.dump(history, f, indent=4)
|
| update_plot(history)
|
|
|
| print(f"📊 7B Stats | Step: {step_val} | Slim: {avg_slim:.4f} | Cultural: {avg_cultural:.4f}")
|
|
|
| processed_checkpoints.add(last_ckpt)
|
| del model
|
| torch.cuda.empty_cache()
|
|
|
| except Exception as e:
|
| print(f"❌ 7B Val Error: {e}")
|
|
|
| time.sleep(60)
|
|
|
| if __name__ == "__main__":
|
| main() |