hajimammad commited on
Commit
f1c980f
·
verified ·
1 Parent(s): bb573b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -144
app.py CHANGED
@@ -1,14 +1,34 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B Training + Weight Tuning
 
 
 
 
 
 
4
  پیش‌نیازها:
5
- - golden_builder.py
6
- - weights_sweep.py
7
- - Secrets: WANDB_API_KEY
 
8
  """
9
 
10
  from __future__ import annotations
11
- import os, sys, re, json, time, pickle, zipfile, warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from dataclasses import dataclass, field
13
  from pathlib import Path
14
  from typing import List, Dict, Optional
@@ -21,19 +41,35 @@ from sklearn.model_selection import train_test_split
21
  import gradio as gr
22
  warnings.filterwarnings("ignore")
23
 
24
- # ====== ML & NLP ======
25
  import transformers as tf
26
  from transformers import (
27
  AutoTokenizer, AutoModelForCausalLM,
28
  Trainer, TrainingArguments, EarlyStoppingCallback
29
  )
30
 
31
- # RAG stack
32
  import chromadb
33
- from chromadb.config import Settings # <-- برای خاموش کردن Telemetry
34
  from rank_bm25 import BM25Okapi
35
  from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # ========= Persian normalization =========
38
  ZWNJ = "\u200c"
39
  AR_DIGITS = "٠١٢٣٤٥٦٧٨٩"
@@ -57,8 +93,8 @@ def normalize_fa(s: str) -> str:
57
  @dataclass
58
  class ModelConfig:
59
  model_name: str = "Qwen/Qwen2.5-7B-Instruct"
60
- max_input_length: int = 4096
61
- max_new_tokens: int = 512
62
  temperature: float = 0.7
63
  top_p: float = 0.9
64
  do_sample: bool = True
@@ -68,9 +104,9 @@ class ModelConfig:
68
  class RAGConfig:
69
  persist_dir: str = "./chroma_db"
70
  collection: str = "legal_articles"
71
- top_k: int = 8
72
- similarity_threshold: float = 0.60
73
- context_char_limit: int = 280
74
  enable: bool = True
75
  reranker_name: str = "Alibaba-NLP/gte-multilingual-reranker-base"
76
 
@@ -95,7 +131,7 @@ class TrainConfig:
95
  save_total_limit: int = 2
96
  report_to: str = "wandb"
97
  max_grad_norm: float = 1.0
98
- use_4bit: bool = True
99
  max_seq_len: int = 2048
100
 
101
  @dataclass
@@ -129,6 +165,23 @@ def log_deps():
129
  except Exception as e:
130
  print("[deps] warn:", e, flush=True)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # ==========================
133
  # RAG: Chroma + BM25 + CrossEncoder reranker
134
  # ==========================
@@ -144,7 +197,6 @@ class LegalRAG:
144
 
145
  def init(self):
146
  Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
147
- # خاموش کردن تله‌متری Chroma
148
  self.client = chromadb.PersistentClient(
149
  path=self.cfg.persist_dir,
150
  settings=Settings(anonymized_telemetry=False)
@@ -154,13 +206,12 @@ class LegalRAG:
154
  except Exception:
155
  try: self.collection = self.client.get_collection(self.cfg.collection)
156
  except Exception: self.collection = self.client.create_collection(self.cfg.collection)
157
- # reranker
158
  try:
159
- dev = "cuda" if torch.cuda.is_available() else "cpu"
160
- self.reranker = CrossEncoder(self.cfg.reranker_name, device=dev)
161
  except Exception:
162
  self.reranker = None
163
- # BM25
164
  if Path(self.bm25_path).exists():
165
  with open(self.bm25_path, "rb") as f:
166
  obj = pickle.load(f)
@@ -174,7 +225,6 @@ class LegalRAG:
174
  pickle.dump({"bm25": self.bm25, "ids": self.bm25_ids}, f)
175
 
176
  def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"):
177
- """ایندکس با تضمین یکتایی ID: ارقام Normalize و در صورت تکرار، پسوند __dN اضافه می‌شود."""
178
  if not self.collection: self.init()
179
 
180
  seen: Dict[str, int] = {}
@@ -224,7 +274,7 @@ class LegalRAG:
224
  if not self.collection: return []
225
  qn = normalize_fa(query)
226
 
227
- # Dense via Chroma
228
  try:
229
  res = self.collection.query(
230
  query_texts=[qn],
@@ -265,10 +315,14 @@ class LegalRAG:
265
  merged = [a for a in pool.values() if a.get("text") and len(a["text"]) > 15]
266
  merged = [a for a in merged if a.get("similarity", 0) >= self.cfg.similarity_threshold]
267
 
268
- # rerank
269
  if merged and self.reranker:
270
  pairs = [(qn, a["text"]) for a in merged]
271
- scores = self.reranker.predict(pairs)
 
 
 
 
272
  for a, s in zip(merged, scores): a["score"] = float(s)
273
  merged = sorted(merged, key=lambda x: x.get("score", 0), reverse=True)[: self.cfg.top_k]
274
  else:
@@ -318,7 +372,7 @@ def ensure_chroma_ready(persist_dir="./chroma_db", collection="legal_articles")
318
  return "پایگاه RAG موجود نیست و منبع خامی هم برای ساخت پیدا نشد."
319
 
320
  # ==========================
321
- # Loader + Generator (Causal-only)
322
  # ==========================
323
  class CausalLoader:
324
  def __init__(self, mcfg: ModelConfig):
@@ -330,15 +384,20 @@ class CausalLoader:
330
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
331
  if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"):
332
  self.tokenizer.pad_token = self.tokenizer.eos_token
333
- kwargs = {}
334
- if torch.cuda.is_available():
335
- kwargs["device_map"] = "auto"
336
- kwargs["torch_dtype"] = torch.bfloat16 if bf16_supported() else torch.float16
337
- kwargs["low_cpu_mem_usage"] = True # <-- امن‌سازی حافظه
338
- self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
339
- if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"):
340
- try: self.model.gradient_checkpointing_enable()
341
- except Exception: pass
 
 
 
 
 
342
  return self
343
 
344
  class Generator:
@@ -353,16 +412,34 @@ class Generator:
353
  if context: parts.append(f"<|system|>\nاز منابع زیر استفاده کن و استنادی پاسخ بده:\n{context}")
354
  parts.append(f"<|user|>\n{question}")
355
  prompt = "\n".join(parts) + "\n<|assistant|>\n"
356
- enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length).to(self.model.device)
357
- with torch.no_grad():
358
- out = self.model.generate(
359
- **enc,
360
- max_new_tokens=self.cfg.max_new_tokens,
361
- do_sample=self.cfg.do_sample,
362
- temperature=self.cfg.temperature,
363
- top_p=self.cfg.top_p,
364
- pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
365
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  return self.tk.decode(out[0], skip_special_tokens=True)
367
 
368
  # ==========================
@@ -538,7 +615,7 @@ def deduplicate_jsonl(in_path: str, out_path: str, sim_threshold: float = 0.90,
538
  return len(kept)
539
 
540
  # ==========================
541
- # App (Gradio)
542
  # ==========================
543
  class LegalApp:
544
  def __init__(self, scfg: Optional[SystemConfig] = None):
@@ -554,7 +631,7 @@ class LegalApp:
554
  if p: paths.append(p)
555
  return paths
556
 
557
- # Core
558
  def load(self, model_name: str):
559
  self.loader = CausalLoader(self.scfg.model).load(model_name)
560
  self.gen = Generator(self.loader, self.scfg.model)
@@ -568,7 +645,10 @@ class LegalApp:
568
  msg_rag = f"RAG خطا: {e}"
569
  return f"مدل بارگذاری شد: {model_name}\n{msg_rag}"
570
 
571
- def build_index(self, laws_file: gr.File, id_key: str, text_key: str):
 
 
 
572
  if not self.scfg.rag.enable: return "RAG غیرفعال است."
573
  try:
574
  self.rag.init()
@@ -578,28 +658,32 @@ class LegalApp:
578
  except Exception as e:
579
  return f"خطا در ایندکس: {e}"
580
 
581
- def answer(self, question: str, system_prompt: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float):
582
- if not question.strip(): return "لطفاً سوال خود را وارد کنید.", ""
583
- if not self.gen: return "ابتدا مدل را بارگذاری کنید.", ""
584
- self.scfg.model.max_new_tokens = int(max_new_tokens)
585
- self.scfg.model.temperature = float(temperature)
586
- self.scfg.model.top_p = float(top_p)
587
-
588
- arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else []
589
- # محدودسازی رفرنس‌ها
590
- max_refs = 4
591
- if arts: arts = arts[:max_refs]
592
- ctx = self.rag.build_context(arts) if arts else ""
593
- ans = self.gen.generate(question, ctx, system_prompt)
594
-
595
- refs = ""
596
- if arts:
597
- refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a.get('similarity',0):.2f})\n{a['text'][:320]}..." for a in arts])
598
- return ans, refs
 
 
599
 
600
  def train(self, model_name: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float,
601
  use_wandb: bool, wandb_project: str, wandb_entity: str, run_name: str,
602
- progress=gr.Progress(track_tqdm=True)):
 
 
603
  progress(0.05, desc="راه‌اندازی")
604
  self.scfg.train.epochs = int(epochs)
605
  self.scfg.train.batch_size = int(batch)
@@ -623,28 +707,9 @@ class LegalApp:
623
  progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها")
624
  return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
625
 
626
- # Dataset Builder
627
- def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None):
628
- try:
629
- from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
630
- except Exception as e:
631
- return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}"
632
- path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
633
- if not path: return None, "⚠️ فایل ورودی معتبر نیست."
634
- try:
635
- data = load_json_or_jsonl(path)
636
- if max_samples and int(max_samples) > 0: data = data[:int(max_samples)]
637
- gb = GoldenBuilder(model_name=model_ckpt)
638
- rows = gb.build(data, text_key=text_key, batch_size=int(batch_size))
639
- out_dir = "/tmp/mahoon_datasets"; Path(out_dir).mkdir(parents=True, exist_ok=True)
640
- out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl"
641
- save_jsonl(rows, out_path)
642
- return out_path, f"✅ {len(rows)} رکورد تولید شد."
643
- except Exception as e:
644
- return None, f"❌ خطا در ساخت دیتاست: {e}"
645
-
646
- # Weight Tuning (W&B Sweep)
647
- def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent):
648
  p = getattr(f, "name", None) or getattr(f, "path", None)
649
  if not p:
650
  return "⚠️ فایل داده نامعتبر است."
@@ -661,8 +726,9 @@ class LegalApp:
661
  except Exception as e:
662
  return f"❌ خطا در اجرای Sweep: {e}"
663
 
664
- # Pull بهترین وزن‌ها از W&B و ذخیره در legal_entity_weights.json
665
- def apply_best_weights(self, wandb_project: str, wandb_entity: str, metric: str = "pass_rate"):
 
666
  try:
667
  import wandb, json as _json
668
  except Exception as e:
@@ -699,6 +765,25 @@ class LegalApp:
699
  rid = getattr(best_run, "id", "unknown")
700
  return f"✅ وزن‌ها اعمال شد از Run `{rid}` با {metric}={best_val:.4f}. فایل: `legal_entity_weights.json`"
701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  # UI
703
  def build_ui(self):
704
  log_deps()
@@ -713,15 +798,18 @@ class LegalApp:
713
  "Mistral-7B Instruct (v0.3)": "mistralai/Mistral-7B-Instruct-v0.3",
714
  }
715
 
716
- with gr.Blocks(title="ماحون — مشاور حقوقی (Causal-only)") as app:
 
 
 
717
  gr.Markdown("""
