# -*- coding: utf-8 -*- """ Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B Training + Weight Tuning - پاسخ‌زایی: Qwen2.5-7B, Llama-3.1-8B, Mistral-7B (همه causal) - RAG: Chroma + BM25 + CrossEncoder reranker (gte-multilingual-reranker-base) - Dataset Ops: Builder (از golden_builder) + Cleaner/Deduper - Training: SFT/LoRA سبک روی causal + W&B logging/Artifacts - Tuning: Weight Tuning با W&B Sweep (weights_sweep.py) - UI: Gradio 5.47.0 نکته: در Settings → Secrets مقدار `WANDB_API_KEY` را ست کنید (مقدار واقعی؛ placeholder 🟡 نگذارید). """ from __future__ import annotations import os, sys, re, json, time, pickle, zipfile, warnings from dataclasses import dataclass, field from pathlib import Path from typing import List, Dict, Optional import numpy as np import torch from torch.utils.data import Dataset from sklearn.model_selection import train_test_split import gradio as gr warnings.filterwarnings("ignore") # ====== ML & NLP ====== import transformers as tf from transformers import ( AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, EarlyStoppingCallback ) # RAG stack import chromadb from rank_bm25 import BM25Okapi from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util # ========= Persian text normalization ========= ZWNJ = "\u200c" AR_DIGITS = "٠١٢٣٤٥٦٧٨٩" FA_DIGITS = "۰۱۲۳۴۵۶۷۸۹" EN_DIGITS = "0123456789" def normalize_fa(s: str) -> str: if not s: return s s = s.replace("\u064A", "ی").replace("\u0643", "ک") # ي/ك → ی/ک s = re.sub(r"[\u064B-\u065F\u0610-\u061A]", "", s) # حذف اعراب trans = {ord(a): e for a, e in zip(AR_DIGITS + FA_DIGITS, EN_DIGITS * 2)} s = s.translate(trans) s = re.sub(r"\s*‌\s*", ZWNJ, s) # ZWNJ s = re.sub(r"\s+", " ", s).strip() return s # ========================== # Configs # ========================== @dataclass class ModelConfig: model_name: str = "Qwen/Qwen2.5-7B-Instruct" max_input_length: int = 4096 max_new_tokens: int = 512 temperature: float = 0.7 top_p: float = 0.9 do_sample: bool = True gradient_checkpointing: bool = True @dataclass class RAGConfig: persist_dir: str = "./chroma_db" collection: str = "legal_articles" top_k: int = 8 similarity_threshold: float = 0.60 context_char_limit: int = 280 enable: bool = True reranker_name: str = "Alibaba-NLP/gte-multilingual-reranker-base" @dataclass class TrainConfig: base_model: str = "PartAI/Dorna-Llama3-8B-Instruct" alt_model_1: str = "zpm/Llama-3.1-PersianQA" hakim_model: str = "AI-Hoosh/HAKIM-7B" hooshvareh_model: str = "HooshvareLab/llama-fa-7b-instruct" output_dir: str = "./mahoon_causal_lora" seed: int = 42 test_size: float = 0.1 epochs: int = 2 batch_size: int = 2 grad_accum: int = 4 lr: float = 2e-4 warmup_ratio: float = 0.03 weight_decay: float = 0.0 logging_steps: int = 50 eval_strategy: str = "epoch" save_strategy: str = "epoch" save_total_limit: int = 2 report_to: str = "wandb" # W&B max_grad_norm: float = 1.0 use_4bit: bool = True # QLoRA 4-bit (در صورت افزودن PEFT/TRL) max_seq_len: int = 2048 @dataclass class SystemConfig: model: ModelConfig = field(default_factory=ModelConfig) rag: RAGConfig = field(default_factory=RAGConfig) train: TrainConfig = field(default_factory=TrainConfig) # ========================== # Helpers # ========================== def set_seed_all(seed: int = 42): import random random.seed(seed); np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def bf16_supported(): return torch.cuda.is_available() and getattr(torch.cuda, "is_bf16_supported", lambda: False)() def log_deps(): try: import accelerate, datasets print("[deps]", f"python={sys.version.split()[0]}", f"transformers={tf.__version__}", f"accelerate={accelerate.__version__}", f"datasets={datasets.__version__}", f"gradio={gr.__version__}", flush=True) except Exception as e: print("[deps] warn:", e, flush=True) # ========================== # RAG: Chroma + BM25 + CrossEncoder reranker # ========================== class LegalRAG: def __init__(self, cfg: RAGConfig): self.cfg = cfg self.client = None self.collection = None self.reranker: Optional[CrossEncoder] = None self.bm25 = None self.bm25_ids: List[str] = [] self.bm25_path = str(Path(self.cfg.persist_dir) / "bm25.pkl") def init(self): Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True) self.client = chromadb.PersistentClient(path=self.cfg.persist_dir) try: self.collection = self.client.get_or_create_collection(self.cfg.collection) except Exception: try: self.collection = self.client.get_collection(self.cfg.collection) except Exception: self.collection = self.client.create_collection(self.cfg.collection) # reranker try: dev = "cuda" if torch.cuda.is_available() else "cpu" self.reranker = CrossEncoder(self.cfg.reranker_name, device=dev) except Exception: self.reranker = None # BM25 if Path(self.bm25_path).exists(): with open(self.bm25_path, "rb") as f: obj = pickle.load(f) self.bm25 = obj["bm25"]; self.bm25_ids = obj["ids"] def _rebuild_bm25(self, ids: List[str], docs: List[str]): corpus = [normalize_fa(d).split() for d in docs] self.bm25 = BM25Okapi(corpus) self.bm25_ids = ids with open(self.bm25_path, "wb") as f: pickle.dump({"bm25": self.bm25, "ids": self.bm25_ids}, f) def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"): if not self.collection: self.init() ids, docs, metas = [], [], [] with open(jsonl_path, "r", encoding="utf-8") as f: for i, line in enumerate(f): s = line.strip() if not s: continue try: obj = json.loads(s) except: continue aid = str(obj.get(id_key, f"auto_{i}")) txt = normalize_fa(str(obj.get(text_key, "")).strip()) if not txt: continue ids.append(aid); docs.append(txt); metas.append({"article_id": aid}) if not ids: return "هیچ سندی برای ایندکس یافت نشد." self.collection.upsert(ids=ids, documents=docs, metadatas=metas) self._rebuild_bm25(ids, docs) return f"✅ {len(ids)} سند ایندکس شد (Dense+BM25)." def retrieve(self, query: str) -> List[Dict]: if not self.collection: return [] qn = normalize_fa(query) # Dense via Chroma try: res = self.collection.query( query_texts=[qn], n_results=max(self.cfg.top_k * 3, 20), include=["documents", "metadatas", "distances"], ) out = [] docs = res.get("documents", [[]])[0] metas = res.get("metadatas", [[]])[0] dists = res.get("distances", [[1.0]])[0] for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)): sim = 1.0 - float(dist) out.append({"article_id": (meta or {}).get("article_id", f"unk_{i}"), "text": doc, "similarity": sim}) except Exception: out = [] # BM25 bm25_hits = [] if self.bm25 is not None and self.bm25_ids: scores = self.bm25.get_scores(normalize_fa(qn).split()) idxs = np.argsort(scores)[::-1][:max(self.cfg.top_k * 3, 20)] smax = float(scores.max() + 1e-8) for j in idxs: aid = self.bm25_ids[int(j)] try: got = self.collection.get(ids=[aid]) tdoc = got["documents"][0] except Exception: tdoc = "" bm25_hits.append({"article_id": aid, "text": tdoc, "similarity": float(scores[j]) / smax}) # union by id pool: Dict[str, Dict] = {} for a in out + bm25_hits: if a["article_id"] not in pool or a.get("similarity", 0) > pool[a["article_id"]].get("similarity", 0): pool[a["article_id"]] = a merged = [a for a in pool.values() if a.get("text") and len(a["text"]) > 15] # threshold merged = [a for a in merged if a.get("similarity", 0) >= self.cfg.similarity_threshold] # rerank if self.reranker and merged: pairs = [(qn, a["text"]) for a in merged] scores = self.reranker.predict(pairs) for a, s in zip(merged, scores): a["score"] = float(s) merged = sorted(merged, key=lambda x: x.get("score", 0), reverse=True)[: self.cfg.top_k] else: merged = sorted(merged, key=lambda x: x.get("similarity", 0), reverse=True)[: self.cfg.top_k] return merged def build_context(self, arts: List[Dict]) -> str: if not arts: return "" bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts] return "مواد مرتبط:\n" + "\n".join(bullets) # ========= RAG bootstrap from repo ========= def parse_law_textfile_to_jsonl(txt_path: str, out_jsonl: str): pat = re.compile(r"(?:ماده|مادّه)\s+(\d+)\s*[:\-–]\s*(.+)") rows = [] with open(txt_path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if not s: continue m = pat.match(s) if not m: continue aid = m.group(1) body = m.group(2).strip() if len(body) < 12: continue rows.append({"article_id": aid, "text": normalize_fa(body)}) if not rows: raise RuntimeError("هیچ ماده‌ای با الگوی تعریف‌شده پیدا نشد.") with open(out_jsonl, "w", encoding="utf-8") as g: for r in rows: g.write(json.dumps(r, ensure_ascii=False) + "\n") return len(rows) def ensure_chroma_ready(persist_dir="./chroma_db", collection="legal_articles") -> str: Path(persist_dir).mkdir(parents=True, exist_ok=True) if any(Path(persist_dir).glob("*")): return f"ChromaDB موجود است." zip_path = Path("./chroma_legal_db.zip") if zip_path.exists(): try: with zipfile.ZipFile(zip_path, "r") as z: z.extractall(persist_dir) return "ChromaDB از zip بازیابی شد." except Exception: pass txt_path = Path("./all_legal_sentences.txt") if txt_path.exists(): n = parse_law_textfile_to_jsonl(str(txt_path), "./laws.jsonl") rag_local = LegalRAG(RAGConfig(persist_dir=persist_dir, collection=collection)) rag_local.init() msg = rag_local.index_jsonl("./laws.jsonl", id_key="article_id", text_key="text") return f"از متن خام {n} رکورد استخراج شد. {msg}" return "پایگاه RAG موجود نیست و منبع خامی هم برای ساخت پیدا نشد." # ========================== # Loader + Generator (Causal-only) # ========================== class CausalLoader: def __init__(self, mcfg: ModelConfig): self.cfg = mcfg self.tokenizer = None self.model = None def load(self, model_name: str): self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"): self.tokenizer.pad_token = self.tokenizer.eos_token kwargs = {} if torch.cuda.is_available(): kwargs["device_map"] = "auto" kwargs["torch_dtype"] = torch.bfloat16 if bf16_supported() else torch.float16 self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"): try: self.model.gradient_checkpointing_enable() except Exception: pass return self class Generator: def __init__(self, loader: CausalLoader, mcfg: ModelConfig): self.tk = loader.tokenizer self.model = loader.model self.cfg = mcfg def generate(self, question: str, context: str = "", system_prompt: str = "You are a helpful Persian legal assistant.") -> str: parts = [] if system_prompt: parts.append(f"<|system|>\n{system_prompt}") if context: parts.append(f"<|system|>\nاز منابع زیر استفاده کن و استنادی پاسخ بده:\n{context}") parts.append(f"<|user|>\n{question}") prompt = "\n".join(parts) + "\n<|assistant|>\n" enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length).to(self.model.device) with torch.no_grad(): out = self.model.generate( **enc, max_new_tokens=self.cfg.max_new_tokens, do_sample=self.cfg.do_sample, temperature=self.cfg.temperature, top_p=self.cfg.top_p, pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, ) return self.tk.decode(out[0], skip_special_tokens=True) # ========================== # Datasets & Trainer (Causal-only, W&B) # ========================== def read_jsonl_files(paths: List[str]) -> List[Dict]: data: List[Dict] = [] for p in paths: if not p: continue with open(p, 'r', encoding='utf-8') as f: for line in f: s = line.strip() if not s: continue try: data.append(json.loads(s)) except json.JSONDecodeError: continue return data class CausalJSONLDataset(Dataset): def __init__(self, data: List[Dict], tokenizer, max_len: int, rag: Optional[LegalRAG] = None, enhance_every:int = 8): self.tk = tokenizer self.max_len = max_len self.items = [] for i, ex in enumerate(data): src = normalize_fa(str(ex.get("input", "")).strip()) tgt = normalize_fa(str(ex.get("output", "")).strip()) if not src or not tgt: continue ctx = "" if rag and i % enhance_every == 0: arts = rag.retrieve(src) ctx = rag.build_context(arts) text = "" if ctx: text += f"<|system|>\nاز منابع زیر استفاده کن:\n{ctx}\n" text += f"<|system|>\nYou are a helpful Persian legal assistant.\n" text += f"<|user|>\n{src}\n<|assistant|>\n{tgt}" self.items.append(text) def __len__(self): return len(self.items) def __getitem__(self, idx): text = self.items[idx] enc = self.tk(text, max_length=self.max_len, padding="max_length", truncation=True) input_ids = torch.tensor(enc["input_ids"]) attn = torch.tensor(enc["attention_mask"]) labels = input_ids.clone(); labels[attn == 0] = -100 return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} def safe_training_args(**kwargs): return TrainingArguments(**kwargs) class TrainerManager: def __init__(self, syscfg: SystemConfig, loader: CausalLoader): self.cfg = syscfg self.loader = loader def train_causal(self, train_paths: List[str], use_rag: bool = True, use_wandb: bool = True, wandb_project: str = "mahoon-legal-ai", wandb_entity: str = "", run_name: str = "mahoon_causal_lora"): set_seed_all(self.cfg.train.seed) data = read_jsonl_files(train_paths) train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None if rag: rag.init() ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.train.max_seq_len, rag) ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.train.max_seq_len, None) fp16_ok = torch.cuda.is_available() and not bf16_supported() bf16_ok = bf16_supported() # ---------- W&B env ---------- if use_wandb: os.environ.setdefault("WANDB_PROJECT", wandb_project or "mahoon-legal-ai") if wandb_entity: os.environ.setdefault("WANDB_ENTITY", wandb_entity) os.environ.pop("WANDB_DISABLED", None) else: os.environ["WANDB_DISABLED"] = "true" args = safe_training_args( output_dir=self.cfg.train.output_dir, num_train_epochs=self.cfg.train.epochs, learning_rate=self.cfg.train.lr, per_device_train_batch_size=self.cfg.train.batch_size, per_device_eval_batch_size=self.cfg.train.batch_size, gradient_accumulation_steps=self.cfg.train.grad_accum, warmup_ratio=self.cfg.train.warmup_ratio, weight_decay=self.cfg.train.weight_decay, evaluation_strategy=self.cfg.train.eval_strategy, save_strategy=self.cfg.train.save_strategy, save_total_limit=self.cfg.train.save_total_limit, load_best_model_at_end=True, metric_for_best_model="eval_loss", logging_steps=self.cfg.train.logging_steps, report_to=(["wandb"] if use_wandb else ["none"]), run_name=run_name, fp16=fp16_ok, bf16=bf16_ok, max_grad_norm=self.cfg.train.max_grad_norm, ) callbacks = [EarlyStoppingCallback(early_stopping_patience=2)] try: if use_wandb: from transformers.integrations import WandbCallback callbacks.append(WandbCallback()) except Exception: pass trainer = Trainer( model=self.loader.model, args=args, train_dataset=ds_tr, eval_dataset=ds_va, tokenizer=self.loader.tokenizer, callbacks=callbacks, ) # Optional richer W&B init if use_wandb: try: import wandb wandb.init(project=os.getenv("WANDB_PROJECT", "mahoon-legal-ai"), entity=os.getenv("WANDB_ENTITY"), name=run_name, config={ "base_model": self.loader.model.name_or_path, "epochs": self.cfg.train.epochs, "batch": self.cfg.train.batch_size, "grad_accum": self.cfg.train.grad_accum, "lr": self.cfg.train.lr, "max_seq_len": self.cfg.train.max_seq_len, "use_rag": use_rag, }) except Exception: pass trainer.train() trainer.save_model(self.cfg.train.output_dir) self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) if use_wandb: try: import wandb art = wandb.Artifact("mahoon-model", type="model") art.add_dir(self.cfg.train.output_dir) wandb.log_artifact(art) wandb.finish() except Exception: pass # ========================== # Dataset utilities (Cleaner/Deduper) # ========================== def deduplicate_jsonl(in_path: str, out_path: str, sim_threshold: float = 0.90, text_keys=("input","output")) -> int: rows = [] with open(in_path, "r", encoding="utf-8") as f: for line in f: s = line.strip() if not s: continue try: obj = json.loads(s) except: continue for k in text_keys: if k in obj: obj[k] = normalize_fa(str(obj[k])) rows.append(obj) if not rows: raise RuntimeError("هیچ رکورد معتبری در ورودی نبود.") model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") embs = model.encode([r.get("input","") for r in rows], convert_to_tensor=True, show_progress_bar=False, normalize_embeddings=True) kept, seen = [], torch.zeros(len(rows), dtype=torch.bool) for i in range(len(rows)): if seen[i]: continue sims = st_util.cos_sim(embs[i], embs)[0] dup_idx = (sims >= sim_threshold).nonzero(as_tuple=True)[0].tolist() for j in dup_idx: seen[j] = True kept.append(rows[i]) with open(out_path, "w", encoding="utf-8") as g: for r in kept: g.write(json.dumps(r, ensure_ascii=False) + "\n") return len(kept) # ========================== # App (Gradio) # ========================== class LegalApp: def __init__(self, scfg: Optional[SystemConfig] = None): self.scfg = scfg or SystemConfig() self.rag = LegalRAG(self.scfg.rag) self.loader: Optional[CausalLoader] = None self.gen: Optional[Generator] = None def _file_paths(self, files: List[gr.File]) -> List[str]: paths = [] for f in (files or []): p = getattr(f, "name", None) or getattr(f, "path", None) if p: paths.append(p) return paths # Core def load(self, model_name: str): self.loader = CausalLoader(self.scfg.model).load(model_name) self.gen = Generator(self.loader, self.scfg.model) # RAG msg_rag = "RAG غیرفعال" if self.scfg.rag.enable: try: self.rag = LegalRAG(self.scfg.rag); self.rag.init() msg_rag = "RAG آماده است" except Exception as e: msg_rag = f"RAG خطا: {e}" return f"مدل بارگذاری شد: {model_name}\n{msg_rag}" def build_index(self, laws_file: gr.File, id_key: str, text_key: str): if not self.scfg.rag.enable: return "RAG غیرفعال است." try: self.rag.init() p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None) if not p: return "فایل قوانین معتبر نیست." return self.rag.index_jsonl(p, id_key=id_key, text_key=text_key) except Exception as e: return f"خطا در ایندکس: {e}" def answer(self, question: str, system_prompt: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float): if not question.strip(): return "لطفاً سوال خود را وارد کنید.", "" if not self.gen: return "ابتدا مدل را بارگذاری کنید.", "" self.scfg.model.max_new_tokens = int(max_new_tokens) self.scfg.model.temperature = float(temperature) self.scfg.model.top_p = float(top_p) arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else [] ctx = self.rag.build_context(arts) if arts else "" ans = self.gen.generate(question, ctx, system_prompt) refs = "" if arts: refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts]) return ans, refs def train(self, model_name: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float, use_wandb: bool, wandb_project: str, wandb_entity: str, run_name: str, progress=gr.Progress(track_tqdm=True)): progress(0.05, desc="راه‌اندازی") self.scfg.train.epochs = int(epochs) self.scfg.train.batch_size = int(batch) self.scfg.train.lr = float(lr) progress(0.10, desc="بارگذاری مدل/توکنایزر") self.loader = CausalLoader(self.scfg.model).load(model_name) paths = self._file_paths(files) if not paths: return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده." tm = TrainerManager(self.scfg, self.loader) set_seed_all(self.scfg.train.seed) progress(0.30, desc="آماده‌سازی دیتاست‌ها و RAG (اختیاری)") tm.train_causal( paths, use_rag=use_rag, use_wandb=use_wandb, wandb_project=wandb_project, wandb_entity=wandb_entity, run_name=run_name ) progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها") return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد." # Dataset Builder (از ماژول شما) def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None): try: from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder except Exception as e: return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}" path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None) if not path: return None, "⚠️ فایل ورودی معتبر نیست." try: data = load_json_or_jsonl(path) if max_samples and int(max_samples) > 0: data = data[:int(max_samples)] gb = GoldenBuilder(model_name=model_ckpt) rows = gb.build(data, text_key=text_key, batch_size=int(batch_size)) out_dir = "/tmp/mahoon_datasets"; Path(out_dir).mkdir(parents=True, exist_ok=True) out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl" save_jsonl(rows, out_path) return out_path, f"✅ {len(rows)} رکورد تولید شد." except Exception as e: return None, f"❌ خطا در ساخت دیتاست: {e}" # Weight Tuning (W&B Sweep) def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent): p = getattr(f, "name", None) or getattr(f, "path", None) if not p: return "⚠️ فایل داده نامعتبر است." try: from weights_sweep import run_sweep except Exception as e: return f"❌ weights_sweep.py یافت نشد/قابل import نیست: {e}" os.environ.setdefault("WANDB_PROJECT", proj or "mahoon-legal-ai") if ent: os.environ.setdefault("WANDB_ENTITY", ent) try: run_sweep(data_path=p, text_key=tk, max_samples=int(ms), batch_size=int(bs), project=proj, entity=ent, count=int(runs)) return "✅ Sweep اجرا شد. بهترین Run را در W&B بررسی و وزن‌ها را تثبیت کنید." except Exception as e: return f"❌ خطا در اجرای Sweep: {e}" # UI def build_ui(self): log_deps() try: print("[rag-bootstrap]", ensure_chroma_ready(self.scfg.rag.persist_dir, self.scfg.rag.collection), flush=True) except Exception as e: print("[rag-bootstrap] error:", e, flush=True) default_gen_models = { "Qwen2.5-7B Instruct": "Qwen/Qwen2.5-7B-Instruct", "Llama-3.1-8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", "Mistral-7B Instruct (v0.3)": "mistralai/Mistral-7B-Instruct-v0.3", } with gr.Blocks(title="ماحون — مشاور حقوقی (Causal-only)") as app: gr.Markdown("""

