hajimammad commited on
Commit
aa01b51
·
verified ·
1 Parent(s): 8282799

Upload app(7).py

Browse files
Files changed (1) hide show
  1. app(7).py +529 -0
app(7).py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Mahoun — Ultimate Legal AI (Single-File, Modular, Polished UI)
4
+ هستهٔ جدید ماحون با ادغام اجزای قبلی (RAG پیشرفته + Training برای Seq2Seq و Causal) و UI زیباتر.
5
+
6
+ ویژگی‌ها:
7
+ - Multi-Architecture: "seq2seq" (T5/MT5/FLAN-T5) و "causal" (Mistral/LLaMA).
8
+ - Loader/Generator یکپارچه + Prompt تطبیقی برحسب معماری.
9
+ - RAG پیشرفته با ChromaDB (پیکربندی مسیر، نام کالکشن، top_k، threshold، قطع متن).
10
+ - Training کامل برای هر دو معماری (Trainer, EarlyStopping, bf16/fp16, gradient_accumulation).
11
+ - Gradio UI بازطراحی‌شده (تم تمیز، کارت‌ها، مثال‌ها، وضعیت زنده، کنترل‌های تولید، انتخاب مدل/معماری/دیتابیس).
12
+
13
+ حداقل نیازمندی‌ها (requirements.txt):
14
+ transformers>=4.44.0
15
+ sentencepiece
16
+ accelerate
17
+ bitsandbytes
18
+ chromadb
19
+ sentence-transformers
20
+ scikit-learn
21
+ gradio
22
+ torch>=2.1
23
+ """
24
+ from __future__ import annotations
25
+ import os, json, gc, warnings, textwrap
26
+ from dataclasses import dataclass, field
27
+ from pathlib import Path
28
+ from typing import List, Dict, Optional, Tuple
29
+
30
+ import torch
31
+ from torch.utils.data import Dataset
32
+ from sklearn.model_selection import train_test_split
33
+
34
+ from transformers import (
35
+ AutoTokenizer,
36
+ AutoModelForSeq2SeqLM,
37
+ AutoModelForCausalLM,
38
+ Trainer,
39
+ TrainingArguments,
40
+ EarlyStoppingCallback,
41
+ DataCollatorForSeq2Seq,
42
+ )
43
+
44
+ import chromadb
45
+ from sentence_transformers import SentenceTransformer
46
+ import gradio as gr
47
+
48
+ warnings.filterwarnings("ignore")
49
+
50
+ # ==========================
51
+ # Config
52
+ # ==========================
53
+ @dataclass
54
+ class ModelConfig:
55
+ model_name: str = "google/mt5-base"
56
+ architecture: str = "seq2seq" # "seq2seq" | "causal"
57
+ max_input_length: int = 1024
58
+ max_target_length: int = 512
59
+ max_new_tokens: int = 384
60
+ temperature: float = 0.7
61
+ top_p: float = 0.9
62
+ num_beams: int = 4
63
+
64
+ @dataclass
65
+ class RAGConfig:
66
+ embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
67
+ persist_dir: str = "./chroma_db"
68
+ collection: str = "legal_articles"
69
+ top_k: int = 5
70
+ similarity_threshold: float = 0.66 # 0..1 (بزرگ‌تر=سخت‌گیرتر)
71
+ context_char_limit: int = 300 # حداکثر کاراکتر هر ماده در Context
72
+
73
+ @dataclass
74
+ class TrainConfig:
75
+ output_dir: str = "./mahoon_model"
76
+ seed: int = 42
77
+ test_size: float = 0.1
78
+ epochs: int = 2
79
+ batch_size: int = 2
80
+ grad_accum: int = 2
81
+ lr: float = 3e-5
82
+ use_bf16: bool = True
83
+
84
+ @dataclass
85
+ class SystemConfig:
86
+ model: ModelConfig = field(default_factory=ModelConfig)
87
+ rag: RAGConfig = field(default_factory=RAGConfig)
88
+ train: TrainConfig = field(default_factory=TrainConfig)
89
+
90
+ # ==========================
91
+ # Utils
92
+ # ==========================
93
+ def set_seed_all(seed: int = 42):
94
+ import random
95
+ random.seed(seed)
96
+ torch.manual_seed(seed)
97
+ torch.cuda.manual_seed_all(seed)
98
+
99
+
100
+ def read_jsonl_files(paths: List[str]) -> List[Dict]:
101
+ data: List[Dict] = []
102
+ for p in paths:
103
+ if not p:
104
+ continue
105
+ with open(p, 'r', encoding='utf-8') as f:
106
+ for line in f:
107
+ s = line.strip()
108
+ if not s:
109
+ continue
110
+ try:
111
+ obj = json.loads(s)
112
+ data.append(obj)
113
+ except json.JSONDecodeError:
114
+ continue
115
+ return data
116
+
117
+ # ==========================
118
+ # RAG
119
+ # ==========================
120
+ class LegalRAG:
121
+ def __init__(self, cfg: RAGConfig):
122
+ self.cfg = cfg
123
+ self.client = None
124
+ self.collection = None
125
+ self.embedder: Optional[SentenceTransformer] = None
126
+
127
+ def init(self):
128
+ Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
129
+ self.client = chromadb.PersistentClient(path=self.cfg.persist_dir)
130
+ # get_or_create برای سازگاری نسخه‌های مختلف chroma
131
+ try:
132
+ self.collection = self.client.get_or_create_collection(self.cfg.collection)
133
+ except Exception:
134
+ try:
135
+ self.collection = self.client.get_collection(self.cfg.collection)
136
+ except Exception:
137
+ self.collection = self.client.create_collection(self.cfg.collection)
138
+ self.embedder = SentenceTransformer(self.cfg.embedding_model)
139
+
140
+ def retrieve(self, query: str) -> List[Dict]:
141
+ if not self.collection:
142
+ return []
143
+ try:
144
+ res = self.collection.query(
145
+ query_texts=[query],
146
+ n_results=self.cfg.top_k,
147
+ include=["documents","metadatas","distances"],
148
+ )
149
+ out = []
150
+ for i,(doc, meta, dist) in enumerate(zip(res.get('documents',[['']])[0], res.get('metadatas',[['']])[0], res.get('distances',[[1.0]])[0])):
151
+ sim = 1 - float(dist)
152
+ if sim >= self.cfg.similarity_threshold:
153
+ out.append({
154
+ "article_id": (meta or {}).get("article_id", f"unk_{i}"),
155
+ "text": doc,
156
+ "similarity": sim,
157
+ })
158
+ return out
159
+ except Exception:
160
+ return []
161
+
162
+ def build_context(self, arts: List[Dict]) -> str:
163
+ if not arts:
164
+ return ""
165
+ bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts]
166
+ return "مواد مرتبط:\n" + "\n".join(bullets)
167
+
168
+ # ==========================
169
+ # Loader + Generator
170
+ # ==========================
171
+ class ModelLoader:
172
+ def __init__(self, mcfg: ModelConfig):
173
+ self.cfg = mcfg
174
+ self.tokenizer = None
175
+ self.model = None
176
+
177
+ def load(self):
178
+ self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
179
+ dtype = torch.bfloat16 if torch.cuda.is_available() else None
180
+ if self.cfg.architecture == "seq2seq":
181
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
182
+ self.cfg.model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype
183
+ )
184
+ elif self.cfg.architecture == "causal":
185
+ self.model = AutoModelForCausalLM.from_pretrained(
186
+ self.cfg.model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype
187
+ )
188
+ if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
189
+ self.tokenizer.pad_token = self.tokenizer.eos_token
190
+ else:
191
+ raise ValueError("Unsupported architecture")
192
+ return self
193
+
194
+ class Generator:
195
+ def __init__(self, loader: ModelLoader, mcfg: ModelConfig):
196
+ self.tk = loader.tokenizer
197
+ self.model = loader.model
198
+ self.cfg = mcfg
199
+
200
+ def generate(self, question: str, context: str = "") -> str:
201
+ if self.cfg.architecture == "seq2seq":
202
+ inp = f"{context}\nسوال: {question}" if context else f"سوال: {question}"
203
+ enc = self.tk(inp, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
204
+ enc = {k: v.to(self.model.device) for k,v in enc.items()}
205
+ out = self.model.generate(
206
+ **enc,
207
+ max_length=self.cfg.max_target_length,
208
+ num_beams=self.cfg.num_beams,
209
+ early_stopping=True,
210
+ )
211
+ else: # causal
212
+ prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
213
+ enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
214
+ enc = {k: v.to(self.model.device) for k,v in enc.items()}
215
+ out = self.model.generate(
216
+ **enc,
217
+ max_new_tokens=self.cfg.max_new_tokens,
218
+ do_sample=True,
219
+ temperature=self.cfg.temperature,
220
+ top_p=self.cfg.top_p,
221
+ pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
222
+ )
223
+ return self.tk.decode(out[0], skip_special_tokens=True)
224
+
225
+ # ==========================
226
+ # Datasets
227
+ # ==========================
228
+ class Seq2SeqJSONLDataset(Dataset):
229
+ def __init__(self, data: List[Dict], tokenizer, max_inp: int, max_tgt: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10):
230
+ self.tk = tokenizer
231
+ self.max_inp = max_inp
232
+ self.max_tgt = max_tgt
233
+ self.items = []
234
+ for i, ex in enumerate(data):
235
+ src = str(ex.get("input", "")).strip()
236
+ tgt = str(ex.get("output", "")).strip()
237
+ if not src or not tgt:
238
+ continue
239
+ inp = src
240
+ if rag and i % enhance_every == 0:
241
+ arts = rag.retrieve(src)
242
+ ctx = rag.build_context(arts)
243
+ if ctx:
244
+ inp = f"<CONTEXT>{ctx}</CONTEXT>\n<QUESTION>{src}</QUESTION>"
245
+ self.items.append((inp, tgt))
246
+
247
+ def __len__(self):
248
+ return len(self.items)
249
+
250
+ def __getitem__(self, idx):
251
+ inp, tgt = self.items[idx]
252
+ model_inputs = self.tk(inp, max_length=self.max_inp, padding="max_length", truncation=True)
253
+ labels = self.tk(text_target=tgt, max_length=self.max_tgt, padding="max_length", truncation=True)
254
+ model_inputs["labels"] = labels["input_ids"]
255
+ return model_inputs
256
+
257
+ class CausalJSONLDataset(Dataset):
258
+ def __init__(self, data: List[Dict], tokenizer, max_inp: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10):
259
+ self.tk = tokenizer
260
+ self.max_inp = max_inp
261
+ self.items = []
262
+ for i, ex in enumerate(data):
263
+ src = str(ex.get("input", "")).strip()
264
+ tgt = str(ex.get("output", "")).strip()
265
+ if not src or not tgt:
266
+ continue
267
+ ctx = ""
268
+ if rag and i % enhance_every == 0:
269
+ arts = rag.retrieve(src)
270
+ ctx = rag.build_context(arts)
271
+ text = f"{ctx}\nسوال: {src}\nپاسخ: {tgt}" if ctx else f"سوال: {src}\nپاسخ: {tgt}"
272
+ self.items.append(text)
273
+
274
+ def __len__(self):
275
+ return len(self.items)
276
+
277
+ def __getitem__(self, idx):
278
+ text = self.items[idx]
279
+ enc = self.tk(text, max_length=self.max_inp, padding="max_length", truncation=True)
280
+ input_ids = torch.tensor(enc["input_ids"])
281
+ return {"input_ids": input_ids, "attention_mask": torch.tensor(enc["attention_mask"]), "labels": input_ids.clone()}
282
+
283
+ # ==========================
284
+ # Trainer Manager
285
+ # ==========================
286
+ class TrainerManager:
287
+ def __init__(self, syscfg: SystemConfig, loader: ModelLoader):
288
+ self.cfg = syscfg
289
+ self.loader = loader
290
+
291
+ def train_seq2seq(self, train_paths: List[str], use_rag: bool = True):
292
+ set_seed_all(self.cfg.train.seed)
293
+ data = read_jsonl_files(train_paths)
294
+ train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
295
+ rag = LegalRAG(self.cfg.rag) if use_rag else None
296
+ if rag:
297
+ rag.init()
298
+ ds_tr = Seq2SeqJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, rag)
299
+ ds_va = Seq2SeqJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, None)
300
+ collator = DataCollatorForSeq2Seq(tokenizer=self.loader.tokenizer, model=self.loader.model)
301
+ fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16)
302
+ bf16_ok = torch.cuda.is_available() and self.cfg.train.use_bf16
303
+ args = TrainingArguments(
304
+ output_dir=self.cfg.train.output_dir,
305
+ num_train_epochs=self.cfg.train.epochs,
306
+ learning_rate=self.cfg.train.lr,
307
+ per_device_train_batch_size=self.cfg.train.batch_size,
308
+ per_device_eval_batch_size=self.cfg.train.batch_size,
309
+ gradient_accumulation_steps=self.cfg.train.grad_accum,
310
+ warmup_ratio=0.05,
311
+ weight_decay=0.01,
312
+ evaluation_strategy="epoch",
313
+ save_strategy="epoch",
314
+ save_total_limit=2,
315
+ load_best_model_at_end=True,
316
+ metric_for_best_model="eval_loss",
317
+ predict_with_generate=True,
318
+ generation_max_length=self.cfg.model.max_target_length,
319
+ generation_num_beams=self.cfg.model.num_beams,
320
+ logging_steps=50,
321
+ report_to="none",
322
+ fp16=fp16_ok,
323
+ bf16=bf16_ok,
324
+ )
325
+ trainer = Trainer(
326
+ model=self.loader.model,
327
+ args=args,
328
+ train_dataset=ds_tr,
329
+ eval_dataset=ds_va,
330
+ data_collator=collator,
331
+ tokenizer=self.loader.tokenizer,
332
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
333
+ )
334
+ trainer.train()
335
+ trainer.save_model(self.cfg.train.output_dir)
336
+ self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
337
+
338
+ def train_causal(self, train_paths: List[str], use_rag: bool = True):
339
+ set_seed_all(self.cfg.train.seed)
340
+ data = read_jsonl_files(train_paths)
341
+ train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
342
+ rag = LegalRAG(self.cfg.rag) if use_rag else None
343
+ if rag:
344
+ rag.init()
345
+ ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, rag)
346
+ ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, None)
347
+ fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16)
348
+ bf16_ok = torch.cuda.is_available() and self.cfg.train.use_bf16
349
+ args = TrainingArguments(
350
+ output_dir=self.cfg.train.output_dir,
351
+ num_train_epochs=self.cfg.train.epochs,
352
+ learning_rate=self.cfg.train.lr,
353
+ per_device_train_batch_size=self.cfg.train.batch_size,
354
+ per_device_eval_batch_size=self.cfg.train.batch_size,
355
+ gradient_accumulation_steps=self.cfg.train.grad_accum,
356
+ warmup_ratio=0.05,
357
+ weight_decay=0.01,
358
+ evaluation_strategy="epoch",
359
+ save_strategy="epoch",
360
+ save_total_limit=2,
361
+ load_best_model_at_end=True,
362
+ metric_for_best_model="eval_loss",
363
+ logging_steps=50,
364
+ report_to="none",
365
+ fp16=fp16_ok,
366
+ bf16=bf16_ok,
367
+ )
368
+ trainer = Trainer(
369
+ model=self.loader.model,
370
+ args=args,
371
+ train_dataset=ds_tr,
372
+ eval_dataset=ds_va,
373
+ tokenizer=self.loader.tokenizer,
374
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
375
+ )
376
+ trainer.train()
377
+ trainer.save_model(self.cfg.train.output_dir)
378
+ self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
379
+
380
+ # ==========================
381
+ # App (Gradio)
382
+ # ==========================
383
+ class LegalApp:
384
+ def __init__(self, scfg: Optional[SystemConfig] = None):
385
+ self.scfg = scfg or SystemConfig()
386
+ self.rag = LegalRAG(self.scfg.rag)
387
+ self.loader: Optional[ModelLoader] = None
388
+ self.gen: Optional[Generator] = None
389
+
390
+ # --- core actions ---
391
+ def load(self, model_name: str, arch: str, use_rag: bool, persist_dir: str, collection: str, top_k: int, threshold: float):
392
+ # configure
393
+ self.scfg.model.model_name = model_name
394
+ self.scfg.model.architecture = arch
395
+ self.scfg.rag.persist_dir = persist_dir
396
+ self.scfg.rag.collection = collection
397
+ self.scfg.rag.top_k = int(top_k)
398
+ self.scfg.rag.similarity_threshold = float(threshold)
399
+ # load model
400
+ self.loader = ModelLoader(self.scfg.model).load()
401
+ self.gen = Generator(self.loader, self.scfg.model)
402
+ # load rag
403
+ msg_rag = "RAG غیر فعال"
404
+ if use_rag:
405
+ try:
406
+ self.rag = LegalRAG(self.scfg.rag)
407
+ self.rag.init()
408
+ msg_rag = "RAG آماده است"
409
+ except Exception as e:
410
+ msg_rag = f"RAG خطا: {e}"
411
+ return f"مدل بارگذاری شد: {model_name} ({arch})\n{msg_rag}"
412
+
413
+ def answer(self, question: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float, num_beams: int):
414
+ if not question.strip():
415
+ return "لطفاً سوال خود را وارد کنید.", ""
416
+ if not self.gen:
417
+ return "ابتدا مدل/RAG را بارگذاری کنید.", ""
418
+ # update runtime params
419
+ self.scfg.model.max_new_tokens = int(max_new_tokens)
420
+ self.scfg.model.temperature = float(temperature)
421
+ self.scfg.model.top_p = float(top_p)
422
+ self.scfg.model.num_beams = int(num_beams)
423
+ arts = self.rag.retrieve(question) if (use_rag and self.rag.collection) else []
424
+ ctx = self.rag.build_context(arts) if arts else ""
425
+ ans = self.gen.generate(question, ctx)
426
+ refs = ""
427
+ if arts:
428
+ refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts])
429
+ return ans, refs
430
+
431
+ def train(self, model_name: str, arch: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float):
432
+ self.scfg.model.model_name = model_name
433
+ self.scfg.model.architecture = arch
434
+ self.scfg.train.epochs = int(epochs)
435
+ self.scfg.train.batch_size = int(batch)
436
+ self.scfg.train.lr = float(lr)
437
+ # ensure loader
438
+ self.loader = ModelLoader(self.scfg.model).load()
439
+ # train
440
+ paths = [f.name for f in files] if files else []
441
+ tm = TrainerManager(self.scfg, self.loader)
442
+ if arch == "seq2seq":
443
+ tm.train_seq2seq(paths, use_rag=use_rag)
444
+ else:
445
+ tm.train_causal(paths, use_rag=use_rag)
446
+ return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
447
+
448
+ # --- UI ---
449
+ def build_ui(self):
450
+ default_models = {
451
+ "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
452
+ "Seq2Seq (t5-fa-base)": ("HooshvareLab/t5-fa-base", "seq2seq"),
453
+ "Seq2Seq (flan-t5-base)": ("google/flan-t5-base", "seq2seq"),
454
+ "Causal (Mistral-7B Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
455
+ }
456
+ with gr.Blocks(title="ماحون — مشاور حقوقی هوشمند", theme=gr.themes.Soft(primary_hue="green", secondary_hue="gray")) as app:
457
+ gr.HTML("""
458
+ <div style='text-align:center;padding:18px'>
459
+ <h1 style='margin-bottom:4px'>ماحون — Ultimate Legal AI</h1>
460
+ <p style='color:#666'>RAG • Seq2Seq/Causal • Training • Polished UI</p>
461
+ </div>
462
+ """)
463
+
464
+ with gr.Tab("مشاوره"):
465
+ with gr.Row():
466
+ model_dd = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
467
+ arch_info = gr.Markdown("""**راهنما:** مدل‌های Seq2Seq (MT5/T5) برای پاسخ‌های ساختاریافته عالی‌اند؛ مدل‌های Causal (Mistral) برای مکالمه طبیعی‌ترند.""")
468
+ with gr.Row():
469
+ use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟")
470
+ persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر پایگاه ChromaDB")
471
+ collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن")
472
+ with gr.Row():
473
+ top_k = gr.Slider(1, 10, value=self.scfg.rag.top_k, step=1, label="Top‑K")
474
+ threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="حد آستانه شباهت")
475
+ load_btn = gr.Button("بارگذاری مدل/RAG", variant="primary")
476
+ status = gr.Textbox(label="وضعیت", interactive=False)
477
+
478
+ with gr.Accordion("پارامترهای تولید", open=False):
479
+ max_new_tokens = gr.Slider(64, 1024, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens")
480
+ temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature")
481
+ top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p")
482
+ num_beams = gr.Slider(1, 8, value=self.scfg.model.num_beams, step=1, label="num_beams (Seq2Seq)")
483
+
484
+ question = gr.Textbox(lines=3, label="سوال حقوقی")
485
+ examples = gr.Examples([
486
+ ["در صورت نقض قرارداد فروش، چه اقداماتی باید انجام دهم؟"],
487
+ ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"],
488
+ ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"],
489
+ ["فرآیند طرح دعوای مطالبه مهریه چگونه است؟"],
490
+ ], inputs=question, label="نمونه پرسش‌ها")
491
+ ask_btn = gr.Button("پرسش", variant="primary")
492
+ answer = gr.Markdown(label="پاسخ")
493
+ refs = gr.Markdown(label="مواد قانونی مرتبط")
494
+
495
+ with gr.Tab("آموزش"):
496
+ gr.Markdown("برای آموزش، فایل‌های JSONL شامل کلیدهای `input` و `output` را بارگذاری کنید.")
497
+ with gr.Row():
498
+ model_dd_train = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
499
+ use_rag_train = gr.Checkbox(value=True, label="RAG‑enhanced Training")
500
+ train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
501
+ with gr.Row():
502
+ epochs = gr.Slider(1, 6, value=self.scfg.train.epochs, step=1, label="epochs")
503
+ batch = gr.Slider(1, 8, value=self.scfg.train.batch_size, step=1, label="batch per device")
504
+ lr = gr.Number(value=self.scfg.train.lr, label="learning rate")
505
+ train_btn = gr.Button("شروع آموزش", variant="primary")
506
+ train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
507
+
508
+ # Events
509
+ def _resolve(choice: str) -> Tuple[str,str]:
510
+ return default_models[choice]
511
+
512
+ load_btn.click(lambda choice, rag, pdir, coll, k, th: self.load(*_resolve(choice), rag, pdir, coll, k, th),
513
+ inputs=[model_dd, use_rag, persist_dir, collection, top_k, threshold], outputs=status)
514
+
515
+ ask_btn.click(lambda q, rag, mnt, t, p, nb: self.answer(q, rag, mnt, t, p, nb),
516
+ inputs=[question, use_rag, max_new_tokens, temperature, top_p, num_beams], outputs=[answer, refs])
517
+
518
+ train_btn.click(lambda choice, files, rag, e, b, l: self.train(*_resolve(choice), files, rag, e, b, l),
519
+ inputs=[model_dd_train, train_files, use_rag_train, epochs, batch, lr], outputs=train_status)
520
+ return app
521
+
522
+ # ==========================
523
+
524
+ # Entrypoint
525
+ # ==========================
526
+ if __name__ == "__main__":
527
+ app = LegalApp()
528
+ ui = app.build_ui()
529
+ ui.launch(share=True)