hajimammad commited on
Commit
f90a910
·
verified ·
1 Parent(s): 12e69cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -129
app.py CHANGED
@@ -1,36 +1,32 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Mahoun — Ultimate Legal AI (Single-File, Modular, Polished UI)
4
- هستهٔ جدید ماحون با ادغام اجزای قبلی (RAG پیشرفته + Training برای Seq2Seq و Causal) و UI زیباتر.
5
 
6
  ویژگی‌ها:
7
- - Multi-Architecture: "seq2seq" (T5/MT5/FLAN-T5) و "causal" (Mistral/LLaMA).
8
- - Loader/Generator یکپارچه + Prompt تطبیقی برحسب معماری.
9
- - RAG پیشرفته با ChromaDB (پیکربندی مسیر، نام کالکشن، top_k، threshold، قطع متن).
10
- - Training کامل برای هر دو معماری (Trainer, EarlyStopping, bf16/fp16, gradient_accumulation).
11
- - Gradio UI بازطراحی‌شده (تم تمیز، کارت‌ها، مثال‌ها، وضعیت زنده، کنترل‌های تولید، انتخاب مدل/معماری/دیتابیس).
12
-
13
- حداقل نیازمندی‌ها (requirements.txt):
14
- transformers>=4.44.0
15
- sentencepiece
16
- accelerate
17
- bitsandbytes
18
- chromadb
19
- sentence-transformers
20
- scikit-learn
21
- gradio
22
- torch>=2.1
23
  """
 
24
  from __future__ import annotations
25
- import os, json, gc, warnings, textwrap
26
  from dataclasses import dataclass, field
27
  from pathlib import Path
28
  from typing import List, Dict, Optional, Tuple
29
 
 
30
  import torch
31
  from torch.utils.data import Dataset
32
  from sklearn.model_selection import train_test_split
33
 
 
 
 
 
34
  from transformers import (
35
  AutoTokenizer,
36
  AutoModelForSeq2SeqLM,
@@ -41,9 +37,18 @@ from transformers import (
41
  DataCollatorForSeq2Seq,
42
  )
43
 
 
44
  import chromadb
45
  from sentence_transformers import SentenceTransformer
46
- import gradio as gr
 
 
 
 
 
 
 
 
47
 
48
  warnings.filterwarnings("ignore")
49
 
@@ -60,6 +65,7 @@ class ModelConfig:
60
  temperature: float = 0.7
61
  top_p: float = 0.9
62
  num_beams: int = 4
 
63
 
64
  @dataclass
65
  class RAGConfig:
@@ -67,19 +73,28 @@ class RAGConfig:
67
  persist_dir: str = "./chroma_db"
68
  collection: str = "legal_articles"
69
  top_k: int = 5
70
- similarity_threshold: float = 0.66 # 0..1 (بزرگ‌تر=سخت‌گیرتر)
71
- context_char_limit: int = 300 # حداکثر کاراکتر هر ماده در Context
 
72
 
73
  @dataclass
74
  class TrainConfig:
75
  output_dir: str = "./mahoon_model"
76
  seed: int = 42
77
  test_size: float = 0.1
78
- epochs: int = 2
79
  batch_size: int = 2
80
  grad_accum: int = 2
81
  lr: float = 3e-5
82
  use_bf16: bool = True
 
 
 
 
 
 
 
 
83
 
84
  @dataclass
85
  class SystemConfig:
@@ -93,26 +108,43 @@ class SystemConfig:
93
  def set_seed_all(seed: int = 42):
94
  import random
95
  random.seed(seed)
 
96
  torch.manual_seed(seed)
97
- torch.cuda.manual_seed_all(seed)
98
-
99
-
100
- def read_jsonl_files(paths: List[str]) -> List[Dict]:
101
- data: List[Dict] = []
102
- for p in paths:
103
- if not p:
104
- continue
105
- with open(p, 'r', encoding='utf-8') as f:
106
- for line in f:
107
- s = line.strip()
108
- if not s:
109
- continue
110
- try:
111
- obj = json.loads(s)
112
- data.append(obj)
113
- except json.JSONDecodeError:
114
- continue
115
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # ==========================
118
  # RAG
@@ -127,7 +159,6 @@ class LegalRAG:
127
  def init(self):
128
  Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
129
  self.client = chromadb.PersistentClient(path=self.cfg.persist_dir)
130
- # get_or_create برای سازگاری نسخه‌های مختلف chroma
131
  try:
132
  self.collection = self.client.get_or_create_collection(self.cfg.collection)
133
  except Exception:
@@ -137,6 +168,32 @@ class LegalRAG:
137
  self.collection = self.client.create_collection(self.cfg.collection)
138
  self.embedder = SentenceTransformer(self.cfg.embedding_model)
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  def retrieve(self, query: str) -> List[Dict]:
141
  if not self.collection:
142
  return []
@@ -147,8 +204,11 @@ class LegalRAG:
147
  include=["documents","metadatas","distances"],
148
  )
149
  out = []
150
- for i,(doc, meta, dist) in enumerate(zip(res.get('documents',[['']])[0], res.get('metadatas',[['']])[0], res.get('distances',[[1.0]])[0])):
151
- sim = 1 - float(dist)
 
 
 
152
  if sim >= self.cfg.similarity_threshold:
153
  out.append({
154
  "article_id": (meta or {}).get("article_id", f"unk_{i}"),
@@ -176,19 +236,27 @@ class ModelLoader:
176
 
177
  def load(self):
178
  self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
179
- dtype = torch.bfloat16 if torch.cuda.is_available() else None
 
 
 
 
 
 
180
  if self.cfg.architecture == "seq2seq":
181
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
182
- self.cfg.model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype
183
- )
184
  elif self.cfg.architecture == "causal":
185
- self.model = AutoModelForCausalLM.from_pretrained(
186
- self.cfg.model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype
187
- )
188
- if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'):
189
  self.tokenizer.pad_token = self.tokenizer.eos_token
190
  else:
191
  raise ValueError("Unsupported architecture")
 
 
 
 
 
 
192
  return self
193
 
194
  class Generator:
@@ -208,7 +276,7 @@ class Generator:
208
  num_beams=self.cfg.num_beams,
209
  early_stopping=True,
210
  )
211
- else: # causal
212
  prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
213
  enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
214
  enc = {k: v.to(self.model.device) for k,v in enc.items()}
@@ -278,50 +346,126 @@ class CausalJSONLDataset(Dataset):
278
  text = self.items[idx]
279
  enc = self.tk(text, max_length=self.max_inp, padding="max_length", truncation=True)
280
  input_ids = torch.tensor(enc["input_ids"])
281
- return {"input_ids": input_ids, "attention_mask": torch.tensor(enc["attention_mask"]), "labels": input_ids.clone()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  # ==========================
284
  # Trainer Manager
285
  # ==========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  class TrainerManager:
287
  def __init__(self, syscfg: SystemConfig, loader: ModelLoader):
288
  self.cfg = syscfg
289
  self.loader = loader
290
 
291
- def train_seq2seq(self, train_paths: List[str], use_rag: bool = True):
292
- set_seed_all(self.cfg.train.seed)
293
- data = read_jsonl_files(train_paths)
294
- train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
295
- rag = LegalRAG(self.cfg.rag) if use_rag else None
296
- if rag:
297
- rag.init()
298
- ds_tr = Seq2SeqJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, rag)
299
- ds_va = Seq2SeqJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, None)
300
- collator = DataCollatorForSeq2Seq(tokenizer=self.loader.tokenizer, model=self.loader.model)
301
  fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16)
302
- bf16_ok = torch.cuda.is_available() and self.cfg.train.use_bf16
303
- args = TrainingArguments(
 
304
  output_dir=self.cfg.train.output_dir,
305
  num_train_epochs=self.cfg.train.epochs,
306
  learning_rate=self.cfg.train.lr,
307
  per_device_train_batch_size=self.cfg.train.batch_size,
308
  per_device_eval_batch_size=self.cfg.train.batch_size,
309
  gradient_accumulation_steps=self.cfg.train.grad_accum,
310
- warmup_ratio=0.05,
311
- weight_decay=0.01,
312
- evaluation_strategy="epoch",
313
- save_strategy="epoch",
314
- save_total_limit=2,
315
  load_best_model_at_end=True,
316
  metric_for_best_model="eval_loss",
317
- predict_with_generate=True,
318
- generation_max_length=self.cfg.model.max_target_length,
319
- generation_num_beams=self.cfg.model.num_beams,
320
- logging_steps=50,
321
- report_to="none",
322
  fp16=fp16_ok,
323
  bf16=bf16_ok,
 
 
 
 
 
 
324
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  trainer = Trainer(
326
  model=self.loader.model,
327
  args=args,
@@ -330,6 +474,7 @@ class TrainerManager:
330
  data_collator=collator,
331
  tokenizer=self.loader.tokenizer,
332
  callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
 
333
  )
334
  trainer.train()
335
  trainer.save_model(self.cfg.train.output_dir)
@@ -339,32 +484,15 @@ class TrainerManager:
339
  set_seed_all(self.cfg.train.seed)
340
  data = read_jsonl_files(train_paths)
341
  train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
342
- rag = LegalRAG(self.cfg.rag) if use_rag else None
 
343
  if rag:
344
  rag.init()
 
345
  ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, rag)
346
  ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, None)
347
- fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16)
348
- bf16_ok = torch.cuda.is_available() and self.cfg.train.use_bf16
349
- args = TrainingArguments(
350
- output_dir=self.cfg.train.output_dir,
351
- num_train_epochs=self.cfg.train.epochs,
352
- learning_rate=self.cfg.train.lr,
353
- per_device_train_batch_size=self.cfg.train.batch_size,
354
- per_device_eval_batch_size=self.cfg.train.batch_size,
355
- gradient_accumulation_steps=self.cfg.train.grad_accum,
356
- warmup_ratio=0.05,
357
- weight_decay=0.01,
358
- evaluation_strategy="epoch",
359
- save_strategy="epoch",
360
- save_total_limit=2,
361
- load_best_model_at_end=True,
362
- metric_for_best_model="eval_loss",
363
- logging_steps=50,
364
- report_to="none",
365
- fp16=fp16_ok,
366
- bf16=bf16_ok,
367
- )
368
  trainer = Trainer(
369
  model=self.loader.model,
370
  args=args,
@@ -372,13 +500,14 @@ class TrainerManager:
372
  eval_dataset=ds_va,
373
  tokenizer=self.loader.tokenizer,
374
  callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
 
375
  )
376
  trainer.train()
377
  trainer.save_model(self.cfg.train.output_dir)
378
  self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
379
 
380
  # ==========================
381
- # App (Gradio)
382
  # ==========================
383
  class LegalApp:
384
  def __init__(self, scfg: Optional[SystemConfig] = None):
@@ -387,6 +516,15 @@ class LegalApp:
387
  self.loader: Optional[ModelLoader] = None
388
  self.gen: Optional[Generator] = None
389
 
 
 
 
 
 
 
 
 
 
390
  # --- core actions ---
391
  def load(self, model_name: str, arch: str, use_rag: bool, persist_dir: str, collection: str, top_k: int, threshold: float):
392
  # configure
@@ -396,11 +534,14 @@ class LegalApp:
396
  self.scfg.rag.collection = collection
397
  self.scfg.rag.top_k = int(top_k)
398
  self.scfg.rag.similarity_threshold = float(threshold)
 
 
399
  # load model
400
  self.loader = ModelLoader(self.scfg.model).load()
401
  self.gen = Generator(self.loader, self.scfg.model)
 
402
  # load rag
403
- msg_rag = "RAG غیر فعال"
404
  if use_rag:
405
  try:
406
  self.rag = LegalRAG(self.scfg.rag)
@@ -408,19 +549,33 @@ class LegalApp:
408
  msg_rag = "RAG آماده است"
409
  except Exception as e:
410
  msg_rag = f"RAG خطا: {e}"
 
411
  return f"مدل بارگذاری شد: {model_name} ({arch})\n{msg_rag}"
412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
  def answer(self, question: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float, num_beams: int):
414
  if not question.strip():
415
  return "لطفاً سوال خود را وارد کنید.", ""
416
  if not self.gen:
417
  return "ابتدا مدل/RAG را بارگذاری کنید.", ""
418
- # update runtime params
419
  self.scfg.model.max_new_tokens = int(max_new_tokens)
420
  self.scfg.model.temperature = float(temperature)
421
  self.scfg.model.top_p = float(top_p)
422
  self.scfg.model.num_beams = int(num_beams)
423
- arts = self.rag.retrieve(question) if (use_rag and self.rag.collection) else []
424
  ctx = self.rag.build_context(arts) if arts else ""
425
  ans = self.gen.generate(question, ctx)
426
  refs = ""
@@ -428,53 +583,99 @@ class LegalApp:
428
  refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts])
429
  return ans, refs
430
 
431
- def train(self, model_name: str, arch: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float):
 
 
432
  self.scfg.model.model_name = model_name
433
  self.scfg.model.architecture = arch
434
  self.scfg.train.epochs = int(epochs)
435
  self.scfg.train.batch_size = int(batch)
436
  self.scfg.train.lr = float(lr)
437
- # ensure loader
 
 
 
 
438
  self.loader = ModelLoader(self.scfg.model).load()
439
- # train
440
- paths = [f.name for f in files] if files else []
 
 
 
441
  tm = TrainerManager(self.scfg, self.loader)
 
 
 
442
  if arch == "seq2seq":
443
  tm.train_seq2seq(paths, use_rag=use_rag)
444
  else:
445
  tm.train_causal(paths, use_rag=use_rag)
 
 
446
  return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  # --- UI ---
449
  def build_ui(self):
 
 
450
  default_models = {
451
  "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
452
  "Seq2Seq (t5-fa-base)": ("HooshvareLab/t5-fa-base", "seq2seq"),
453
  "Seq2Seq (flan-t5-base)": ("google/flan-t5-base", "seq2seq"),
454
  "Causal (Mistral-7B Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
455
  }
 
456
  with gr.Blocks(title="ماحون — مشاور حقوقی هوشمند", theme=gr.themes.Soft(primary_hue="green", secondary_hue="gray")) as app:
457
  gr.HTML("""
