Spaces:
Sleeping
Sleeping
Upload training/eval_baseline.py with huggingface_hub
Browse files- training/eval_baseline.py +320 -0
training/eval_baseline.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
eval_baseline.py β LLM baseline evaluation (no fine-tuning)
|
| 3 |
+
Loads MODEL_NAME in 4-bit, evaluates on all EVAL_TASKS, saves results to
|
| 4 |
+
runs/baselines/MODEL-DATETIME/ and uploads to HF Hub.
|
| 5 |
+
|
| 6 |
+
Usage (HF Spaces / Colab with GPU):
|
| 7 |
+
MODEL_NAME=Qwen/Qwen2.5-7B-Instruct python eval_baseline.py
|
| 8 |
+
HF_TOKEN=hf_... MODEL_NAME=meta-llama/Meta-Llama-3-8B-Instruct python eval_baseline.py
|
| 9 |
+
"""
|
| 10 |
+
import os, json, re, datetime, time
|
| 11 |
+
import requests
|
| 12 |
+
import matplotlib
|
| 13 |
+
matplotlib.use('Agg')
|
| 14 |
+
import matplotlib.pyplot as plt
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
ENV_URL = 'https://pathikreet-ap-clerk-env.hf.space'
|
| 18 |
+
MODEL_NAME = os.environ.get('MODEL_NAME', 'Qwen/Qwen2.5-7B-Instruct')
|
| 19 |
+
SEEDS = [42, 99, 7] # 3 seeds per task β mean score per task
|
| 20 |
+
EVAL_TASKS = [
|
| 21 |
+
'easy_perfect_match', 'easy_no_po_found',
|
| 22 |
+
'medium_quantity_shortfall', 'medium_price_discrepancy',
|
| 23 |
+
'medium_split_delivery', 'medium_vendor_mismatch',
|
| 24 |
+
'hard_policy_violation', 'hard_duplicate_invoice',
|
| 25 |
+
'hard_partial_po_match', 'hard_tax_discrepancy',
|
| 26 |
+
'long_invoice_dispute', 'long_policy_migration',
|
| 27 |
+
'long_batch_reconciliation', 'long_manager_chain',
|
| 28 |
+
'long_fraud_investigation', 'long_audit_trail',
|
| 29 |
+
'long_multi_vendor_split',
|
| 30 |
+
]
|
| 31 |
+
TASK_DIFFICULTY = {
|
| 32 |
+
'easy_perfect_match': 'easy', 'easy_no_po_found': 'easy',
|
| 33 |
+
'medium_quantity_shortfall':'medium','medium_price_discrepancy':'medium',
|
| 34 |
+
'medium_split_delivery':'medium', 'medium_vendor_mismatch':'medium',
|
| 35 |
+
'hard_policy_violation':'hard', 'hard_duplicate_invoice':'hard',
|
| 36 |
+
'hard_partial_po_match':'hard', 'hard_tax_discrepancy':'hard',
|
| 37 |
+
'long_invoice_dispute':'long', 'long_policy_migration':'long',
|
| 38 |
+
'long_batch_reconciliation':'long', 'long_manager_chain':'long',
|
| 39 |
+
'long_fraud_investigation':'long', 'long_audit_trail':'long',
|
| 40 |
+
'long_multi_vendor_split':'long',
|
| 41 |
+
}
|
| 42 |
+
DIFF_COLORS = {'easy': '#3fb950', 'medium': '#d29922', 'hard': '#f85149', 'long': '#a371f7'}
|
| 43 |
+
DIFF_ORDER = ['easy', 'medium', 'hard', 'long']
|
| 44 |
+
|
| 45 |
+
SYSTEM_PROMPT = """You are an AI Accounts Payable Clerk. Review the invoice, PO, and GRN, then output ONLY valid JSON:
|
| 46 |
+
{"decision": "APPROVE_FULL"|"APPROVE_PARTIAL"|"REJECT"|"ESCALATE"|"QUERY_VENDOR",
|
| 47 |
+
"approved_amount": <float>,
|
| 48 |
+
"reason_code": "MATCH_CONFIRMED"|"QUANTITY_MISMATCH"|"PRICE_DISCREPANCY"|"POLICY_VIOLATION"|"NO_PO_FOUND"|"DUPLICATE_INVOICE"|"VENDOR_MISMATCH"|"TAX_DISCREPANCY"|"PENDING_CLARIFICATION"|"MANAGER_REVIEW",
|
| 49 |
+
"explanation": "<cite specific $ amounts>"}"""
|
| 50 |
+
|
| 51 |
+
VALID_DECISIONS = {'APPROVE_FULL','APPROVE_PARTIAL','REJECT','ESCALATE','QUERY_VENDOR','HOLD'}
|
| 52 |
+
VALID_REASON_CODES = {'MATCH_CONFIRMED','QUANTITY_MISMATCH','PRICE_DISCREPANCY','POLICY_VIOLATION',
|
| 53 |
+
'NO_PO_FOUND','DUPLICATE_INVOICE','VENDOR_MISMATCH','TAX_DISCREPANCY',
|
| 54 |
+
'PENDING_CLARIFICATION','MANAGER_REVIEW'}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def obs_to_prompt(obs):
|
| 58 |
+
inv = obs['invoice']
|
| 59 |
+
lines = '\n'.join(f" {li['description']}: qty={li['quantity']}, unit_price=${li['unit_price']:.2f}"
|
| 60 |
+
for li in inv.get('line_items', []))
|
| 61 |
+
pos = '\n'.join(
|
| 62 |
+
f" PO {p['po_number']} ({p['status']}) {p['vendor_name']}: " +
|
| 63 |
+
', '.join(f"{l['description']} qty={l['ordered_quantity']} @${l['agreed_unit_price']:.2f}"
|
| 64 |
+
for l in p.get('lines', []))
|
| 65 |
+
for p in obs.get('purchase_orders', []))
|
| 66 |
+
grns = '\n'.join(
|
| 67 |
+
f" GRN {g['grn_id']}: " + ', '.join(f"{l['description']} recv={l['received_quantity']}"
|
| 68 |
+
for l in g.get('lines', []))
|
| 69 |
+
for g in obs.get('goods_receipts', []))
|
| 70 |
+
context = '\n'.join(f' {n}' for n in obs.get('context_notes', []))
|
| 71 |
+
paid = ', '.join(obs.get('paid_invoice_ids', []))
|
| 72 |
+
return (f"TASK: {obs['task_name']}\n{obs['task_description']}\n\n"
|
| 73 |
+
f"INVOICE {inv['invoice_id']} | {inv['vendor_name']} | ${inv['invoice_total']:,.2f}\n{lines}\n"
|
| 74 |
+
f"Freight: ${inv.get('freight_charge',0):.2f}\n\n"
|
| 75 |
+
f"PURCHASE ORDERS:\n{pos}\n\nGOODS RECEIPTS:\n{grns}\n"
|
| 76 |
+
+ (f"PAID LEDGER: {paid}\n" if paid else "")
|
| 77 |
+
+ (f"CONTEXT:\n{context}\n" if context else "")
|
| 78 |
+
+ f"\nPOLICY:\n{obs['company_policy']}\n\nOutput JSON decision.")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def parse_action(raw):
|
| 82 |
+
clean = re.sub(r'```(?:json)?\s*|\s*```', '', raw).strip()
|
| 83 |
+
m = re.search(r'\{.*\}', clean, re.DOTALL)
|
| 84 |
+
if m:
|
| 85 |
+
try:
|
| 86 |
+
a = json.loads(m.group())
|
| 87 |
+
if (a.get('decision') in VALID_DECISIONS and
|
| 88 |
+
a.get('reason_code') in VALID_REASON_CODES and
|
| 89 |
+
isinstance(a.get('approved_amount'), (int, float)) and
|
| 90 |
+
len(a.get('explanation', '')) > 10):
|
| 91 |
+
return a, True
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
return {'decision': 'REJECT', 'approved_amount': 0.0,
|
| 95 |
+
'reason_code': 'NO_PO_FOUND', 'explanation': 'parse error'}, False
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def eval_one(model, tokenizer, task_id, seed):
|
| 99 |
+
import torch
|
| 100 |
+
model.eval()
|
| 101 |
+
try:
|
| 102 |
+
reset = requests.post(f'{ENV_URL}/reset',
|
| 103 |
+
json={'task_id': task_id, 'seed': seed}, timeout=20).json()
|
| 104 |
+
obs, sid = reset['observation'], reset['session_id']
|
| 105 |
+
msgs = [{'role': 'system', 'content': SYSTEM_PROMPT},
|
| 106 |
+
{'role': 'user', 'content': obs_to_prompt(obs)}]
|
| 107 |
+
text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 108 |
+
inputs = tokenizer(text, return_tensors='pt').to('cuda')
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
out = model.generate(**inputs, max_new_tokens=250, temperature=0.1, do_sample=True)
|
| 111 |
+
raw = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
| 112 |
+
action, fmt_ok = parse_action(raw)
|
| 113 |
+
score = float(requests.post(f'{ENV_URL}/step',
|
| 114 |
+
json={'session_id': sid, 'action': action},
|
| 115 |
+
timeout=20).json()['reward']['score'])
|
| 116 |
+
return score, raw[:120], action.get('decision', '?'), fmt_ok
|
| 117 |
+
except Exception as e:
|
| 118 |
+
print(f' error: {e}')
|
| 119 |
+
return 0.01, '', 'ERROR', False
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def main():
|
| 123 |
+
hf_token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
|
| 124 |
+
if hf_token:
|
| 125 |
+
from huggingface_hub import login
|
| 126 |
+
login(token=hf_token, add_to_git_credential=False)
|
| 127 |
+
print('[AUTH] Logged in.')
|
| 128 |
+
|
| 129 |
+
model_slug = MODEL_NAME.split('/')[-1].lower().replace('.', '-')
|
| 130 |
+
ts = datetime.datetime.now().strftime('%Y-%m-%d_%H%M')
|
| 131 |
+
run_dir = f'/app/runs/baselines/{model_slug}-{ts}'
|
| 132 |
+
os.makedirs(run_dir, exist_ok=True)
|
| 133 |
+
print(f'[RUN] {MODEL_NAME} β {run_dir}')
|
| 134 |
+
|
| 135 |
+
print(f'[ENV] Waking {ENV_URL}...')
|
| 136 |
+
for attempt in range(12): # up to 2 min (12 Γ 10 s)
|
| 137 |
+
try:
|
| 138 |
+
resp = requests.get(f'{ENV_URL}/health', timeout=30)
|
| 139 |
+
if resp.status_code == 200 and resp.text.strip().startswith('{'):
|
| 140 |
+
h = resp.json()
|
| 141 |
+
print(f"[ENV] status={h['status']} tasks={h.get('total_tasks')}")
|
| 142 |
+
break
|
| 143 |
+
print(f'[ENV] attempt {attempt+1}: not ready (status={resp.status_code}), waiting 10 s...')
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f'[ENV] attempt {attempt+1}: {e}, waiting 10 s...')
|
| 146 |
+
time.sleep(10)
|
| 147 |
+
else:
|
| 148 |
+
raise RuntimeError(f'Environment at {ENV_URL} did not become healthy after 120 s.')
|
| 149 |
+
|
| 150 |
+
print(f'[MODEL] Loading {MODEL_NAME} (4-bit NF4, no LoRA)...')
|
| 151 |
+
import torch
|
| 152 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 153 |
+
|
| 154 |
+
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',
|
| 155 |
+
bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
|
| 156 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
| 157 |
+
if tokenizer.pad_token is None:
|
| 158 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 159 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb,
|
| 160 |
+
device_map='auto', trust_remote_code=True)
|
| 161 |
+
print('[MODEL] Ready.')
|
| 162 |
+
|
| 163 |
+
# Evaluate
|
| 164 |
+
results = {} # task_id β {scores: [], mean: float, decisions: [], fmt_rate: float}
|
| 165 |
+
parse_failures = 0
|
| 166 |
+
print(f'\n[EVAL] {len(EVAL_TASKS)} tasks Γ {len(SEEDS)} seeds = {len(EVAL_TASKS)*len(SEEDS)} episodes\n')
|
| 167 |
+
|
| 168 |
+
for task_id in EVAL_TASKS:
|
| 169 |
+
diff = TASK_DIFFICULTY[task_id]
|
| 170 |
+
scores, decisions, fmts = [], [], []
|
| 171 |
+
for seed in SEEDS:
|
| 172 |
+
score, raw, dec, fmt_ok = eval_one(model, tokenizer, task_id, seed)
|
| 173 |
+
scores.append(score)
|
| 174 |
+
decisions.append(dec)
|
| 175 |
+
fmts.append(fmt_ok)
|
| 176 |
+
if not fmt_ok:
|
| 177 |
+
parse_failures += 1
|
| 178 |
+
print(f' [{diff[:4]}] {task_id} seed={seed}: {score:.3f} {dec} fmt={fmt_ok}')
|
| 179 |
+
print(f' {raw[:90]}')
|
| 180 |
+
time.sleep(0.2)
|
| 181 |
+
results[task_id] = {
|
| 182 |
+
'difficulty': diff,
|
| 183 |
+
'scores': [round(s, 4) for s in scores],
|
| 184 |
+
'mean': round(sum(scores) / len(scores), 4),
|
| 185 |
+
'decisions': decisions,
|
| 186 |
+
'fmt_rate': round(sum(fmts) / len(fmts), 3),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Summary
|
| 190 |
+
print('\n' + '='*70)
|
| 191 |
+
by_diff = {}
|
| 192 |
+
for tid, v in results.items():
|
| 193 |
+
by_diff.setdefault(v['difficulty'], []).append(v['mean'])
|
| 194 |
+
for diff in DIFF_ORDER:
|
| 195 |
+
ms = by_diff.get(diff, [])
|
| 196 |
+
if ms:
|
| 197 |
+
print(f" {diff:<8}: mean={sum(ms)/len(ms):.3f} tasks={[round(m,3) for m in ms]}")
|
| 198 |
+
all_means = [v['mean'] for v in results.values()]
|
| 199 |
+
overall = sum(all_means) / len(all_means)
|
| 200 |
+
print(f" overall : mean={overall:.3f} parse_failures={parse_failures}/{len(EVAL_TASKS)*len(SEEDS)}")
|
| 201 |
+
print('='*70)
|
| 202 |
+
|
| 203 |
+
# Save JSON
|
| 204 |
+
output = {
|
| 205 |
+
'run_type': 'llm_baseline_no_finetuning',
|
| 206 |
+
'model': MODEL_NAME,
|
| 207 |
+
'quantization': '4-bit NF4 (BitsAndBytes)',
|
| 208 |
+
'lora': None,
|
| 209 |
+
'timestamp': datetime.datetime.now().isoformat(),
|
| 210 |
+
'run_dir': run_dir,
|
| 211 |
+
'env_url': ENV_URL,
|
| 212 |
+
'seeds': SEEDS,
|
| 213 |
+
'eval_tasks': EVAL_TASKS,
|
| 214 |
+
'overall_mean': round(overall, 4),
|
| 215 |
+
'parse_failures': parse_failures,
|
| 216 |
+
'tasks': results,
|
| 217 |
+
'by_difficulty': {d: round(sum(ms)/len(ms), 4) for d, ms in by_diff.items()},
|
| 218 |
+
}
|
| 219 |
+
json_path = os.path.join(run_dir, 'baseline_results.json')
|
| 220 |
+
with open(json_path, 'w') as f:
|
| 221 |
+
json.dump(output, f, indent=2)
|
| 222 |
+
print(f'[SAVED] {json_path}')
|
| 223 |
+
|
| 224 |
+
# ββ Plots βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
fig = plt.figure(figsize=(16, max(9, len(results) * 0.5 + 2)))
|
| 226 |
+
fig.patch.set_facecolor('#0d1117')
|
| 227 |
+
gs = fig.add_gridspec(1, 2, wspace=0.30)
|
| 228 |
+
|
| 229 |
+
def _dark(ax, title='', xlabel='', ylabel=''):
|
| 230 |
+
ax.set_facecolor('#161b22')
|
| 231 |
+
ax.tick_params(colors='#c9d1d9', labelsize=8)
|
| 232 |
+
for sp in ax.spines.values(): sp.set_color('#30363d')
|
| 233 |
+
ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
|
| 234 |
+
ax.yaxis.grid(True, color='#21262d', linewidth=0.7)
|
| 235 |
+
ax.set_axisbelow(True)
|
| 236 |
+
if title: ax.set_title(title, color='#e6edf3', fontsize=11, fontweight='bold', pad=8)
|
| 237 |
+
if xlabel: ax.set_xlabel(xlabel, color='#8b949e', fontsize=8)
|
| 238 |
+
if ylabel: ax.set_ylabel(ylabel, color='#8b949e', fontsize=8)
|
| 239 |
+
|
| 240 |
+
# Panel 1: Per-task mean score (horizontal bar), ordered by difficulty
|
| 241 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 242 |
+
tasks = sorted(results.keys(),
|
| 243 |
+
key=lambda t: (DIFF_ORDER.index(results[t]['difficulty']), t))
|
| 244 |
+
means = [results[t]['mean'] for t in tasks]
|
| 245 |
+
colors = [DIFF_COLORS[results[t]['difficulty']] for t in tasks]
|
| 246 |
+
short = [t.replace('easy_','').replace('medium_','').replace('hard_','').replace('long_','')
|
| 247 |
+
.replace('_',' ').title() for t in tasks]
|
| 248 |
+
yp = range(len(tasks))
|
| 249 |
+
bars = ax1.barh(list(yp), means, color=colors, alpha=0.85, edgecolor='#0d1117')
|
| 250 |
+
ax1.set_yticks(list(yp))
|
| 251 |
+
ax1.set_yticklabels(short, fontsize=8)
|
| 252 |
+
ax1.set_xlim(0, 1.05)
|
| 253 |
+
ax1.axvline(overall, color='#f78166', linestyle='--', linewidth=1.2,
|
| 254 |
+
label=f'Overall mean: {overall:.3f}')
|
| 255 |
+
ax1.axvline(0.5, color='#484f58', linestyle=':', linewidth=1)
|
| 256 |
+
for i, m in enumerate(means):
|
| 257 |
+
ax1.text(m + 0.01, i, f'{m:.3f}', va='center', color='#c9d1d9', fontsize=8)
|
| 258 |
+
from matplotlib.patches import Patch
|
| 259 |
+
legend_els = [Patch(facecolor=c, label=d) for d, c in DIFF_COLORS.items()]
|
| 260 |
+
legend_els.append(plt.Line2D([0],[0], color='#f78166', linestyle='--',
|
| 261 |
+
label=f'Mean {overall:.3f}'))
|
| 262 |
+
ax1.legend(handles=legend_els, fontsize=8, facecolor='#161b22',
|
| 263 |
+
edgecolor='#30363d', labelcolor='#c9d1d9', loc='lower right')
|
| 264 |
+
_dark(ax1, f'Untrained Baseline β Per-Task Mean Score ({len(SEEDS)} seeds)',
|
| 265 |
+
xlabel='Mean Score [0.01 β 0.99]', ylabel='Task')
|
| 266 |
+
|
| 267 |
+
# Panel 2: Mean by difficulty
|
| 268 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 269 |
+
diffs = [d for d in DIFF_ORDER if d in by_diff]
|
| 270 |
+
d_means = [sum(by_diff.get(d, [0])) / max(1, len(by_diff.get(d, [0]))) for d in diffs]
|
| 271 |
+
d_colors = [DIFF_COLORS[d] for d in diffs]
|
| 272 |
+
bars2 = ax2.bar(diffs, d_means, color=d_colors, alpha=0.85, edgecolor='#0d1117', width=0.5)
|
| 273 |
+
for i, (d, m) in enumerate(zip(diffs, d_means)):
|
| 274 |
+
ax2.text(i, m + 0.02, f'{m:.3f}', ha='center', color='#c9d1d9', fontsize=10,
|
| 275 |
+
fontweight='bold')
|
| 276 |
+
ax2.set_ylim(0, 1.05)
|
| 277 |
+
ax2.axhline(overall, color='#f78166', linestyle='--', linewidth=1,
|
| 278 |
+
label=f'Overall {overall:.3f}')
|
| 279 |
+
ax2.legend(fontsize=8, facecolor='#161b22', edgecolor='#30363d', labelcolor='#c9d1d9')
|
| 280 |
+
_dark(ax2, 'Mean Score by Difficulty Tier',
|
| 281 |
+
xlabel='Difficulty Tier', ylabel='Mean Score [0.01 β 0.99]')
|
| 282 |
+
|
| 283 |
+
model_short = MODEL_NAME.split('/')[-1]
|
| 284 |
+
fig.suptitle(
|
| 285 |
+
f'{model_short} β Untrained Baseline | 4-bit NF4 | {len(SEEDS)} seeds | '
|
| 286 |
+
f'{len(EVAL_TASKS)} tasks | overall={overall:.3f} | '
|
| 287 |
+
f'{datetime.datetime.now().strftime("%Y-%m-%d")}',
|
| 288 |
+
color='#e6edf3', fontsize=10, y=1.01
|
| 289 |
+
)
|
| 290 |
+
fig.text(0.5, 0.0,
|
| 291 |
+
'Baseline = model loaded 4-bit NF4 with no fine-tuning. '
|
| 292 |
+
'Score range [0.01, 0.99]. Tasks: easy (green), medium (yellow), hard (red), long-horizon (purple).',
|
| 293 |
+
ha='center', color='#8b949e', fontsize=7, style='italic')
|
| 294 |
+
plot_path = os.path.join(run_dir, 'baseline_plot.png')
|
| 295 |
+
plt.savefig(plot_path, dpi=130, bbox_inches='tight', facecolor=fig.get_facecolor())
|
| 296 |
+
plt.close()
|
| 297 |
+
print(f'[SAVED] {plot_path}')
|
| 298 |
+
|
| 299 |
+
# Upload run folder to HF Space repo
|
| 300 |
+
repo_run_path = run_dir.replace('/app/', '')
|
| 301 |
+
try:
|
| 302 |
+
from huggingface_hub import HfApi
|
| 303 |
+
api = HfApi()
|
| 304 |
+
api.upload_folder(
|
| 305 |
+
folder_path=run_dir,
|
| 306 |
+
path_in_repo=repo_run_path,
|
| 307 |
+
repo_id='Pathikreet/ap-commander-training',
|
| 308 |
+
repo_type='space',
|
| 309 |
+
commit_message=f'Baseline: {model_short} untrained {ts}',
|
| 310 |
+
)
|
| 311 |
+
print(f'[UPLOAD] {repo_run_path} β Pathikreet/ap-commander-training')
|
| 312 |
+
except Exception as e:
|
| 313 |
+
print(f'[UPLOAD] skipped: {e}')
|
| 314 |
+
|
| 315 |
+
print(f'\n[DONE] Results in {run_dir}')
|
| 316 |
+
return output
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == '__main__':
|
| 320 |
+
main()
|