718
  <div style='text-align:center;padding:18px'>
719
- <h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only)</h1>
720
  <p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning</p>
721
  </div>
722
  """)
723
 
724
- # --- Tab: Consultation ---
725
  with gr.Tab("مشاوره"):
726
  with gr.Row():
727
  gen_model_dd = gr.Dropdown(choices=list(default_gen_models.keys()), value="Qwen2.5-7B Instruct", label="مدل تولید")
@@ -754,15 +842,16 @@ class LegalApp:
754
  ask_btn = gr.Button("پرسش", variant="primary")
755
  answer = gr.Markdown(label="پاسخ"); refs = gr.Markdown(label="مواد قانونی مرتبط")
756
 
757
- # --- Tab: Indexing ---
758
  with gr.Tab("ایندکس قوانین"):
759
  gr.Markdown("فایل JSONL قوانین را بارگذاری و ایندکس کنید (کلیدها: `article_id`, `text`).")
760
  laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"])
761
  id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده")
762
  text_key = gr.Textbox(value="text", label="کلید متن ماده")
763
  index_btn = gr.Button("ایندکس‌سازی قوانین"); index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False)
 
764
 
765
- # --- Tab: Dataset Builder ---
766
  with gr.Tab("ساخت دیتاست"):
767
  gr.Markdown("فایل خام (JSON/JSONL) → خروجی JSONL سازگار با `{input, output}` (از golden_builder).")
768
  raw_file = gr.File(label="فایل خام", file_types=[".json",".jsonl"])
@@ -779,8 +868,9 @@ class LegalApp:
779
  build_btn = gr.Button("ساخت دیتاست", variant="primary")
780
  out_file = gr.File(label="دانلود خروجی JSONL", interactive=False)
781
  build_status = gr.Textbox(label="وضعیت", interactive=False)
 
782
 
783
- # --- Tab: Dataset Cleaning ---
784
  with gr.Tab("پاکسازی دیتاست"):
785
  gr.Markdown("نرمال‌سازی فارسی + حذف تکراری‌های معنایی (cosine). ورودی: JSONL `{input, output}`.")
786
  raw_ds = gr.File(label="JSONL ورودی", file_types=[".jsonl"])
@@ -788,8 +878,9 @@ class LegalApp:
788
  clean_btn = gr.Button("اجرای پاکسازی", variant="primary")
789
  cleaned_out = gr.File(label="دانلود JSONL پاک", interactive=False)
790
  clean_status = gr.Markdown()
 
791
 
792
- # --- Tab: Training (W&B integrated) ---
793
  with gr.Tab("آموزش"):
794
  gr.Markdown("SFT/LoRA روی مدل‌های causal (فقط `{input, output}`) + W&B logging.")
795
  with gr.Row():
@@ -806,7 +897,6 @@ class LegalApp:
806
  model_train_id = gr.Textbox(value="AI-Hoosh/HAKIM-7B", label="HF Model ID (قابل ویرایش)")
807
  use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training")
808
 
809
- # W&B controls
810
  use_wandb = gr.Checkbox(value=True, label="W&B logging فعال باشد؟")
811
  wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
812
  wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
@@ -820,8 +910,10 @@ class LegalApp:
820
  lr = gr.Number(value=2e-4, label="learning rate")
821
  train_btn = gr.Button("شروع آموزش", variant="primary")
822
  train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
 
 
823
 
824
- # --- Tab: Weight Tuning ---
825
  with gr.Tab("Weight Tuning"):
826
  gr.Markdown("تیون خودکار وزن‌های موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.")
827
  tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"])
@@ -838,8 +930,10 @@ class LegalApp:
838
  gr.Markdown("اعمال خودکار بهترین وزن‌ها از داشبورد W&B (بر اساس بالاترین `pass_rate`).")
839
  metric_dd = gr.Dropdown(choices=["pass_rate"], value="pass_rate", label="متریک انتخاب بهترین Run")
840
  apply_btn = gr.Button("اعمال بهترین وزن‌ها از W&B", variant="secondary")
 
 
841
 
842
- # ---- Events ----
843
  def _resolve_gen(choice: str, override: str) -> str:
844
  return override.strip() if override.strip() else default_gen_models[choice]
845
 
@@ -851,62 +945,79 @@ class LegalApp:
851
  self.scfg.rag.similarity_threshold = float(th)
852
  return self.load(_resolve_gen(choice, override))
853
 
 
 
 
 
854
  load_btn.click(_on_load,
855
  inputs=[gen_model_dd, gen_model_id, use_rag, persist_dir, collection, top_k, threshold],
856
  outputs=status)
857
 
858
- ask_btn.click(lambda q, sys_p, rag, mnt, t, p: self.answer(q, sys_p, rag, mnt, t, p),
859
  inputs=[question, system_prompt, use_rag, max_new_tokens, temperature, top_p],
860
  outputs=[answer, refs])
861
 
862
- index_btn.click(lambda f, ik, tk: self.build_index(f, ik, tk),
863
- inputs=[laws_file, id_key, text_key], outputs=index_status)
 
 
864
 
865
- build_btn.click(lambda rf, tk, ckpt, bs, mx: self.build_dataset(rf, tk, ckpt, bs, mx),
 
 
866
  inputs=[raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples],
867
  outputs=[out_file, build_status])
868
 
869
- def _map_profile_to_id(profile: str, current_id: str) -> str:
870
- if current_id.strip(): return current_id.strip()
871
- if "Dorna" in profile: return "PartAI/Dorna-Llama3-8B-Instruct"
872
- if "PersianQA" in profile: return "zpm/Llama-3.1-PersianQA"
873
- if "HAKIM" in profile: return "AI-Hoosh/HAKIM-7B"
874
- if "Hooshvareh" in profile: return "HooshvareLab/llama-fa-7b-instruct"
875
- return "PartAI/Dorna-Llama3-8B-Instruct"
876
-
877
- train_btn.click(
878
- lambda prof, mid, files, rg, e, b, l, uw, wp, we, rn:
879
- self.train(_map_profile_to_id(prof, mid), files, rg, e, b, l, uw, wp, we, rn),
880
- inputs=[model_train_dd, model_train_id, train_files, use_rag_train, epochs, batch, lr,
881
- use_wandb, wandb_project, wandb_entity, run_name],
882
- outputs=train_status
883
- )
884
-
885
- clean_btn.click(
886
- lambda f, th: (
887
- (lambda _p, _out:
888
- ( _out,
889
- f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{deduplicate_jsonl(_p, _out, sim_threshold=float(th))}**" )
890
- )(
891
- getattr(f, "name", None) or getattr(f, "path", None),
892
- f"/tmp/cleaned_{int(time.time())}.jsonl"
893
- ) if (getattr(f, 'name', None) or getattr(f, 'path', None)) else (None, "⚠️ فایل نامعتبر.")
894
- ),
895
- inputs=[raw_ds, sim_th],
896
- outputs=[cleaned_out, clean_status]
897
- )
898
-
899
- run_tune.click(
900
- lambda f, tk, ms, runs, bs, proj, ent: self.run_weight_tune(f, tk, ms, runs, bs, proj, ent),
901
- inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity],
902
- outputs=tune_status
903
- )
 
 
 
 
 
 
 
 
 
 
 
904
 
905
- apply_btn.click(
906
- lambda proj, ent, m: self.apply_best_weights(proj, ent, m),
907
- inputs=[tune_proj, tune_entity, metric_dd],
908
- outputs=tune_status
909
- )
910
 
911
  return app
912
 
@@ -917,7 +1028,7 @@ if __name__ == "__main__":
917
  app = LegalApp()
918
  ui = app.build_ui()
919
  try:
920
- ui = ui.queue() # بدون پارامتر سفارشی؛ پایدارتر
921
  except TypeError:
922
  pass
923
  ui.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Mahoon Legal AI — Causal-only Generation + Hybrid RAG + W&B + ZeroGPU + Role Gating
4
+ safari-shojaei-goldasteh-dr.pasandi
5
+
6
+ - تب «مشاوره» برای همه تعاملی است.
7
+ - تب‌های «ایندکس»، «ساخت دیتاست»، «پاکسازی»، «آموزش»، «Weight Tuning» برای بازدیدکننده فقط نمایشی‌اند؛
8
+ و سمت‌سرور نیز گِیت نقش دارد (ادمین/بازدیدکننده).
9
+
10
  پیش‌نیازها:
11
+ - golden_builder.py , weights_sweep.py
12
+ - Settings → Secrets: WANDB_API_KEY (در صورت استفاده از W&B)
13
+ - Settings → Environment Variables: ADMIN_USERS (مثلاً: haji-mammad, teammate1)
14
+ - requirements.txt (ZeroGPU-ready) شامل spaces>=0.42.0
15
  """