458
  <div style='text-align:center;padding:18px'>
459
  <h1 style='margin-bottom:4px'>ماحون — Ultimate Legal AI</h1>
460
- <p style='color:#666'>RAG • Seq2Seq/Causal • Training • Polished UI</p>
461
  </div>
462
  """)
463
 
 
464
  with gr.Tab("مشاوره"):
465
  with gr.Row():
466
  model_dd = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
467
- arch_info = gr.Markdown("""**راهنما:** مدل‌های Seq2Seq (MT5/T5) برای پاسخ‌های ساختاریافته عالی‌اند؛ مدل‌های Causal (Mistral) برای مکالمه طبیعی‌ترند.""")
468
  with gr.Row():
469
  use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟")
470
- persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر پایگاه ChromaDB")
471
  collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن")
472
  with gr.Row():
473
- top_k = gr.Slider(1, 10, value=self.scfg.rag.top_k, step=1, label="TopK")
474
- threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="حد آستانه شباهت")
475
  load_btn = gr.Button("بارگذاری مدل/RAG", variant="primary")
476
  status = gr.Textbox(label="وضعیت", interactive=False)
477
 
 
 
 
 
 
 
 
478
  with gr.Accordion("پارامترهای تولید", open=False):
479
  max_new_tokens = gr.Slider(64, 1024, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens")
480
  temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature")
@@ -482,30 +683,56 @@ class LegalApp:
482
  num_beams = gr.Slider(1, 8, value=self.scfg.model.num_beams, step=1, label="num_beams (Seq2Seq)")
483
 
484
  question = gr.Textbox(lines=3, label="سوال حقوقی")
485
- examples = gr.Examples([
486
- ["در صورت نقض قرارداد فروش، چه اقداماتی باید انجام دهم؟"],
487
- ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"],
488
- ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"],
489
- ["فرآیند طرح دعوای مطالبه مهریه چگونه است؟"],
490
- ], inputs=question, label="نمونه پرسش‌ها")
 
 
 
491
  ask_btn = gr.Button("پرسش", variant="primary")
492
  answer = gr.Markdown(label="پاسخ")
493
  refs = gr.Markdown(label="مواد قانونی مرتبط")
494
 
 
495
  with gr.Tab("آموزش"):
496
- gr.Markdown("برای آموزش، فایل‌های JSONL شامل کلیدهای `input` و `output` را بارگذاری کنید.")
497
  with gr.Row():
498
  model_dd_train = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
499
- use_rag_train = gr.Checkbox(value=True, label="RAGenhanced Training")
500
  train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
501
  with gr.Row():
502
- epochs = gr.Slider(1, 6, value=self.scfg.train.epochs, step=1, label="epochs")
503
- batch = gr.Slider(1, 8, value=self.scfg.train.batch_size, step=1, label="batch per device")
504
  lr = gr.Number(value=self.scfg.train.lr, label="learning rate")
 
 
 
 
505
  train_btn = gr.Button("شروع آموزش", variant="primary")
506
  train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
507
 
508
- # Events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  def _resolve(choice: str) -> Tuple[str,str]:
510
  return default_models[choice]
511
 
@@ -513,17 +740,36 @@ class LegalApp:
513
  inputs=[model_dd, use_rag, persist_dir, collection, top_k, threshold], outputs=status)
514
 
515
  ask_btn.click(lambda q, rag, mnt, t, p, nb: self.answer(q, rag, mnt, t, p, nb),
516
- inputs=[question, use_rag, max_new_tokens, temperature, top_p, num_beams], outputs=[answer, refs])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
- train_btn.click(lambda choice, files, rag, e, b, l: self.train(*_resolve(choice), files, rag, e, b, l),
519
- inputs=[model_dd_train, train_files, use_rag_train, epochs, batch, lr], outputs=train_status)
520
  return app
521
 
522
  # ==========================
523
-
524
- # Entrypoint
525
  # ==========================
526
  if __name__ == "__main__":
527
  app = LegalApp()
528
  ui = app.build_ui()
529
- ui.launch(share=True)
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Mahoun — Legal AI (RAG + Training + Dataset Builder) for HF Spaces / Gradio 5.46.x
 
4
 
5
  ویژگی‌ها:
6
+ - RAG با ChromaDB + ایندکس‌سازی JSONL قوانین
7
+ - آموزش Seq2Seq/Causal با Trainer (ایمن و عقب‌سازگار با TrainingArguments)
8
+ - متریک‌ها: ROUGE-L (Seq2Seq) و F1 ساده (Causal)
9
+ - تب «ساخت دیتاست» (داخل اپ): تولید JSONL سازگار با {input, output}
10
+ - Progress به‌صورت DI (gr.Progress(track_tqdm=True))
11
+ - لانچ پایدار برای Gradio 5.46.x (بدون concurrency_count)
12
+
 
 
 
 
 
 
 
 
 
13
  """