ماحون — Persian Legal (Causal-only)

Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning

""") # --- Tab: Consultation --- with gr.Tab("مشاوره"): with gr.Row(): gen_model_dd = gr.Dropdown(choices=list(default_gen_models.keys()), value="Qwen2.5-7B Instruct", label="مدل تولید") gen_model_id = gr.Textbox(value=default_gen_models["Qwen2.5-7B Instruct"], label="Model ID (قابل ویرایش)") with gr.Row(): use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟") persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر ChromaDB") collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن") with gr.Row(): top_k = gr.Slider(1, 15, value=self.scfg.rag.top_k, step=1, label="Top-K") threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="آستانه شباهت") load_btn = gr.Button("بارگذاری مدل", variant="primary") status = gr.Textbox(label="وضعیت", interactive=False) with gr.Accordion("پارامترهای تولید", open=False): system_prompt = gr.Textbox(value="You are a helpful Persian legal assistant.", label="System prompt") max_new_tokens = gr.Slider(64, 2048, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens") temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature") top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p") question = gr.Textbox(lines=3, label="سوال حقوقی") gr.Examples( examples=[ ["در صورت نقض قرارداد EPC چه راهکارهای حقوقی دارم؟"], ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"], ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"], ], inputs=question, label="نمونه پرسش‌ها" ) ask_btn = gr.Button("پرسش", variant="primary") answer = gr.Markdown(label="پاسخ"); refs = gr.Markdown(label="مواد قانونی مرتبط") # --- Tab: Indexing --- with gr.Tab("ایندکس قوانین"): gr.Markdown("فایل JSONL قوانین را بارگذاری و ایندکس کنید (کلیدها: `article_id`, `text`).") laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"]) id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده") text_key = gr.Textbox(value="text", label="کلید متن ماده") index_btn = gr.Button("ایندکس‌سازی قوانین"); index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False) # --- Tab: Dataset Builder --- with gr.Tab("ساخت دیتاست"): gr.Markdown("فایل خام (JSON/JSONL) → خروجی JSONL سازگار با `{input, output}` (از golden_builder).") raw_file = gr.File(label="فایل خام", file_types=[".json",".jsonl"]) with gr.Row(): ds_text_key = gr.Textbox(value="متن_کامل", label="کلید متن (text_key)") model_ckpt = gr.Dropdown( choices=["google/mt5-base", "google/flan-t5-base", "t5-base"], value="google/mt5-base", label="مدل خلاصه‌ساز برای ساخت دیتاست (فقط Builder)" ) with gr.Row(): ds_batch_size = gr.Slider(1, 16, value=4, step=1, label="Batch size") max_samples = gr.Number(value=0, label="حداکثر نمونه (۰=همه)") build_btn = gr.Button("ساخت دیتاست", variant="primary") out_file = gr.File(label="دانلود خروجی JSONL", interactive=False) build_status = gr.Textbox(label="وضعیت", interactive=False) # --- Tab: Dataset Cleaning --- with gr.Tab("پاکسازی دیتاست"): gr.Markdown("نرمال‌سازی فارسی + حذف تکراری‌های معنایی (cosine). ورودی: JSONL `{input, output}`.") raw_ds = gr.File(label="JSONL ورودی", file_types=[".jsonl"]) sim_th = gr.Slider(0.80, 0.98, value=0.90, step=0.01, label="آستانه شباهت (cosine)") clean_btn = gr.Button("اجرای پاکسازی", variant="primary") cleaned_out = gr.File(label="دانلود JSONL پاک", interactive=False) clean_status = gr.Markdown() # --- Tab: Training (W&B integrated) --- with gr.Tab("آموزش"): gr.Markdown("SFT/LoRA روی مدل‌های causal (فقط `{input, output}`) + W&B logging.") with gr.Row(): model_train_dd = gr.Dropdown( choices=[ "HAKIM (Editable ID below)", "Hooshvareh (Editable ID below)", "Dorna-Llama3-8B", "PersianQA-8B", "Custom (Editable ID below)" ], value="HAKIM (Editable ID below)", label="پروفایل مدل" ) model_train_id = gr.Textbox(value="AI-Hoosh/HAKIM-7B", label="HF Model ID (قابل ویرایش)") use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training") # W&B controls use_wandb = gr.Checkbox(value=True, label="W&B logging فعال باشد؟") wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT") wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)") run_name = gr.Textbox(value="mahoon_causal_lora", label="Run name") gr.Markdown("راهنما: در Settings → Secrets مقدار `WANDB_API_KEY` را تنظیم کنید (مقدار واقعی).") train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"]) with gr.Row(): epochs = gr.Slider(1, 6, value=2, step=1, label="epochs") batch = gr.Slider(1, 8, value=2, step=1, label="batch per device") lr = gr.Number(value=2e-4, label="learning rate") train_btn = gr.Button("شروع آموزش", variant="primary") train_status = gr.Textbox(label="وضعیت آموزش", interactive=False) # --- Tab: Weight Tuning --- with gr.Tab("Weight Tuning"): gr.Markdown("تیون خودکار وزن‌های موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.") tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"]) tune_text_key = gr.Textbox(value="متن_کامل", label="کلید متن") tune_max_samples = gr.Slider(50, 400, value=120, step=10, label="حداکثر نمونه") tune_runs = gr.Slider(4, 64, value=16, step=4, label="تعداد ران Sweep") tune_batch = gr.Slider(1, 4, value=2, step=1, label="batch size Builder") tune_proj = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT") tune_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)") run_tune = gr.Button("شروع Sweep", variant="primary") tune_status = gr.Markdown() # ---- Events ---- def _resolve_gen(choice: str, override: str) -> str: return override.strip() if override.strip() else default_gen_models[choice] def _on_load(choice, override, rag, pdir, coll, k, th): self.scfg.rag.enable = bool(rag) self.scfg.rag.persist_dir = pdir self.scfg.rag.collection = coll self.scfg.rag.top_k = int(k) self.scfg.rag.similarity_threshold = float(th) return self.load(_resolve_gen(choice, override)) load_btn.click(_on_load, inputs=[gen_model_dd, gen_model_id, use_rag, persist_dir, collection, top_k, threshold], outputs=status) ask_btn.click(lambda q, sys_p, rag, mnt, t, p: self.answer(q, sys_p, rag, mnt, t, p), inputs=[question, system_prompt, use_rag, max_new_tokens, temperature, top_p], outputs=[answer, refs]) index_btn.click(lambda f, ik, tk: self.build_index(f, ik, tk), inputs=[laws_file, id_key, text_key], outputs=index_status) build_btn.click(lambda rf, tk, ckpt, bs, mx: self.build_dataset(rf, tk, ckpt, bs, mx), inputs=[raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples], outputs=[out_file, build_status]) def _map_profile_to_id(profile: str, current_id: str) -> str: if current_id.strip(): return current_id.strip() if "Dorna" in profile: return "PartAI/Dorna-Llama3-8B-Instruct" if "PersianQA" in profile: return "zpm/Llama-3.1-PersianQA" if "HAKIM" in profile: return "AI-Hoosh/HAKIM-7B" if "Hooshvareh" in profile: return "HooshvareLab/llama-fa-7b-instruct" return "PartAI/Dorna-Llama3-8B-Instruct" train_btn.click( lambda prof, mid, files, rg, e, b, l, uw, wp, we, rn: self.train(_map_profile_to_id(prof, mid), files, rg, e, b, l, uw, wp, we, rn), inputs=[model_train_dd, model_train_id, train_files, use_rag_train, epochs, batch, lr, use_wandb, wandb_project, wandb_entity, run_name], outputs=train_status ) clean_btn.click( lambda f, th: ( (lambda _p, _out: ( _out, f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{deduplicate_jsonl(_p, _out, sim_threshold=float(th))}**" ) )( getattr(f, "name", None) or getattr(f, "path", None), f"/tmp/cleaned_{int(time.time())}.jsonl" ) if (getattr(f, 'name', None) or getattr(f, 'path', None)) else (None, "⚠️ فایل نامعتبر.") ), inputs=[raw_ds, sim_th], outputs=[cleaned_out, clean_status] ) run_tune.click( lambda f, tk, ms, runs, bs, proj, ent: self.run_weight_tune(f, tk, ms, runs, bs, proj, ent), inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity], outputs=tune_status ) return app # ========================== # Entrypoint # ========================== if __name__ == "__main__": app = LegalApp() ui = app.build_ui() try: ui = ui.queue() except TypeError: pass ui.launch(server_name="0.0.0.0", server_port=7860)