16
 
17
  from __future__ import annotations
18
+
19
+ # --- Telemetry hard-off + ZeroGPU SDK (must be before chroma import) ---
20
+ import os, logging
21
+ os.environ["CHROMA_TELEMETRY_ENABLED"] = "false"
22
+ os.environ["ANONYMIZED_TELEMETRY"] = "false"
23
+
24
+ import spaces # ZeroGPU SDK
25
+
26
+ # (اختیاری) کاهش نویز لاگ‌ها
27
+ logging.getLogger("chromadb").setLevel(logging.ERROR)
28
+ logging.getLogger("posthog").setLevel(logging.CRITICAL)
29
+ # -----------------------------------------------------------------------
30
+
31
+ import sys, re, json, time, pickle, zipfile, warnings
32
  from dataclasses import dataclass, field
33
  from pathlib import Path
34
  from typing import List, Dict, Optional
 
41
  import gradio as gr
42
  warnings.filterwarnings("ignore")
43
 
44
+ # ====== Transformers ======
45
  import transformers as tf
46
  from transformers import (
47
  AutoTokenizer, AutoModelForCausalLM,
48
  Trainer, TrainingArguments, EarlyStoppingCallback
49
  )
50
 
51
+ # ====== RAG stack ======
52
  import chromadb
53
+ from chromadb.config import Settings
54
  from rank_bm25 import BM25Okapi