14
+
15
  from __future__ import annotations
16
+ import os, sys, json, warnings
17
  from dataclasses import dataclass, field
18
  from pathlib import Path
19
  from typing import List, Dict, Optional, Tuple
20
 
21
+ import numpy as np
22
  import torch
23
  from torch.utils.data import Dataset
24
  from sklearn.model_selection import train_test_split
25
 
26
+ import gradio as gr
27
+ from packaging import version
28
+
29
+ import transformers as tf
30
  from transformers import (
31
  AutoTokenizer,
32
  AutoModelForSeq2SeqLM,
 
37
  DataCollatorForSeq2Seq,
38
  )
39
 
40
+ # RAG stack
41
  import chromadb
42
  from sentence_transformers import SentenceTransformer
43
+
44
+ # Optional metrics
45
+ try:
46
+ from evaluate import load as eval_load
47
+ except Exception:
48
+ eval_load = None
49
+
50
+ # Dataset builder (ماژول داخلی اپ)
51
+ from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
52
 
53
  warnings.filterwarnings("ignore")
54
 
 
65
  temperature: float = 0.7
66
  top_p: float = 0.9
67
  num_beams: int = 4
68
+ gradient_checkpointing: bool = True
69
 
70
  @dataclass
71
  class RAGConfig:
 
