File size: 5,935 Bytes
0f23dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
"""
weights_sweep.py — Auto-tuning legal tag weights via W&B Sweeps

کارکرد اجرایی:
- هر ران: وزن‌ها از config → نوشتن legal_entity_weights.json → اجرای GoldenBuilder روی یک زیرمجموعه کوچک →
  محاسبهٔ pass_rate (نرخ قبولی گیت کیفیت) → لاگ متریک‌ها و آرتیفکت‌ها روی W&B.
- در پایان Sweep، از داشبورد W&B «بهترین Run» را انتخاب کنید و وزن‌ها را تثبیت نمایید
  (یا از آرتیفکت همان Run دانلود کنید و جایگزین legal_entity_weights.json نمایید).

پیش‌نیاز:
- فایل golden_builder.py در ریشه ریپو
- Secrets: WANDB_API_KEY در HF Spaces
- requirements: wandb, transformers, torch, ...

پارامترها (قابل‌تنظیم از طریق env یا UI):
- TUNE_DATA: مسیر فایل JSON/JSONL داده
- TUNE_TEXT_KEY: کلید متن در داده (پیش‌فرض "متن_کامل")
- TUNE_MAX_SAMPLES: تعداد نمونهٔ کوچک برای هر ران (پیش‌فرض 120)
- TUNE_BATCH: batch size Builder (پیش‌فرض 2)
- TUNE_COUNT: تعداد ران در sweep (پیش‌فرض 16)
- WANDB_PROJECT, WANDB_ENTITY: پروژه/ورک‌اسپیس W&B
"""

import os
import json
import random
from typing import Dict, List

import wandb

# فضای جست‌وجو: در صورت نیاز بازه‌ها را سخت‌گیرانه‌تر/وسیع‌تر کنید
SWEEP_CONFIG = {
    "method": "bayes",  # "random" یا "grid" هم قابل استفاده است
    "metric": {"name": "pass_rate", "goal": "maximize"},
    "parameters": {
        "STATUTE":  {"min": 0.8, "max": 1.4},
        "COURT":    {"min": 0.6, "max": 1.2},
        "CRIME":    {"min": 0.9, "max": 1.6},
        "CIVIL":    {"min": 0.5, "max": 1.2},
        "PROCED":   {"min": 0.5, "max": 1.0},
        "PARTY":    {"min": 0.4, "max": 0.9},
        "BUSINESS": {"min": 0.4, "max": 0.9},
    }
}

DEFAULT_TEXT_KEY = "متن_کامل"

def write_weights_file(weights: Dict[str, float], path: str = "legal_entity_weights.json"):
    with open(path, "w", encoding="utf-8") as f:
        json.dump({k: float(v) for k, v in weights.items()}, f, ensure_ascii=False, indent=2)

def sample_data(path: str, text_key: str, max_samples: int) -> List[dict]:
    from golden_builder import load_json_or_jsonl
    data = load_json_or_jsonl(path)
    data = [r for r in data if isinstance(r, dict) and text_key in r and isinstance(r[text_key], str) and len(r[text_key].strip()) > 20]
    random.shuffle(data)
    return data[:max_samples]

def run_once(data_path: str, text_key: str, max_samples: int, batch_size: int):
    """
    یک اجرای واحد Agent: وزن‌ها ← Builder → pass_rate
    """
    cfg = wandb.config
    weights = {
        "STATUTE": cfg.STATUTE,
        "COURT": cfg.COURT,
        "CRIME": cfg.CRIME,
        "CIVIL": cfg.CIVIL,
        "PROCED": cfg.PROCED,
        "PARTY": cfg.PARTY,
        "BUSINESS": cfg.BUSINESS,
    }
    write_weights_file(weights)  # این فایل توسط GoldenBuilder خوانده می‌شود

    from golden_builder import GoldenBuilder, save_jsonl
    rows_in = sample_data(data_path, text_key, max_samples=max_samples)

    if not rows_in:
        wandb.log({"pass_rate": 0.0, "kept": 0, "processed": 0})
        wandb.summary.update({"weights": weights})
        return

    # برای سرعت/پایداری: mt5-base کافی است؛ اگر مدل دیگری می‌خواهید، پارامتر کنید
    gb = GoldenBuilder(model_name="google/mt5-base")
    rows_out = gb.build(rows_in, text_key=text_key, batch_size=batch_size)

    processed = len(rows_in)
    kept = len(rows_out)
    pass_rate = kept / max(processed, 1)

    # لاگ متریک‌ها + وزن‌ها
    wandb.log({
        "pass_rate": pass_rate,
        "kept": kept,
        "processed": processed
    })
    wandb.summary.update({"weights": weights})

    # آرتیفکت خروجی نمونه (اختیاری ولی مفید برای ارزیابی کیفی)
    outp = f"/tmp/gb_out_{wandb.run.id}.jsonl"
    save_jsonl(rows_out, outp)
    art = wandb.Artifact("gb-sample", type="dataset")
    art.add_file(outp)
    wandb.log_artifact(art)

def run_sweep(
    data_path: str,
    text_key: str = DEFAULT_TEXT_KEY,
    max_samples: int = 120,
    batch_size: int = 2,
    project: str = "mahoon-legal-ai",
    entity: str = None,
    count: int = 16
):
    os.environ.setdefault("WANDB_PROJECT", project)
    if entity: os.environ.setdefault("WANDB_ENTITY", entity)

    # ایجاد Sweep
    sweep_id = wandb.sweep(SWEEP_CONFIG, project=os.getenv("WANDB_PROJECT", project), entity=os.getenv("WANDB_ENTITY", entity))

    def _agent():
        wandb.init(project=os.getenv("WANDB_PROJECT", project),
                   entity=os.getenv("WANDB_ENTITY", entity),
                   name="weights-tune")
        run_once(data_path=data_path, text_key=text_key, max_samples=max_samples, batch_size=batch_size)

    # اجرای تعداد مشخصی Agent-run
    wandb.agent(sweep_id, function=_agent, count=count)

if __name__ == "__main__":
    # اجرای خط فرمان/محلی:
    # export WANDB_API_KEY=<توکن واقعی>
    # python weights_sweep.py
    data = os.getenv("TUNE_DATA", "./sample.jsonl")
    text_key = os.getenv("TUNE_TEXT_KEY", DEFAULT_TEXT_KEY)
    max_samples = int(os.getenv("TUNE_MAX_SAMPLES", "120"))
    count = int(os.getenv("TUNE_COUNT", "16"))
    batch_size = int(os.getenv("TUNE_BATCH", "2"))
    project = os.getenv("WANDB_PROJECT", "mahoon-legal-ai")
    entity = os.getenv("WANDB_ENTITY", None)

    run_sweep(
        data_path=data,
        text_key=text_key,
        max_samples=max_samples,
        batch_size=batch_size,
        project=project,
        entity=entity,
        count=count
    )