| import json |
|
|
| path = r'c:\Users\vaibhav patil\.gemini\antigravity\scratch\financial-doc-extractor\scripts\financial-doc-training (1).ipynb' |
| with open(path, 'r', encoding='utf-8') as f: |
| nb = json.load(f) |
|
|
| new_source = """# COMPARISON TABLE |
| print("\\n" + "="*90) |
| print(" π MODEL COMPARISON: Base vs Fine-Tuned") |
| print("="*90) |
| print(f"{'Metric':<45} {'Base Qwen':>15} {'Fine-Tuned':>15} {'Ξ':>10}") |
| print(f"{'β'*90}") |
| |
| # Group 1: Structure metrics |
| for label, key in [("L1: Valid JSON Format", "L1_JSON"), ("L2: Schema Compliance", "L2_Schema"), |
| ("L4: Anomaly Detection F1", "L4_F1")]: |
| bv, fv = base_results.get(key,0), ft_results.get(key,0) |
| print(f"{label:<45} {bv:>14.1f}% {fv:>14.1f}% {'+' if fv-bv>=0 else ''}{fv-bv:>9.1f}%") |
| |
| print(f"{'β'*90}") |
| print(" π ACCURACY SUCCESS METRICS") |
| print(f"{'β'*90}") |
| |
| # Group 2: Accuracy metrics |
| bv_f, fv_f = base_results.get('L3_Fields',0), ft_results.get('L3_Fields',0) |
| print(f"{'L3: Field-Level Accuracy (Partial Credit)':<45} {bv_f:>14.1f}% {fv_f:>14.1f}% {'+' if fv_f-bv_f>=0 else ''}{fv_f-bv_f:>9.1f}%") |
| |
| bv, fv = base_results.get('L5_E2E',0), ft_results.get('L5_E2E',0) |
| print(f"{'β L5: End-to-End Success (100% Perfect Doc)':<45} {bv:>14.1f}% {fv:>14.1f}% {'+' if fv-bv>=0 else ''}{fv-bv:>9.1f}%") |
| |
| print(f"{'β'*90}") |
| print("\\nπ Metric Explanations:") |
| print(" β’ Field-Level Accuracy: Calculates overall accuracy per field. (e.g. 24/25 fields right = 96%)") |
| print(" β’ End-to-End Success : Extremely strict binary pass/fail. The ENTIRE document must have 0 errors.") |
| print(f"\\nEvaluated on {len(test_gt)} held-out test documents") |
| |
| # Save |
| with open('comparison_results.json', 'w') as f: |
| json.dump({'base': base_results, 'fine_tuned': ft_results, 'test_size': len(test_gt)}, f, indent=2) |
| print(f"β
Saved to comparison_results.json") |
| """ |
|
|
| modifications = 0 |
| for cell in nb.get('cells', []): |
| if cell['cell_type'] == 'code': |
| source = cell['source'] |
| if isinstance(source, list): |
| source_str = ''.join(source) |
| if '# COMPARISON TABLE' in source_str: |
| cell['source'] = [s + '\n' for s in new_source.split('\n')] |
| if cell['source']: |
| cell['source'][-1] = cell['source'][-1][:-1] |
| modifications += 1 |
| elif isinstance(source, str): |
| if '# COMPARISON TABLE' in source: |
| cell['source'] = new_source |
| modifications += 1 |
|
|
| if modifications > 0: |
| with open(path, 'w', encoding='utf-8') as f: |
| json.dump(nb, f) |
| print(f'Successfully updated COMPARISON TABLE format.') |
| else: |
| print('Could not find COMPARISON TABLE.') |
|
|