mahoon-legal-ai / app.py
hajimammad's picture
Rename app(6).py to app.py
f90240d verified
raw
history blame
28.6 kB
# -*- coding: utf-8 -*-
"""
Ultimate Legal AI System: Fine-tuning + RAG Integration
سیستم هوش مصنوعی حقوقی ماحون:
1. فاین‌تیون با 700K رای قضایی
2. RAG با 8000 ماده قانونی (chroma.sqlite3)
3. ترکیب برای بهترین مشاوره حقوقی
بهبودها:
- پشتیبانی از Seq2Seq (T5/MT5) و Causal (Mistral).
- کوانتایز 4-bit با bitsandbytes.
- لاگینگ با logging و مانیتورینگ با wandb.
- بهینه‌سازی برای Hugging Face Spaces.
"""
import os
import json
import logging
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
import chromadb
from chromadb.config import Settings
# Fine-tuning imports
from transformers import (
AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM,
TrainingArguments, Trainer, EarlyStoppingCallback,
DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
# RAG imports
from sentence_transformers import SentenceTransformer
# UI and monitoring
import gradio as gr
import wandb
from bitsandbytes import quantize_model # برای کوانتایز
# تنظیم لاگینگ
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ==============================================================================
# تنظیمات سیستم کامل
# ==============================================================================
@dataclass
class UltimateLegalSystemConfig:
# Fine-tuning settings
model_name: str = "google/mt5-base" # Compatible with T4 GPU
architecture: str = "seq2seq" # "seq2seq" یا "causal"
max_input_length: int = 1024
max_target_length: int = 512
batch_size_per_gpu: int = 4 # Reduced for T4 stability
num_gpus: int = 1
learning_rate: float = 5e-5
num_epochs: int = 3
quantize_bits: int = 4 # 4-bit quantization
use_wandb: bool = True # WandB monitoring
# RAG settings
embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
chroma_db_path: str = "./chroma_db"
top_k_retrieval: int = 5
similarity_threshold: float = 0.7
# System paths
finetuned_model_path: str = "./ultimate_legal_model"
legal_db_path: str = "./legal_database"
cache_dir: str = "./cache"
# Generation settings
max_new_tokens: int = 512
temperature: float = 0.7
do_sample: bool = True
top_p: float = 0.9
# ==============================================================================
# سیستم RAG برای قوانین
# ==============================================================================
class LegalRAGSystem:
def __init__(self, config: UltimateLegalSystemConfig):
self.config = config
self.embedding_model = None
self.chroma_client = None
self.collection = None
self.legal_articles = {}
def setup_embedding_model(self):
logger.info("📚 بارگیری مدل embedding...")
try:
self.embedding_model = SentenceTransformer(
self.config.embedding_model,
cache_folder=self.config.cache_dir
)
logger.info("✅ مدل embedding آماده شد")
return True
except Exception as e:
logger.error(f"❌ خطا در بارگیری مدل embedding: {e}")
return False
def load_legal_database(self):
logger.info("📚 بارگیری دیتابیس قوانین...")
try:
db_path = Path(self.config.chroma_db_path)
db_path.mkdir(parents=True, exist_ok=True)
self.chroma_client = chromadb.PersistentClient(path=str(db_path))
try:
self.collection = self.chroma_client.get_collection(name="legal_articles")
logger.info(f"✅ Collection موجود بارگذاری شد")
except Exception:
logger.warning("⚠️ Collection موجود نیست، ایجاد می‌شود...")
self.collection = self.chroma_client.create_collection(
name="legal_articles",
metadata={"description": "مواد قانونی ایران"}
)
count = self.collection.count()
logger.info(f"📊 تعداد مواد قانونی: {count:,}")
if count == 0:
logger.warning("⚠️ دیتابیس خالی است! لطفاً قوانین را اضافه کنید.")
return False
return True
except Exception as e:
logger.error(f"❌ خطا در بارگیری دیتابیس: {e}")
return False
def add_legal_articles_to_db(self, legal_texts: Dict[str, str]):
logger.info("📝 اضافه کردن مواد قانونی به دیتابیس...")
if not self.collection:
logger.error("❌ Collection موجود نیست!")
return False
try:
documents, metadatas, ids = [], [], []
for article_id, text in legal_texts.items():
documents.append(text)
metadatas.append({"article_id": str(article_id), "type": "legal_article", "source": "official_laws"})
ids.append(f"article_{article_id}")
batch_size = 100
for i in range(0, len(documents), batch_size):
batch_end = min(i + batch_size, len(documents))
self.collection.add(
documents=documents[i:batch_end],
metadatas=metadatas[i:batch_end],
ids=ids[i:batch_end]
)
logger.info(f" Added batch {i//batch_size + 1}/{(len(documents)-1)//batch_size + 1}")
logger.info(f"✅ {len(documents)} ماده قانونی اضافه شد")
return True
except Exception as e:
logger.error(f"❌ خطا در اضافه کردن مواد: {e}")
return False
def retrieve_relevant_articles(self, query: str) -> List[Dict]:
if not self.collection:
return []
try:
results = self.collection.query(
query_texts=[query],
n_results=min(self.config.top_k_retrieval, self.collection.count()),
include=["documents", "metadatas", "distances"]
)
relevant_articles = []
if results and results['documents'] and len(results['documents'][0]) > 0:
for i, (doc, metadata, distance) in enumerate(zip(
results['documents'][0], results['metadatas'][0], results['distances'][0]
)):
similarity = 1 - (distance / 2)
if similarity >= self.config.similarity_threshold:
relevant_articles.append({
"article_id": metadata.get("article_id", f"unknown_{i}"),
"text": doc,
"similarity": similarity,
"source": metadata.get("source", "unknown")
})
return relevant_articles[:self.config.top_k_retrieval]
except Exception as e:
logger.warning(f"⚠️ خطا در جستجو: {e}")
return []
# ==============================================================================
# سیستم Fine-tuning بهینه شده
# ==============================================================================
class LegalDatasetWithRAG(Dataset):
def __init__(self, data: List[Dict], tokenizer, config: UltimateLegalSystemConfig, rag_system: Optional[LegalRAGSystem] = None):
self.tokenizer = tokenizer
self.config = config
self.rag_system = rag_system
logger.info(f"📊 آماده‌سازی {len(data)} نمونه با RAG enhancement...")
self.processed_data = self._enhance_with_rag(data)
def _enhance_with_rag(self, data: List[Dict]) -> List[Dict]:
enhanced_data = []
for i, item in enumerate(data):
if i % 1000 == 0:
logger.info(f" RAG Enhancement: {i}/{len(data)}")
question = item.get("questionTitle", item.get("question", ""))
answers = item.get("answers", item.get("answer", []))
if isinstance(answers, list) and len(answers) > 0:
answer = answers[0]
elif isinstance(answers, str):
answer = answers
else:
continue
if len(question.split()) < 3 or len(answer.split()) < 3:
continue
enhanced_item = {
"original_question": question,
"original_answer": answer,
"enhanced_input": question,
"enhanced_output": answer
}
if self.rag_system and i % 10 == 0:
try:
relevant_articles = self.rag_system.retrieve_relevant_articles(question)
if relevant_articles:
context = "\n".join([f"ماده {art['article_id']}: {art['text'][:200]}..." for art in relevant_articles[:2]])
enhanced_item["enhanced_input"] = f"سوال: {question}\nمواد مرتبط: {context}"
except Exception as e:
logger.warning(f"⚠️ خطا در enhancement RAG برای نمونه {i}: {e}")
enhanced_data.append(enhanced_item)
logger.info(f"✅ {len(enhanced_data)} نمونه آماده شد")
return enhanced_data
def __len__(self):
return len(self.processed_data)
def __getitem__(self, idx):
item = self.processed_data[idx]
model_inputs = self.tokenizer(
item["enhanced_input"],
max_length=self.config.max_input_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
with self.tokenizer.as_target_tokenizer():
labels = self.tokenizer(
item["enhanced_output"],
max_length=self.config.max_target_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
labels[labels == self.tokenizer.pad_token_id] = -100
return {
"input_ids": model_inputs.input_ids.squeeze(),
"attention_mask": model_inputs.attention_mask.squeeze(),
"labels": labels.squeeze()
}
# ==============================================================================
# سیستم یکپارچه Fine-tuning + RAG
# ==============================================================================
class UltimateLegalAI:
def __init__(self, config: UltimateLegalSystemConfig):
self.config = config
self.rag_system = LegalRAGSystem(config)
self.finetuned_model = None
self.tokenizer = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def setup_rag_system(self, legal_texts: Optional[Dict[str, str]] = None):
logger.info("🚀 راه‌اندازی سیستم RAG...")
if not self.rag_system.setup_embedding_model():
return False
if not self.rag_system.load_legal_database():
if legal_texts:
logger.info("\n📝 ایجاد دیتابیس جدید از متون قانونی...")
return self.rag_system.add_legal_articles_to_db(legal_texts)
else:
logger.warning("\n⚠️ هیچ دیتابیس قانونی موجود نیست!")
return False
return True
def finetune_model(self, jsonl_files: List[str], sample_size: Optional[int] = 10000):
logger.info("🎯 شروع فاین‌تیون با RAG enhancement...")
all_data = []
for jsonl_file in jsonl_files:
if not os.path.exists(jsonl_file):
logger.warning(f"⚠️ فایل {jsonl_file} موجود نیست")
continue
with open(jsonl_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
try:
all_data.append(json.loads(line))
except json.JSONDecodeError:
continue
if not all_data:
logger.error("❌ هیچ داده‌ای بارگیری نشد!")
return False
logger.info(f"📊 مجموع داده‌ها: {len(all_data):,} نمونه")
if sample_size and sample_size < len(all_data):
all_data = all_data[:sample_size]
logger.info(f"📊 استفاده از {sample_size} نمونه برای آموزش")
logger.info("\n🤖 بارگیری مدل پایه...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.info("🔄 اعمال کوانتایز 4-bit...")
if self.config.architecture == "seq2seq":
model = AutoModelForSeq2SeqLM.from_pretrained(
self.config.model_name,
load_in_4bit=True if self.config.quantize_bits == 4 else False,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
else: # causal
model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
load_in_4bit=True if self.config.quantize_bits == 4 else False,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
except Exception as e:
logger.error(f"❌ خطا در بارگیری مدل: {e}")
return False
if self.config.use_wandb:
wandb.init(project="mahoon-legal-ai", config={
"epochs": self.config.num_epochs,
"batch_size": self.config.batch_size_per_gpu * self.config.num_gpus,
"lr": self.config.learning_rate
})
train_data, val_data = train_test_split(all_data, test_size=0.1, random_state=42)
train_dataset = LegalDatasetWithRAG(train_data, self.tokenizer, self.config, self.rag_system)
val_dataset = LegalDatasetWithRAG(val_data, self.tokenizer, self.config, None)
training_args = TrainingArguments(
output_dir=self.config.finetuned_model_path,
per_device_train_batch_size=self.config.batch_size_per_gpu,
per_device_eval_batch_size=self.config.batch_size_per_gpu,
gradient_accumulation_steps=4,
num_train_epochs=self.config.num_epochs,
learning_rate=self.config.learning_rate,
fp16=torch.cuda.is_available(),
gradient_checkpointing=True,
eval_strategy="steps", # تغییر از evaluation_strategy به eval_strategy
eval_steps=500,
save_steps=1000,
save_total_limit=2,
logging_steps=100,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
report_to="wandb" if self.config.use_wandb else "none",
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
remove_unused_columns=False
)
data_collator = DataCollatorForSeq2Seq(
tokenizer=self.tokenizer,
model=model,
label_pad_token_id=-100,
pad_to_multiple_of=8
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=self.tokenizer,
data_collator=data_collator,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)
logger.info("🚀 شروع آموزش مدل...")
try:
trainer.train()
trainer.save_model()
self.tokenizer.save_pretrained(self.config.finetuned_model_path)
if self.config.use_wandb:
wandb.finish()
logger.info("\n✅ آموزش با موفقیت انجام شد!")
return True
except Exception as e:
logger.error(f"❌ خطا در آموزش: {e}")
if self.config.use_wandb:
wandb.finish()
return False
def load_finetuned_model(self):
logger.info("📥 بارگیری مدل فاین‌تیون شده...")
try:
model_path = Path(self.config.finetuned_model_path)
if not model_path.exists():
logger.error(f"❌ مدل در مسیر {model_path} موجود نیست")
return False
self.tokenizer = AutoTokenizer.from_pretrained(str(model_path))
logger.info("🔄 اعمال کوانتایز 4-bit...")
if self.config.architecture == "seq2seq":
self.finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(
str(model_path),
load_in_4bit=True if self.config.quantize_bits == 4 else False,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
else: # causal
self.finetuned_model = AutoModelForCausalLM.from_pretrained(
str(model_path),
load_in_4bit=True if self.config.quantize_bits == 4 else False,
device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.finetuned_model.to(self.device)
logger.info("✅ مدل آماده است!")
return True
except Exception as e:
logger.error(f"❌ خطا در بارگیری مدل: {e}")
return False
def generate_legal_advice(self, question: str) -> Dict:
if not self.finetuned_model:
return {"error": "مدل بارگذاری نشده است"}
try:
relevant_articles = self.rag_system.retrieve_relevant_articles(question)
context = ""
if relevant_articles:
context = "مواد مرتبط:\n"
for art in relevant_articles[:3]:
context += f"• ماده {art['article_id']}: {art['text'][:200]}...\n"
context += "\n"
full_prompt = f"{context}سوال: {question}\n\nپاسخ:"
inputs = self.tokenizer.encode(
full_prompt,
return_tensors="pt",
max_length=self.config.max_input_length,
truncation=True
).to(self.device)
with torch.no_grad():
outputs = self.finetuned_model.generate(
inputs,
max_new_tokens=self.config.max_new_tokens,
temperature=self.config.temperature,
do_sample=self.config.do_sample,
top_p=self.config.top_p,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
num_beams=4,
early_stopping=True
)
generated_answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return {
"question": question,
"answer": generated_answer,
"relevant_articles": relevant_articles,
"context_used": bool(relevant_articles)
}
except Exception as e:
logger.warning(f"⚠️ خطا در تولید پاسخ: {e}")
return {"error": f"خطا در تولید پاسخ: {str(e)}"}
# ==============================================================================
# رابط کاربری Gradio
# ==============================================================================
def create_gradio_interface(legal_ai: UltimateLegalAI):
def process_question(question):
if not question.strip():
return "لطفاً سوال خود را وارد کنید.", ""
result = legal_ai.generate_legal_advice(question)
if "error" in result:
return f"⚠️ {result['error']}", ""
answer = f"## 📝 پاسخ:\n\n{result['answer']}\n\n"
references = ""
if result.get('relevant_articles'):
references = "## 📚 مواد قانونی مرتبط:\n\n"
for art in result['relevant_articles'][:3]:
similarity_percent = art['similarity'] * 100
references += f"### **ماده {art['article_id']}** (شباهت: {similarity_percent:.0f}%)\n"
references += f"{art['text'][:300]}...\n\n---\n\n"
return answer, references
with gr.Blocks(
title="سیستم هوش مصنوعی حقوقی",
theme=gr.themes.Soft(primary_hue="green"),
css="""
.gradio-container { font-family: 'Vazir', 'B Nazanin', 'Tahoma', sans-serif; }
"""
) as interface:
gr.HTML("""
<div style="text-align: center; padding: 30px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px;">
<h1 style="margin: 0; font-size: 2.5em;">⚖️ سیستم هوش مصنوعی حقوقی پیشرفته</h1>
<p style="margin-top: 10px; font-size: 1.1em; opacity: 0.95;">
مبتنی بر 700,000 رای قضایی + 8,000 ماده قانونی
</p>
</div>
""")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="💭 سوال حقوقی خود را بپرسید",
placeholder="مثال: شرایط طلاق توافقی چیست؟",
lines=4, rtl=True
)
submit_btn = gr.Button("🔍 دریافت مشاوره", variant="primary")
gr.Examples(
examples=[
["شرایط طلاق توافقی چیست؟"],
["مراحل ثبت شرکت با مسئولیت محدود چگونه است؟"],
["در صورت نقض قرارداد اجاره چه اقداماتی می‌توان انجام داد؟"],
["حقوق کارگر در صورت اخراج غیرقانونی چیست؟"],
["نحوه تنظیم وصیت‌نامه رسمی چگونه است؟"]
],
inputs=question_input
)
with gr.Row():
with gr.Column():
answer_output = gr.Markdown(label="پاسخ مشاور", rtl=True)
with gr.Column():
references_output = gr.Markdown(label="مراجع قانونی", rtl=True)
submit_btn.click(fn=process_question, inputs=question_input, outputs=[answer_output, references_output])
gr.HTML("""
<div style="text-align: center; padding: 20px; margin-top: 30px; background: #f0f0f0; border-radius: 10px;">
<p style="color: #666; font-size: 0.9em;">⚠️ توجه: این سیستم صرفاً جنبه مشاوره‌ای دارد و جایگزین مشاور حقوقی نیست</p>
</div>
""")
return interface
# ==============================================================================
# توابع اصلی اجرا
# ==============================================================================
def main():
logger.info("🚀 راه‌اندازی سیستم هوش مصنوعی حقوقی نهایی...")
logger.info("=" * 60)
config = UltimateLegalSystemConfig()
legal_ai = UltimateLegalAI(config)
logger.info("\n1️⃣ راه‌اندازی سیستم RAG...")
if not legal_ai.setup_rag_system():
logger.error("❌ خطا در راه‌اندازی RAG")
logger.info("لطفاً دیتابیس قوانین را اضافه کنید")
return None
logger.info("\n2️⃣ بارگیری مدل فاین‌تیون شده...")
if not legal_ai.load_finetuned_model():
logger.warning("⚠️ مدل فاین‌تیون شده موجود نیست.")
logger.info("برای آموزش مدل از تابع train_model() استفاده کنید")
return legal_ai
logger.info("\n3️⃣ راه‌اندازی رابط کاربری...")
interface = create_gradio_interface(legal_ai)
logger.info("\n✅ سیستم آماده است!")
logger.info("=" * 60)
interface.launch(share=True, server_name="0.0.0.0", server_port=7860)
return legal_ai
def train_model(jsonl_files: List[str], sample_size: Optional[int] = 1000):
logger.info("🎓 شروع فرآیند آموزش مدل...")
logger.info("=" * 60)
config = UltimateLegalSystemConfig()
legal_ai = UltimateLegalAI(config)
logger.info("\n📚 راه‌اندازی سیستم RAG...")
legal_ai.setup_rag_system()
logger.info("\n🔧 شروع Fine-tuning...")
success = legal_ai.finetune_model(jsonl_files, sample_size)
if success:
logger.info("\n✅ آموزش با موفقیت انجام شد!")
logger.info(f"مدل در مسیر {config.finetuned_model_path} ذخیره شد")
else:
logger.error("\n❌ آموزش با خطا مواجه شد")
return legal_ai
# ==============================================================================
# نمونه استفاده
# ==============================================================================
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "train":
jsonl_files = ["data/train.jsonl", "data/valid.jsonl"] # سازگار با ساختار
train_model(jsonl_files, sample_size=1000)
else:
legal_ai = main()
print("""
╔══════════════════════════════════════════════════════════════╗
║ دستورالعمل استفاده از سیستم ║
╠══════════════════════════════════════════════════════════════╣
║ ║
║ 1️⃣ برای آموزش مدل: ║
║ python app.py train ║
║ ║
║ 2️⃣ برای اجرای سیستم: ║
║ python app.py ║
║ ║
║ 3️⃣ ویژگی‌ها: ║
║ ✅ Fine-tuning با 700K داده ║
║ ✅ RAG با 8000 ماده قانونی ║
║ ✅ رابط Gradio با طراحی شیک ║
║ ✅ پشتیبانی از Seq2Seq و Causal ║
║ ✅ کوانتایز 4-bit برای بهینه‌سازی ║
║ ║
╚══════════════════════════════════════════════════════════════╝
📌 نکات:
- فایل‌های JSONL باید در ./data باشند
- ChromaDB در ./chroma_db ذخیره می‌شود
- مدل در ./ultimate_legal_model ذخیره می‌شود
- برای HF، از T4 GPU استفاده کن
""")