55
  from sentence_transformers import CrossEncoder, SentenceTransformer, util as st_util
56
 
57
+ # ---- Monkeypatch Chroma telemetry (fallback) ----
58
+ try:
59
+ import chromadb.telemetry as _ctel
60
+ try: _ctel.client = None
61
+ except Exception: pass
62
+ for _n in ("capture", "capture_event"):
63
+ if hasattr(_ctel, _n):
64
+ try: setattr(_ctel, _n, lambda *a, **k: None)
65
+ except Exception: pass
66
+ if hasattr(_ctel, "Telemetry"):
67
+ try: _ctel.Telemetry().capture = lambda *a, **k: None
68
+ except Exception: pass
69
+ except Exception:
70
+ pass
71
+ # -------------------------------------------------
72
+
73
  # ========= Persian normalization =========
74
  ZWNJ = "\u200c"
75
  AR_DIGITS = "٠١٢٣٤٥٦٧٨٩"
 
93
  @dataclass
94
  class ModelConfig:
95
  model_name: str = "Qwen/Qwen2.5-7B-Instruct"
96
+ max_input_length: int = 3072
97
+ max_new_tokens: int = 256
98
  temperature: float = 0.7
99
  top_p: float = 0.9
100
  do_sample: bool = True
 
104
  class RAGConfig:
