File size: 2,678 Bytes
10ff0db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import json
path = r'c:\Users\vaibhav patil\.gemini\antigravity\scratch\financial-doc-extractor\scripts\financial-doc-training (1).ipynb'
with open(path, 'r', encoding='utf-8') as f:
nb = json.load(f)
new_source = """# COMPARISON TABLE
print("\\n" + "="*90)
print(" π MODEL COMPARISON: Base vs Fine-Tuned")
print("="*90)
print(f"{'Metric':<45} {'Base Qwen':>15} {'Fine-Tuned':>15} {'Ξ':>10}")
print(f"{'β'*90}")
# Group 1: Structure metrics
for label, key in [("L1: Valid JSON Format", "L1_JSON"), ("L2: Schema Compliance", "L2_Schema"),
("L4: Anomaly Detection F1", "L4_F1")]:
bv, fv = base_results.get(key,0), ft_results.get(key,0)
print(f"{label:<45} {bv:>14.1f}% {fv:>14.1f}% {'+' if fv-bv>=0 else ''}{fv-bv:>9.1f}%")
print(f"{'β'*90}")
print(" π ACCURACY SUCCESS METRICS")
print(f"{'β'*90}")
# Group 2: Accuracy metrics
bv_f, fv_f = base_results.get('L3_Fields',0), ft_results.get('L3_Fields',0)
print(f"{'L3: Field-Level Accuracy (Partial Credit)':<45} {bv_f:>14.1f}% {fv_f:>14.1f}% {'+' if fv_f-bv_f>=0 else ''}{fv_f-bv_f:>9.1f}%")
bv, fv = base_results.get('L5_E2E',0), ft_results.get('L5_E2E',0)
print(f"{'β L5: End-to-End Success (100% Perfect Doc)':<45} {bv:>14.1f}% {fv:>14.1f}% {'+' if fv-bv>=0 else ''}{fv-bv:>9.1f}%")
print(f"{'β'*90}")
print("\\nπ Metric Explanations:")
print(" β’ Field-Level Accuracy: Calculates overall accuracy per field. (e.g. 24/25 fields right = 96%)")
print(" β’ End-to-End Success : Extremely strict binary pass/fail. The ENTIRE document must have 0 errors.")
print(f"\\nEvaluated on {len(test_gt)} held-out test documents")
# Save
with open('comparison_results.json', 'w') as f:
json.dump({'base': base_results, 'fine_tuned': ft_results, 'test_size': len(test_gt)}, f, indent=2)
print(f"β
Saved to comparison_results.json")
"""
modifications = 0
for cell in nb.get('cells', []):
if cell['cell_type'] == 'code':
source = cell['source']
if isinstance(source, list):
source_str = ''.join(source)
if '# COMPARISON TABLE' in source_str:
cell['source'] = [s + '\n' for s in new_source.split('\n')]
if cell['source']:
cell['source'][-1] = cell['source'][-1][:-1]
modifications += 1
elif isinstance(source, str):
if '# COMPARISON TABLE' in source:
cell['source'] = new_source
modifications += 1
if modifications > 0:
with open(path, 'w', encoding='utf-8') as f:
json.dump(nb, f)
print(f'Successfully updated COMPARISON TABLE format.')
else:
print('Could not find COMPARISON TABLE.')
|