73
  persist_dir: str = "./chroma_db"
74
  collection: str = "legal_articles"
75
  top_k: int = 5
76
+ similarity_threshold: float = 0.66 # 0..1
77
+ context_char_limit: int = 300
78
+ enable: bool = True
79
 
80
  @dataclass
81
  class TrainConfig:
82
  output_dir: str = "./mahoon_model"
83
  seed: int = 42
84
  test_size: float = 0.1
85
+ epochs: int = 3
86
  batch_size: int = 2
87
  grad_accum: int = 2
88
  lr: float = 3e-5
89
  use_bf16: bool = True
90
+ weight_decay: float = 0.01
91
+ warmup_ratio: float = 0.05
92
+ logging_steps: int = 50
93
+ eval_strategy: str = "epoch" # "steps" | "epoch"
94
+ save_strategy: str = "epoch"
95
+ save_total_limit: int = 2
96
+ report_to: str = "none" # "none" | "wandb"
97
+ max_grad_norm: float = 1.0
98
 
99
  @dataclass
100
  class SystemConfig:
 
108
  def set_seed_all(seed: int = 42):
109
  import random
110
  random.seed(seed)
111
+ np.random.seed(seed)
112
  torch.manual_seed(seed)
113
+ if torch.cuda.is_available():
114
+ torch.cuda.manual_seed_all(seed)
115
+
116
+ def log_deps():
117
+ try:
118
+ import accelerate, datasets
119
+ print("[deps]",
120
+ f"python={sys.version.split()[0]}",
121
+ f"transformers={tf.__version__}",
122
+ f"accelerate={accelerate.__version__}",
123
+ f"datasets={datasets.__version__}",
124
+ f"gradio={gr.__version__}",
125
+ flush=True)
126
+ except Exception as e:
127
+ print("[deps] warn:", e, flush=True)
128
+
129
+ def bf16_supported():
130
+ return torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
131
+
132
+ def safe_training_args(**kwargs):
133
+ """
134
+ Wrapper برای سازگاری با نسخه‌های قدیمی‌تر Transformers (قبل از 4.4):
135
+ - evaluation_strategy -> evaluate_during_training
136
+ - حذف کلیدهای جدید که ممکن است ناشناخته باشند
137
+ """
138
+ tf_ver = version.parse(tf.__version__)
139
+ k = dict(kwargs)
140
+ if tf_ver < version.parse("4.4.0"):
141
+ eval_strat = k.pop("evaluation_strategy", None)
142
+ k["evaluate_during_training"] = bool(eval_strat and str(eval_strat).lower() != "no")
143
+ for rm in ["save_strategy","load_best_model_at_end","metric_for_best_model",
144
+ "greater_is_better","predict_with_generate","generation_max_length",
145
+ "generation_num_beams","report_to","max_grad_norm"]:
146
+ k.pop(rm, None)
147
+ return TrainingArguments(**k)
148
 