105
  persist_dir: str = "./chroma_db"
106
  collection: str = "legal_articles"
107
+ top_k: int = 6
108
+ similarity_threshold: float = 0.68
109
+ context_char_limit: int = 260
110
  enable: bool = True
111
  reranker_name: str = "Alibaba-NLP/gte-multilingual-reranker-base"
112
 
 
131
  save_total_limit: int = 2
132
  report_to: str = "wandb"
133
  max_grad_norm: float = 1.0
134
+ use_4bit: bool = False
135
  max_seq_len: int = 2048
136
 
137
  @dataclass
 
165
  except Exception as e:
166
  print("[deps] warn:", e, flush=True)
167
 
168
+ # ==========================
169
+ # Role gating helpers
170
+ # ==========================
171
+ def _get_username(request: gr.Request) -> str | None:
172
+ try:
173
+ return getattr(request, "username", None)
174
+ except Exception:
175
+ return None
176
+
177
+ def is_admin(request: gr.Request) -> bool:
178
+ uname = _get_username(request)
179
+ if not uname:
180
+ return False
181
+ author = os.getenv("SPACE_AUTHOR_NAME", "").strip()
182
+ allow = {u.strip() for u in os.getenv("ADMIN_USERS", "").split(",") if u.strip()}
183
+ return (uname == author) or (uname in allow)
184
+
185
  # ==========================
186
  # RAG: Chroma + BM25 + CrossEncoder reranker
187
  # ==========================
 
197
 
198
  def init(self):
199
  Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
 
