Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Mahoun — Legal AI (RAG + Training + Metrics) for HF Spaces / Gradio 5 | |
| - سازگار با Gradio 5.x و Transformers >= 4.44 | |
| - TrainingArguments ایمن با عقبسازگاری (safe_training_args) | |
| - RAG با ChromaDB + ایندکسسازی JSONL قوانین | |
| - متریکها: ROUGE-L (seq2seq) و F1 ساده (causal) | |
| - ماسک پدینگ روی labels در معماری علّی | |
| - Progress بهصورت DI: progress=gr.Progress(track_tqdm=True) | |
| ساختار ورودی دیتاست آموزش: | |
| JSONL با کلیدهای "input" و "output" | |
| ساختار ورودی قوانین برای ایندکس: | |
| JSONL با کلیدهای (پیشفرض) "article_id" و "text" | |
| """ | |
| from __future__ import annotations | |
| import os, sys, json, warnings | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from sklearn.model_selection import train_test_split | |
| import gradio as gr | |
| from packaging import version | |
| import transformers as tf | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| AutoModelForCausalLM, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback, | |
| DataCollatorForSeq2Seq, | |
| ) | |
| # RAG stack | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| # Optional metrics | |
| try: | |
| from evaluate import load as eval_load | |
| except Exception: | |
| eval_load = None | |
| warnings.filterwarnings("ignore") | |
| # ========================== | |
| # Config | |
| # ========================== | |
| class ModelConfig: | |
| model_name: str = "google/mt5-base" | |
| architecture: str = "seq2seq" # "seq2seq" | "causal" | |
| max_input_length: int = 1024 | |
| max_target_length: int = 512 | |
| max_new_tokens: int = 384 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| num_beams: int = 4 | |
| gradient_checkpointing: bool = True | |
| class RAGConfig: | |
| embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| persist_dir: str = "./chroma_db" | |
| collection: str = "legal_articles" | |
| top_k: int = 5 | |
| similarity_threshold: float = 0.66 # 0..1 | |
| context_char_limit: int = 300 | |
| enable: bool = True | |
| class TrainConfig: | |
| output_dir: str = "./mahoon_model" | |
| seed: int = 42 | |
| test_size: float = 0.1 | |
| epochs: int = 3 | |
| batch_size: int = 2 | |
| grad_accum: int = 2 | |
| lr: float = 3e-5 | |
| use_bf16: bool = True | |
| weight_decay: float = 0.01 | |
| warmup_ratio: float = 0.05 | |
| logging_steps: int = 50 | |
| eval_strategy: str = "epoch" # "steps" | "epoch" | |
| save_strategy: str = "epoch" | |
| save_total_limit: int = 2 | |
| report_to: str = "none" # "none" | "wandb" | |
| max_grad_norm: float = 1.0 | |
| class SystemConfig: | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| rag: RAGConfig = field(default_factory=RAGConfig) | |
| train: TrainConfig = field(default_factory=TrainConfig) | |
| # ========================== | |
| # Utils | |
| # ========================== | |
| def set_seed_all(seed: int = 42): | |
| import random | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def log_deps(): | |
| try: | |
| import accelerate, datasets | |
| print("[deps]", | |
| f"python={sys.version.split()[0]}", | |
| f"transformers={tf.__version__}", | |
| f"accelerate={accelerate.__version__}", | |
| f"datasets={datasets.__version__}", | |
| f"gradio={gr.__version__}", | |
| flush=True) | |
| except Exception as e: | |
| print("[deps] warn:", e, flush=True) | |
| def bf16_supported(): | |
| return torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() | |
| def safe_training_args(**kwargs): | |
| """ | |
| Wrapper برای سازگاری با نسخههای قدیمیتر Transformers (قبل از 4.4): | |
| - evaluation_strategy -> evaluate_during_training | |
| - حذف کلیدهای جدید که ممکن است ناشناخته باشند | |
| """ | |
| tf_ver = version.parse(tf.__version__) | |
| k = dict(kwargs) | |
| if tf_ver < version.parse("4.4.0"): | |
| eval_strat = k.pop("evaluation_strategy", None) | |
| k["evaluate_during_training"] = bool(eval_strat and str(eval_strat).lower() != "no") | |
| for rm in ["save_strategy","load_best_model_at_end","metric_for_best_model", | |
| "greater_is_better","predict_with_generate","generation_max_length", | |
| "generation_num_beams","report_to","max_grad_norm"]: | |
| k.pop(rm, None) | |
| return TrainingArguments(**k) | |
| # ========================== | |
| # RAG | |
| # ========================== | |
| class LegalRAG: | |
| def __init__(self, cfg: RAGConfig): | |
| self.cfg = cfg | |
| self.client = None | |
| self.collection = None | |
| self.embedder: Optional[SentenceTransformer] = None | |
| def init(self): | |
| Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True) | |
| self.client = chromadb.PersistentClient(path=self.cfg.persist_dir) | |
| try: | |
| self.collection = self.client.get_or_create_collection(self.cfg.collection) | |
| except Exception: | |
| try: | |
| self.collection = self.client.get_collection(self.cfg.collection) | |
| except Exception: | |
| self.collection = self.client.create_collection(self.cfg.collection) | |
| self.embedder = SentenceTransformer(self.cfg.embedding_model) | |
| def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"): | |
| """ایندکسسازی اولیه قوانین از JSONL: هر خط یک شیء {article_id, text, ...}.""" | |
| if not self.collection or not self.embedder: | |
| self.init() | |
| ids, docs, metas = [], [], [] | |
| with open(jsonl_path, "r", encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| s = line.strip() | |
| if not s: | |
| continue | |
| try: | |
| obj = json.loads(s) | |
| except: | |
| continue | |
| aid = str(obj.get(id_key, f"auto_{i}")) | |
| txt = str(obj.get(text_key, "")).strip() | |
| if not txt: | |
| continue | |
| ids.append(aid) | |
| docs.append(txt) | |
| metas.append({"article_id": aid}) | |
| if not ids: | |
| return "هیچ سندی برای ایندکس پیدا نشد." | |
| self.collection.upsert(ids=ids, documents=docs, metadatas=metas) | |
| return f"✅ {len(ids)} سند قانونی ایندکس شد." | |
| def retrieve(self, query: str) -> List[Dict]: | |
| if not self.collection: | |
| return [] | |
| try: | |
| res = self.collection.query( | |
| query_texts=[query], | |
| n_results=self.cfg.top_k, | |
| include=["documents","metadatas","distances"], | |
| ) | |
| out = [] | |
| docs = res.get("documents", [[]])[0] | |
| metas = res.get("metadatas", [[]])[0] | |
| dists = res.get("distances", [[1.0]])[0] | |
| for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)): | |
| sim = 1.0 - float(dist) | |
| if sim >= self.cfg.similarity_threshold: | |
| out.append({ | |
| "article_id": (meta or {}).get("article_id", f"unk_{i}"), | |
| "text": doc, | |
| "similarity": sim, | |
| }) | |
| return out | |
| except Exception: | |
| return [] | |
| def build_context(self, arts: List[Dict]) -> str: | |
| if not arts: | |
| return "" | |
| bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts] | |
| return "مواد مرتبط:\n" + "\n".join(bullets) | |
| # ========================== | |
| # Loader + Generator | |
| # ========================== | |
| class ModelLoader: | |
| def __init__(self, mcfg: ModelConfig): | |
| self.cfg = mcfg | |
| self.tokenizer = None | |
| self.model = None | |
| def load(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name) | |
| # dtype انتخاب هوشمند | |
| use_bf16 = bf16_supported() and self.cfg.gradient_checkpointing | |
| dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else None) | |
| model_kwargs = {"torch_dtype": dtype} | |
| if torch.cuda.is_available(): | |
| model_kwargs["device_map"] = "auto" | |
| if self.cfg.architecture == "seq2seq": | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(self.cfg.model_name, **model_kwargs) | |
| elif self.cfg.architecture == "causal": | |
| self.model = AutoModelForCausalLM.from_pretrained(self.cfg.model_name, **model_kwargs) | |
| if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"): | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| else: | |
| raise ValueError("Unsupported architecture") | |
| if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"): | |
| try: | |
| self.model.gradient_checkpointing_enable() | |
| except Exception: | |
| pass | |
| return self | |
| class Generator: | |
| def __init__(self, loader: ModelLoader, mcfg: ModelConfig): | |
| self.tk = loader.tokenizer | |
| self.model = loader.model | |
| self.cfg = mcfg | |
| def generate(self, question: str, context: str = "") -> str: | |
| if self.cfg.architecture == "seq2seq": | |
| inp = f"{context}\nسوال: {question}" if context else f"سوال: {question}" | |
| enc = self.tk(inp, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length) | |
| enc = {k: v.to(self.model.device) for k,v in enc.items()} | |
| out = self.model.generate( | |
| **enc, | |
| max_length=self.cfg.max_target_length, | |
| num_beams=self.cfg.num_beams, | |
| early_stopping=True, | |
| ) | |
| else: | |
| prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:" | |
| enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length) | |
| enc = {k: v.to(self.model.device) for k,v in enc.items()} | |
| out = self.model.generate( | |
| **enc, | |
| max_new_tokens=self.cfg.max_new_tokens, | |
| do_sample=True, | |
| temperature=self.cfg.temperature, | |
| top_p=self.cfg.top_p, | |
| pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, | |
| ) | |
| return self.tk.decode(out[0], skip_special_tokens=True) | |
| # ========================== | |
| # Datasets | |
| # ========================== | |
| class Seq2SeqJSONLDataset(Dataset): | |
| def __init__(self, data: List[Dict], tokenizer, max_inp: int, max_tgt: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10): | |
| self.tk = tokenizer | |
| self.max_inp = max_inp | |
| self.max_tgt = max_tgt | |
| self.items = [] | |
| for i, ex in enumerate(data): | |
| src = str(ex.get("input", "")).strip() | |
| tgt = str(ex.get("output", "")).strip() | |
| if not src or not tgt: | |
| continue | |
| inp = src | |
| if rag and i % enhance_every == 0: | |
| arts = rag.retrieve(src) | |
| ctx = rag.build_context(arts) | |
| if ctx: | |
| inp = f"<CONTEXT>{ctx}</CONTEXT>\n<QUESTION>{src}</QUESTION>" | |
| self.items.append((inp, tgt)) | |
| def __len__(self): | |
| return len(self.items) | |
| def __getitem__(self, idx): | |
| inp, tgt = self.items[idx] | |
| model_inputs = self.tk(inp, max_length=self.max_inp, padding="max_length", truncation=True) | |
| labels = self.tk(text_target=tgt, max_length=self.max_tgt, padding="max_length", truncation=True) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| class CausalJSONLDataset(Dataset): | |
| def __init__(self, data: List[Dict], tokenizer, max_inp: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10): | |
| self.tk = tokenizer | |
| self.max_inp = max_inp | |
| self.items = [] | |
| for i, ex in enumerate(data): | |
| src = str(ex.get("input", "")).strip() | |
| tgt = str(ex.get("output", "")).strip() | |
| if not src or not tgt: | |
| continue | |
| ctx = "" | |
| if rag and i % enhance_every == 0: | |
| arts = rag.retrieve(src) | |
| ctx = rag.build_context(arts) | |
| text = f"{ctx}\nسوال: {src}\nپاسخ: {tgt}" if ctx else f"سوال: {src}\nپاسخ: {tgt}" | |
| self.items.append(text) | |
| def __len__(self): | |
| return len(self.items) | |
| def __getitem__(self, idx): | |
| text = self.items[idx] | |
| enc = self.tk(text, max_length=self.max_inp, padding="max_length", truncation=True) | |
| input_ids = torch.tensor(enc["input_ids"]) | |
| attn = torch.tensor(enc["attention_mask"]) | |
| labels = input_ids.clone() | |
| labels[attn == 0] = -100 # padding mask for loss | |
| return {"input_ids": input_ids, "attention_mask": attn, "labels": labels} | |
| # ========================== | |
| # Metrics | |
| # ========================== | |
| def build_metrics_fn(arch: str, tokenizer): | |
| rouge = eval_load("rouge") if eval_load else None | |
| def _postprocess(preds): | |
| if isinstance(preds, (list, tuple)): | |
| return [p.strip() for p in preds] | |
| return preds | |
| def compute_metrics_seq2seq(eval_pred): | |
| if rouge is None: | |
| return {"rougeL": 0.0} | |
| preds, labels = eval_pred | |
| if isinstance(preds, tuple): | |
| preds = preds[0] | |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| decoded_preds = _postprocess(decoded_preds) | |
| decoded_labels = _postprocess(decoded_labels) | |
| r = rouge.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rougeL"]) | |
| return {"rougeL": float(r.get("rougeL", 0.0))} | |
| def compute_metrics_causal(eval_pred): | |
| preds, labels = eval_pred | |
| if isinstance(preds, tuple): | |
| preds = preds[0] | |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) | |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) | |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) | |
| tp = fp = fn = 0 | |
| for p, g in zip(decoded_preds, decoded_labels): | |
| p_set, g_set = set(p.split()), set(g.split()) | |
| tp += len(p_set & g_set) | |
| fp += len(p_set - g_set) | |
| fn += len(g_set - p_set) | |
| precision = tp / (tp + fp + 1e-8) | |
| recall = tp / (tp + fn + 1e-8) | |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |
| return {"f1_simple": float(f1)} | |
| return compute_metrics_seq2seq if arch == "seq2seq" else compute_metrics_causal | |
| # ========================== | |
| # Trainer Manager | |
| # ========================== | |
| def read_jsonl_files(paths: List[str]) -> List[Dict]: | |
| data: List[Dict] = [] | |
| for p in paths: | |
| if not p: | |
| continue | |
| with open(p, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| s = line.strip() | |
| if not s: | |
| continue | |
| try: | |
| obj = json.loads(s) | |
| data.append(obj) | |
| except json.JSONDecodeError: | |
| continue | |
| return data | |
| class TrainerManager: | |
| def __init__(self, syscfg: SystemConfig, loader: ModelLoader): | |
| self.cfg = syscfg | |
| self.loader = loader | |
| def _args_common(self, is_seq2seq: bool): | |
| fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16) | |
| bf16_ok = bf16_supported() and self.cfg.train.use_bf16 | |
| args = safe_training_args( | |
| output_dir=self.cfg.train.output_dir, | |
| num_train_epochs=self.cfg.train.epochs, | |
| learning_rate=self.cfg.train.lr, | |
| per_device_train_batch_size=self.cfg.train.batch_size, | |
| per_device_eval_batch_size=self.cfg.train.batch_size, | |
| gradient_accumulation_steps=self.cfg.train.grad_accum, | |
| warmup_ratio=self.cfg.train.warmup_ratio, | |
| weight_decay=self.cfg.train.weight_decay, | |
| evaluation_strategy=self.cfg.train.eval_strategy, | |
| save_strategy=self.cfg.train.save_strategy, | |
| save_total_limit=self.cfg.train.save_total_limit, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| logging_steps=self.cfg.train.logging_steps, | |
| report_to=([] if self.cfg.train.report_to == "none" else [self.cfg.train.report_to]), | |
| fp16=fp16_ok, | |
| bf16=bf16_ok, | |
| max_grad_norm=self.cfg.train.max_grad_norm, | |
| **({ | |
| "predict_with_generate": True, | |
| "generation_max_length": self.cfg.model.max_target_length, | |
| "generation_num_beams": self.cfg.model.num_beams | |
| } if is_seq2seq else {}) | |
| ) | |
| return args | |
| def train_seq2seq(self, train_paths: List[str], use_rag: bool = True): | |
| set_seed_all(self.cfg.train.seed) | |
| data = read_jsonl_files(train_paths) | |
| train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) | |
| rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None | |
| if rag: | |
| rag.init() | |
| ds_tr = Seq2SeqJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, rag) | |
| ds_va = Seq2SeqJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, None) | |
| collator = DataCollatorForSeq2Seq(tokenizer=self.loader.tokenizer, model=self.loader.model) | |
| args = self._args_common(is_seq2seq=True) | |
| trainer = Trainer( | |
| model=self.loader.model, | |
| args=args, | |
| train_dataset=ds_tr, | |
| eval_dataset=ds_va, | |
| data_collator=collator, | |
| tokenizer=self.loader.tokenizer, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], | |
| compute_metrics=build_metrics_fn("seq2seq", self.loader.tokenizer) | |
| ) | |
| trainer.train() | |
| trainer.save_model(self.cfg.train.output_dir) | |
| self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) | |
| def train_causal(self, train_paths: List[str], use_rag: bool = True): | |
| set_seed_all(self.cfg.train.seed) | |
| data = read_jsonl_files(train_paths) | |
| train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) | |
| rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None | |
| if rag: | |
| rag.init() | |
| ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, rag) | |
| ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, None) | |
| args = self._args_common(is_seq2seq=False) | |
| trainer = Trainer( | |
| model=self.loader.model, | |
| args=args, | |
| train_dataset=ds_tr, | |
| eval_dataset=ds_va, | |
| tokenizer=self.loader.tokenizer, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], | |
| compute_metrics=build_metrics_fn("causal", self.loader.tokenizer) | |
| ) | |
| trainer.train() | |
| trainer.save_model(self.cfg.train.output_dir) | |
| self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) | |
| # ========================== | |
| # App (Gradio 5) | |
| # ========================== | |
| class LegalApp: | |
| def __init__(self, scfg: Optional[SystemConfig] = None): | |
| self.scfg = scfg or SystemConfig() | |
| self.rag = LegalRAG(self.scfg.rag) | |
| self.loader: Optional[ModelLoader] = None | |
| self.gen: Optional[Generator] = None | |
| # --- helpers --- | |
| def _file_paths(self, files: List[gr.File]) -> List[str]: | |
| paths = [] | |
| for f in (files or []): | |
| p = getattr(f, "name", None) or getattr(f, "path", None) | |
| if p: | |
| paths.append(p) | |
| return paths | |
| # --- core actions --- | |
| def load(self, model_name: str, arch: str, use_rag: bool, persist_dir: str, collection: str, top_k: int, threshold: float): | |
| # configure | |
| self.scfg.model.model_name = model_name | |
| self.scfg.model.architecture = arch | |
| self.scfg.rag.persist_dir = persist_dir | |
| self.scfg.rag.collection = collection | |
| self.scfg.rag.top_k = int(top_k) | |
| self.scfg.rag.similarity_threshold = float(threshold) | |
| self.scfg.rag.enable = bool(use_rag) | |
| # load model | |
| self.loader = ModelLoader(self.scfg.model).load() | |
| self.gen = Generator(self.loader, self.scfg.model) | |
| # load rag | |
| msg_rag = "RAG غیرفعال" | |
| if use_rag: | |
| try: | |
| self.rag = LegalRAG(self.scfg.rag) | |
| self.rag.init() | |
| msg_rag = "RAG آماده است" | |
| except Exception as e: | |
| msg_rag = f"RAG خطا: {e}" | |
| return f"مدل بارگذاری شد: {model_name} ({arch})\n{msg_rag}" | |
| def build_index(self, laws_file: gr.File, id_key: str, text_key: str): | |
| if not self.scfg.rag.enable: | |
| return "RAG غیرفعال است." | |
| try: | |
| self.rag.init() | |
| p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None) | |
| if not p: | |
| return "فایل قوانین معتبر نیست." | |
| res = self.rag.index_jsonl(p, id_key=id_key, text_key=text_key) | |
| return res | |
| except Exception as e: | |
| return f"خطا در ایندکس: {e}" | |
| def answer(self, question: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float, num_beams: int): | |
| if not question.strip(): | |
| return "لطفاً سوال خود را وارد کنید.", "" | |
| if not self.gen: | |
| return "ابتدا مدل/RAG را بارگذاری کنید.", "" | |
| # runtime params | |
| self.scfg.model.max_new_tokens = int(max_new_tokens) | |
| self.scfg.model.temperature = float(temperature) | |
| self.scfg.model.top_p = float(top_p) | |
| self.scfg.model.num_beams = int(num_beams) | |
| arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else [] | |
| ctx = self.rag.build_context(arts) if arts else "" | |
| ans = self.gen.generate(question, ctx) | |
| refs = "" | |
| if arts: | |
| refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts]) | |
| return ans, refs | |
| def train(self, model_name: str, arch: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float, | |
| wd: float, warmup: float, report_to: str, progress=gr.Progress(track_tqdm=True)): | |
| progress(0.0, desc="راهاندازی") | |
| self.scfg.model.model_name = model_name | |
| self.scfg.model.architecture = arch | |
| self.scfg.train.epochs = int(epochs) | |
| self.scfg.train.batch_size = int(batch) | |
| self.scfg.train.lr = float(lr) | |
| self.scfg.train.weight_decay = float(wd) | |
| self.scfg.train.warmup_ratio = float(warmup) | |
| self.scfg.train.report_to = report_to | |
| progress(0.1, desc="بارگذاری مدل/توکنایزر") | |
| self.loader = ModelLoader(self.scfg.model).load() | |
| paths = self._file_paths(files) | |
| if not paths: | |
| return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده." | |
| tm = TrainerManager(self.scfg, self.loader) | |
| set_seed_all(self.scfg.train.seed) | |
| progress(0.3, desc="آمادهسازی دیتاستها و RAG") | |
| if arch == "seq2seq": | |
| tm.train_seq2seq(paths, use_rag=use_rag) | |
| else: | |
| tm.train_causal(paths, use_rag=use_rag) | |
| progress(0.95, desc="ذخیرهٔ آرتیفکتها") | |
| return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد." | |
| # --- UI --- | |
| def build_ui(self): | |
| log_deps() | |
| default_models = { | |
| "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"), | |
| "Seq2Seq (t5-fa-base)": ("HooshvareLab/t5-fa-base", "seq2seq"), | |
| "Seq2Seq (flan-t5-base)": ("google/flan-t5-base", "seq2seq"), | |
| "Causal (Mistral-7B Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"), | |
| } | |
| with gr.Blocks(title="ماحون — مشاور حقوقی هوشمند", theme=gr.themes.Soft(primary_hue="green", secondary_hue="gray")) as app: | |
| gr.HTML(""" | |
| <div style='text-align:center;padding:18px'> | |
| <h1 style='margin-bottom:4px'>ماحون — Ultimate Legal AI</h1> | |
| <p style='color:#666'>RAG • Seq2Seq/Causal • Training • Metrics</p> | |
| </div> | |
| """) | |
| with gr.Tab("مشاوره"): | |
| with gr.Row(): | |
| model_dd = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل") | |
| gr.Markdown("**راهنما:** Seq2Seq برای پاسخهای ساختاریافته؛ Causal برای مکالمه طبیعیتر.") | |
| with gr.Row(): | |
| use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟") | |
| persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر ChromaDB") | |
| collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن") | |
| with gr.Row(): | |
| top_k = gr.Slider(1, 15, value=self.scfg.rag.top_k, step=1, label="Top-K") | |
| threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="آستانه شباهت") | |
| load_btn = gr.Button("بارگذاری مدل/RAG", variant="primary") | |
| status = gr.Textbox(label="وضعیت", interactive=False) | |
| with gr.Accordion("ساخت ایندکس قوانین (اختیاری)", open=False): | |
| laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"]) | |
| id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده") | |
| text_key = gr.Textbox(value="text", label="کلید متن ماده") | |
| index_btn = gr.Button("ایندکسسازی قوانین") | |
| index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False) | |
| with gr.Accordion("پارامترهای تولید", open=False): | |
| max_new_tokens = gr.Slider(64, 1024, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p") | |
| num_beams = gr.Slider(1, 8, value=self.scfg.model.num_beams, step=1, label="num_beams (Seq2Seq)") | |
| question = gr.Textbox(lines=3, label="سوال حقوقی") | |
| gr.Examples( | |
| examples=[ | |
| ["در صورت نقض قرارداد فروش، چه اقداماتی باید انجام دهم؟"], | |
| ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"], | |
| ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"], | |
| ["فرآیند طرح دعوای مطالبه مهریه چگونه است؟"], | |
| ], | |
| inputs=question, label="نمونه پرسشها" | |
| ) | |
| ask_btn = gr.Button("پرسش", variant="primary") | |
| answer = gr.Markdown(label="پاسخ") | |
| refs = gr.Markdown(label="مواد قانونی مرتبط") | |
| with gr.Tab("آموزش"): | |
| gr.Markdown("فایلهای JSONL با کلیدهای `input` و `output` را بارگذاری کنید.") | |
| with gr.Row(): | |
| model_dd_train = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل") | |
| use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training") | |
| train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"]) | |
| with gr.Row(): | |
| epochs = gr.Slider(1, 8, value=self.scfg.train.epochs, step=1, label="epochs") | |
| batch = gr.Slider(1, 16, value=self.scfg.train.batch_size, step=1, label="batch per device") | |
| lr = gr.Number(value=self.scfg.train.lr, label="learning rate") | |
| with gr.Row(): | |
| wd = gr.Number(value=self.scfg.train.weight_decay, label="weight decay") | |
| warmup = gr.Slider(0.0, 0.2, value=self.scfg.train.warmup_ratio, step=0.01, label="warmup ratio") | |
| report_to = gr.Dropdown(choices=["none","wandb"], value=self.scfg.train.report_to, label="report_to") | |
| train_btn = gr.Button("شروع آموزش", variant="primary") | |
| train_status = gr.Textbox(label="وضعیت آموزش", interactive=False) | |
| # رویدادها | |
| def _resolve(choice: str) -> Tuple[str,str]: | |
| return default_models[choice] | |
| load_btn.click(lambda choice, rag, pdir, coll, k, th: self.load(*_resolve(choice), rag, pdir, coll, k, th), | |
| inputs=[model_dd, use_rag, persist_dir, collection, top_k, threshold], outputs=status) | |
| ask_btn.click(lambda q, rag, mnt, t, p, nb: self.answer(q, rag, mnt, t, p, nb), | |
| inputs=[question, use_rag, max_new_tokens, temperature, top_p, num_beams], | |
| outputs=[answer, refs]) | |
| index_btn.click(lambda f, ik, tk: self.build_index(f, ik, tk), | |
| inputs=[laws_file, id_key, text_key], outputs=index_status) | |
| train_btn.click( | |
| lambda choice, files, rag, e, b, l, _wd, _wu, _r: | |
| self.train(*_resolve(choice), files, rag, e, b, l, _wd, _wu, _r), | |
| inputs=[model_dd_train, train_files, use_rag_train, epochs, batch, lr, wd, warmup, report_to], | |
| outputs=train_status | |
| ) | |
| return app | |
| # ========================== | |
| # Entrypoint for HF Spaces | |
| # ========================== | |
| if __name__ == "__main__": | |
| app = LegalApp() | |
| ui = app.build_ui() | |
| # Gradio 5: بدون concurrency_count | |
| try: | |
| ui = ui.queue() # صف را فعال میکند، پارامتر ندارد | |
| except TypeError: | |
| # در صورت تفاوت نسخه، ساده لانچ کن | |
| pass | |
| ui.launch(server_name="0.0.0.0", server_port=7860) | |