149
  # ==========================
150
  # RAG
 
159
  def init(self):
160
  Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
161
  self.client = chromadb.PersistentClient(path=self.cfg.persist_dir)
 
162
  try:
163
  self.collection = self.client.get_or_create_collection(self.cfg.collection)
164
  except Exception:
 
168
  self.collection = self.client.create_collection(self.cfg.collection)
169
  self.embedder = SentenceTransformer(self.cfg.embedding_model)
170
 
171
+ def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"):
172
+ """ایندکس‌سازی اولیه قوانین از JSONL: هر خط یک شیء {article_id, text, ...}."""
173
+ if not self.collection or not self.embedder:
174
+ self.init()
175
+ ids, docs, metas = [], [], []
176
+ with open(jsonl_path, "r", encoding="utf-8") as f:
177
+ for i, line in enumerate(f):
178
+ s = line.strip()
179
+ if not s:
180
+ continue
181
+ try:
182
+ obj = json.loads(s)
183
+ except:
184
+ continue
185
+ aid = str(obj.get(id_key, f"auto_{i}"))
186
+ txt = str(obj.get(text_key, "")).strip()
187
+ if not txt:
188
+ continue
189
+ ids.append(aid)
190
+ docs.append(txt)
191
+ metas.append({"article_id": aid})
192
+ if not ids:
193
+ return "هیچ سندی برای ایندکس پیدا نشد."
194
+ self.collection.upsert(ids=ids, documents=docs, metadatas=metas)
195
+ return f"✅ {len(ids)} سند قانونی ایندکس شد."
196
+
197
  def retrieve(self, query: str) -> List[Dict]:
198
  if not self.collection:
199
  return []
 
204
  include=["documents","metadatas","distances"],
205
  )
206
  out = []
207
+ docs = res.get("documents", [[]])[0]
208
+ metas = res.get("metadatas", [[]])[0]
209
+ dists = res.get("distances", [[1.0]])[0]
210
+ for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)):
211
+ sim = 1.0 - float(dist)
212
  if sim >= self.cfg.similarity_threshold:
213
  out.append({
214
  "article_id": (meta or {}).get("article_id", f"unk_{i}"),
 
236
 
237
  def load(self):
238
  self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
239
+ # dtype انتخاب هوشمند
240
+ use_bf16 = bf16_supported() and self.cfg.gradient_checkpointing
241
+ dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else None)
242
+ model_kwargs = {"torch_dtype": dtype}
243
+ if torch.cuda.is_available():
244
+ model_kwargs["device_map"] = "auto"
245
+
246
  if self.cfg.architecture == "seq2seq":
247
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.cfg.model_name, **model_kwargs)
 
 
248
  elif self.cfg.architecture == "causal":
249
+ self.model = AutoModelForCausalLM.from_pretrained(self.cfg.model_name, **model_kwargs)
250
+ if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"):
 
 
251
  self.tokenizer.pad_token = self.tokenizer.eos_token
252
  else:
253
  raise ValueError("Unsupported architecture")
254
+
255
+ if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"):
256
+ try:
257
+ self.model.gradient_checkpointing_enable()
258
+ except Exception:
259
+ pass
260
  return self
261
 
262
  class Generator:
 
276
  num_beams=self.cfg.num_beams,
277
  early_stopping=True,
278
  )
279
+ else:
280
  prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
281
  enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
282
  enc = {k: v.to(self.model.device) for k,v in enc.items()}
 
346
  text = self.items[idx]
347
  enc = self.tk(text, max_length=self.max_inp, padding="max_length", truncation=True)
348
  input_ids = torch.tensor(enc["input_ids"])
349
+ attn = torch.tensor(enc["attention_mask"])
350
+ labels = input_ids.clone()
351
+ labels[attn == 0] = -100 # padding mask for loss
352
+ return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}
353
+
354
+ # ==========================
355
+ # Metrics
356
+ # ==========================
357
+ def build_metrics_fn(arch: str, tokenizer):
358
+ rouge = eval_load("rouge") if eval_load else None
359
+
360
+ def _postprocess(preds):
361
+ if isinstance(preds, (list, tuple)):
362
+ return [p.strip() for p in preds]
363
+ return preds
364
+
365
+ def compute_metrics_seq2seq(eval_pred):
366
+ if rouge is None:
367
+ return {"rougeL": 0.0}
368
+ preds, labels = eval_pred
369
+ if isinstance(preds, tuple):
370
+ preds = preds[0]
371
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
372
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
373
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
374
+ decoded_preds = _postprocess(decoded_preds)
375
+ decoded_labels = _postprocess(decoded_labels)
376
+ r = rouge.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rougeL"])
377
+ return {"rougeL": float(r.get("rougeL", 0.0))}
378
+
379
+ def compute_metrics_causal(eval_pred):
380
+ preds, labels = eval_pred
381
+ if isinstance(preds, tuple):
382
+ preds = preds[0]
383
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
384
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
385
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
386
+ tp = fp = fn = 0
387
+ for p, g in zip(decoded_preds, decoded_labels):
388
+ p_set, g_set = set(p.split()), set(g.split())
389
+ tp += len(p_set & g_set)
390
+ fp += len(p_set - g_set)
391
+ fn += len(g_set - p_set)
392
+ precision = tp / (tp + fp + 1e-8)
393
+ recall = tp / (tp + fn + 1e-8)
394
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
395
+ return {"f1_simple": float(f1)}
396
+
397
+ return compute_metrics_seq2seq if arch == "seq2seq" else compute_metrics_causal
398
 
