financial-intelligence-ai / scripts /fix_comparison_table.py
Vaibuzzz's picture
Upload folder using huggingface_hub
10ff0db verified
import json
path = r'c:\Users\vaibhav patil\.gemini\antigravity\scratch\financial-doc-extractor\scripts\financial-doc-training (1).ipynb'
with open(path, 'r', encoding='utf-8') as f:
nb = json.load(f)
new_source = """# COMPARISON TABLE
print("\\n" + "="*90)
print(" πŸ“Š MODEL COMPARISON: Base vs Fine-Tuned")
print("="*90)
print(f"{'Metric':<45} {'Base Qwen':>15} {'Fine-Tuned':>15} {'Ξ”':>10}")
print(f"{'─'*90}")
# Group 1: Structure metrics
for label, key in [("L1: Valid JSON Format", "L1_JSON"), ("L2: Schema Compliance", "L2_Schema"),
("L4: Anomaly Detection F1", "L4_F1")]:
bv, fv = base_results.get(key,0), ft_results.get(key,0)
print(f"{label:<45} {bv:>14.1f}% {fv:>14.1f}% {'+' if fv-bv>=0 else ''}{fv-bv:>9.1f}%")
print(f"{'─'*90}")
print(" πŸ† ACCURACY SUCCESS METRICS")
print(f"{'─'*90}")
# Group 2: Accuracy metrics
bv_f, fv_f = base_results.get('L3_Fields',0), ft_results.get('L3_Fields',0)
print(f"{'L3: Field-Level Accuracy (Partial Credit)':<45} {bv_f:>14.1f}% {fv_f:>14.1f}% {'+' if fv_f-bv_f>=0 else ''}{fv_f-bv_f:>9.1f}%")
bv, fv = base_results.get('L5_E2E',0), ft_results.get('L5_E2E',0)
print(f"{'⭐ L5: End-to-End Success (100% Perfect Doc)':<45} {bv:>14.1f}% {fv:>14.1f}% {'+' if fv-bv>=0 else ''}{fv-bv:>9.1f}%")
print(f"{'─'*90}")
print("\\nπŸ“ Metric Explanations:")
print(" β€’ Field-Level Accuracy: Calculates overall accuracy per field. (e.g. 24/25 fields right = 96%)")
print(" β€’ End-to-End Success : Extremely strict binary pass/fail. The ENTIRE document must have 0 errors.")
print(f"\\nEvaluated on {len(test_gt)} held-out test documents")
# Save
with open('comparison_results.json', 'w') as f:
json.dump({'base': base_results, 'fine_tuned': ft_results, 'test_size': len(test_gt)}, f, indent=2)
print(f"βœ… Saved to comparison_results.json")
"""
modifications = 0
for cell in nb.get('cells', []):
if cell['cell_type'] == 'code':
source = cell['source']
if isinstance(source, list):
source_str = ''.join(source)
if '# COMPARISON TABLE' in source_str:
cell['source'] = [s + '\n' for s in new_source.split('\n')]
if cell['source']:
cell['source'][-1] = cell['source'][-1][:-1]
modifications += 1
elif isinstance(source, str):
if '# COMPARISON TABLE' in source:
cell['source'] = new_source
modifications += 1
if modifications > 0:
with open(path, 'w', encoding='utf-8') as f:
json.dump(nb, f)
print(f'Successfully updated COMPARISON TABLE format.')
else:
print('Could not find COMPARISON TABLE.')