Spaces:
Sleeping
Sleeping
File size: 5,935 Bytes
0f23dca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# -*- coding: utf-8 -*-
"""
weights_sweep.py — Auto-tuning legal tag weights via W&B Sweeps
کارکرد اجرایی:
- هر ران: وزنها از config → نوشتن legal_entity_weights.json → اجرای GoldenBuilder روی یک زیرمجموعه کوچک →
محاسبهٔ pass_rate (نرخ قبولی گیت کیفیت) → لاگ متریکها و آرتیفکتها روی W&B.
- در پایان Sweep، از داشبورد W&B «بهترین Run» را انتخاب کنید و وزنها را تثبیت نمایید
(یا از آرتیفکت همان Run دانلود کنید و جایگزین legal_entity_weights.json نمایید).
پیشنیاز:
- فایل golden_builder.py در ریشه ریپو
- Secrets: WANDB_API_KEY در HF Spaces
- requirements: wandb, transformers, torch, ...
پارامترها (قابلتنظیم از طریق env یا UI):
- TUNE_DATA: مسیر فایل JSON/JSONL داده
- TUNE_TEXT_KEY: کلید متن در داده (پیشفرض "متن_کامل")
- TUNE_MAX_SAMPLES: تعداد نمونهٔ کوچک برای هر ران (پیشفرض 120)
- TUNE_BATCH: batch size Builder (پیشفرض 2)
- TUNE_COUNT: تعداد ران در sweep (پیشفرض 16)
- WANDB_PROJECT, WANDB_ENTITY: پروژه/ورکاسپیس W&B
"""
import os
import json
import random
from typing import Dict, List
import wandb
# فضای جستوجو: در صورت نیاز بازهها را سختگیرانهتر/وسیعتر کنید
SWEEP_CONFIG = {
"method": "bayes", # "random" یا "grid" هم قابل استفاده است
"metric": {"name": "pass_rate", "goal": "maximize"},
"parameters": {
"STATUTE": {"min": 0.8, "max": 1.4},
"COURT": {"min": 0.6, "max": 1.2},
"CRIME": {"min": 0.9, "max": 1.6},
"CIVIL": {"min": 0.5, "max": 1.2},
"PROCED": {"min": 0.5, "max": 1.0},
"PARTY": {"min": 0.4, "max": 0.9},
"BUSINESS": {"min": 0.4, "max": 0.9},
}
}
DEFAULT_TEXT_KEY = "متن_کامل"
def write_weights_file(weights: Dict[str, float], path: str = "legal_entity_weights.json"):
with open(path, "w", encoding="utf-8") as f:
json.dump({k: float(v) for k, v in weights.items()}, f, ensure_ascii=False, indent=2)
def sample_data(path: str, text_key: str, max_samples: int) -> List[dict]:
from golden_builder import load_json_or_jsonl
data = load_json_or_jsonl(path)
data = [r for r in data if isinstance(r, dict) and text_key in r and isinstance(r[text_key], str) and len(r[text_key].strip()) > 20]
random.shuffle(data)
return data[:max_samples]
def run_once(data_path: str, text_key: str, max_samples: int, batch_size: int):
"""
یک اجرای واحد Agent: وزنها ← Builder → pass_rate
"""
cfg = wandb.config
weights = {
"STATUTE": cfg.STATUTE,
"COURT": cfg.COURT,
"CRIME": cfg.CRIME,
"CIVIL": cfg.CIVIL,
"PROCED": cfg.PROCED,
"PARTY": cfg.PARTY,
"BUSINESS": cfg.BUSINESS,
}
write_weights_file(weights) # این فایل توسط GoldenBuilder خوانده میشود
from golden_builder import GoldenBuilder, save_jsonl
rows_in = sample_data(data_path, text_key, max_samples=max_samples)
if not rows_in:
wandb.log({"pass_rate": 0.0, "kept": 0, "processed": 0})
wandb.summary.update({"weights": weights})
return
# برای سرعت/پایداری: mt5-base کافی است؛ اگر مدل دیگری میخواهید، پارامتر کنید
gb = GoldenBuilder(model_name="google/mt5-base")
rows_out = gb.build(rows_in, text_key=text_key, batch_size=batch_size)
processed = len(rows_in)
kept = len(rows_out)
pass_rate = kept / max(processed, 1)
# لاگ متریکها + وزنها
wandb.log({
"pass_rate": pass_rate,
"kept": kept,
"processed": processed
})
wandb.summary.update({"weights": weights})
# آرتیفکت خروجی نمونه (اختیاری ولی مفید برای ارزیابی کیفی)
outp = f"/tmp/gb_out_{wandb.run.id}.jsonl"
save_jsonl(rows_out, outp)
art = wandb.Artifact("gb-sample", type="dataset")
art.add_file(outp)
wandb.log_artifact(art)
def run_sweep(
data_path: str,
text_key: str = DEFAULT_TEXT_KEY,
max_samples: int = 120,
batch_size: int = 2,
project: str = "mahoon-legal-ai",
entity: str = None,
count: int = 16
):
os.environ.setdefault("WANDB_PROJECT", project)
if entity: os.environ.setdefault("WANDB_ENTITY", entity)
# ایجاد Sweep
sweep_id = wandb.sweep(SWEEP_CONFIG, project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity))
def _agent():
wandb.init(project=os.getenv("WANDB_PROJECT", project),
entity=os.getenv("WANDB_ENTITY", entity),
name="weights-tune")
run_once(data_path=data_path, text_key=text_key, max_samples=max_samples, batch_size=batch_size)
# اجرای تعداد مشخصی Agent-run
wandb.agent(sweep_id, function=_agent, count=count)
if __name__ == "__main__":
# اجرای خط فرمان/محلی:
# export WANDB_API_KEY=<توکن واقعی>
# python weights_sweep.py
data = os.getenv("TUNE_DATA", "./sample.jsonl")
text_key = os.getenv("TUNE_TEXT_KEY", DEFAULT_TEXT_KEY)
max_samples = int(os.getenv("TUNE_MAX_SAMPLES", "120"))
count = int(os.getenv("TUNE_COUNT", "16"))
batch_size = int(os.getenv("TUNE_BATCH", "2"))
project = os.getenv("WANDB_PROJECT", "mahoon-legal-ai")
entity = os.getenv("WANDB_ENTITY", None)
run_sweep(
data_path=data,
text_key=text_key,
max_samples=max_samples,
batch_size=batch_size,
project=project,
entity=entity,
count=count
)
|