Pathikreet commited on
Commit
61ca2ab
Β·
verified Β·
1 Parent(s): a9752a6

Upload training/eval_baseline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/eval_baseline.py +320 -0
training/eval_baseline.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ eval_baseline.py β€” LLM baseline evaluation (no fine-tuning)
3
+ Loads MODEL_NAME in 4-bit, evaluates on all EVAL_TASKS, saves results to
4
+ runs/baselines/MODEL-DATETIME/ and uploads to HF Hub.
5
+
6
+ Usage (HF Spaces / Colab with GPU):
7
+ MODEL_NAME=Qwen/Qwen2.5-7B-Instruct python eval_baseline.py
8
+ HF_TOKEN=hf_... MODEL_NAME=meta-llama/Meta-Llama-3-8B-Instruct python eval_baseline.py
9
+ """
10
+ import os, json, re, datetime, time
11
+ import requests
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+
17
+ ENV_URL = 'https://pathikreet-ap-clerk-env.hf.space'
18
+ MODEL_NAME = os.environ.get('MODEL_NAME', 'Qwen/Qwen2.5-7B-Instruct')
19
+ SEEDS = [42, 99, 7] # 3 seeds per task β†’ mean score per task
20
+ EVAL_TASKS = [
21
+ 'easy_perfect_match', 'easy_no_po_found',
22
+ 'medium_quantity_shortfall', 'medium_price_discrepancy',
23
+ 'medium_split_delivery', 'medium_vendor_mismatch',
24
+ 'hard_policy_violation', 'hard_duplicate_invoice',
25
+ 'hard_partial_po_match', 'hard_tax_discrepancy',
26
+ 'long_invoice_dispute', 'long_policy_migration',
27
+ 'long_batch_reconciliation', 'long_manager_chain',
28
+ 'long_fraud_investigation', 'long_audit_trail',
29
+ 'long_multi_vendor_split',
30
+ ]
31
+ TASK_DIFFICULTY = {
32
+ 'easy_perfect_match': 'easy', 'easy_no_po_found': 'easy',
33
+ 'medium_quantity_shortfall':'medium','medium_price_discrepancy':'medium',
34
+ 'medium_split_delivery':'medium', 'medium_vendor_mismatch':'medium',
35
+ 'hard_policy_violation':'hard', 'hard_duplicate_invoice':'hard',
36
+ 'hard_partial_po_match':'hard', 'hard_tax_discrepancy':'hard',
37
+ 'long_invoice_dispute':'long', 'long_policy_migration':'long',
38
+ 'long_batch_reconciliation':'long', 'long_manager_chain':'long',
39
+ 'long_fraud_investigation':'long', 'long_audit_trail':'long',
40
+ 'long_multi_vendor_split':'long',
41
+ }
42
+ DIFF_COLORS = {'easy': '#3fb950', 'medium': '#d29922', 'hard': '#f85149', 'long': '#a371f7'}
43
+ DIFF_ORDER = ['easy', 'medium', 'hard', 'long']
44
+
45
+ SYSTEM_PROMPT = """You are an AI Accounts Payable Clerk. Review the invoice, PO, and GRN, then output ONLY valid JSON:
46
+ {"decision": "APPROVE_FULL"|"APPROVE_PARTIAL"|"REJECT"|"ESCALATE"|"QUERY_VENDOR",
47
+ "approved_amount": <float>,
48
+ "reason_code": "MATCH_CONFIRMED"|"QUANTITY_MISMATCH"|"PRICE_DISCREPANCY"|"POLICY_VIOLATION"|"NO_PO_FOUND"|"DUPLICATE_INVOICE"|"VENDOR_MISMATCH"|"TAX_DISCREPANCY"|"PENDING_CLARIFICATION"|"MANAGER_REVIEW",
49
+ "explanation": "<cite specific $ amounts>"}"""
50
+
51
+ VALID_DECISIONS = {'APPROVE_FULL','APPROVE_PARTIAL','REJECT','ESCALATE','QUERY_VENDOR','HOLD'}
52
+ VALID_REASON_CODES = {'MATCH_CONFIRMED','QUANTITY_MISMATCH','PRICE_DISCREPANCY','POLICY_VIOLATION',
53
+ 'NO_PO_FOUND','DUPLICATE_INVOICE','VENDOR_MISMATCH','TAX_DISCREPANCY',
54
+ 'PENDING_CLARIFICATION','MANAGER_REVIEW'}
55
+
56
+
57
+ def obs_to_prompt(obs):
58
+ inv = obs['invoice']
59
+ lines = '\n'.join(f" {li['description']}: qty={li['quantity']}, unit_price=${li['unit_price']:.2f}"
60
+ for li in inv.get('line_items', []))
61
+ pos = '\n'.join(
62
+ f" PO {p['po_number']} ({p['status']}) {p['vendor_name']}: " +
63
+ ', '.join(f"{l['description']} qty={l['ordered_quantity']} @${l['agreed_unit_price']:.2f}"
64
+ for l in p.get('lines', []))
65
+ for p in obs.get('purchase_orders', []))
66
+ grns = '\n'.join(
67
+ f" GRN {g['grn_id']}: " + ', '.join(f"{l['description']} recv={l['received_quantity']}"
68
+ for l in g.get('lines', []))
69
+ for g in obs.get('goods_receipts', []))
70
+ context = '\n'.join(f' {n}' for n in obs.get('context_notes', []))
71
+ paid = ', '.join(obs.get('paid_invoice_ids', []))
72
+ return (f"TASK: {obs['task_name']}\n{obs['task_description']}\n\n"
73
+ f"INVOICE {inv['invoice_id']} | {inv['vendor_name']} | ${inv['invoice_total']:,.2f}\n{lines}\n"
74
+ f"Freight: ${inv.get('freight_charge',0):.2f}\n\n"
75
+ f"PURCHASE ORDERS:\n{pos}\n\nGOODS RECEIPTS:\n{grns}\n"
76
+ + (f"PAID LEDGER: {paid}\n" if paid else "")
77
+ + (f"CONTEXT:\n{context}\n" if context else "")
78
+ + f"\nPOLICY:\n{obs['company_policy']}\n\nOutput JSON decision.")
79
+
80
+
81
+ def parse_action(raw):
82
+ clean = re.sub(r'```(?:json)?\s*|\s*```', '', raw).strip()
83
+ m = re.search(r'\{.*\}', clean, re.DOTALL)
84
+ if m:
85
+ try:
86
+ a = json.loads(m.group())
87
+ if (a.get('decision') in VALID_DECISIONS and
88
+ a.get('reason_code') in VALID_REASON_CODES and
89
+ isinstance(a.get('approved_amount'), (int, float)) and
90
+ len(a.get('explanation', '')) > 10):
91
+ return a, True
92
+ except Exception:
93
+ pass
94
+ return {'decision': 'REJECT', 'approved_amount': 0.0,
95
+ 'reason_code': 'NO_PO_FOUND', 'explanation': 'parse error'}, False
96
+
97
+
98
+ def eval_one(model, tokenizer, task_id, seed):
99
+ import torch
100
+ model.eval()
101
+ try:
102
+ reset = requests.post(f'{ENV_URL}/reset',
103
+ json={'task_id': task_id, 'seed': seed}, timeout=20).json()
104
+ obs, sid = reset['observation'], reset['session_id']
105
+ msgs = [{'role': 'system', 'content': SYSTEM_PROMPT},
106
+ {'role': 'user', 'content': obs_to_prompt(obs)}]
107
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
108
+ inputs = tokenizer(text, return_tensors='pt').to('cuda')
109
+ with torch.no_grad():
110
+ out = model.generate(**inputs, max_new_tokens=250, temperature=0.1, do_sample=True)
111
+ raw = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
112
+ action, fmt_ok = parse_action(raw)
113
+ score = float(requests.post(f'{ENV_URL}/step',
114
+ json={'session_id': sid, 'action': action},
115
+ timeout=20).json()['reward']['score'])
116
+ return score, raw[:120], action.get('decision', '?'), fmt_ok
117
+ except Exception as e:
118
+ print(f' error: {e}')
119
+ return 0.01, '', 'ERROR', False
120
+
121
+
122
+ def main():
123
+ hf_token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGING_FACE_HUB_TOKEN')
124
+ if hf_token:
125
+ from huggingface_hub import login
126
+ login(token=hf_token, add_to_git_credential=False)
127
+ print('[AUTH] Logged in.')
128
+
129
+ model_slug = MODEL_NAME.split('/')[-1].lower().replace('.', '-')
130
+ ts = datetime.datetime.now().strftime('%Y-%m-%d_%H%M')
131
+ run_dir = f'/app/runs/baselines/{model_slug}-{ts}'
132
+ os.makedirs(run_dir, exist_ok=True)
133
+ print(f'[RUN] {MODEL_NAME} β†’ {run_dir}')
134
+
135
+ print(f'[ENV] Waking {ENV_URL}...')
136
+ for attempt in range(12): # up to 2 min (12 Γ— 10 s)
137
+ try:
138
+ resp = requests.get(f'{ENV_URL}/health', timeout=30)
139
+ if resp.status_code == 200 and resp.text.strip().startswith('{'):
140
+ h = resp.json()
141
+ print(f"[ENV] status={h['status']} tasks={h.get('total_tasks')}")
142
+ break
143
+ print(f'[ENV] attempt {attempt+1}: not ready (status={resp.status_code}), waiting 10 s...')
144
+ except Exception as e:
145
+ print(f'[ENV] attempt {attempt+1}: {e}, waiting 10 s...')
146
+ time.sleep(10)
147
+ else:
148
+ raise RuntimeError(f'Environment at {ENV_URL} did not become healthy after 120 s.')
149
+
150
+ print(f'[MODEL] Loading {MODEL_NAME} (4-bit NF4, no LoRA)...')
151
+ import torch
152
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
153
+
154
+ bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type='nf4',
155
+ bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
156
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
157
+ if tokenizer.pad_token is None:
158
+ tokenizer.pad_token = tokenizer.eos_token
159
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb,
160
+ device_map='auto', trust_remote_code=True)
161
+ print('[MODEL] Ready.')
162
+
163
+ # Evaluate
164
+ results = {} # task_id β†’ {scores: [], mean: float, decisions: [], fmt_rate: float}
165
+ parse_failures = 0
166
+ print(f'\n[EVAL] {len(EVAL_TASKS)} tasks Γ— {len(SEEDS)} seeds = {len(EVAL_TASKS)*len(SEEDS)} episodes\n')
167
+
168
+ for task_id in EVAL_TASKS:
169
+ diff = TASK_DIFFICULTY[task_id]
170
+ scores, decisions, fmts = [], [], []
171
+ for seed in SEEDS:
172
+ score, raw, dec, fmt_ok = eval_one(model, tokenizer, task_id, seed)
173
+ scores.append(score)
174
+ decisions.append(dec)
175
+ fmts.append(fmt_ok)
176
+ if not fmt_ok:
177
+ parse_failures += 1
178
+ print(f' [{diff[:4]}] {task_id} seed={seed}: {score:.3f} {dec} fmt={fmt_ok}')
179
+ print(f' {raw[:90]}')
180
+ time.sleep(0.2)
181
+ results[task_id] = {
182
+ 'difficulty': diff,
183
+ 'scores': [round(s, 4) for s in scores],
184
+ 'mean': round(sum(scores) / len(scores), 4),
185
+ 'decisions': decisions,
186
+ 'fmt_rate': round(sum(fmts) / len(fmts), 3),
187
+ }
188
+
189
+ # Summary
190
+ print('\n' + '='*70)
191
+ by_diff = {}
192
+ for tid, v in results.items():
193
+ by_diff.setdefault(v['difficulty'], []).append(v['mean'])
194
+ for diff in DIFF_ORDER:
195
+ ms = by_diff.get(diff, [])
196
+ if ms:
197
+ print(f" {diff:<8}: mean={sum(ms)/len(ms):.3f} tasks={[round(m,3) for m in ms]}")
198
+ all_means = [v['mean'] for v in results.values()]
199
+ overall = sum(all_means) / len(all_means)
200
+ print(f" overall : mean={overall:.3f} parse_failures={parse_failures}/{len(EVAL_TASKS)*len(SEEDS)}")
201
+ print('='*70)
202
+
203
+ # Save JSON
204
+ output = {
205
+ 'run_type': 'llm_baseline_no_finetuning',
206
+ 'model': MODEL_NAME,
207
+ 'quantization': '4-bit NF4 (BitsAndBytes)',
208
+ 'lora': None,
209
+ 'timestamp': datetime.datetime.now().isoformat(),
210
+ 'run_dir': run_dir,
211
+ 'env_url': ENV_URL,
212
+ 'seeds': SEEDS,
213
+ 'eval_tasks': EVAL_TASKS,
214
+ 'overall_mean': round(overall, 4),
215
+ 'parse_failures': parse_failures,
216
+ 'tasks': results,
217
+ 'by_difficulty': {d: round(sum(ms)/len(ms), 4) for d, ms in by_diff.items()},
218
+ }
219
+ json_path = os.path.join(run_dir, 'baseline_results.json')
220
+ with open(json_path, 'w') as f:
221
+ json.dump(output, f, indent=2)
222
+ print(f'[SAVED] {json_path}')
223
+
224
+ # ── Plots ───────────────────────────────────────────────────────────────────
225
+ fig = plt.figure(figsize=(16, max(9, len(results) * 0.5 + 2)))
226
+ fig.patch.set_facecolor('#0d1117')
227
+ gs = fig.add_gridspec(1, 2, wspace=0.30)
228
+
229
+ def _dark(ax, title='', xlabel='', ylabel=''):
230
+ ax.set_facecolor('#161b22')
231
+ ax.tick_params(colors='#c9d1d9', labelsize=8)
232
+ for sp in ax.spines.values(): sp.set_color('#30363d')
233
+ ax.spines['top'].set_visible(False); ax.spines['right'].set_visible(False)
234
+ ax.yaxis.grid(True, color='#21262d', linewidth=0.7)
235
+ ax.set_axisbelow(True)
236
+ if title: ax.set_title(title, color='#e6edf3', fontsize=11, fontweight='bold', pad=8)
237
+ if xlabel: ax.set_xlabel(xlabel, color='#8b949e', fontsize=8)
238
+ if ylabel: ax.set_ylabel(ylabel, color='#8b949e', fontsize=8)
239
+
240
+ # Panel 1: Per-task mean score (horizontal bar), ordered by difficulty
241
+ ax1 = fig.add_subplot(gs[0, 0])
242
+ tasks = sorted(results.keys(),
243
+ key=lambda t: (DIFF_ORDER.index(results[t]['difficulty']), t))
244
+ means = [results[t]['mean'] for t in tasks]
245
+ colors = [DIFF_COLORS[results[t]['difficulty']] for t in tasks]
246
+ short = [t.replace('easy_','').replace('medium_','').replace('hard_','').replace('long_','')
247
+ .replace('_',' ').title() for t in tasks]
248
+ yp = range(len(tasks))
249
+ bars = ax1.barh(list(yp), means, color=colors, alpha=0.85, edgecolor='#0d1117')
250
+ ax1.set_yticks(list(yp))
251
+ ax1.set_yticklabels(short, fontsize=8)
252
+ ax1.set_xlim(0, 1.05)
253
+ ax1.axvline(overall, color='#f78166', linestyle='--', linewidth=1.2,
254
+ label=f'Overall mean: {overall:.3f}')
255
+ ax1.axvline(0.5, color='#484f58', linestyle=':', linewidth=1)
256
+ for i, m in enumerate(means):
257
+ ax1.text(m + 0.01, i, f'{m:.3f}', va='center', color='#c9d1d9', fontsize=8)
258
+ from matplotlib.patches import Patch
259
+ legend_els = [Patch(facecolor=c, label=d) for d, c in DIFF_COLORS.items()]
260
+ legend_els.append(plt.Line2D([0],[0], color='#f78166', linestyle='--',
261
+ label=f'Mean {overall:.3f}'))
262
+ ax1.legend(handles=legend_els, fontsize=8, facecolor='#161b22',
263
+ edgecolor='#30363d', labelcolor='#c9d1d9', loc='lower right')
264
+ _dark(ax1, f'Untrained Baseline β€” Per-Task Mean Score ({len(SEEDS)} seeds)',
265
+ xlabel='Mean Score [0.01 – 0.99]', ylabel='Task')
266
+
267
+ # Panel 2: Mean by difficulty
268
+ ax2 = fig.add_subplot(gs[0, 1])
269
+ diffs = [d for d in DIFF_ORDER if d in by_diff]
270
+ d_means = [sum(by_diff.get(d, [0])) / max(1, len(by_diff.get(d, [0]))) for d in diffs]
271
+ d_colors = [DIFF_COLORS[d] for d in diffs]
272
+ bars2 = ax2.bar(diffs, d_means, color=d_colors, alpha=0.85, edgecolor='#0d1117', width=0.5)
273
+ for i, (d, m) in enumerate(zip(diffs, d_means)):
274
+ ax2.text(i, m + 0.02, f'{m:.3f}', ha='center', color='#c9d1d9', fontsize=10,
275
+ fontweight='bold')
276
+ ax2.set_ylim(0, 1.05)
277
+ ax2.axhline(overall, color='#f78166', linestyle='--', linewidth=1,
278
+ label=f'Overall {overall:.3f}')
279
+ ax2.legend(fontsize=8, facecolor='#161b22', edgecolor='#30363d', labelcolor='#c9d1d9')
280
+ _dark(ax2, 'Mean Score by Difficulty Tier',
281
+ xlabel='Difficulty Tier', ylabel='Mean Score [0.01 – 0.99]')
282
+
283
+ model_short = MODEL_NAME.split('/')[-1]
284
+ fig.suptitle(
285
+ f'{model_short} β€” Untrained Baseline | 4-bit NF4 | {len(SEEDS)} seeds | '
286
+ f'{len(EVAL_TASKS)} tasks | overall={overall:.3f} | '
287
+ f'{datetime.datetime.now().strftime("%Y-%m-%d")}',
288
+ color='#e6edf3', fontsize=10, y=1.01
289
+ )
290
+ fig.text(0.5, 0.0,
291
+ 'Baseline = model loaded 4-bit NF4 with no fine-tuning. '
292
+ 'Score range [0.01, 0.99]. Tasks: easy (green), medium (yellow), hard (red), long-horizon (purple).',
293
+ ha='center', color='#8b949e', fontsize=7, style='italic')
294
+ plot_path = os.path.join(run_dir, 'baseline_plot.png')
295
+ plt.savefig(plot_path, dpi=130, bbox_inches='tight', facecolor=fig.get_facecolor())
296
+ plt.close()
297
+ print(f'[SAVED] {plot_path}')
298
+
299
+ # Upload run folder to HF Space repo
300
+ repo_run_path = run_dir.replace('/app/', '')
301
+ try:
302
+ from huggingface_hub import HfApi
303
+ api = HfApi()
304
+ api.upload_folder(
305
+ folder_path=run_dir,
306
+ path_in_repo=repo_run_path,
307
+ repo_id='Pathikreet/ap-commander-training',
308
+ repo_type='space',
309
+ commit_message=f'Baseline: {model_short} untrained {ts}',
310
+ )
311
+ print(f'[UPLOAD] {repo_run_path} β†’ Pathikreet/ap-commander-training')
312
+ except Exception as e:
313
+ print(f'[UPLOAD] skipped: {e}')
314
+
315
+ print(f'\n[DONE] Results in {run_dir}')
316
+ return output
317
+
318
+
319
+ if __name__ == '__main__':
320
+ main()