399
  # ==========================
400
  # Trainer Manager
401
  # ==========================
402
+ def read_jsonl_files(paths: List[str]) -> List[Dict]:
403
+ data: List[Dict] = []
404
+ for p in paths:
405
+ if not p:
406
+ continue
407
+ with open(p, 'r', encoding='utf-8') as f:
408
+ for line in f:
409
+ s = line.strip()
410
+ if not s:
411
+ continue
412
+ try:
413
+ obj = json.loads(s)
414
+ data.append(obj)
415
+ except json.JSONDecodeError:
416
+ continue
417
+ return data
418
+
419
  class TrainerManager:
420
  def __init__(self, syscfg: SystemConfig, loader: ModelLoader):
421
  self.cfg = syscfg
422
  self.loader = loader
423
 
424
+ def _args_common(self, is_seq2seq: bool):
 
 
 
 
 
 
 
 
 
425
  fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16)
426
+ bf16_ok = bf16_supported() and self.cfg.train.use_bf16
427
+
428
+ args = safe_training_args(
429
  output_dir=self.cfg.train.output_dir,
430
  num_train_epochs=self.cfg.train.epochs,
431
  learning_rate=self.cfg.train.lr,
432
  per_device_train_batch_size=self.cfg.train.batch_size,
433
  per_device_eval_batch_size=self.cfg.train.batch_size,
434
  gradient_accumulation_steps=self.cfg.train.grad_accum,
435
+ warmup_ratio=self.cfg.train.warmup_ratio,
436
+ weight_decay=self.cfg.train.weight_decay,
437
+ evaluation_strategy=self.cfg.train.eval_strategy,
438
+ save_strategy=self.cfg.train.save_strategy,
439
+ save_total_limit=self.cfg.train.save_total_limit,
440
  load_best_model_at_end=True,
441
  metric_for_best_model="eval_loss",
442
+ logging_steps=self.cfg.train.logging_steps,
443
+ report_to=([] if self.cfg.train.report_to == "none" else [self.cfg.train.report_to]),
 
 
 
444
  fp16=fp16_ok,
445
  bf16=bf16_ok,
446
+ max_grad_norm=self.cfg.train.max_grad_norm,
447
+ **({
448
+ "predict_with_generate": True,
449
+ "generation_max_length": self.cfg.model.max_target_length,
450
+ "generation_num_beams": self.cfg.model.num_beams
451
+ } if is_seq2seq else {})
452
  )
453
+ return args
454
+
455
+ def train_seq2seq(self, train_paths: List[str], use_rag: bool = True):
456
+ set_seed_all(self.cfg.train.seed)
457
+ data = read_jsonl_files(train_paths)
458
+ train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
459
+
460
+ rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None
461
+ if rag:
462
+ rag.init()
463
+
464
+ ds_tr = Seq2SeqJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, rag)
465
+ ds_va = Seq2SeqJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, None)
466
+ collator = DataCollatorForSeq2Seq(tokenizer=self.loader.tokenizer, model=self.loader.model)
467
+
468
+ args = self._args_common(is_seq2seq=True)
469
  trainer = Trainer(
470
  model=self.loader.model,
471
  args=args,
 
474
  data_collator=collator,
475
  tokenizer=self.loader.tokenizer,
476
  callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
477
+ compute_metrics=build_metrics_fn("seq2seq", self.loader.tokenizer)
478
  )
479
  trainer.train()
480
  trainer.save_model(self.cfg.train.output_dir)
 
484
  set_seed_all(self.cfg.train.seed)
485
  data = read_jsonl_files(train_paths)
486
  train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
487
+
488
+ rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None
489
  if rag:
490
  rag.init()
491
+
492
  ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, rag)
493
  ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, None)
494
+
495
+ args = self._args_common(is_seq2seq=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  trainer = Trainer(
497
  model=self.loader.model,
498
  args=args,
 
500
  eval_dataset=ds_va,
501
  tokenizer=self.loader.tokenizer,
502
  callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
503
+ compute_metrics=build_metrics_fn("causal", self.loader.tokenizer)
504
  )
505
  trainer.train()
506
  trainer.save_model(self.cfg.train.output_dir)
507
  self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
508
 
509
  # ==========================
510
+ # App (Gradio 5)
511
  # ==========================
512
  class LegalApp:
513
  def __init__(self, scfg: Optional[SystemConfig] = None):
 
516
  self.loader: Optional[ModelLoader] = None
517
  self.gen: Optional[Generator] = None
518
 
519
+ # --- helpers ---
520
+ def _file_paths(self, files: List[gr.File]) -> List[str]:
521
+ paths = []
522
+ for f in (files or []):
523
+ p = getattr(f, "name", None) or getattr(f, "path", None)
524
+ if p:
525
+ paths.append(p)
526
+ return paths
527
+
528
  # --- core actions ---
529
  def load(self, model_name: str, arch: str, use_rag: bool, persist_dir: str, collection: str, top_k: int, threshold: float):
530
  # configure
 
534
  self.scfg.rag.collection = collection
535
  self.scfg.rag.top_k = int(top_k)
536
  self.scfg.rag.similarity_threshold = float(threshold)
537
+ self.scfg.rag.enable = bool(use_rag)
538
+
539
  # load model
540
  self.loader = ModelLoader(self.scfg.model).load()
541
  self.gen = Generator(self.loader, self.scfg.model)
542
+
543
  # load rag
544
+ msg_rag = "RAG غیرفعال"
545
  if use_rag:
546
  try:
547
  self.rag = LegalRAG(self.scfg.rag)
 
549
  msg_rag = "RAG آماده است"
550
  except Exception as e:
551
  msg_rag = f"RAG خطا: {e}"
552
+
553
  return f"مدل بارگذاری شد: {model_name} ({arch})\n{msg_rag}"
554
 
555
+ def build_index(self, laws_file: gr.File, id_key: str, text_key: str):
556
+ if not self.scfg.rag.enable:
557
+ return "RAG غیرفعال است."
558
+ try:
559
+ self.rag.init()
560
+ p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None)
561
+ if not p:
562
+ return "فایل قوانین معتبر نیست."
563
+ res = self.rag.index_jsonl(p, id_key=id_key, text_key=text_key)
564
+ return res
565
+ except Exception as e:
566
+ return f"خطا در ایندکس: {e}"
567
+
568
  def answer(self, question: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float, num_beams: int):