200
  self.client = chromadb.PersistentClient(
201
  path=self.cfg.persist_dir,
202
  settings=Settings(anonymized_telemetry=False)
 
206
  except Exception:
207
  try: self.collection = self.client.get_collection(self.cfg.collection)
208
  except Exception: self.collection = self.client.create_collection(self.cfg.collection)
209
+
210
  try:
211
+ self.reranker = CrossEncoder(self.cfg.reranker_name, device="cpu")
 
212
  except Exception:
213
  self.reranker = None
214
+
215
  if Path(self.bm25_path).exists():
216
  with open(self.bm25_path, "rb") as f:
217
  obj = pickle.load(f)
 
225
  pickle.dump({"bm25": self.bm25, "ids": self.bm25_ids}, f)
226
 
227
  def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"):
 
228
  if not self.collection: self.init()
229
 
230
  seen: Dict[str, int] = {}
 
274
  if not self.collection: return []
275
  qn = normalize_fa(query)
276
 
277
+ # Dense
278
  try:
279
  res = self.collection.query(
280
  query_texts=[qn],
 
315
  merged = [a for a in pool.values() if a.get("text") and len(a["text"]) > 15]
316
  merged = [a for a in merged if a.get("similarity", 0) >= self.cfg.similarity_threshold]
317
 
318
+ # rerank (GPU only during predict)
319
  if merged and self.reranker:
320
  pairs = [(qn, a["text"]) for a in merged]
321
+ try:
322
+ with spaces.GPU(duration=30):
323
+ scores = self.reranker.predict(pairs)
324
+ except Exception:
325
+ scores = self.reranker.predict(pairs)
326
  for a, s in zip(merged, scores): a["score"] = float(s)
327
  merged = sorted(merged, key=lambda x: x.get("score", 0), reverse=True)[: self.cfg.top_k]
328
  else:
 
372
  return "پایگاه RAG موجود نیست و منبع خامی هم برای ساخت پیدا نشد."
373
 
374
  # ==========================
375
+ # Loader + Generator (Causal-only, ZeroGPU)
376
  # ==========================
377
  class CausalLoader:
378
  def __init__(self, mcfg: ModelConfig):
 
384
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
385
  if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"):
386
  self.tokenizer.pad_token = self.tokenizer.eos_token
387
+
388
+ try:
389
+ with spaces.GPU(duration=90):
390
+ kwargs = {"low_cpu_mem_usage": True}
391
+ if torch.cuda.is_available():
392
+ kwargs["device_map"] = "auto"
393
+ kwargs["torch_dtype"] = torch.bfloat16 if bf16_supported() else torch.float16
394
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
395
+ if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"):
396
+ try: self.model.gradient_checkpointing_enable()
397
+ except Exception: pass
398
+ except Exception:
399
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True)
400
+
401
  return self
402
 
403
  class Generator:
 
412
  if context: parts.append(f"<|system|>\nاز منابع زیر استفاده کن و استنادی پاسخ بده:\n{context}")
413
  parts.append(f"<|user|>\n{question}")
414
  prompt = "\n".join(parts) + "\n<|assistant|>\n"
415
+
416
+ enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
417
+
418
+ try:
419
+ with spaces.GPU(duration=60):
420
+ dev_model = next(self.model.parameters()).device if hasattr(self.model, "parameters") else "cpu"
421
+ inputs = {k: v.to(dev_model) for k, v in enc.items()}
422
+ with torch.no_grad():
423
+ out = self.model.generate(
424
+ **inputs,
425
+ max_new_tokens=self.cfg.max_new_tokens,
426
+ do_sample=self.cfg.do_sample,
427
+ temperature=self.cfg.temperature,
428
+ top_p=self.cfg.top_p,
429
+ pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
430
+ )
431
+ except Exception:
432
+ inputs = {k: v for k, v in enc.items()}
433
+ with torch.no_grad():
434
+ out = self.model.generate(
435
+ **inputs,
436
+ max_new_tokens=min(self.cfg.max_new_tokens, 256),
437
+ do_sample=self.cfg.do_sample,
438
+ temperature=self.cfg.temperature,
439
+ top_p=self.cfg.top_p,
440
+ pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
441
+ )
442
+
443
  return self.tk.decode(out[0], skip_special_tokens=True)
444
 
445
  # ==========================
 
615
  return len(kept)
616
 
617
  # ==========================
618
+ # App (Gradio) + Role Gating
619
  # ==========================
620
  class LegalApp:
621
  def __init__(self, scfg: Optional[SystemConfig] = None):
 
631
  if p: paths.append(p)
632
  return paths
633
 
634
+ # Core (مشاوره/لود آزاد است)
635
  def load(self, model_name: str):
636
  self.loader = CausalLoader(self.scfg.model).load(model_name)
637
  self.gen = Generator(self.loader, self.scfg.model)
 
645
  msg_rag = f"RAG خطا: {e}"
646
  return f"مدل بارگذاری شد: {model_name}\n{msg_rag}"
647
 
648
+ # --- گیت سمت‌سرور: فقط ادمین ---
649
+ def build_index(self, laws_file: gr.File, id_key: str, text_key: str, request: gr.Request):
650
+ if not is_admin(request):
651
+ return "🔒 این عملیات فقط برای مدیران فعال است."
652
  if not self.scfg.rag.enable: return "RAG غیرفعال است."
653
  try:
654
  self.rag.init()
 
658
  except Exception as e:
659
  return f"خطا در ایندکس: {e}"
660
 
661
+ def build_dataset(self, raw_file, text_key: str, model_ckpt: str, batch_size: int, max_samples: int | None, request: gr.Request):
662
+ if not is_admin(request):
663
+ return None, "🔒 این عملیات فقط برای مدیران فعال است."
664
+ try:
665
+ from golden_builder import load_json_or_jsonl, save_jsonl, GoldenBuilder
666
+ except Exception as e:
667
+ return None, f"❌ golden_builder.py یافت نشد/قابل import نیست: {e}"
668
+ path = getattr(raw_file, "name", None) or getattr(raw_file, "path", None)
669
+ if not path: return None, "⚠️ فایل ورودی معتبر نیست."
670
+ try:
671
+ data = load_json_or_jsonl(path)
672
+ if max_samples and int(max_samples) > 0: data = data[:int(max_samples)]
673
+ gb = GoldenBuilder(model_name=model_ckpt)
674
+ rows = gb.build(data, text_key=text_key, batch_size=int(batch_size))
675
+ out_dir = "/tmp/mahoon_datasets"; Path(out_dir).mkdir(parents=True, exist_ok=True)
676
+ out_path = f"{out_dir}/golden_{os.path.basename(path)}.jsonl"
677
+ save_jsonl(rows, out_path)
678
+ return out_path, f"✅ {len(rows)} رکورد تولید شد."
679
+ except Exception as e:
680
+ return None, f"❌ خطا در ساخت دیتاست: {e}"
681
 
682
  def train(self, model_name: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float,
683
  use_wandb: bool, wandb_project: str, wandb_entity: str, run_name: str,
684
+ progress=gr.Progress(track_tqdm=True), request: gr.Request = None):
685
+ if not is_admin(request):
686
+ return "🔒 این عملیات فقط برای مدیران فعال است."
687
  progress(0.05, desc="راه‌اندازی")
688
  self.scfg.train.epochs = int(epochs)
689
  self.scfg.train.batch_size = int(batch)
 
707
  progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها")
708
  return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
709
 
710
+ def run_weight_tune(self, f, tk, ms, runs, bs, proj, ent, request: gr.Request):
711
+ if not is_admin(request):
712
+ return "🔒 این عملیات فقط برای مدیران فعال است."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  p = getattr(f, "name", None) or getattr(f, "path", None)
714
  if not p:
715
  return "⚠️ فایل داده نامعتبر است."
 
726
  except Exception as e:
727
  return f"❌ خطا در اجرای Sweep: {e}"
728
 
729
+ def apply_best_weights(self, wandb_project: str, wandb_entity: str, metric: str = "pass_rate", request: gr.Request = None):
730
+ if request is not None and not is_admin(request):
731
+ return "🔒 این عملیات فقط برای مدیران فعال است."
732
  try:
733
  import wandb, json as _json
734
  except Exception as e:
 
765
  rid = getattr(best_run, "id", "unknown")
766
  return f"✅ وزن‌ها اعمال شد از Run `{rid}` با {metric}={best_val:.4f}. فایل: `legal_entity_weights.json`"
767
 
768
+ # Consultation (عمومی)
769
+ def answer(self, question: str, system_prompt: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float):
770
+ if not question.strip(): return "لطفاً سوال خود را وارد کنید.", ""
771
+ if not self.gen: return "ابتدا مدل را بارگذاری کنید.", ""
772
+ self.scfg.model.max_new_tokens = int(max_new_tokens)
773
+ self.scfg.model.temperature = float(temperature)
774
+ self.scfg.model.top_p = float(top_p)
775
+
776
+ arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else []
777
+ max_refs = 4
778
+ if arts: arts = arts[:max_refs]
779
+ ctx = self.rag.build_context(arts) if arts else ""
780
+ ans = self.gen.generate(question, ctx, system_prompt)
781
+
782
+ refs = ""
783
+ if arts:
784
+ refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a.get('similarity',0):.2f})\n{a['text'][:320]}..." for a in arts])
785
+ return ans, refs
786
+
787
  # UI
