MuratcanKoylan's picture
Upload folder using huggingface_hub
685d968 verified
"""
Training Monitor - Check progress and evaluate completed models.
"""
import asyncio
import json
import os
from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
import tinker
from tinker import types
from tinker_cookbook import renderers
from tinker_cookbook.tokenizer_utils import get_tokenizer
import numpy as np
BASE_MODEL = "meta-llama/Llama-3.1-8B"
VALID_CATEGORIES = {
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts",
"company.business_priorities", "company.tools_config", "company.performance_context",
"user.communication_style", "user.strategic_approach", "user.role_context",
"user.workflow_patterns", "user.session_history", "user.interaction_preferences",
"none"
}
def list_training_runs():
"""List all training runs and their checkpoints."""
service_client = tinker.ServiceClient()
rest_client = service_client.create_rest_client()
runs = rest_client.list_training_runs().result()
print("=" * 70)
print("TRAINING RUNS")
print("=" * 70)
for run in runs.training_runs[:10]:
ckpts = rest_client.list_checkpoints(run.training_run_id).result()
# Categorize checkpoints
sft_ckpts = [c for c in ckpts.checkpoints if 'sft' in c.checkpoint_id]
rl_ckpts = [c for c in ckpts.checkpoints if 'rl_' in c.checkpoint_id]
print(f"\nRun: {run.training_run_id}")
print(f" Last request: {run.last_request_time}")
print(f" SFT checkpoints: {len(sft_ckpts)}")
print(f" RL checkpoints: {len(rl_ckpts)}")
if rl_ckpts:
# Find the latest RL checkpoint
latest = sorted(rl_ckpts, key=lambda x: x.time)[-1]
print(f" Latest RL: {latest.checkpoint_id}")
# Check if it's a final checkpoint
if 'final' in latest.checkpoint_id:
print(f" STATUS: RL COMPLETE")
print(f" Final checkpoint: tinker://{run.training_run_id}/{latest.checkpoint_id}")
async def quick_eval(checkpoint_path: str, n_samples: int = 20):
"""Quick evaluation of a checkpoint."""
service_client = tinker.ServiceClient()
tokenizer = get_tokenizer(BASE_MODEL)
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer)
# Load test data
with open("training/processed_data/test_data.json", "r") as f:
test_data = json.load(f)
print(f"\nEvaluating: {checkpoint_path}")
print(f"Samples: {n_samples}")
sampling_client = service_client.create_sampling_client(model_path=checkpoint_path)
stop_sequences = renderer.get_stop_sequences()
correct = 0
total = 0
for example in test_data[:n_samples]:
gold = example.get("categories", [])
messages = example.get("messages", [])
prompt_messages = [m for m in messages if m.get("role") != "assistant"]
if not prompt_messages:
continue
prompt = renderer.build_generation_prompt(prompt_messages)
params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences)
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result()
response, success = renderer.parse_response(result.sequences[0].tokens)
predicted_text = response["content"] if success else ""
predicted_set = set([c.strip().lower() for c in predicted_text.split(",")
if c.strip().lower() in VALID_CATEGORIES])
gold_set = set([c.lower() for c in gold])
if predicted_set & gold_set:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0
print(f"Any Match Accuracy: {accuracy:.1%} ({correct}/{total})")
return accuracy
def find_best_checkpoint():
"""Find the best completed RL checkpoint."""
service_client = tinker.ServiceClient()
rest_client = service_client.create_rest_client()
runs = rest_client.list_training_runs().result()
best_rl_checkpoint = None
best_sft_checkpoint = None
for run in runs.training_runs:
ckpts = rest_client.list_checkpoints(run.training_run_id).result()
for ckpt in ckpts.checkpoints:
if 'rl_final' in ckpt.checkpoint_id:
path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}"
if best_rl_checkpoint is None or ckpt.time > best_rl_checkpoint[1]:
best_rl_checkpoint = (path, ckpt.time)
if 'sft_final_sampler' in ckpt.checkpoint_id:
path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}"
if best_sft_checkpoint is None or ckpt.time > best_sft_checkpoint[1]:
best_sft_checkpoint = (path, ckpt.time)
return best_sft_checkpoint, best_rl_checkpoint
async def main():
import sys
if len(sys.argv) > 1 and sys.argv[1] == "eval":
# Evaluate mode
sft_ckpt, rl_ckpt = find_best_checkpoint()
print("=" * 70)
print("CHECKPOINT EVALUATION")
print("=" * 70)
if sft_ckpt:
print(f"\nBest SFT: {sft_ckpt[0]}")
await quick_eval(sft_ckpt[0], n_samples=50)
if rl_ckpt:
print(f"\nBest RL: {rl_ckpt[0]}")
await quick_eval(rl_ckpt[0], n_samples=50)
else:
# List mode
list_training_runs()
sft_ckpt, rl_ckpt = find_best_checkpoint()
print("\n" + "=" * 70)
print("BEST CHECKPOINTS")
print("=" * 70)
if sft_ckpt:
print(f"\nSFT: {sft_ckpt[0]}")
print(f" Time: {sft_ckpt[1]}")
if rl_ckpt:
print(f"\nRL: {rl_ckpt[0]}")
print(f" Time: {rl_ckpt[1]}")
print("\nTo evaluate, run: python training/monitor.py eval")
if __name__ == "__main__":
asyncio.run(main())