Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Mahoun — Ultimate Legal AI (Single-File, Modular, Polished UI) | |
| هستهٔ جدید ماحون با ادغام اجزای قبلی (RAG پیشرفته + Training برای Seq2Seq و Causal) و UI زیباتر. | |
| ویژگیها: | |
| - Multi-Architecture: "seq2seq" (T5/MT5/FLAN-T5) و "causal" (Mistral/LLaMA). | |
| - Loader/Generator یکپارچه + Prompt تطبیقی برحسب معماری. | |
| - RAG پیشرفته با ChromaDB (پیکربندی مسیر، نام کالکشن، top_k، threshold، قطع متن). | |
| - Training کامل برای هر دو معماری (Trainer, EarlyStopping, bf16/fp16, gradient_accumulation). | |
| - Gradio UI بازطراحیشده (تم تمیز، کارتها، مثالها، وضعیت زنده، کنترلهای تولید، انتخاب مدل/معماری/دیتابیس). | |
| حداقل نیازمندیها (requirements.txt): | |
| transformers>=4.44.0 | |
| sentencepiece | |
| accelerate | |
| bitsandbytes | |
| chromadb | |
| sentence-transformers | |
| scikit-learn | |
| gradio | |
| torch>=2.1 | |
| """ | |
| from __future__ import annotations | |
| import os, json, gc, warnings, textwrap | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Tuple | |
| import torch | |
| from torch.utils.data import Dataset | |
| from sklearn.model_selection import train_test_split | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM, | |
| AutoModelForCausalLM, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback, | |
| DataCollatorForSeq2Seq, | |
| ) | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| import gradio as gr | |
| warnings.filterwarnings("ignore") | |
| # ========================== | |
| # Config | |
| # ========================== | |
| class ModelConfig: | |
| model_name: str = "google/mt5-base" | |
| architecture: str = "seq2seq" # "seq2seq" | "causal" | |
| max_input_length: int = 1024 | |
| max_target_length: int = 512 | |
| max_new_tokens: int = 384 | |
| temperature: float = 0.7 | |
| top_p: float = 0.9 | |
| num_beams: int = 4 | |
| class RAGConfig: | |
| embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
| persist_dir: str = "./chroma_db" | |
| collection: str = "legal_articles" | |
| top_k: int = 5 | |
| similarity_threshold: float = 0.66 # 0..1 (بزرگتر=سختگیرتر) | |
| context_char_limit: int = 300 # حداکثر کاراکتر هر ماده در Context | |
| class TrainConfig: | |
| output_dir: str = "./mahoon_model" | |
| seed: int = 42 | |
| test_size: float = 0.1 | |
| epochs: int = 2 | |
| batch_size: int = 2 | |
| grad_accum: int = 2 | |
| lr: float = 3e-5 | |
| use_bf16: bool = True | |
| class SystemConfig: | |
| model: ModelConfig = field(default_factory=ModelConfig) | |
| rag: RAGConfig = field(default_factory=RAGConfig) | |
| train: TrainConfig = field(default_factory=TrainConfig) | |
| # ========================== | |
| # Utils | |
| # ========================== | |
| def set_seed_all(seed: int = 42): | |
| import random | |
| random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def read_jsonl_files(paths: List[str]) -> List[Dict]: | |
| data: List[Dict] = [] | |
| for p in paths: | |
| if not p: | |
| continue | |
| with open(p, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| s = line.strip() | |
| if not s: | |
| continue | |
| try: | |
| obj = json.loads(s) | |
| data.append(obj) | |
| except json.JSONDecodeError: | |
| continue | |
| return data | |
| # ========================== | |
| # RAG | |
| # ========================== | |
| class LegalRAG: | |
| def __init__(self, cfg: RAGConfig): | |
| self.cfg = cfg | |
| self.client = None | |
| self.collection = None | |
| self.embedder: Optional[SentenceTransformer] = None | |
| def init(self): | |
| Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True) | |
| self.client = chromadb.PersistentClient(path=self.cfg.persist_dir) | |
| # get_or_create برای سازگاری نسخههای مختلف chroma | |
| try: | |
| self.collection = self.client.get_or_create_collection(self.cfg.collection) | |
| except Exception: | |
| try: | |
| self.collection = self.client.get_collection(self.cfg.collection) | |
| except Exception: | |
| self.collection = self.client.create_collection(self.cfg.collection) | |
| self.embedder = SentenceTransformer(self.cfg.embedding_model) | |
| def retrieve(self, query: str) -> List[Dict]: | |
| if not self.collection: | |
| return [] | |
| try: | |
| res = self.collection.query( | |
| query_texts=[query], | |
| n_results=self.cfg.top_k, | |
| include=["documents","metadatas","distances"], | |
| ) | |
| out = [] | |
| for i,(doc, meta, dist) in enumerate(zip(res.get('documents',[['']])[0], res.get('metadatas',[['']])[0], res.get('distances',[[1.0]])[0])): | |
| sim = 1 - float(dist) | |
| if sim >= self.cfg.similarity_threshold: | |
| out.append({ | |
| "article_id": (meta or {}).get("article_id", f"unk_{i}"), | |
| "text": doc, | |
| "similarity": sim, | |
| }) | |
| return out | |
| except Exception: | |
| return [] | |
| def build_context(self, arts: List[Dict]) -> str: | |
| if not arts: | |
| return "" | |
| bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts] | |
| return "مواد مرتبط:\n" + "\n".join(bullets) | |
| # ========================== | |
| # Loader + Generator | |
| # ========================== | |
| class ModelLoader: | |
| def __init__(self, mcfg: ModelConfig): | |
| self.cfg = mcfg | |
| self.tokenizer = None | |
| self.model = None | |
| def load(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name) | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else None | |
| if self.cfg.architecture == "seq2seq": | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
| self.cfg.model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype | |
| ) | |
| elif self.cfg.architecture == "causal": | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.cfg.model_name, device_map="auto" if torch.cuda.is_available() else None, torch_dtype=dtype | |
| ) | |
| if self.tokenizer.pad_token is None and hasattr(self.tokenizer, 'eos_token'): | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| else: | |
| raise ValueError("Unsupported architecture") | |
| return self | |
| class Generator: | |
| def __init__(self, loader: ModelLoader, mcfg: ModelConfig): | |
| self.tk = loader.tokenizer | |
| self.model = loader.model | |
| self.cfg = mcfg | |
| def generate(self, question: str, context: str = "") -> str: | |
| if self.cfg.architecture == "seq2seq": | |
| inp = f"{context}\nسوال: {question}" if context else f"سوال: {question}" | |
| enc = self.tk(inp, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length) | |
| enc = {k: v.to(self.model.device) for k,v in enc.items()} | |
| out = self.model.generate( | |
| **enc, | |
| max_length=self.cfg.max_target_length, | |
| num_beams=self.cfg.num_beams, | |
| early_stopping=True, | |
| ) | |
| else: # causal | |
| prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:" | |
| enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length) | |
| enc = {k: v.to(self.model.device) for k,v in enc.items()} | |
| out = self.model.generate( | |
| **enc, | |
| max_new_tokens=self.cfg.max_new_tokens, | |
| do_sample=True, | |
| temperature=self.cfg.temperature, | |
| top_p=self.cfg.top_p, | |
| pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id, | |
| ) | |
| return self.tk.decode(out[0], skip_special_tokens=True) | |
| # ========================== | |
| # Datasets | |
| # ========================== | |
| class Seq2SeqJSONLDataset(Dataset): | |
| def __init__(self, data: List[Dict], tokenizer, max_inp: int, max_tgt: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10): | |
| self.tk = tokenizer | |
| self.max_inp = max_inp | |
| self.max_tgt = max_tgt | |
| self.items = [] | |
| for i, ex in enumerate(data): | |
| src = str(ex.get("input", "")).strip() | |
| tgt = str(ex.get("output", "")).strip() | |
| if not src or not tgt: | |
| continue | |
| inp = src | |
| if rag and i % enhance_every == 0: | |
| arts = rag.retrieve(src) | |
| ctx = rag.build_context(arts) | |
| if ctx: | |
| inp = f"<CONTEXT>{ctx}</CONTEXT>\n<QUESTION>{src}</QUESTION>" | |
| self.items.append((inp, tgt)) | |
| def __len__(self): | |
| return len(self.items) | |
| def __getitem__(self, idx): | |
| inp, tgt = self.items[idx] | |
| model_inputs = self.tk(inp, max_length=self.max_inp, padding="max_length", truncation=True) | |
| labels = self.tk(text_target=tgt, max_length=self.max_tgt, padding="max_length", truncation=True) | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| class CausalJSONLDataset(Dataset): | |
| def __init__(self, data: List[Dict], tokenizer, max_inp: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10): | |
| self.tk = tokenizer | |
| self.max_inp = max_inp | |
| self.items = [] | |
| for i, ex in enumerate(data): | |
| src = str(ex.get("input", "")).strip() | |
| tgt = str(ex.get("output", "")).strip() | |
| if not src or not tgt: | |
| continue | |
| ctx = "" | |
| if rag and i % enhance_every == 0: | |
| arts = rag.retrieve(src) | |
| ctx = rag.build_context(arts) | |
| text = f"{ctx}\nسوال: {src}\nپاسخ: {tgt}" if ctx else f"سوال: {src}\nپاسخ: {tgt}" | |
| self.items.append(text) | |
| def __len__(self): | |
| return len(self.items) | |
| def __getitem__(self, idx): | |
| text = self.items[idx] | |
| enc = self.tk(text, max_length=self.max_inp, padding="max_length", truncation=True) | |
| input_ids = torch.tensor(enc["input_ids"]) | |
| return {"input_ids": input_ids, "attention_mask": torch.tensor(enc["attention_mask"]), "labels": input_ids.clone()} | |
| # ========================== | |
| # Trainer Manager | |
| # ========================== | |
| class TrainerManager: | |
| def __init__(self, syscfg: SystemConfig, loader: ModelLoader): | |
| self.cfg = syscfg | |
| self.loader = loader | |
| def train_seq2seq(self, train_paths: List[str], use_rag: bool = True): | |
| set_seed_all(self.cfg.train.seed) | |
| data = read_jsonl_files(train_paths) | |
| train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) | |
| rag = LegalRAG(self.cfg.rag) if use_rag else None | |
| if rag: | |
| rag.init() | |
| ds_tr = Seq2SeqJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, rag) | |
| ds_va = Seq2SeqJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, None) | |
| collator = DataCollatorForSeq2Seq(tokenizer=self.loader.tokenizer, model=self.loader.model) | |
| fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16) | |
| bf16_ok = torch.cuda.is_available() and self.cfg.train.use_bf16 | |
| args = TrainingArguments( | |
| output_dir=self.cfg.train.output_dir, | |
| num_train_epochs=self.cfg.train.epochs, | |
| learning_rate=self.cfg.train.lr, | |
| per_device_train_batch_size=self.cfg.train.batch_size, | |
| per_device_eval_batch_size=self.cfg.train.batch_size, | |
| gradient_accumulation_steps=self.cfg.train.grad_accum, | |
| warmup_ratio=0.05, | |
| weight_decay=0.01, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| predict_with_generate=True, | |
| generation_max_length=self.cfg.model.max_target_length, | |
| generation_num_beams=self.cfg.model.num_beams, | |
| logging_steps=50, | |
| report_to="none", | |
| fp16=fp16_ok, | |
| bf16=bf16_ok, | |
| ) | |
| trainer = Trainer( | |
| model=self.loader.model, | |
| args=args, | |
| train_dataset=ds_tr, | |
| eval_dataset=ds_va, | |
| data_collator=collator, | |
| tokenizer=self.loader.tokenizer, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], | |
| ) | |
| trainer.train() | |
| trainer.save_model(self.cfg.train.output_dir) | |
| self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) | |
| def train_causal(self, train_paths: List[str], use_rag: bool = True): | |
| set_seed_all(self.cfg.train.seed) | |
| data = read_jsonl_files(train_paths) | |
| train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed) | |
| rag = LegalRAG(self.cfg.rag) if use_rag else None | |
| if rag: | |
| rag.init() | |
| ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, rag) | |
| ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, None) | |
| fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16) | |
| bf16_ok = torch.cuda.is_available() and self.cfg.train.use_bf16 | |
| args = TrainingArguments( | |
| output_dir=self.cfg.train.output_dir, | |
| num_train_epochs=self.cfg.train.epochs, | |
| learning_rate=self.cfg.train.lr, | |
| per_device_train_batch_size=self.cfg.train.batch_size, | |
| per_device_eval_batch_size=self.cfg.train.batch_size, | |
| gradient_accumulation_steps=self.cfg.train.grad_accum, | |
| warmup_ratio=0.05, | |
| weight_decay=0.01, | |
| evaluation_strategy="epoch", | |
| save_strategy="epoch", | |
| save_total_limit=2, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| logging_steps=50, | |
| report_to="none", | |
| fp16=fp16_ok, | |
| bf16=bf16_ok, | |
| ) | |
| trainer = Trainer( | |
| model=self.loader.model, | |
| args=args, | |
| train_dataset=ds_tr, | |
| eval_dataset=ds_va, | |
| tokenizer=self.loader.tokenizer, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=2)], | |
| ) | |
| trainer.train() | |
| trainer.save_model(self.cfg.train.output_dir) | |
| self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir) | |
| # ========================== | |
| # App (Gradio) | |
| # ========================== | |
| class LegalApp: | |
| def __init__(self, scfg: Optional[SystemConfig] = None): | |
| self.scfg = scfg or SystemConfig() | |
| self.rag = LegalRAG(self.scfg.rag) | |
| self.loader: Optional[ModelLoader] = None | |
| self.gen: Optional[Generator] = None | |
| # --- core actions --- | |
| def load(self, model_name: str, arch: str, use_rag: bool, persist_dir: str, collection: str, top_k: int, threshold: float): | |
| # configure | |
| self.scfg.model.model_name = model_name | |
| self.scfg.model.architecture = arch | |
| self.scfg.rag.persist_dir = persist_dir | |
| self.scfg.rag.collection = collection | |
| self.scfg.rag.top_k = int(top_k) | |
| self.scfg.rag.similarity_threshold = float(threshold) | |
| # load model | |
| self.loader = ModelLoader(self.scfg.model).load() | |
| self.gen = Generator(self.loader, self.scfg.model) | |
| # load rag | |
| msg_rag = "RAG غیر فعال" | |
| if use_rag: | |
| try: | |
| self.rag = LegalRAG(self.scfg.rag) | |
| self.rag.init() | |
| msg_rag = "RAG آماده است" | |
| except Exception as e: | |
| msg_rag = f"RAG خطا: {e}" | |
| return f"مدل بارگذاری شد: {model_name} ({arch})\n{msg_rag}" | |
| def answer(self, question: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float, num_beams: int): | |
| if not question.strip(): | |
| return "لطفاً سوال خود را وارد کنید.", "" | |
| if not self.gen: | |
| return "ابتدا مدل/RAG را بارگذاری کنید.", "" | |
| # update runtime params | |
| self.scfg.model.max_new_tokens = int(max_new_tokens) | |
| self.scfg.model.temperature = float(temperature) | |
| self.scfg.model.top_p = float(top_p) | |
| self.scfg.model.num_beams = int(num_beams) | |
| arts = self.rag.retrieve(question) if (use_rag and self.rag.collection) else [] | |
| ctx = self.rag.build_context(arts) if arts else "" | |
| ans = self.gen.generate(question, ctx) | |
| refs = "" | |
| if arts: | |
| refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts]) | |
| return ans, refs | |
| def train(self, model_name: str, arch: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float): | |
| self.scfg.model.model_name = model_name | |
| self.scfg.model.architecture = arch | |
| self.scfg.train.epochs = int(epochs) | |
| self.scfg.train.batch_size = int(batch) | |
| self.scfg.train.lr = float(lr) | |
| # ensure loader | |
| self.loader = ModelLoader(self.scfg.model).load() | |
| # train | |
| paths = [f.name for f in files] if files else [] | |
| tm = TrainerManager(self.scfg, self.loader) | |
| if arch == "seq2seq": | |
| tm.train_seq2seq(paths, use_rag=use_rag) | |
| else: | |
| tm.train_causal(paths, use_rag=use_rag) | |
| return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد." | |
| # --- UI --- | |
| def build_ui(self): | |
| default_models = { | |
| "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"), | |
| "Seq2Seq (t5-fa-base)": ("HooshvareLab/t5-fa-base", "seq2seq"), | |
| "Seq2Seq (flan-t5-base)": ("google/flan-t5-base", "seq2seq"), | |
| "Causal (Mistral-7B Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"), | |
| } | |
| with gr.Blocks(title="ماحون — مشاور حقوقی هوشمند", theme=gr.themes.Soft(primary_hue="green", secondary_hue="gray")) as app: | |
| gr.HTML(""" | |
| <div style='text-align:center;padding:18px'> | |
| <h1 style='margin-bottom:4px'>ماحون — Ultimate Legal AI</h1> | |
| <p style='color:#666'>RAG • Seq2Seq/Causal • Training • Polished UI</p> | |
| </div> | |
| """) | |
| with gr.Tab("مشاوره"): | |
| with gr.Row(): | |
| model_dd = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل") | |
| arch_info = gr.Markdown("""**راهنما:** مدلهای Seq2Seq (MT5/T5) برای پاسخهای ساختاریافته عالیاند؛ مدلهای Causal (Mistral) برای مکالمه طبیعیترند.""") | |
| with gr.Row(): | |
| use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟") | |
| persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر پایگاه ChromaDB") | |
| collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن") | |
| with gr.Row(): | |
| top_k = gr.Slider(1, 10, value=self.scfg.rag.top_k, step=1, label="Top‑K") | |
| threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="حد آستانه شباهت") | |
| load_btn = gr.Button("بارگذاری مدل/RAG", variant="primary") | |
| status = gr.Textbox(label="وضعیت", interactive=False) | |
| with gr.Accordion("پارامترهای تولید", open=False): | |
| max_new_tokens = gr.Slider(64, 1024, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens") | |
| temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p") | |
| num_beams = gr.Slider(1, 8, value=self.scfg.model.num_beams, step=1, label="num_beams (Seq2Seq)") | |
| question = gr.Textbox(lines=3, label="سوال حقوقی") | |
| examples = gr.Examples([ | |
| ["در صورت نقض قرارداد فروش، چه اقداماتی باید انجام دهم؟"], | |
| ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"], | |
| ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"], | |
| ["فرآیند طرح دعوای مطالبه مهریه چگونه است؟"], | |
| ], inputs=question, label="نمونه پرسشها") | |
| ask_btn = gr.Button("پرسش", variant="primary") | |
| answer = gr.Markdown(label="پاسخ") | |
| refs = gr.Markdown(label="مواد قانونی مرتبط") | |
| with gr.Tab("آموزش"): | |
| gr.Markdown("برای آموزش، فایلهای JSONL شامل کلیدهای `input` و `output` را بارگذاری کنید.") | |
| with gr.Row(): | |
| model_dd_train = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل") | |
| use_rag_train = gr.Checkbox(value=True, label="RAG‑enhanced Training") | |
| train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"]) | |
| with gr.Row(): | |
| epochs = gr.Slider(1, 6, value=self.scfg.train.epochs, step=1, label="epochs") | |
| batch = gr.Slider(1, 8, value=self.scfg.train.batch_size, step=1, label="batch per device") | |
| lr = gr.Number(value=self.scfg.train.lr, label="learning rate") | |
| train_btn = gr.Button("شروع آموزش", variant="primary") | |
| train_status = gr.Textbox(label="وضعیت آموزش", interactive=False) | |
| # Events | |
| def _resolve(choice: str) -> Tuple[str,str]: | |
| return default_models[choice] | |
| load_btn.click(lambda choice, rag, pdir, coll, k, th: self.load(*_resolve(choice), rag, pdir, coll, k, th), | |
| inputs=[model_dd, use_rag, persist_dir, collection, top_k, threshold], outputs=status) | |
| ask_btn.click(lambda q, rag, mnt, t, p, nb: self.answer(q, rag, mnt, t, p, nb), | |
| inputs=[question, use_rag, max_new_tokens, temperature, top_p, num_beams], outputs=[answer, refs]) | |
| train_btn.click(lambda choice, files, rag, e, b, l: self.train(*_resolve(choice), files, rag, e, b, l), | |
| inputs=[model_dd_train, train_files, use_rag_train, epochs, batch, lr], outputs=train_status) | |
| return app | |
| # ========================== | |
| # Entrypoint | |
| # ========================== | |
| if __name__ == "__main__": | |
| app = LegalApp() | |
| ui = app.build_ui() | |
| ui.launch(share=True) |