788
  def build_ui(self):
789
  log_deps()
 
798
  "Mistral-7B Instruct (v0.3)": "mistralai/Mistral-7B-Instruct-v0.3",
799
  }
800
 
801
+ with gr.Blocks(title="ماحون — مشاور حقوقی (Causal-only, ZeroGPU)") as app:
802
+ # بنر نقش
803
+ role_banner = gr.Markdown()
804
+
805
  gr.Markdown("""
806
  <div style='text-align:center;padding:18px'>
807
+ <h1 style='margin-bottom:4px'>ماحون — Persian Legal (Causal-only, ZeroGPU)</h1>
808
  <p style='color:#666'>Hybrid RAG • Qwen/Llama/Mistral • Dataset Ops • W&B Training • Weight Tuning</p>
809
  </div>
810
  """)
811
 
812
+ # --- Tab: Consultation (interactive for all) ---
813
  with gr.Tab("مشاوره"):
814
  with gr.Row():
815
  gen_model_dd = gr.Dropdown(choices=list(default_gen_models.keys()), value="Qwen2.5-7B Instruct", label="مدل تولید")
 
842
  ask_btn = gr.Button("پرسش", variant="primary")
843
  answer = gr.Markdown(label="پاسخ"); refs = gr.Markdown(label="مواد قانونی مرتبط")
844
 
845
+ # --- Tab: Indexing (view-only for visitors) ---
846
  with gr.Tab("ایندکس قوانین"):
847
  gr.Markdown("فایل JSONL قوانین را بارگذاری و ایندکس کنید (کلیدها: `article_id`, `text`).")
848
  laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"])
849
  id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده")
850
  text_key = gr.Textbox(value="text", label="کلید متن ماده")
851
  index_btn = gr.Button("ایندکس‌سازی قوانین"); index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False)
852
+ index_widgets = [laws_file, id_key, text_key, index_btn]
853
 
854
+ # --- Tab: Dataset Builder (view-only for visitors) ---
855
  with gr.Tab("ساخت دیتاست"):
856
  gr.Markdown("فایل خام (JSON/JSONL) → خروجی JSONL سازگار با `{input, output}` (از golden_builder).")
857
  raw_file = gr.File(label="فایل خام", file_types=[".json",".jsonl"])
 
868
  build_btn = gr.Button("ساخت دیتاست", variant="primary")
869
  out_file = gr.File(label="دانلود خروجی JSONL", interactive=False)
870
  build_status = gr.Textbox(label="وضعیت", interactive=False)
871
+ builder_widgets = [raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples, build_btn]
872
 
873
+ # --- Tab: Dataset Cleaning (view-only for visitors) ---
874
  with gr.Tab("پاکسازی دیتاست"):
875
  gr.Markdown("نرمال‌سازی فارسی + حذف تکراری‌های معنایی (cosine). ورودی: JSONL `{input, output}`.")
876
  raw_ds = gr.File(label="JSONL ورودی", file_types=[".jsonl"])
 
878
  clean_btn = gr.Button("اجرای پاکسازی", variant="primary")
879
  cleaned_out = gr.File(label="دانلود JSONL پاک", interactive=False)
880
  clean_status = gr.Markdown()
881
+ clean_widgets = [raw_ds, sim_th, clean_btn]
882
 
883
+ # --- Tab: Training (view-only for visitors) ---
884
  with gr.Tab("آموزش"):
885
  gr.Markdown("SFT/LoRA روی مدل‌های causal (فقط `{input, output}`) + W&B logging.")
886
  with gr.Row():
 
897
  model_train_id = gr.Textbox(value="AI-Hoosh/HAKIM-7B", label="HF Model ID (قابل ویرایش)")
898
  use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training")
899
 
 
900
  use_wandb = gr.Checkbox(value=True, label="W&B logging فعال باشد؟")
901
  wandb_project = gr.Textbox(value="mahoon-legal-ai", label="WANDB_PROJECT")
902
  wandb_entity = gr.Textbox(value="", label="WANDB_ENTITY (اختیاری)")
 
910
  lr = gr.Number(value=2e-4, label="learning rate")
911
  train_btn = gr.Button("شروع آموزش", variant="primary")
912
  train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
913
+ train_widgets = [model_train_dd, model_train_id, use_rag_train, use_wandb, wandb_project, wandb_entity,
914
+ run_name, train_files, epochs, batch, lr, train_btn]
915
 
916
+ # --- Tab: Weight Tuning (view-only for visitors) ---
917
  with gr.Tab("Weight Tuning"):
918
  gr.Markdown("تیون خودکار وزن‌های موجودیت با W&B Sweep. ابتدا در Settings→Secrets مقدار `WANDB_API_KEY` را ست کنید.")
919
  tune_file = gr.File(label="فایل داده (JSON/JSONL)", file_types=[".json",".jsonl"])
 
930
  gr.Markdown("اعمال خودکار بهترین وزن‌ها از داشبورد W&B (بر اساس بالاترین `pass_rate`).")
931
  metric_dd = gr.Dropdown(choices=["pass_rate"], value="pass_rate", label="متریک انتخاب بهترین Run")
932
  apply_btn = gr.Button("اعمال بهترین وزن‌ها از W&B", variant="secondary")
933
+ tuning_widgets = [tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch,
934
+ tune_proj, tune_entity, run_tune, metric_dd, apply_btn]
935
 