569
  if not question.strip():
570
  return "لطفاً سوال خود را وارد کنید.", ""
571
  if not self.gen:
572
  return "ابتدا مدل/RAG را بارگذاری کنید.", ""
573
+ # runtime params
574
  self.scfg.model.max_new_tokens = int(max_new_tokens)
575
  self.scfg.model.temperature = float(temperature)
576
  self.scfg.model.top_p = float(top_p)
577
  self.scfg.model.num_beams = int(num_beams)
578
+ arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else []
579
  ctx = self.rag.build_context(arts) if arts else ""
580
  ans = self.gen.generate(question, ctx)
581
  refs = ""
 
583
  refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts])
584
  return ans, refs
585
 
586
+ def train(self, model_name: str, arch: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float,
587
+ wd: float, warmup: float, report_to: str, progress=gr.Progress(track_tqdm=True)):
588
+ progress(0.0, desc="راه‌اندازی")
589
  self.scfg.model.model_name = model_name
590
  self.scfg.model.architecture = arch
591
  self.scfg.train.epochs = int(epochs)
592
  self.scfg.train.batch_size = int(batch)
593
  self.scfg.train.lr = float(lr)
594
+ self.scfg.train.weight_decay = float(wd)
595
+ self.scfg.train.warmup_ratio = float(warmup)
596
+ self.scfg.train.report_to = report_to
597
+
598
+ progress(0.1, desc="بارگذاری مدل/توکنایزر")
599
  self.loader = ModelLoader(self.scfg.model).load()
600
+
601
+ paths = self._file_paths(files)
602
+ if not paths:
603
+ return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده."
604
+
605
  tm = TrainerManager(self.scfg, self.loader)
606
+ set_seed_all(self.scfg.train.seed)
607
+
608
+ progress(0.3, desc="آماده‌سازی دیتاست‌ها و RAG")
609
  if arch == "seq2seq":
610
  tm.train_seq2seq(paths, use_rag=use_rag)
611
  else:
612
  tm.train_causal(paths, use_rag=use_rag)
613
+
614
+ progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها")
615
  return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
616
 
617
+ # --- Dataset Builder (داخل اپ) ---
618
+ def build_dataset(self, raw_file, text_key: str, model_ckpt: str,
619
+ batch_size: int, max_samples: int | None):
620
+ path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
621
+ if not path:
622
+ return None, "⚠️ فایل ورودی معتبر نیست."
623
+ try:
624
+ data = load_json_or_jsonl(path)
625
+ if max_samples and int(max_samples) > 0:
626
+ data = data[:int(max_samples)]
627
+ gb = GoldenBuilder(model_name=model_ckpt)
628
+ rows = gb.build(data, text_key=text_key, batch_size=int(batch_size))
629
+ out_dir = "/tmp/mahoun_datasets"
630
+ Path(out_dir).mkdir(parents=True, exist_ok=True)
631
+ out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl"
632
+ save_jsonl(rows, out_path)
633
+ msg = f"✅ {len(rows)} رکورد تولید شد. فایل آمادهٔ دانلود است."
634
+ return out_path, msg
635
+ except Exception as e:
636
+ return None, f"❌ خطا در ساخت دیتاست: {e}"
637
+
638
  # --- UI ---
639
  def build_ui(self):
640
+ log_deps()
641
+
642
  default_models = {
643
  "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
644
  "Seq2Seq (t5-fa-base)": ("HooshvareLab/t5-fa-base", "seq2seq"),
645
  "Seq2Seq (flan-t5-base)": ("google/flan-t5-base", "seq2seq"),
646
  "Causal (Mistral-7B Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
647
  }
648
+
649
  with gr.Blocks(title="ماحون — مشاور حقوقی هوشمند", theme=gr.themes.Soft(primary_hue="green", secondary_hue="gray")) as app:
650
  gr.HTML("""
651
  <div style='text-align:center;padding:18px'>
652
  <h1 style='margin-bottom:4px'>ماحون — Ultimate Legal AI</h1>
653
+ <p style='color:#666'>RAG • Seq2Seq/Causal • Training • Dataset Builder</p>
654
  </div>
655
  """)
656
 
657
+ # تب مشاوره
658
  with gr.Tab("مشاوره"):
659
  with gr.Row():
660
  model_dd = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
661
+ gr.Markdown("**راهنما:** Seq2Seq برای پاسخ‌های ساختاریافته؛ Causal برای مکالمه طبیعی‌تر.")
662
  with gr.Row():
663
  use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟")
664
+ persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر ChromaDB")
665
  collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن")
666
  with gr.Row():
667
+ top_k = gr.Slider(1, 15, value=self.scfg.rag.top_k, step=1, label="Top-K")
668
+ threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="آستانه شباهت")
669
  load_btn = gr.Button("بارگذاری مدل/RAG", variant="primary")