936
+ # ---- Events (مشاوره آزاد / عملیاتِ ادمینی با گیت) ----
937
  def _resolve_gen(choice: str, override: str) -> str:
938
  return override.strip() if override.strip() else default_gen_models[choice]
939
 
 
945
  self.scfg.rag.similarity_threshold = float(th)
946
  return self.load(_resolve_gen(choice, override))
947
 
948
+ def _whoami(request: gr.Request):
949
+ u = _get_username(request) or "Visitor"
950
+ return f"👤 کاربر: **{u}** — دسترسی: {'مدیریتی' if is_admin(request) else 'بازدیدکننده (فقط مشاهده)'}"
951
+
952
  load_btn.click(_on_load,
953
  inputs=[gen_model_dd, gen_model_id, use_rag, persist_dir, collection, top_k, threshold],
954
  outputs=status)
955
 
956
+ ask_btn.click(self.answer,
957
  inputs=[question, system_prompt, use_rag, max_new_tokens, temperature, top_p],
958
  outputs=[answer, refs])
959
 
960
+ # ادمینی: استفاده از request injection (Gradio به‌طور خودکار تزریق می‌کند)
961
+ def _index_handler(f, ik, tk, request: gr.Request):
962
+ return self.build_index(f, ik, tk, request)
963
+ index_btn.click(_index_handler, inputs=[laws_file, id_key, text_key], outputs=index_status)
964
 
965
+ def _build_ds_handler(rf, tk, ckpt, bs, mx, request: gr.Request):
966
+ return self.build_dataset(rf, tk, ckpt, bs, mx, request)
967
+ build_btn.click(_build_ds_handler,
968
  inputs=[raw_file, ds_text_key, model_ckpt, ds_batch_size, max_samples],
969
  outputs=[out_file, build_status])
970
 
971
+ def _train_handler(prof, mid, files, rg, e, b, l, uw, wp, we, rn, request: gr.Request):
972
+ def _map_profile_to_id(profile: str, current_id: str) -> str:
973
+ if current_id.strip(): return current_id.strip()
974
+ if "Dorna" in profile: return "PartAI/Dorna-Llama3-8B-Instruct"
975
+ if "PersianQA" in profile: return "zpm/Llama-3.1-PersianQA"
976
+ if "HAKIM" in profile: return "AI-Hoosh/HAKIM-7B"
977
+ if "Hooshvareh" in profile: return "HooshvareLab/llama-fa-7b-instruct"
978
+ return "PartAI/Dorna-Llama3-8B-Instruct"
979
+ model_id = _map_profile_to_id(prof, mid)
980
+ return self.train(model_id, files, rg, e, b, l, uw, wp, we, rn, request=request)
981
+ train_btn.click(_train_handler,
982
+ inputs=[model_train_dd, model_train_id, train_files, use_rag_train, epochs, batch, lr,
983
+ use_wandb, wandb_project, wandb_entity, run_name],
984
+ outputs=train_status)
985
+
986
+ def _clean_handler(f, th):
987
+ p = getattr(f, "name", None) or getattr(f, "path", None)
988
+ if not p: return None, "⚠️ فایل نامعتبر."
989
+ outp = f"/tmp/cleaned_{int(time.time())}.jsonl"
990
+ n = deduplicate_jsonl(p, outp, sim_threshold=float(th))
991
+ return outp, f"✅ دیتاست پاک شد. تعداد رکوردهای نهایی: **{n}**"
992
+ clean_btn.click(_clean_handler, inputs=[raw_ds, sim_th], outputs=[cleaned_out, clean_status])
993
+
994
+ def _tune_handler(f, tk, ms, runs, bs, proj, ent, request: gr.Request):
995
+ return self.run_weight_tune(f, tk, ms, runs, bs, proj, ent, request)
996
+ run_tune.click(_tune_handler,
997
+ inputs=[tune_file, tune_text_key, tune_max_samples, tune_runs, tune_batch, tune_proj, tune_entity],
998
+ outputs=tune_status)
999
+
1000
+ def _apply_best_handler(proj, ent, m, request: gr.Request):
1001
+ return self.apply_best_weights(proj, ent, m, request)
1002
+ apply_btn.click(_apply_best_handler,
1003
+ inputs=[tune_proj, tune_entity, metric_dd],
1004
+ outputs=tune_status)
1005
+
1006
+ # --- Lock non-consultation tabs for visitors on load ---
1007
+ def _gate_all(request: gr.Request):
1008
+ admin = is_admin(request)
1009
+ role_txt = f"👤 کاربر: **{_get_username(request) or 'Visitor'}** — دسترسی: {'مدیریتی' if admin else 'بازدیدکننده (فقط مشاهده)'}"
1010
+ if not admin:
1011
+ lock = gr.update(interactive=False)
1012
+ updates = [lock] * (len(index_widgets) + len(builder_widgets) + len(clean_widgets) + len(train_widgets) + len(tuning_widgets))
1013
+ else:
1014
+ unlock = gr.update(interactive=True)
1015
+ updates = [unlock] * (len(index_widgets) + len(builder_widgets) + len(clean_widgets) + len(train_widgets) + len(tuning_widgets))
1016
+ return [role_txt] + updates
1017
 
1018
+ app.load(_whoami, inputs=None, outputs=role_banner)
1019
+ app.load(_gate_all, inputs=None,
1020
+ outputs=[role_banner] + index_widgets + builder_widgets + clean_widgets + train_widgets + tuning_widgets)
 
 
1021
 
1022
  return app
1023
 
 
1028
  app = LegalApp()
1029
  ui = app.build_ui()
1030
  try:
1031
+ ui = ui.queue() # پایدار برای ZeroGPU
1032
  except TypeError:
1033
  pass
1034
  ui.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)