670
  status = gr.Textbox(label="وضعیت", interactive=False)
671
 
672
+ with gr.Accordion("ساخت ایندکس قوانین (اختیاری)", open=False):
673
+ laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"])
674
+ id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده")
675
+ text_key = gr.Textbox(value="text", label="کلید متن ماده")
676
+ index_btn = gr.Button("ایندکس‌سازی قوانین")
677
+ index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False)
678
+
679
  with gr.Accordion("پارامترهای تولید", open=False):
680
  max_new_tokens = gr.Slider(64, 1024, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens")
681
  temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature")
 
683
  num_beams = gr.Slider(1, 8, value=self.scfg.model.num_beams, step=1, label="num_beams (Seq2Seq)")
684
 
685
  question = gr.Textbox(lines=3, label="سوال حقوقی")
686
+ gr.Examples(
687
+ examples=[
688
+ ["در صورت نقض قرارداد فروش، چه اقداماتی باید انجام دهم؟"],
689
+ ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"],
690
+ ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"],
691
+ ["فرآیند طرح دعوای مطالبه مهریه چگونه است؟"],
692
+ ],
693
+ inputs=question, label="نمونه پرسش‌ها"
694
+ )
695
  ask_btn = gr.Button("پرسش", variant="primary")
696
  answer = gr.Markdown(label="پاسخ")
697
  refs = gr.Markdown(label="مواد قانونی مرتبط")
698
 
699
+ # تب آموزش
700
  with gr.Tab("آموزش"):
701
+ gr.Markdown("فایل‌های JSONL با کلیدهای `input` و `output` را بارگذاری کنید.")
702
  with gr.Row():
703
  model_dd_train = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
704
+ use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training")
705
  train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
706
  with gr.Row():
707
+ epochs = gr.Slider(1, 8, value=self.scfg.train.epochs, step=1, label="epochs")
708
+ batch = gr.Slider(1, 16, value=self.scfg.train.batch_size, step=1, label="batch per device")
709
  lr = gr.Number(value=self.scfg.train.lr, label="learning rate")
710
+ with gr.Row():
711
+ wd = gr.Number(value=self.scfg.train.weight_decay, label="weight decay")
712
+ warmup = gr.Slider(0.0, 0.2, value=self.scfg.train.warmup_ratio, step=0.01, label="warmup ratio")
713
+ report_to = gr.Dropdown(choices=["none","wandb"], value=self.scfg.train.report_to, label="report_to")
714
  train_btn = gr.Button("شروع آموزش", variant="primary")
715
  train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
716
 
717
+ # تب ساخت دیتاست
718
+ with gr.Tab("ساخت دیتاست"):
719
+ gr.Markdown("فایل خام (JSON/JSONL) را بارگذاری کنید. خروجی سازگار با اپ `{input, output}` خواهد بود.")
720
+ raw_file = gr.File(label="فایل خام", file_types=[".json",".jsonl"])
721
+ with gr.Row():
722
+ ds_text_key = gr.Textbox(value="متن_کامل", label="کلید متن (text_key)")
723
+ model_ckpt = gr.Dropdown(
724
+ choices=["google/mt5-base", "google/flan-t5-base", "t5-base"],
725
+ value="google/mt5-base",
726
+ label="مدل خلاصه‌ساز"
727
+ )
728
+ with gr.Row():
729
+ ds_batch_size = gr.Slider(1, 16, value=4, step=1, label="Batch size")
730
+ max_samples = gr.Number(value=0, label="حداکثر نمونه (۰=همه)")
731
+ build_btn = gr.Button("ساخت دیتاست", variant="primary")
732
+ out_file = gr.File(label="دانلود خروجی JSONL", interactive=False)
733
+ build_status = gr.Textbox(label="وضعیت", interactive=False)
734
+
735
+ # رویدادها
736
  def _resolve(choice: str) -> Tuple[str,str]:
737
  return default_models[choice]
738
 
 
740
  inputs=[model_dd, use_rag, persist_dir, collection, top_k, threshold], outputs=status)
741
 
742
  ask_btn.click(lambda q, rag, mnt, t, p, nb: self.answer(q, rag, mnt, t, p, nb),
743
+ inputs=[question, use_rag, max_new_tokens, temperature, top_p, num_beams],
744
+ outputs=[answer, refs])
745
+
746
+ index_btn.click(lambda f, ik, tk: self.build_index(f, ik, tk),
747
+ inputs=[laws_file, id_key, text_key], outputs=index_status)
748
+
749
+ train_btn.click(
750
+ lambda choice, files, rag, e, b, l, _wd, _wu, _r:
751
+ self.train(*_resolve(choice), files, rag, e, b, l, _wd, _wu, _r),
752
+ inputs=[model_dd_train, train_files, use_rag_train, epochs, batch, lr, wd, warmup, report_to],
753
+ outputs=train_status
754
+ )
755
+
756
+ build_btn.click(
757
+ lambda rf, tk, ckpt, bs, mx: self.build_dataset(rf, tk, ckpt, bs, mx),
758
+ inputs=[raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples],
759
+ outputs=[out_file, build_status]
760
+ )
761
 
 
 
762
  return app
763
 
764
  # ==========================
765
+ # Entrypoint for HF Spaces
 
766
  # ==========================
767
  if __name__ == "__main__":
768
  app = LegalApp()
769
  ui = app.build_ui()
770
+ # Gradio 5.46.x: queue بدون پارامتر legacy
771
+ try:
772
+ ui = ui.queue()
773
+ except TypeError:
774
+ pass
775
+ ui.launch(server_name="0.0.0.0", server_port=7860)