hajimammad commited on
Commit
0f55eea
·
verified ·
1 Parent(s): fb092cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +582 -1477
app.py CHANGED
@@ -1,38 +1,35 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Mahoon Legal AI Enhanced Version
4
- Features:
5
- - Improved memory management and resource cleanup
6
- - Caching system for models and embeddings
7
- - Enhanced security and input validation
8
- - Better error handling and logging
9
- - Metrics and monitoring
10
- - Thread safety improvements
11
- - Configuration validation with Pydantic
12
- - Comprehensive testing support
13
- - Gradio UI with advanced features
 
 
14
  """
15
 
16
  from __future__ import annotations
17
- import os
18
- import json
19
- import warnings
20
- import hashlib
21
- import threading
22
- import time
23
- from contextlib import contextmanager
24
  from dataclasses import dataclass, field
25
  from pathlib import Path
26
- from typing import List, Dict, Optional, Tuple, Any, Union
27
- from datetime import datetime
28
- from functools import lru_cache
29
- import logging
30
 
 
31
  import torch
32
  from torch.utils.data import Dataset
33
  from sklearn.model_selection import train_test_split
34
- from pydantic import BaseModel, validator, Field
35
 
 
 
 
 
36
  from transformers import (
37
  AutoTokenizer,
38
  AutoModelForSeq2SeqLM,
@@ -41,1576 +38,684 @@ from transformers import (
41
  TrainingArguments,
42
  EarlyStoppingCallback,
43
  DataCollatorForSeq2Seq,
44
- TrainerCallback
45
  )
46
 
 
47
  import chromadb
48
  from sentence_transformers import SentenceTransformer
49
- import gradio as gr
50
 
51
- warnings.filterwarnings("ignore")
 
 
 
 
52
 
53
- # Configure logging
54
- logging.basicConfig(
55
- level=logging.INFO,
56
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
57
- )
58
- logger = logging.getLogger(__name__)
59
 
60
  # ==========================
61
- # Enhanced Config with Validation
62
  # ==========================
63
- class ModelConfig(BaseModel):
64
- model_name: str = "persiannlp/parsi-t5-base"
65
- architecture: str = "seq2seq"
66
- max_input_length: int = Field(default=1024, ge=64, le=4096)
67
- max_target_length: int = Field(default=512, ge=32, le=2048)
68
- max_new_tokens: int = Field(default=512, ge=32, le=1024)
69
- temperature: float = Field(default=0.7, ge=0.0, le=2.0)
70
- top_p: float = Field(default=0.9, ge=0.1, le=1.0)
71
- num_beams: int = Field(default=4, ge=1, le=8)
72
- use_bf16: bool = True
73
-
74
- @validator('architecture')
75
- def validate_architecture(cls, v):
76
- if v not in ['seq2seq', 'causal']:
77
- raise ValueError('architecture must be seq2seq or causal')
78
- return v
79
-
80
- class Config:
81
- validate_assignment = True
82
 
83
- class SystemConfig(BaseModel):
84
- model: ModelConfig = Field(default_factory=ModelConfig)
85
  embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
86
- chroma_db_path: str = "./chroma.sqlite3"
87
- top_k_retrieval: int = Field(default=5, ge=1, le=20)
88
- similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
89
- cache_dir: str = "./cache"
90
- output_dir: str = "./mahoon_legal_model"
 
 
 
 
 
91
  seed: int = 42
92
- train_test_ratio: float = Field(default=0.1, ge=0.05, le=0.3)
93
- batch_size: int = Field(default=2, ge=1, le=16)
94
- grad_accum: int = Field(default=2, ge=1, le=8)
95
- epochs: int = Field(default=2, ge=1, le=10)
96
- lr: float = Field(default=3e-5, ge=1e-6, le=1e-3)
97
- max_file_size_mb: int = Field(default=10, ge=1, le=100)
98
- max_lines_per_file: int = Field(default=10000, ge=100, le=100000)
99
- request_timeout: int = Field(default=30, ge=5, le=300)
100
-
101
- class Config:
102
- validate_assignment = True
 
 
 
103
 
104
- # ==========================
105
- # Metrics and Monitoring
106
- # ==========================
107
  @dataclass
108
- class SystemMetrics:
109
- requests_count: int = 0
110
- avg_response_time: float = 0.0
111
- error_count: int = 0
112
- success_count: int = 0
113
- memory_usage_mb: float = 0.0
114
- last_updated: datetime = field(default_factory=datetime.now)
115
- active_models: List[str] = field(default_factory=list)
116
-
117
- class MetricsCollector:
118
- def __init__(self):
119
- self.metrics = SystemMetrics()
120
- self._lock = threading.Lock()
121
-
122
- def record_request(self, response_time: float, success: bool = True):
123
- with self._lock:
124
- self.metrics.requests_count += 1
125
- if success:
126
- self.metrics.success_count += 1
127
- else:
128
- self.metrics.error_count += 1
129
-
130
- # Update average response time
131
- total_requests = self.metrics.requests_count
132
- old_avg = self.metrics.avg_response_time
133
- self.metrics.avg_response_time = (old_avg * (total_requests - 1) + response_time) / total_requests
134
- self.metrics.last_updated = datetime.now()
135
-
136
- def update_memory_usage(self):
137
- if torch.cuda.is_available():
138
- memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
139
- self.metrics.memory_usage_mb = memory_mb
140
-
141
- def get_metrics(self) -> Dict[str, Any]:
142
- with self._lock:
143
- return {
144
- "requests_total": self.metrics.requests_count,
145
- "success_rate": self.metrics.success_count / max(self.metrics.requests_count, 1) * 100,
146
- "avg_response_time": round(self.metrics.avg_response_time, 2),
147
- "error_count": self.metrics.error_count,
148
- "memory_usage_mb": round(self.metrics.memory_usage_mb, 2),
149
- "active_models": self.metrics.active_models.copy(),
150
- "last_updated": self.metrics.last_updated.isoformat()
151
- }
152
-
153
- # Global metrics instance
154
- metrics = MetricsCollector()
155
 
156
  # ==========================
157
- # Enhanced Utilities
158
  # ==========================
159
  def set_seed_all(seed: int = 42):
160
  import random
161
  random.seed(seed)
 
162
  torch.manual_seed(seed)
163
- torch.cuda.manual_seed_all(seed)
164
- logger.info(f"Set random seed to {seed}")
165
 
166
- def validate_file_security(file_path: str, max_size_mb: int = 10, max_lines: int = 10000) -> Tuple[bool, str]:
167
- """Enhanced file validation with security checks"""
168
  try:
169
- path = Path(file_path)
170
-
171
- # Check if file exists and is readable
172
- if not path.exists() or not path.is_file():
173
- return False, "فایل وجود ندارد یا قابل خواندن نیست"
174
-
175
- # Check file extension
176
- if path.suffix.lower() != '.jsonl':
177
- return False, "فقط فایل‌های .jsonl پذیرفته می‌شوند"
178
-
179
- # Check file size
180
- size_mb = path.stat().st_size / (1024 * 1024)
181
- if size_mb > max_size_mb:
182
- return False, f"حجم فایل نباید از {max_size_mb} مگابایت بیشتر باشد"
183
-
184
- # Validate content structure
185
- line_count = 0
186
- with open(path, 'r', encoding='utf-8') as f:
187
- for line_num, line in enumerate(f, 1):
188
- line = line.strip()
189
- if not line:
190
- continue
191
-
192
- line_count += 1
193
- if line_count > max_lines:
194
- return False, f"فایل نباید بیش از {max_lines} خط داشته باشد"
195
-
196
- # Validate JSON structure
197
- try:
198
- data = json.loads(line)
199
- if not isinstance(data, dict):
200
- return False, f"خط {line_num}: فرمت JSON نامعتبر"
201
-
202
- if 'input' not in data or 'output' not in data:
203
- return False, f"خط {line_num}: کلیدهای 'input' و 'output' الزامی هستند"
204
-
205
- # Check content length
206
- if len(str(data['input'])) > 2048 or len(str(data['output'])) > 2048:
207
- return False, f"خط {line_num}: طول محتوا بیش از حد مجاز"
208
-
209
- except json.JSONDecodeError:
210
- return False, f"خط {line_num}: فرمت JSON نام��تبر"
211
-
212
- if line_count == 0:
213
- return False, "فایل خالی است"
214
-
215
- return True, f"فایل معتبر است ({line_count} خط)"
216
-
217
  except Exception as e:
218
- logger.error(f"File validation error: {e}")
219
- return False, f"خطا در بررسی فایل: {str(e)}"
220
-
221
- def read_jsonl_files_safe(paths: List[str], cfg: SystemConfig) -> Tuple[List[Dict], List[str]]:
222
- """Safe JSONL file reading with validation"""
223
- data: List[Dict] = []
224
- errors: List[str] = []
225
-
226
- for path in paths:
227
- # Validate file first
228
- is_valid, msg = validate_file_security(path, cfg.max_file_size_mb, cfg.max_lines_per_file)
229
- if not is_valid:
230
- errors.append(f"{Path(path).name}: {msg}")
231
- continue
232
-
233
- try:
234
- with open(path, 'r', encoding='utf-8') as f:
235
- for line_num, line in enumerate(f, 1):
236
- line = line.strip()
237
- if not line:
238
- continue
239
-
240
- try:
241
- obj = json.loads(line)
242
- # Sanitize input
243
- obj['input'] = str(obj['input']).strip()
244
- obj['output'] = str(obj['output']).strip()
245
-
246
- if obj['input'] and obj['output']:
247
- data.append(obj)
248
- except json.JSONDecodeError:
249
- errors.append(f"{Path(path).name} line {line_num}: JSON decode error")
250
-
251
- except Exception as e:
252
- errors.append(f"{Path(path).name}: {str(e)}")
253
-
254
- logger.info(f"Loaded {len(data)} samples from {len(paths)} files")
255
- return data, errors
256
-
257
- # ==========================
258
- # Model Cache System
259
- # ==========================
260
- class ModelCache:
261
- _instances: Dict[str, Any] = {}
262
- _lock = threading.Lock()
263
- _access_times: Dict[str, float] = {}
264
- _max_cache_size = 3 # Maximum models to keep in cache
265
-
266
- @classmethod
267
- def _generate_key(cls, model_name: str, architecture: str) -> str:
268
- return hashlib.md5(f"{model_name}_{architecture}".encode()).hexdigest()[:16]
269
-
270
- @classmethod
271
- def get_model(cls, model_name: str, architecture: str, model_config: ModelConfig):
272
- key = cls._generate_key(model_name, architecture)
273
-
274
- with cls._lock:
275
- if key in cls._instances:
276
- cls._access_times[key] = time.time()
277
- logger.info(f"Model loaded from cache: {model_name}")
278
- return cls._instances[key]
279
-
280
- # Cleanup old models if cache is full
281
- if len(cls._instances) >= cls._max_cache_size:
282
- cls._cleanup_cache()
283
-
284
- # Load new model
285
- try:
286
- loader = ModelLoader(model_config)
287
- loader.load()
288
- cls._instances[key] = loader
289
- cls._access_times[key] = time.time()
290
-
291
- # Update metrics
292
- if model_name not in metrics.metrics.active_models:
293
- metrics.metrics.active_models.append(model_name)
294
-
295
- logger.info(f"Model loaded and cached: {model_name}")
296
- return loader
297
-
298
- except Exception as e:
299
- logger.error(f"Failed to load model {model_name}: {e}")
300
- raise
301
-
302
- @classmethod
303
- def _cleanup_cache(cls):
304
- """Remove least recently used model"""
305
- if not cls._access_times:
306
- return
307
-
308
- # Find least recently used model
309
- lru_key = min(cls._access_times.keys(), key=lambda k: cls._access_times[k])
310
-
311
- # Clean up resources
312
- if lru_key in cls._instances:
313
- loader = cls._instances[lru_key]
314
- cls._cleanup_model_resources(loader)
315
- del cls._instances[lru_key]
316
- del cls._access_times[lru_key]
317
- logger.info(f"Removed model from cache: {lru_key}")
318
-
319
- @classmethod
320
- def _cleanup_model_resources(cls, loader):
321
- """Clean up model resources"""
322
- try:
323
- if hasattr(loader, 'model') and hasattr(loader.model, 'cpu'):
324
- loader.model.cpu()
325
- if torch.cuda.is_available():
326
- torch.cuda.empty_cache()
327
- except Exception as e:
328
- logger.warning(f"Error cleaning up model resources: {e}")
329
-
330
- @classmethod
331
- def clear_cache(cls):
332
- """Clear all cached models"""
333
- with cls._lock:
334
- for loader in cls._instances.values():
335
- cls._cleanup_model_resources(loader)
336
- cls._instances.clear()
337
- cls._access_times.clear()
338
- metrics.metrics.active_models.clear()
339
- logger.info("Model cache cleared")
340
 
341
  # ==========================
342
- # Enhanced RAG System
343
  # ==========================
344
- class LegalRAGSystem:
345
- def __init__(self, cfg: SystemConfig):
346
  self.cfg = cfg
347
- self.embedding_model: Optional[SentenceTransformer] = None
348
  self.client = None
349
  self.collection = None
350
- self._lock = threading.Lock()
351
 
352
- @contextmanager
353
- def _safe_operation(self, operation_name: str):
354
- """Context manager for safe RAG operations"""
355
- start_time = time.time()
356
  try:
357
- yield
358
- except Exception as e:
359
- logger.error(f"RAG {operation_name} failed: {e}")
360
- metrics.record_request(time.time() - start_time, success=False)
361
- raise
362
- else:
363
- metrics.record_request(time.time() - start_time, success=True)
364
-
365
- def setup_embedding(self):
366
- if self.embedding_model is None:
367
- try:
368
- self.embedding_model = SentenceTransformer(
369
- self.cfg.embedding_model,
370
- cache_folder=self.cfg.cache_dir
371
- )
372
- logger.info(f"Embedding model loaded: {self.cfg.embedding_model}")
373
- except Exception as e:
374
- logger.error(f"Failed to load embedding model: {e}")
375
- raise
376
-
377
- def load_chroma(self) -> Tuple[bool, str]:
378
- with self._safe_operation("load_chroma"):
379
  try:
380
- base_path = str(Path(self.cfg.chroma_db_path).parent)
381
- os.makedirs(base_path, exist_ok=True)
382
-
383
- self.client = chromadb.PersistentClient(path=base_path)
 
 
 
 
 
 
 
 
 
 
 
384
  try:
385
- self.collection = self.client.get_collection("legal_articles")
386
- count = self.collection.count()
387
- logger.info(f"Loaded existing collection with {count} documents")
388
- return count > 0, f"مجموعه موجود با {count} سند بارگذاری شد"
389
- except Exception:
390
- self.collection = self.client.create_collection(
391
- "legal_articles",
392
- metadata={"description": "مواد قانونی"}
393
- )
394
- logger.info("Created new collection")
395
- return False, "مجموعه جدید ایجاد شد"
396
-
397
- except Exception as e:
398
- logger.error(f"ChromaDB initialization failed: {e}")
399
- return False, f"خطا در بارگذاری پایگاه داده: {str(e)}"
400
 
401
  def retrieve(self, query: str) -> List[Dict]:
402
- if not self.collection or not query.strip():
403
  return []
404
-
405
- with self._safe_operation("retrieve"):
406
- try:
407
- # Sanitize query
408
- query = query.strip()[:500] # Limit query length
409
-
410
- result = self.collection.query(
411
- query_texts=[query],
412
- n_results=self.cfg.top_k_retrieval,
413
- include=["documents", "metadatas", "distances"]
414
- )
415
-
416
- articles = []
417
- if result['documents'] and result['documents'][0]:
418
- for i, (doc, meta, dist) in enumerate(zip(
419
- result['documents'][0],
420
- result['metadatas'][0],
421
- result['distances'][0]
422
- )):
423
- similarity = max(0, min(1, 1 - dist)) # Normalize similarity
424
- if similarity >= self.cfg.similarity_threshold:
425
- articles.append({
426
- "article_id": meta.get("article_id", f"unknown_{i}"),
427
- "text": str(doc)[:500], # Limit text length
428
- "similarity": round(similarity, 3),
429
- })
430
-
431
- logger.info(f"Retrieved {len(articles)} relevant articles")
432
- return articles
433
-
434
- except Exception as e:
435
- logger.error(f"Article retrieval failed: {e}")
436
- return []
437
-
438
- @staticmethod
439
- def build_context(articles: List[Dict], limit_chars: int = 500) -> str:
440
- if not articles:
441
- return ""
442
-
443
- context_parts = []
444
- total_chars = 0
445
-
446
- for article in articles:
447
- text = article['text'][:limit_chars]
448
- part = f"• ماده {article['article_id']}: {text}"
449
-
450
- if total_chars + len(part) > limit_chars * 3: # Max total context
451
- break
452
-
453
- context_parts.append(part)
454
- total_chars += len(part)
455
-
456
- return "مواد مرتبط:\n" + "\n".join(context_parts)
457
-
458
- # ==========================
459
- # Enhanced Formalizer
460
- # ==========================
461
- class Formalizer:
462
- def __init__(self, model_name="erfan226/persian-t5-formality-transfer", device=None):
463
- self.model_name = model_name
464
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
465
- self.tokenizer = None
466
- self.model = None
467
- self._initialized = False
468
- self._lock = threading.Lock()
469
-
470
- def _initialize(self):
471
- """Lazy initialization of formalizer model"""
472
- if self._initialized:
473
- return
474
-
475
- with self._lock:
476
- if self._initialized: # Double-check pattern
477
- return
478
-
479
- try:
480
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
481
- self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
482
- self._initialized = True
483
- logger.info("Formalizer model initialized")
484
- except Exception as e:
485
- logger.error(f"Formalizer initialization failed: {e}")
486
- raise
487
-
488
- def formalize(self, text: str, max_len: int = 512) -> str:
489
- if not text or not text.strip():
490
- return text
491
-
492
- self._initialize()
493
-
494
  try:
495
- # Sanitize and limit input
496
- text = text.strip()[:max_len]
497
-
498
- inputs = self.tokenizer(
499
- text,
500
- return_tensors="pt",
501
- truncation=True,
502
- max_length=max_len
503
- ).to(self.device)
504
-
505
- with torch.no_grad():
506
- outputs = self.model.generate(
507
- **inputs,
508
- max_length=max_len,
509
- num_beams=4,
510
- early_stopping=True
511
- )
512
-
513
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
514
- logger.debug(f"Formalized text: {text[:50]}... -> {result[:50]}...")
515
- return result
516
 
517
- except Exception as e:
518
- logger.error(f"Text formalization failed: {e}")
519
- return text # Return original text on error
 
 
520
 
521
  # ==========================
522
- # Enhanced Model Loader
523
  # ==========================
524
  class ModelLoader:
525
- def __init__(self, model_config: ModelConfig):
526
- self.cfg = model_config
527
  self.tokenizer = None
528
  self.model = None
529
- self._loaded = False
530
-
531
- def _is_persianmind(self, name: str) -> bool:
532
- return "PersianMind" in name or "universitytehran/PersianMind" in name
533
 
534
- @contextmanager
535
- def _gpu_memory_context(self):
536
- """Context manager for GPU memory management"""
537
- initial_memory = 0
 
 
538
  if torch.cuda.is_available():
539
- initial_memory = torch.cuda.memory_allocated()
 
 
 
 
 
 
 
 
 
540
 
541
- try:
542
- yield
543
- finally:
544
- if torch.cuda.is_available():
545
- final_memory = torch.cuda.memory_allocated()
546
- logger.info(f"Memory change: {(final_memory - initial_memory) / 1024**2:.1f} MB")
547
- metrics.update_memory_usage()
548
-
549
- def load(self, prefer_quantized: bool = True):
550
- if self._loaded:
551
- return self
552
-
553
- with self._gpu_memory_context():
554
  try:
555
- self._load_tokenizer()
556
- self._load_model(prefer_quantized)
557
- self._loaded = True
558
- logger.info(f"Successfully loaded {self.cfg.model_name}")
559
- return self
560
-
561
- except Exception as e:
562
- logger.error(f"Model loading failed: {e}")
563
- self._cleanup()
564
- raise
565
-
566
- def _load_tokenizer(self):
567
- """Load tokenizer with error handling"""
568
- try:
569
- self.tokenizer = AutoTokenizer.from_pretrained(
570
- self.cfg.model_name,
571
- use_fast=True,
572
- trust_remote_code=True
573
- )
574
- logger.info("Tokenizer loaded successfully")
575
- except Exception as e:
576
- logger.error(f"Tokenizer loading failed: {e}")
577
- raise
578
-
579
- def _load_model(self, prefer_quantized: bool):
580
- """Load model with quantization support"""
581
- device_map = "auto" if torch.cuda.is_available() else None
582
- cuda_available = torch.cuda.is_available()
583
- dtype = torch.bfloat16 if (cuda_available and self.cfg.use_bf16) else (
584
- torch.float16 if cuda_available else torch.float32
585
- )
586
-
587
- # Try quantized loading for PersianMind causal models
588
- if (self.cfg.architecture == "causal" and
589
- self._is_persianmind(self.cfg.model_name) and
590
- prefer_quantized and cuda_available):
591
-
592
- if self._try_quantized_loading(device_map, dtype):
593
- return
594
-
595
- # Standard loading
596
- self._load_standard_model(device_map, dtype)
597
-
598
- def _try_quantized_loading(self, device_map, dtype) -> bool:
599
- """Try loading model with quantization"""
600
- # Try 8-bit first
601
- try:
602
- self.model = AutoModelForCausalLM.from_pretrained(
603
- self.cfg.model_name,
604
- device_map=device_map,
605
- load_in_8bit=True,
606
- torch_dtype=dtype,
607
- trust_remote_code=True
608
- )
609
- self._setup_pad_token()
610
- logger.info("Model loaded with 8-bit quantization")
611
- return True
612
- except Exception as e:
613
- logger.warning(f"8-bit loading failed: {e}")
614
-
615
- # Try 4-bit
616
- try:
617
- self.model = AutoModelForCausalLM.from_pretrained(
618
- self.cfg.model_name,
619
- device_map=device_map,
620
- load_in_4bit=True,
621
- bnb_4bit_use_double_quant=True,
622
- bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
623
- torch_dtype=dtype,
624
- trust_remote_code=True
625
- )
626
- self._setup_pad_token()
627
- logger.info("Model loaded with 4-bit quantization")
628
- return True
629
- except Exception as e:
630
- logger.warning(f"4-bit loading failed: {e}")
631
-
632
- return False
633
-
634
- def _load_standard_model(self, device_map, dtype):
635
- """Load model with standard precision"""
636
- try:
637
- if self.cfg.architecture == "seq2seq":
638
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
639
- self.cfg.model_name,
640
- device_map=device_map,
641
- torch_dtype=dtype,
642
- trust_remote_code=True
643
- )
644
- elif self.cfg.architecture == "causal":
645
- self.model = AutoModelForCausalLM.from_pretrained(
646
- self.cfg.model_name,
647
- device_map=device_map,
648
- torch_dtype=dtype,
649
- trust_remote_code=True
650
- )
651
- self._setup_pad_token()
652
- else:
653
- raise ValueError(f"Unsupported architecture: {self.cfg.architecture}")
654
-
655
- logger.info("Model loaded with standard precision")
656
-
657
- except Exception as e:
658
- logger.error(f"Standard model loading failed: {e}")
659
- raise
660
-
661
- def _setup_pad_token(self):
662
- """Setup pad token for causal models"""
663
- if (self.tokenizer.pad_token is None and
664
- hasattr(self.tokenizer, 'eos_token') and
665
- self.tokenizer.eos_token):
666
- self.tokenizer.pad_token = self.tokenizer.eos_token
667
-
668
- def _cleanup(self):
669
- """Clean up resources on failure"""
670
- try:
671
- if self.model and hasattr(self.model, 'cpu'):
672
- self.model.cpu()
673
- if torch.cuda.is_available():
674
- torch.cuda.empty_cache()
675
- except Exception as e:
676
- logger.warning(f"Cleanup error: {e}")
677
-
678
- # ==========================
679
- # Enhanced Generator
680
- # ==========================
681
- class UnifiedGenerator:
682
- def __init__(self, loader: ModelLoader):
683
- self.loader = loader
684
- self.tokenizer = loader.tokenizer
685
  self.model = loader.model
686
- self.cfg = loader.cfg
687
-
688
- def generate(self, question: str, context: str = "") -> Tuple[str, str]:
689
- """Generate response with comprehensive error handling"""
690
- if not question or not question.strip():
691
- return "لطفاً سوال خود را وارد کنید.", "EMPTY_QUERY"
692
-
693
- if not self.model or not self.tokenizer:
694
- return "مدل بارگذاری نشده است.", "MODEL_NOT_LOADED"
695
-
696
- start_time = time.time()
697
- try:
698
- # Sanitize inputs
699
- question = question.strip()[:self.cfg.max_input_length // 2]
700
- context = context.strip()[:self.cfg.max_input_length // 2]
701
-
702
- if self.cfg.architecture == "seq2seq":
703
- result = self._generate_seq2seq(question, context)
704
- else:
705
- result = self._generate_causal(question, context)
706
-
707
- response_time = time.time() - start_time
708
- metrics.record_request(response_time, success=True)
709
-
710
- logger.info(f"Generated response in {response_time:.2f}s")
711
- return result, ""
712
-
713
- except torch.cuda.OutOfMemoryError:
714
- error_msg = "حافظه GPU کافی نیست. لطفاً پارامترها را کاهش دهید."
715
- logger.error("CUDA out of memory error")
716
- metrics.record_request(time.time() - start_time, success=False)
717
- return error_msg, "CUDA_OOM"
718
-
719
- except Exception as e:
720
- error_msg = "خطای غیرمنتظره در تولید پاسخ رخ داد."
721
- logger.error(f"Generation error: {e}")
722
- metrics.record_request(time.time() - start_time, success=False)
723
- return error_msg, str(e)
724
-
725
- def _generate_seq2seq(self, question: str, context: str) -> str:
726
- """Generate response using seq2seq model"""
727
- input_text = f"{context}\nسوال: {question}" if context else f"سوال: {question}"
728
-
729
- inputs = self.tokenizer(
730
- input_text,
731
- return_tensors="pt",
732
- truncation=True,
733
- max_length=self.cfg.max_input_length
734
- )
735
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
736
-
737
- with torch.no_grad():
738
- outputs = self.model.generate(
739
- **inputs,
740
  max_length=self.cfg.max_target_length,
741
  num_beams=self.cfg.num_beams,
742
  early_stopping=True,
743
- no_repeat_ngram_size=2,
744
- do_sample=False
745
  )
746
-
747
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
748
-
749
- # Clean up response (remove input echo if present)
750
- if input_text in response:
751
- response = response.replace(input_text, "").strip()
752
-
753
- return response or "پاسخی تولید نشد."
754
-
755
- def _generate_causal(self, question: str, context: str) -> str:
756
- """Generate response using causal model"""
757
- prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
758
-
759
- inputs = self.tokenizer(
760
- prompt,
761
- return_tensors="pt",
762
- truncation=True,
763
- max_length=self.cfg.max_input_length
764
- )
765
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
766
- input_length = inputs['input_ids'].shape[1]
767
-
768
- with torch.no_grad():
769
- outputs = self.model.generate(
770
- **inputs,
771
  max_new_tokens=self.cfg.max_new_tokens,
772
  do_sample=True,
773
- temperature=max(0.1, self.cfg.temperature), # Ensure min temperature
774
  top_p=self.cfg.top_p,
775
- pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
776
- repetition_penalty=1.1,
777
- no_repeat_ngram_size=3
778
  )
779
-
780
- # Extract only the generated part
781
- generated_tokens = outputs[0][input_length:]
782
- response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
783
-
784
- # Clean up response
785
- response = response.strip()
786
- if not response:
787
- return "پاسخی تولید نشد."
788
-
789
- # Remove any remaining prompt artifacts
790
- response = response.split("سوال:")[0].strip()
791
-
792
- return response
793
 
794
  # ==========================
795
- # Enhanced Datasets
796
  # ==========================
797
  class Seq2SeqJSONLDataset(Dataset):
798
- def __init__(self, data: List[Dict], tokenizer, max_input: int, max_target: int):
799
- self.tokenizer = tokenizer
800
- self.max_input = max_input
801
- self.max_target = max_target
802
-
803
- # Filter and validate data
804
  self.items = []
805
- for item in data:
806
- src = str(item.get("input", "")).strip()
807
- tgt = str(item.get("output", "")).strip()
808
-
809
- if src and tgt and len(src) > 5 and len(tgt) > 5: # Minimum length check
810
- self.items.append((src, tgt))
811
-
812
- logger.info(f"Seq2Seq dataset created with {len(self.items)} samples")
 
 
 
 
813
 
814
  def __len__(self):
815
  return len(self.items)
816
 
817
  def __getitem__(self, idx):
818
- source_text, target_text = self.items[idx]
819
-
820
- # Tokenize inputs
821
- model_inputs = self.tokenizer(
822
- source_text,
823
- max_length=self.max_input,
824
- padding="max_length",
825
- truncation=True,
826
- return_tensors="pt"
827
- )
828
-
829
- # Tokenize targets
830
- labels = self.tokenizer(
831
- text_target=target_text,
832
- max_length=self.max_target,
833
- padding="max_length",
834
- truncation=True,
835
- return_tensors="pt"
836
- )
837
-
838
- # Convert to proper format
839
- return {
840
- "input_ids": model_inputs["input_ids"].flatten(),
841
- "attention_mask": model_inputs["attention_mask"].flatten(),
842
- "labels": labels["input_ids"].flatten()
843
- }
844
 
845
  class CausalJSONLDataset(Dataset):
846
- def __init__(self, data: List[Dict], tokenizer, max_length: int):
847
- self.tokenizer = tokenizer
848
- self.max_length = max_length
849
-
850
- # Process data
851
  self.items = []
852
- for item in data:
853
- src = str(item.get("input", "")).strip()
854
- tgt = str(item.get("output", "")).strip()
855
-
856
- if src and tgt and len(src) > 5 and len(tgt) > 5:
857
- formatted_text = f"سوال: {src}\nپاسخ: {tgt}"
858
- self.items.append(formatted_text)
859
-
860
- logger.info(f"Causal dataset created with {len(self.items)} samples")
 
 
861
 
862
  def __len__(self):
863
  return len(self.items)
864
 
865
  def __getitem__(self, idx):
866
  text = self.items[idx]
867
-
868
- encoding = self.tokenizer(
869
- text,
870
- max_length=self.max_length,
871
- padding="max_length",
872
- truncation=True,
873
- return_tensors="pt"
874
- )
875
-
876
- input_ids = encoding["input_ids"].flatten()
877
- attention_mask = encoding["attention_mask"].flatten()
878
-
879
  labels = input_ids.clone()
880
- labels[attention_mask == 0] = -100
881
-
882
- return {
883
- "input_ids": input_ids,
884
- "attention_mask": attention_mask,
885
- "labels": labels
886
- }
887
 
888
  # ==========================
889
- # Enhanced Progress Callback
890
  # ==========================
891
- class GradioProgressCallback(TrainerCallback):
892
- def __init__(self, progress: gr.Progress, status_textbox: gr.Textbox):
893
- self.progress = progress
894
- self.status_textbox = status_textbox
895
- self.total_steps = None
896
- self.start_time = None
897
- self.last_update = 0
898
-
899
- def on_train_begin(self, args, state, control, **kwargs):
900
- self.total_steps = state.max_steps
901
- self.start_time = time.time()
902
- self.progress(0, desc="آموزش شروع شد 🚀")
903
- self.status_textbox.update(value="آموزش شروع شد...")
904
-
905
- def on_step_end(self, args, state, control, **kwargs):
906
- if not self.total_steps or time.time() - self.last_update < 1.0: # Throttle updates
907
- return
908
-
909
- self.last_update = time.time()
910
-
911
- # Calculate progress
912
- progress_pct = min(100, int((state.global_step / self.total_steps) * 100))
913
-
914
- # Estimate remaining time
915
- elapsed = time.time() - self.start_time
916
- if state.global_step > 0:
917
- avg_time_per_step = elapsed / state.global_step
918
- remaining_steps = self.total_steps - state.global_step
919
- eta_seconds = avg_time_per_step * remaining_steps
920
- eta_minutes = int(eta_seconds / 60)
921
- eta_str = f" (تخمین باقی‌مانده: {eta_minutes} دقیقه)" if eta_minutes > 0 else ""
922
- else:
923
- eta_str = ""
924
-
925
- # Update progress
926
- self.progress(progress_pct, desc=f"آموزش: {progress_pct}%")
927
-
928
- # Update status with more details
929
- current_lr = state.learning_rate if hasattr(state, 'learning_rate') else args.learning_rate
930
- status_msg = (f"Step {state.global_step}/{self.total_steps} → {progress_pct}%{eta_str}\n"
931
- f"Learning Rate: {current_lr:.2e}")
932
-
933
- if hasattr(state, 'log_history') and state.log_history:
934
- last_log = state.log_history[-1]
935
- if 'train_loss' in last_log:
936
- status_msg += f"\nTrain Loss: {last_log['train_loss']:.4f}"
937
- if 'eval_loss' in last_log:
938
- status_msg += f"\nEval Loss: {last_log['eval_loss']:.4f}"
939
-
940
- self.status_textbox.update(value=status_msg)
941
-
942
- def on_evaluate(self, args, state, control, **kwargs):
943
- if hasattr(state, 'log_history') and state.log_history:
944
- last_log = state.log_history[-1]
945
- if 'eval_loss' in last_log:
946
- self.status_textbox.update(
947
- value=f"ارزیابی انجام شد - Loss: {last_log['eval_loss']:.4f}"
948
- )
949
-
950
- def on_train_end(self, args, state, control, **kwargs):
951
- total_time = time.time() - self.start_time
952
- total_minutes = int(total_time / 60)
953
-
954
- self.progress(100, desc="آموزش تکمیل شد ✅")
955
- self.status_textbox.update(
956
- value=f"آموزش با موفقیت تکمیل شد ✅\n"
957
- f"زمان کل: {total_minutes} دقیقه\n"
958
- f"کل Steps: {state.global_step}"
959
- )
960
 
961
  # ==========================
962
- # Enhanced Trainer Manager
963
  # ==========================
964
- class TrainerManager:
965
- def __init__(self, system_config: SystemConfig, model_loader: ModelLoader):
966
- self.cfg = system_config
967
- self.loader = model_loader
968
-
969
- def train(self, train_paths: List[str], extra_callbacks: List = None) -> Tuple[bool, str]:
970
- """Main training method with comprehensive error handling"""
971
- if extra_callbacks is None:
972
- extra_callbacks = []
973
-
974
- try:
975
- # Validate training files
976
- data, errors = read_jsonl_files_safe(train_paths, self.cfg)
977
-
978
- if errors:
979
- error_msg = "خطاهای فایل:\n" + "\n".join(errors[:5]) # Show first 5 errors
980
- return False, error_msg
981
-
982
- if len(data) < 10:
983
- return False, f"داده کافی نیست. حداقل 10 نمونه نیاز است (موجود: {len(data)})"
984
-
985
- # Set random seed
986
- set_seed_all(self.cfg.seed)
987
-
988
- # Split data
989
- train_data, val_data = train_test_split(
990
- data,
991
- test_size=self.cfg.train_test_ratio,
992
- random_state=self.cfg.seed
993
- )
994
-
995
- logger.info(f"Training samples: {len(train_data)}, Validation samples: {len(val_data)}")
996
-
997
- # Train based on architecture
998
- if self.cfg.model.architecture == "seq2seq":
999
- success, msg = self._train_seq2seq(train_data, val_data, extra_callbacks)
1000
- else:
1001
- success, msg = self._train_causal(train_data, val_data, extra_callbacks)
1002
-
1003
- if success:
1004
- # Save configuration
1005
- self._save_training_config()
1006
-
1007
- return success, msg
1008
-
1009
- except Exception as e:
1010
- logger.error(f"Training failed: {e}")
1011
- return False, f"خطا در آموزش: {str(e)}"
1012
-
1013
- def _train_seq2seq(self, train_data: List[Dict], val_data: List[Dict],
1014
- extra_callbacks: List) -> Tuple[bool, str]:
1015
- """Train seq2seq model"""
1016
- try:
1017
- # Create datasets
1018
- train_dataset = Seq2SeqJSONLDataset(
1019
- train_data, self.loader.tokenizer,
1020
- self.cfg.model.max_input_length,
1021
- self.cfg.model.max_target_length
1022
- )
1023
-
1024
- val_dataset = Seq2SeqJSONLDataset(
1025
- val_data, self.loader.tokenizer,
1026
- self.cfg.model.max_input_length,
1027
- self.cfg.model.max_target_length
1028
- )
1029
-
1030
- # Data collator
1031
- data_collator = DataCollatorForSeq2Seq(
1032
- tokenizer=self.loader.tokenizer,
1033
- model=self.loader.model,
1034
- padding=True
1035
- )
1036
-
1037
- # Training arguments
1038
- training_args = self._get_training_args()
1039
- training_args.predict_with_generate = True
1040
- training_args.generation_max_length = self.cfg.model.max_target_length
1041
- training_args.generation_num_beams = self.cfg.model.num_beams
1042
-
1043
- # Create trainer
1044
- trainer = Trainer(
1045
- model=self.loader.model,
1046
- args=training_args,
1047
- train_dataset=train_dataset,
1048
- eval_dataset=val_dataset,
1049
- data_collator=data_collator,
1050
- tokenizer=self.loader.tokenizer,
1051
- callbacks=self._get_callbacks(extra_callbacks)
1052
- )
1053
-
1054
- # Train
1055
- trainer.train()
1056
-
1057
- # Save model
1058
- trainer.save_model(self.cfg.output_dir)
1059
- self.loader.tokenizer.save_pretrained(self.cfg.output_dir)
1060
-
1061
- return True, "مدل Seq2Seq با موفقیت آموزش داده شد"
1062
-
1063
- except Exception as e:
1064
- logger.error(f"Seq2Seq training failed: {e}")
1065
- return False, f"خطا در آموزش Seq2Seq: {str(e)}"
1066
-
1067
- def _train_causal(self, train_data: List[Dict], val_data: List[Dict],
1068
- extra_callbacks: List) -> Tuple[bool, str]:
1069
- """Train causal language model"""
1070
- try:
1071
- # Create datasets
1072
- train_dataset = CausalJSONLDataset(
1073
- train_data, self.loader.tokenizer,
1074
- self.cfg.model.max_input_length
1075
- )
1076
-
1077
- val_dataset = CausalJSONLDataset(
1078
- val_data, self.loader.tokenizer,
1079
- self.cfg.model.max_input_length
1080
- )
1081
-
1082
- # Training arguments
1083
- training_args = self._get_training_args()
1084
-
1085
- # Create trainer
1086
- trainer = Trainer(
1087
- model=self.loader.model,
1088
- args=training_args,
1089
- train_dataset=train_dataset,
1090
- eval_dataset=val_dataset,
1091
- tokenizer=self.loader.tokenizer,
1092
- callbacks=self._get_callbacks(extra_callbacks)
1093
- )
1094
-
1095
- # Train
1096
- trainer.train()
1097
-
1098
- # Save model
1099
- trainer.save_model(self.cfg.output_dir)
1100
- self.loader.tokenizer.save_pretrained(self.cfg.output_dir)
1101
 
1102
- return True, "مدل Causal با موفقیت آموزش داده شد"
 
 
 
1103
 
1104
- except Exception as e:
1105
- logger.error(f"Causal training failed: {e}")
1106
- return False, f"خطا در آموزش Causal: {str(e)}"
1107
-
1108
- def _get_training_args(self) -> TrainingArguments:
1109
- """Get training arguments with optimized settings"""
1110
- return TrainingArguments(
1111
- output_dir=self.cfg.output_dir,
1112
- num_train_epochs=self.cfg.epochs,
1113
- learning_rate=self.cfg.lr,
1114
- per_device_train_batch_size=self.cfg.batch_size,
1115
- per_device_eval_batch_size=self.cfg.batch_size,
1116
- gradient_accumulation_steps=self.cfg.grad_accum,
1117
- warmup_ratio=0.05,
1118
- weight_decay=0.01,
1119
- evaluation_strategy="epoch",
1120
- eval_steps=500,
1121
- save_strategy="epoch",
1122
- save_total_limit=3, # Keep more checkpoints
1123
  load_best_model_at_end=True,
1124
  metric_for_best_model="eval_loss",
1125
- greater_is_better=False,
1126
- logging_steps=50,
1127
- logging_dir=f"{self.cfg.output_dir}/logs",
1128
- report_to="none",
1129
- bf16=self.cfg.model.use_bf16 if torch.cuda.is_available() else False,
1130
- fp16=(not self.cfg.model.use_bf16) if torch.cuda.is_available() else False,
1131
- dataloader_drop_last=True,
1132
- remove_unused_columns=False,
1133
- gradient_checkpointing=True, # Save memory
 
1134
  )
1135
-
1136
- def _get_callbacks(self, extra_callbacks: List) -> List:
1137
- """Get training callbacks"""
1138
- callbacks = [
1139
- EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
1140
- ]
1141
- callbacks.extend(extra_callbacks)
1142
- return callbacks
1143
-
1144
- def _save_training_config(self):
1145
- """Save training configuration"""
1146
- try:
1147
- config_path = Path(self.cfg.output_dir) / "training_config.json"
1148
- config_dict = self.cfg.dict()
1149
- config_dict['training_timestamp'] = datetime.now().isoformat()
1150
- config_dict['training_completed'] = True
1151
-
1152
- with open(config_path, 'w', encoding='utf-8') as f:
1153
- json.dump(config_dict, f, ensure_ascii=False, indent=2)
1154
-
1155
- logger.info(f"Training config saved to {config_path}")
1156
- except Exception as e:
1157
- logger.warning(f"Failed to save training config: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1158
 
1159
  # ==========================
1160
- # Enhanced Legal App
1161
  # ==========================
1162
  class LegalApp:
1163
- def __init__(self, system_config: Optional[SystemConfig] = None):
1164
- self.cfg = system_config or SystemConfig()
1165
- self.rag = LegalRAGSystem(self.cfg)
1166
- self.formalizer: Optional[Formalizer] = None
1167
- self._current_loader: Optional[ModelLoader] = None
1168
- self._current_generator: Optional[UnifiedGenerator] = None
1169
- self._lock = threading.Lock()
1170
-
1171
- def _ensure_model(self, model_name: str, architecture: str) -> Tuple[bool, str]:
1172
- """Ensure model is loaded with thread safety"""
1173
- with self._lock:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1174
  try:
1175
- # Update config
1176
- self.cfg.model.model_name = model_name
1177
- self.cfg.model.architecture = architecture
1178
-
1179
- # Get model from cache
1180
- self._current_loader = ModelCache.get_model(model_name, architecture, self.cfg.model)
1181
- self._current_generator = UnifiedGenerator(self._current_loader)
1182
-
1183
- return True, f"مدل بارگذاری شد: {model_name} ({architecture})"
1184
-
1185
  except Exception as e:
1186
- logger.error(f"Model loading failed: {e}")
1187
- return False, f"خطا در بارگذاری مدل: {str(e)}"
1188
-
1189
- def _ensure_rag(self) -> Tuple[bool, str]:
1190
- """Ensure RAG system is ready"""
1191
- try:
1192
- self.rag.setup_embedding()
1193
- success, message = self.rag.load_chroma()
1194
- return success, message
1195
- except Exception as e:
1196
- logger.error(f"RAG setup failed: {e}")
1197
- return False, f"خطا در راه‌اندازی RAG: {str(e)}"
1198
-
1199
- def _ensure_formalizer(self) -> str:
1200
- """Ensure formalizer is ready"""
1201
- try:
1202
- if not self.formalizer:
1203
- self.formalizer = Formalizer()
1204
- return "Formalizer آماده است."
1205
- except Exception as e:
1206
- logger.error(f"Formalizer setup failed: {e}")
1207
- return f"خطا در راه‌اندازی Formalizer: {str(e)}"
1208
-
1209
- # Event handlers
1210
- def handle_load_model(self, model_choice: str, use_rag: bool) -> str:
1211
- """Handle model loading"""
1212
- try:
1213
- model_configs = self._get_model_configs()
1214
- if model_choice not in model_configs:
1215
- return "مدل نامعتبر انتخاب شده"
1216
 
1217
- model_name, architecture = model_configs[model_choice]
1218
-
1219
- # Load model
1220
- success, model_msg = self._ensure_model(model_name, architecture)
1221
- if not success:
1222
- return model_msg
1223
-
1224
- # Setup RAG if requested
1225
- rag_msg = ""
1226
- if use_rag:
1227
- rag_success, rag_msg = self._ensure_rag()
1228
- rag_msg = f"\nRAG: {rag_msg}"
1229
- else:
1230
- rag_msg = "\nRAG: غیر فعال"
1231
-
1232
- return f"{model_msg}{rag_msg}"
1233
-
1234
- except Exception as e:
1235
- logger.error(f"Model loading handler failed: {e}")
1236
- return f"خطا در بارگذاری: {str(e)}"
1237
-
1238
- def handle_generate_response(self, question: str, use_rag: bool, use_formalizer: bool,
1239
- max_new_tokens: int, temperature: float, top_p: float,
1240
- num_beams: int) -> Tuple[str, str, str]: # response, references, metrics
1241
- """Handle response generation"""
1242
- if not question or not question.strip():
1243
- return "لطفاً سوال خود را وارد کنید.", "", ""
1244
-
1245
- if not self._current_generator:
1246
- return "ابتدا مدل را بارگذاری کنید.", "", ""
1247
-
1248
- start_time = time.time()
1249
 
 
 
 
1250
  try:
1251
- # Update generation parameters
1252
- self.cfg.model.max_new_tokens = max(32, min(1024, int(max_new_tokens)))
1253
- self.cfg.model.temperature = max(0.1, min(2.0, float(temperature)))
1254
- self.cfg.model.top_p = max(0.1, min(1.0, float(top_p)))
1255
- self.cfg.model.num_beams = max(1, min(8, int(num_beams)))
1256
-
1257
- # Apply input formalization if requested
1258
- processed_question = question
1259
- if use_formalizer:
1260
- formalizer_msg = self._ensure_formalizer()
1261
- if "خطا" not in formalizer_msg and self.formalizer:
1262
- processed_question = self.formalizer.formalize(question)
1263
-
1264
- # Retrieve relevant articles if RAG is enabled
1265
- articles = []
1266
- if use_rag and self.rag.collection:
1267
- articles = self.rag.retrieve(processed_question)
1268
-
1269
- # Build context
1270
- context = LegalRAGSystem.build_context(articles) if articles else ""
1271
-
1272
- # Generate response
1273
- response, error = self._current_generator.generate(processed_question, context)
1274
-
1275
- # Build references section
1276
- references = ""
1277
- if articles:
1278
- ref_parts = []
1279
- for article in articles[:3]: # Limit to top 3 references
1280
- ref_parts.append(
1281
- f"**ماده {article['article_id']}** (شباهت: {article['similarity']:.2f})\n"
1282
- f"{article['text'][:400]}{'...' if len(article['text']) > 400 else ''}"
1283
- )
1284
- references = "\n\n".join(ref_parts)
1285
-
1286
- # Generate metrics info
1287
- elapsed_time = time.time() - start_time
1288
- metrics_info = f"زمان پردازش: {elapsed_time:.2f}s"
1289
- if articles:
1290
- metrics_info += f" | مواد یافت شده: {len(articles)}"
1291
- if use_formalizer:
1292
- metrics_info += " | فرمالایزر فعال"
1293
-
1294
- return response, references, metrics_info
1295
-
1296
  except Exception as e:
1297
- logger.error(f"Response generation failed: {e}")
1298
- error_time = time.time() - start_time
1299
- metrics.record_request(error_time, success=False)
1300
- return f"خطا در تولید پاسخ: {str(e)}", "", f"خطا پس از {error_time:.2f}s"
1301
-
1302
- def handle_training(self, model_choice: str, uploaded_files, use_rag_training: bool,
1303
- epochs: int, batch_size: int, learning_rate: float,
1304
- progress: gr.Progress, status_textbox: gr.Textbox) -> str:
1305
- """Handle model training"""
1306
- try:
1307
- # Validate inputs
1308
- if not uploaded_files:
1309
- return "لطفاً فایل‌های آموزشی را بارگذاری کنید."
1310
-
1311
- # Get model config
1312
- model_configs = self._get_model_configs()
1313
- if model_choice not in model_configs:
1314
- return "مدل نامعتبر انتخاب شده"
1315
-
1316
- model_name, architecture = model_configs[model_choice]
1317
-
1318
- # Load model for training
1319
- success, msg = self._ensure_model(model_name, architecture)
1320
- if not success:
1321
- return f"خطا در بارگذاری مدل: {msg}"
1322
-
1323
- # Update training config
1324
- self.cfg.epochs = max(1, min(10, int(epochs)))
1325
- self.cfg.batch_size = max(1, min(16, int(batch_size)))
1326
- self.cfg.lr = max(1e-6, min(1e-3, float(learning_rate)))
1327
-
1328
- # Setup RAG if requested
1329
- if use_rag_training:
1330
- rag_success, rag_msg = self._ensure_rag()
1331
- if not rag_success:
1332
- logger.warning(f"RAG setup failed for training: {rag_msg}")
1333
-
1334
- # Get file paths (gr.File with type="filepath" returns list[str])
1335
- file_paths = uploaded_files
1336
-
1337
- if not file_paths:
1338
- return "فایل‌های معتبر یافت نشد."
1339
-
1340
- # Create trainer
1341
- trainer_manager = TrainerManager(self.cfg, self._current_loader)
1342
-
1343
- # Create progress callback
1344
- progress_callback = GradioProgressCallback(progress, status_textbox)
1345
-
1346
- # Start training
1347
- success, result_msg = trainer_manager.train(file_paths, [progress_callback])
1348
-
1349
- if success:
1350
- # Clear model cache to force reload of trained model
1351
- ModelCache.clear_cache()
1352
- return f"✅ {result_msg}\nمدل در مسیر '{self.cfg.output_dir}' ذخیره شد."
1353
- else:
1354
- return f"❌ {result_msg}"
1355
-
1356
- except Exception as e:
1357
- logger.error(f"Training handler failed: {e}")
1358
- return f"خطا در آموزش: {str(e)}"
1359
-
1360
- def get_system_status(self) -> str:
1361
- """Get system status information"""
1362
- try:
1363
- status_parts = []
1364
-
1365
- # Model status
1366
- if self._current_loader:
1367
- status_parts.append(f"✅ مدل فعال: {self.cfg.model.model_name}")
1368
- else:
1369
- status_parts.append("❌ مدل بارگذاری نشده")
1370
-
1371
- # RAG status
1372
- if self.rag.collection:
1373
- doc_count = self.rag.collection.count()
1374
- status_parts.append(f"✅ RAG فعال ({doc_count} سند)")
1375
- else:
1376
- status_parts.append("❌ RAG غیر فعال")
1377
-
1378
- # System metrics
1379
- sys_metrics = metrics.get_metrics()
1380
- status_parts.append(f"📊 درخواست‌ها: {sys_metrics['requests_total']}")
1381
- status_parts.append(f"📈 نرخ موفقیت: {sys_metrics['success_rate']:.1f}%")
1382
- status_parts.append(f"⏱️ زمان متوسط: {sys_metrics['avg_response_time']}s")
1383
-
1384
- if torch.cuda.is_available():
1385
- memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
1386
- status_parts.append(f"🖥️ حافظه GPU: {memory_mb:.1f} MB")
1387
-
1388
- return "\n".join(status_parts)
1389
 
1390
- except Exception as e:
1391
- return f"خطا در دریافت وضعیت: {str(e)}"
1392
 
1393
- def _get_model_configs(self) -> Dict[str, Tuple[str, str]]:
1394
- """Get available model configurations"""
1395
- return {
1396
- "Seq2Seq (parsi-t5-base)": ("persiannlp/parsi-t5-base", "seq2seq"),
1397
  "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
1398
- "Causal (Mistral-7B)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
1399
- "Causal (PersianMind-v1.0)": ("universitytehran/PersianMind-v1.0", "causal"),
 
1400
  }
1401
 
1402
- def build_ui(self) -> gr.Blocks:
1403
- """Build enhanced Gradio interface"""
1404
- model_choices = list(self._get_model_configs().keys())
1405
-
1406
- with gr.Blocks(
1407
- title="ماحون — مشاور حقوقی هوشمند",
1408
- theme=gr.themes.Soft(),
1409
- css="""
1410
- .status-box { font-family: 'Courier New', monospace; font-size: 12px; }
1411
- .metrics-box { background-color: #f0f0f0; padding: 10px; border-radius: 5px; }
1412
- """
1413
- ) as app:
1414
-
1415
  gr.HTML("""
1416
- <div style='text-align: center; margin-bottom: 20px;'>
1417
- <h1>ماحون — مشاور حقوقی هوشمند 🏛️</h1>
1418
- <p>سیستم پیشرفته مشاوره حقوقی با قابلیت RAG، Fine-tuning و هوش مصنوعی</p>
1419
  </div>
1420
  """)
1421
 
1422
- # System Status
1423
- with gr.Accordion("وضعیت سیستم", open=False):
1424
- system_status = gr.Markdown(
1425
- value=self.get_system_status(),
1426
- elem_classes=["status-box"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1427
  )
1428
- refresh_status_btn = gr.Button("🔄 بروزرسانی وضعیت", size="sm")
1429
-
1430
- with gr.Tabs() as tabs:
1431
- # Consultation Tab
1432
- with gr.Tab("💬 مشاوره") as advice_tab:
1433
- with gr.Row():
1434
- with gr.Column(scale=2):
1435
- model_dropdown = gr.Dropdown(
1436
- choices=model_choices,
1437
- value=model_choices[0],
1438
- label="انتخاب مدل",
1439
- info="نوع مدل مورد نظر را انتخاب کنید"
1440
- )
1441
- with gr.Column(scale=1):
1442
- use_rag_checkbox = gr.Checkbox(
1443
- value=True,
1444
- label="استفاده از RAG",
1445
- info="بازیابی مواد قانونی مرتبط"
1446
- )
1447
- use_formalizer_checkbox = gr.Checkbox(
1448
- value=False,
1449
- label="رسمی‌سازی ورودی",
1450
- info="تبدیل متن غیررسمی به رسمی"
1451
- )
1452
-
1453
- load_model_btn = gr.Button("🚀 بارگذاری مدل/RAG", variant="primary", size="lg")
1454
- load_status = gr.Textbox(
1455
- label="وضعیت بارگذاری",
1456
- interactive=False,
1457
- elem_classes=["status-box"]
1458
- )
1459
-
1460
- # Generation Parameters
1461
- with gr.Accordion("⚙️ پارامترهای تولید", open=False):
1462
- with gr.Row():
1463
- max_new_tokens = gr.Slider(
1464
- minimum=32, maximum=1024, value=self.cfg.model.max_new_tokens,
1465
- step=16, label="حداکثر توکن‌های جدید"
1466
- )
1467
- temperature = gr.Slider(
1468
- minimum=0.1, maximum=2.0, value=self.cfg.model.temperature,
1469
- step=0.05, label="دما (خلاقیت)"
1470
- )
1471
- with gr.Row():
1472
- top_p = gr.Slider(
1473
- minimum=0.1, maximum=1.0, value=self.cfg.model.top_p,
1474
- step=0.05, label="Top-p (تنوع)"
1475
- )
1476
- num_beams = gr.Slider(
1477
- minimum=1, maximum=8, value=self.cfg.model.num_beams,
1478
- step=1, label="تعداد Beam"
1479
- )
1480
-
1481
- # Input/Output
1482
- with gr.Row():
1483
- with gr.Column(scale=1):
1484
- question_input = gr.Textbox(
1485
- label="سوال حقوقی خود را وارد کنید",
1486
- placeholder="مثال: شرایط فسخ قرارداد اجاره چیست؟",
1487
- lines=3
1488
- )
1489
- submit_btn = gr.Button("🔍 دریافت پاسخ", variant="primary")
1490
- with gr.Column(scale=1):
1491
- response_output = gr.Textbox(
1492
- label="پاسخ سیستم",
1493
- lines=8,
1494
- interactive=False
1495
- )
1496
- references_output = gr.Textbox(
1497
- label="مراجع حقوقی مرتبط",
1498
- lines=6,
1499
- interactive=False
1500
- )
1501
- metrics_output = gr.Textbox(
1502
- label="معیارهای عملکرد",
1503
- lines=1,
1504
- interactive=False,
1505
- elem_classes=["metrics-box"]
1506
- )
1507
-
1508
- # Training Tab
1509
- with gr.Tab("🎓 آموزش مدل") as training_tab:
1510
- with gr.Row():
1511
- with gr.Column(scale=1):
1512
- train_model_dropdown = gr.Dropdown(
1513
- choices=model_choices,
1514
- value=model_choices[0],
1515
- label="انتخاب مدل برای آموزش"
1516
- )
1517
- use_rag_training_checkbox = gr.Checkbox(
1518
- value=True,
1519
- label="استفاده از RAG در آموزش",
1520
- info="استفاده از مواد قانونی در آموزش"
1521
- )
1522
- train_file_upload = gr.File(
1523
- label="بارگذاری فایل‌ها�� آموزشی (JSONL)",
1524
- file_types=[".jsonl"],
1525
- type="filepath",
1526
- file_count="multiple"
1527
- )
1528
- with gr.Column(scale=1):
1529
- with gr.Accordion("⚙️ پارامترهای آموزش", open=False):
1530
- train_epochs = gr.Slider(
1531
- minimum=1, maximum=10, value=self.cfg.epochs,
1532
- step=1, label="تعداد Epoch"
1533
- )
1534
- train_batch_size = gr.Slider(
1535
- minimum=1, maximum=16, value=self.cfg.batch_size,
1536
- step=1, label="اندازه Batch"
1537
- )
1538
- train_lr = gr.Slider(
1539
- minimum=1e-6, maximum=1e-3, value=self.cfg.lr,
1540
- step=1e-5, label="نرخ یادگیری"
1541
- )
1542
-
1543
- train_btn = gr.Button("🎯 شروع آموزش", variant="primary")
1544
- train_status = gr.Textbox(
1545
- label="وضعیت آموزش",
1546
- interactive=False,
1547
- elem_classes=["status-box"]
1548
- )
1549
- train_progress = gr.Progress(label="پیشرفت آموزش")
1550
-
1551
- # Event handlers
1552
- load_model_btn.click(
1553
- fn=lambda m, r: self.handle_load_model(m, r),
1554
- inputs=[model_dropdown, use_rag_checkbox],
1555
- outputs=load_status
1556
- )
1557
-
1558
- submit_btn.click(
1559
- fn=lambda q, r, f, m, t, p, b: self.handle_generate_response(
1560
- q, r, f, m, t, p, b
1561
- ),
1562
- inputs=[
1563
- question_input,
1564
- use_rag_checkbox,
1565
- use_formalizer_checkbox,
1566
- max_new_tokens,
1567
- temperature,
1568
- top_p,
1569
- num_beams
1570
- ],
1571
- outputs=[response_output, references_output, metrics_output]
1572
- )
1573
-
1574
- refresh_status_btn.click(
1575
- fn=lambda: self.get_system_status(),
1576
- outputs=system_status
1577
- )
1578
 
1579
  train_btn.click(
1580
- fn=lambda m, f, r, e, b, lr, p, s: self.handle_training(
1581
- m, f, r, e, b, lr, p, s
1582
- ),
1583
- inputs=[
1584
- train_model_dropdown,
1585
- train_file_upload,
1586
- use_rag_training_checkbox,
1587
- train_epochs,
1588
- train_batch_size,
1589
- train_lr,
1590
- train_progress,
1591
- train_status
1592
- ],
1593
  outputs=train_status
1594
  )
1595
-
1596
  return app
1597
 
1598
  # ==========================
1599
- # Main Application
1600
  # ==========================
1601
- def main():
1602
- """Main entry point for the application"""
1603
- # Initialize system
1604
  app = LegalApp()
1605
-
1606
- # Build and launch UI
1607
  ui = app.build_ui()
1608
- ui.launch(
1609
- server_name="0.0.0.0",
1610
- server_port=7860,
1611
- inbrowser=True,
1612
- share=False
1613
- )
1614
-
1615
- if __name__ == "__main__":
1616
- main()
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Mahoun Legal AI (RAG + Training + Metrics) for HF Spaces / Gradio 5
4
+ - سازگار با Gradio 5.x و Transformers >= 4.44
5
+ - TrainingArguments ایمن با عقب‌سازگاری (safe_training_args)
6
+ - RAG با ChromaDB + ایندکس‌سازی JSONL قوانین
7
+ - متریک‌ها: ROUGE-L (seq2seq) و F1 ساده (causal)
8
+ - ماسک پدینگ روی labels در معماری علّی
9
+ - Progress به‌صورت DI: progress=gr.Progress(track_tqdm=True)
10
+
11
+ ساختار ورودی دیتاست آموزش:
12
+ JSONL با کلیدهای "input" و "output"
13
+
14
+ ساختار ورودی قوانین برای ایندکس:
15
+ JSONL با کلیدهای (پیش‌فرض) "article_id" و "text"
16
  """
17
 
18
  from __future__ import annotations
19
+ import os, sys, json, warnings
 
 
 
 
 
 
20
  from dataclasses import dataclass, field
21
  from pathlib import Path
22
+ from typing import List, Dict, Optional, Tuple
 
 
 
23
 
24
+ import numpy as np
25
  import torch
26
  from torch.utils.data import Dataset
27
  from sklearn.model_selection import train_test_split
 
28
 
29
+ import gradio as gr
30
+ from packaging import version
31
+
32
+ import transformers as tf
33
  from transformers import (
34
  AutoTokenizer,
35
  AutoModelForSeq2SeqLM,
 
38
  TrainingArguments,
39
  EarlyStoppingCallback,
40
  DataCollatorForSeq2Seq,
 
41
  )
42
 
43
+ # RAG stack
44
  import chromadb
45
  from sentence_transformers import SentenceTransformer
 
46
 
47
+ # Optional metrics
48
+ try:
49
+ from evaluate import load as eval_load
50
+ except Exception:
51
+ eval_load = None
52
 
53
+ warnings.filterwarnings("ignore")
 
 
 
 
 
54
 
55
  # ==========================
56
+ # Config
57
  # ==========================
58
+ @dataclass
59
+ class ModelConfig:
60
+ model_name: str = "google/mt5-base"
61
+ architecture: str = "seq2seq" # "seq2seq" | "causal"
62
+ max_input_length: int = 1024
63
+ max_target_length: int = 512
64
+ max_new_tokens: int = 384
65
+ temperature: float = 0.7
66
+ top_p: float = 0.9
67
+ num_beams: int = 4
68
+ gradient_checkpointing: bool = True
 
 
 
 
 
 
 
 
69
 
70
+ @dataclass
71
+ class RAGConfig:
72
  embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
73
+ persist_dir: str = "./chroma_db"
74
+ collection: str = "legal_articles"
75
+ top_k: int = 5
76
+ similarity_threshold: float = 0.66 # 0..1
77
+ context_char_limit: int = 300
78
+ enable: bool = True
79
+
80
+ @dataclass
81
+ class TrainConfig:
82
+ output_dir: str = "./mahoon_model"
83
  seed: int = 42
84
+ test_size: float = 0.1
85
+ epochs: int = 3
86
+ batch_size: int = 2
87
+ grad_accum: int = 2
88
+ lr: float = 3e-5
89
+ use_bf16: bool = True
90
+ weight_decay: float = 0.01
91
+ warmup_ratio: float = 0.05
92
+ logging_steps: int = 50
93
+ eval_strategy: str = "epoch" # "steps" | "epoch"
94
+ save_strategy: str = "epoch"
95
+ save_total_limit: int = 2
96
+ report_to: str = "none" # "none" | "wandb"
97
+ max_grad_norm: float = 1.0
98
 
 
 
 
99
  @dataclass
100
+ class SystemConfig:
101
+ model: ModelConfig = field(default_factory=ModelConfig)
102
+ rag: RAGConfig = field(default_factory=RAGConfig)
103
+ train: TrainConfig = field(default_factory=TrainConfig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # ==========================
106
+ # Utils
107
  # ==========================
108
  def set_seed_all(seed: int = 42):
109
  import random
110
  random.seed(seed)
111
+ np.random.seed(seed)
112
  torch.manual_seed(seed)
113
+ if torch.cuda.is_available():
114
+ torch.cuda.manual_seed_all(seed)
115
 
116
+ def log_deps():
 
117
  try:
118
+ import accelerate, datasets
119
+ print("[deps]",
120
+ f"python={sys.version.split()[0]}",
121
+ f"transformers={tf.__version__}",
122
+ f"accelerate={accelerate.__version__}",
123
+ f"datasets={datasets.__version__}",
124
+ f"gradio={gr.__version__}",
125
+ flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
+ print("[deps] warn:", e, flush=True)
128
+
129
+ def bf16_supported():
130
+ return torch.cuda.is_available() and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported()
131
+
132
+ def safe_training_args(**kwargs):
133
+ """
134
+ Wrapper برای سازگاری با نسخه‌های قدیمی‌تر Transformers (قبل از 4.4):
135
+ - evaluation_strategy -> evaluate_during_training
136
+ - حذف کلیدهای جدید که ممکن است ناشناخته باشند
137
+ """
138
+ tf_ver = version.parse(tf.__version__)
139
+ k = dict(kwargs)
140
+ if tf_ver < version.parse("4.4.0"):
141
+ eval_strat = k.pop("evaluation_strategy", None)
142
+ k["evaluate_during_training"] = bool(eval_strat and str(eval_strat).lower() != "no")
143
+ for rm in ["save_strategy","load_best_model_at_end","metric_for_best_model",
144
+ "greater_is_better","predict_with_generate","generation_max_length",
145
+ "generation_num_beams","report_to","max_grad_norm"]:
146
+ k.pop(rm, None)
147
+ return TrainingArguments(**k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  # ==========================
150
+ # RAG
151
  # ==========================
152
+ class LegalRAG:
153
+ def __init__(self, cfg: RAGConfig):
154
  self.cfg = cfg
 
155
  self.client = None
156
  self.collection = None
157
+ self.embedder: Optional[SentenceTransformer] = None
158
 
159
+ def init(self):
160
+ Path(self.cfg.persist_dir).mkdir(parents=True, exist_ok=True)
161
+ self.client = chromadb.PersistentClient(path=self.cfg.persist_dir)
 
162
  try:
163
+ self.collection = self.client.get_or_create_collection(self.cfg.collection)
164
+ except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  try:
166
+ self.collection = self.client.get_collection(self.cfg.collection)
167
+ except Exception:
168
+ self.collection = self.client.create_collection(self.cfg.collection)
169
+ self.embedder = SentenceTransformer(self.cfg.embedding_model)
170
+
171
+ def index_jsonl(self, jsonl_path: str, id_key="article_id", text_key="text"):
172
+ """ایندکس‌سازی اولیه قوانین از JSONL: هر خط یک شیء {article_id, text, ...}."""
173
+ if not self.collection or not self.embedder:
174
+ self.init()
175
+ ids, docs, metas = [], [], []
176
+ with open(jsonl_path, "r", encoding="utf-8") as f:
177
+ for i, line in enumerate(f):
178
+ s = line.strip()
179
+ if not s:
180
+ continue
181
  try:
182
+ obj = json.loads(s)
183
+ except:
184
+ continue
185
+ aid = str(obj.get(id_key, f"auto_{i}"))
186
+ txt = str(obj.get(text_key, "")).strip()
187
+ if not txt:
188
+ continue
189
+ ids.append(aid)
190
+ docs.append(txt)
191
+ metas.append({"article_id": aid})
192
+ if not ids:
193
+ return "هیچ سندی برای ایندکس پیدا نشد."
194
+ self.collection.upsert(ids=ids, documents=docs, metadatas=metas)
195
+ return f" {len(ids)} سند قانونی ایندکس شد."
 
196
 
197
  def retrieve(self, query: str) -> List[Dict]:
198
+ if not self.collection:
199
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  try:
201
+ res = self.collection.query(
202
+ query_texts=[query],
203
+ n_results=self.cfg.top_k,
204
+ include=["documents","metadatas","distances"],
205
+ )
206
+ out = []
207
+ docs = res.get("documents", [[]])[0]
208
+ metas = res.get("metadatas", [[]])[0]
209
+ dists = res.get("distances", [[1.0]])[0]
210
+ for i, (doc, meta, dist) in enumerate(zip(docs, metas, dists)):
211
+ sim = 1.0 - float(dist)
212
+ if sim >= self.cfg.similarity_threshold:
213
+ out.append({
214
+ "article_id": (meta or {}).get("article_id", f"unk_{i}"),
215
+ "text": doc,
216
+ "similarity": sim,
217
+ })
218
+ return out
219
+ except Exception:
220
+ return []
 
221
 
222
+ def build_context(self, arts: List[Dict]) -> str:
223
+ if not arts:
224
+ return ""
225
+ bullets = [f"• ماده {a['article_id']}: {a['text'][:self.cfg.context_char_limit]}..." for a in arts]
226
+ return "مواد مرتبط:\n" + "\n".join(bullets)
227
 
228
  # ==========================
229
+ # Loader + Generator
230
  # ==========================
231
  class ModelLoader:
232
+ def __init__(self, mcfg: ModelConfig):
233
+ self.cfg = mcfg
234
  self.tokenizer = None
235
  self.model = None
 
 
 
 
236
 
237
+ def load(self):
238
+ self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
239
+ # dtype انتخاب هوشمند
240
+ use_bf16 = bf16_supported() and self.cfg.gradient_checkpointing
241
+ dtype = torch.bfloat16 if use_bf16 else (torch.float16 if torch.cuda.is_available() else None)
242
+ model_kwargs = {"torch_dtype": dtype}
243
  if torch.cuda.is_available():
244
+ model_kwargs["device_map"] = "auto"
245
+
246
+ if self.cfg.architecture == "seq2seq":
247
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.cfg.model_name, **model_kwargs)
248
+ elif self.cfg.architecture == "causal":
249
+ self.model = AutoModelForCausalLM.from_pretrained(self.cfg.model_name, **model_kwargs)
250
+ if self.tokenizer.pad_token is None and hasattr(self.tokenizer, "eos_token"):
251
+ self.tokenizer.pad_token = self.tokenizer.eos_token
252
+ else:
253
+ raise ValueError("Unsupported architecture")
254
 
255
+ if self.cfg.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"):
 
 
 
 
 
 
 
 
 
 
 
 
256
  try:
257
+ self.model.gradient_checkpointing_enable()
258
+ except Exception:
259
+ pass
260
+ return self
261
+
262
+ class Generator:
263
+ def __init__(self, loader: ModelLoader, mcfg: ModelConfig):
264
+ self.tk = loader.tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  self.model = loader.model
266
+ self.cfg = mcfg
267
+
268
+ def generate(self, question: str, context: str = "") -> str:
269
+ if self.cfg.architecture == "seq2seq":
270
+ inp = f"{context}\nسوال: {question}" if context else f"سوال: {question}"
271
+ enc = self.tk(inp, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
272
+ enc = {k: v.to(self.model.device) for k,v in enc.items()}
273
+ out = self.model.generate(
274
+ **enc,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  max_length=self.cfg.max_target_length,
276
  num_beams=self.cfg.num_beams,
277
  early_stopping=True,
 
 
278
  )
279
+ else:
280
+ prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
281
+ enc = self.tk(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_input_length)
282
+ enc = {k: v.to(self.model.device) for k,v in enc.items()}
283
+ out = self.model.generate(
284
+ **enc,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  max_new_tokens=self.cfg.max_new_tokens,
286
  do_sample=True,
287
+ temperature=self.cfg.temperature,
288
  top_p=self.cfg.top_p,
289
+ pad_token_id=self.tk.pad_token_id or self.tk.eos_token_id,
 
 
290
  )
291
+ return self.tk.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  # ==========================
294
+ # Datasets
295
  # ==========================
296
  class Seq2SeqJSONLDataset(Dataset):
297
+ def __init__(self, data: List[Dict], tokenizer, max_inp: int, max_tgt: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10):
298
+ self.tk = tokenizer
299
+ self.max_inp = max_inp
300
+ self.max_tgt = max_tgt
 
 
301
  self.items = []
302
+ for i, ex in enumerate(data):
303
+ src = str(ex.get("input", "")).strip()
304
+ tgt = str(ex.get("output", "")).strip()
305
+ if not src or not tgt:
306
+ continue
307
+ inp = src
308
+ if rag and i % enhance_every == 0:
309
+ arts = rag.retrieve(src)
310
+ ctx = rag.build_context(arts)
311
+ if ctx:
312
+ inp = f"<CONTEXT>{ctx}</CONTEXT>\n<QUESTION>{src}</QUESTION>"
313
+ self.items.append((inp, tgt))
314
 
315
  def __len__(self):
316
  return len(self.items)
317
 
318
  def __getitem__(self, idx):
319
+ inp, tgt = self.items[idx]
320
+ model_inputs = self.tk(inp, max_length=self.max_inp, padding="max_length", truncation=True)
321
+ labels = self.tk(text_target=tgt, max_length=self.max_tgt, padding="max_length", truncation=True)
322
+ model_inputs["labels"] = labels["input_ids"]
323
+ return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
  class CausalJSONLDataset(Dataset):
326
+ def __init__(self, data: List[Dict], tokenizer, max_inp: int, rag: Optional[LegalRAG] = None, enhance_every:int = 10):
327
+ self.tk = tokenizer
328
+ self.max_inp = max_inp
 
 
329
  self.items = []
330
+ for i, ex in enumerate(data):
331
+ src = str(ex.get("input", "")).strip()
332
+ tgt = str(ex.get("output", "")).strip()
333
+ if not src or not tgt:
334
+ continue
335
+ ctx = ""
336
+ if rag and i % enhance_every == 0:
337
+ arts = rag.retrieve(src)
338
+ ctx = rag.build_context(arts)
339
+ text = f"{ctx}\nسوال: {src}\nپاسخ: {tgt}" if ctx else f"سوال: {src}\nپاسخ: {tgt}"
340
+ self.items.append(text)
341
 
342
  def __len__(self):
343
  return len(self.items)
344
 
345
  def __getitem__(self, idx):
346
  text = self.items[idx]
347
+ enc = self.tk(text, max_length=self.max_inp, padding="max_length", truncation=True)
348
+ input_ids = torch.tensor(enc["input_ids"])
349
+ attn = torch.tensor(enc["attention_mask"])
 
 
 
 
 
 
 
 
 
350
  labels = input_ids.clone()
351
+ labels[attn == 0] = -100 # padding mask for loss
352
+ return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}
 
 
 
 
 
353
 
354
  # ==========================
355
+ # Metrics
356
  # ==========================
357
+ def build_metrics_fn(arch: str, tokenizer):
358
+ rouge = eval_load("rouge") if eval_load else None
359
+
360
+ def _postprocess(preds):
361
+ if isinstance(preds, (list, tuple)):
362
+ return [p.strip() for p in preds]
363
+ return preds
364
+
365
+ def compute_metrics_seq2seq(eval_pred):
366
+ if rouge is None:
367
+ return {"rougeL": 0.0}
368
+ preds, labels = eval_pred
369
+ if isinstance(preds, tuple):
370
+ preds = preds[0]
371
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
372
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
373
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
374
+ decoded_preds = _postprocess(decoded_preds)
375
+ decoded_labels = _postprocess(decoded_labels)
376
+ r = rouge.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=["rougeL"])
377
+ return {"rougeL": float(r.get("rougeL", 0.0))}
378
+
379
+ def compute_metrics_causal(eval_pred):
380
+ preds, labels = eval_pred
381
+ if isinstance(preds, tuple):
382
+ preds = preds[0]
383
+ decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
384
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
385
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
386
+ tp = fp = fn = 0
387
+ for p, g in zip(decoded_preds, decoded_labels):
388
+ p_set, g_set = set(p.split()), set(g.split())
389
+ tp += len(p_set & g_set)
390
+ fp += len(p_set - g_set)
391
+ fn += len(g_set - p_set)
392
+ precision = tp / (tp + fp + 1e-8)
393
+ recall = tp / (tp + fn + 1e-8)
394
+ f1 = 2 * precision * recall / (precision + recall + 1e-8)
395
+ return {"f1_simple": float(f1)}
396
+
397
+ return compute_metrics_seq2seq if arch == "seq2seq" else compute_metrics_causal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  # ==========================
400
+ # Trainer Manager
401
  # ==========================
402
+ def read_jsonl_files(paths: List[str]) -> List[Dict]:
403
+ data: List[Dict] = []
404
+ for p in paths:
405
+ if not p:
406
+ continue
407
+ with open(p, 'r', encoding='utf-8') as f:
408
+ for line in f:
409
+ s = line.strip()
410
+ if not s:
411
+ continue
412
+ try:
413
+ obj = json.loads(s)
414
+ data.append(obj)
415
+ except json.JSONDecodeError:
416
+ continue
417
+ return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
+ class TrainerManager:
420
+ def __init__(self, syscfg: SystemConfig, loader: ModelLoader):
421
+ self.cfg = syscfg
422
+ self.loader = loader
423
 
424
+ def _args_common(self, is_seq2seq: bool):
425
+ fp16_ok = torch.cuda.is_available() and (not self.cfg.train.use_bf16)
426
+ bf16_ok = bf16_supported() and self.cfg.train.use_bf16
427
+
428
+ args = safe_training_args(
429
+ output_dir=self.cfg.train.output_dir,
430
+ num_train_epochs=self.cfg.train.epochs,
431
+ learning_rate=self.cfg.train.lr,
432
+ per_device_train_batch_size=self.cfg.train.batch_size,
433
+ per_device_eval_batch_size=self.cfg.train.batch_size,
434
+ gradient_accumulation_steps=self.cfg.train.grad_accum,
435
+ warmup_ratio=self.cfg.train.warmup_ratio,
436
+ weight_decay=self.cfg.train.weight_decay,
437
+ evaluation_strategy=self.cfg.train.eval_strategy,
438
+ save_strategy=self.cfg.train.save_strategy,
439
+ save_total_limit=self.cfg.train.save_total_limit,
 
 
 
440
  load_best_model_at_end=True,
441
  metric_for_best_model="eval_loss",
442
+ logging_steps=self.cfg.train.logging_steps,
443
+ report_to=([] if self.cfg.train.report_to == "none" else [self.cfg.train.report_to]),
444
+ fp16=fp16_ok,
445
+ bf16=bf16_ok,
446
+ max_grad_norm=self.cfg.train.max_grad_norm,
447
+ **({
448
+ "predict_with_generate": True,
449
+ "generation_max_length": self.cfg.model.max_target_length,
450
+ "generation_num_beams": self.cfg.model.num_beams
451
+ } if is_seq2seq else {})
452
  )
453
+ return args
454
+
455
+ def train_seq2seq(self, train_paths: List[str], use_rag: bool = True):
456
+ set_seed_all(self.cfg.train.seed)
457
+ data = read_jsonl_files(train_paths)
458
+ train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
459
+
460
+ rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None
461
+ if rag:
462
+ rag.init()
463
+
464
+ ds_tr = Seq2SeqJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, rag)
465
+ ds_va = Seq2SeqJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, self.cfg.model.max_target_length, None)
466
+ collator = DataCollatorForSeq2Seq(tokenizer=self.loader.tokenizer, model=self.loader.model)
467
+
468
+ args = self._args_common(is_seq2seq=True)
469
+ trainer = Trainer(
470
+ model=self.loader.model,
471
+ args=args,
472
+ train_dataset=ds_tr,
473
+ eval_dataset=ds_va,
474
+ data_collator=collator,
475
+ tokenizer=self.loader.tokenizer,
476
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
477
+ compute_metrics=build_metrics_fn("seq2seq", self.loader.tokenizer)
478
+ )
479
+ trainer.train()
480
+ trainer.save_model(self.cfg.train.output_dir)
481
+ self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
482
+
483
+ def train_causal(self, train_paths: List[str], use_rag: bool = True):
484
+ set_seed_all(self.cfg.train.seed)
485
+ data = read_jsonl_files(train_paths)
486
+ train, val = train_test_split(data, test_size=self.cfg.train.test_size, random_state=self.cfg.train.seed)
487
+
488
+ rag = LegalRAG(self.cfg.rag) if (use_rag and self.cfg.rag.enable) else None
489
+ if rag:
490
+ rag.init()
491
+
492
+ ds_tr = CausalJSONLDataset(train, self.loader.tokenizer, self.cfg.model.max_input_length, rag)
493
+ ds_va = CausalJSONLDataset(val, self.loader.tokenizer, self.cfg.model.max_input_length, None)
494
+
495
+ args = self._args_common(is_seq2seq=False)
496
+ trainer = Trainer(
497
+ model=self.loader.model,
498
+ args=args,
499
+ train_dataset=ds_tr,
500
+ eval_dataset=ds_va,
501
+ tokenizer=self.loader.tokenizer,
502
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
503
+ compute_metrics=build_metrics_fn("causal", self.loader.tokenizer)
504
+ )
505
+ trainer.train()
506
+ trainer.save_model(self.cfg.train.output_dir)
507
+ self.loader.tokenizer.save_pretrained(self.cfg.train.output_dir)
508
 
509
  # ==========================
510
+ # App (Gradio 5)
511
  # ==========================
512
  class LegalApp:
513
+ def __init__(self, scfg: Optional[SystemConfig] = None):
514
+ self.scfg = scfg or SystemConfig()
515
+ self.rag = LegalRAG(self.scfg.rag)
516
+ self.loader: Optional[ModelLoader] = None
517
+ self.gen: Optional[Generator] = None
518
+
519
+ # --- helpers ---
520
+ def _file_paths(self, files: List[gr.File]) -> List[str]:
521
+ paths = []
522
+ for f in (files or []):
523
+ p = getattr(f, "name", None) or getattr(f, "path", None)
524
+ if p:
525
+ paths.append(p)
526
+ return paths
527
+
528
+ # --- core actions ---
529
+ def load(self, model_name: str, arch: str, use_rag: bool, persist_dir: str, collection: str, top_k: int, threshold: float):
530
+ # configure
531
+ self.scfg.model.model_name = model_name
532
+ self.scfg.model.architecture = arch
533
+ self.scfg.rag.persist_dir = persist_dir
534
+ self.scfg.rag.collection = collection
535
+ self.scfg.rag.top_k = int(top_k)
536
+ self.scfg.rag.similarity_threshold = float(threshold)
537
+ self.scfg.rag.enable = bool(use_rag)
538
+
539
+ # load model
540
+ self.loader = ModelLoader(self.scfg.model).load()
541
+ self.gen = Generator(self.loader, self.scfg.model)
542
+
543
+ # load rag
544
+ msg_rag = "RAG غیرفعال"
545
+ if use_rag:
546
  try:
547
+ self.rag = LegalRAG(self.scfg.rag)
548
+ self.rag.init()
549
+ msg_rag = "RAG آماده است"
 
 
 
 
 
 
 
550
  except Exception as e:
551
+ msg_rag = f"RAG خطا: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
+ return f"مدل بارگذاری شد: {model_name} ({arch})\n{msg_rag}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
+ def build_index(self, laws_file: gr.File, id_key: str, text_key: str):
556
+ if not self.scfg.rag.enable:
557
+ return "RAG غیرفعال است."
558
  try:
559
+ self.rag.init()
560
+ p = getattr(laws_file, "name", None) or getattr(laws_file, "path", None)
561
+ if not p:
562
+ return "فایل قوانین معتبر نیست."
563
+ res = self.rag.index_jsonl(p, id_key=id_key, text_key=text_key)
564
+ return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  except Exception as e:
566
+ return f"خطا در ایندکس: {e}"
567
+
568
+ def answer(self, question: str, use_rag: bool, max_new_tokens: int, temperature: float, top_p: float, num_beams: int):
569
+ if not question.strip():
570
+ return "لطفاً سوال خود را وارد کنید.", ""
571
+ if not self.gen:
572
+ return "ابتدا مدل/RAG را بارگذاری کنید.", ""
573
+ # runtime params
574
+ self.scfg.model.max_new_tokens = int(max_new_tokens)
575
+ self.scfg.model.temperature = float(temperature)
576
+ self.scfg.model.top_p = float(top_p)
577
+ self.scfg.model.num_beams = int(num_beams)
578
+ arts = self.rag.retrieve(question) if (use_rag and self.scfg.rag.enable and self.rag.collection) else []
579
+ ctx = self.rag.build_context(arts) if arts else ""
580
+ ans = self.gen.generate(question, ctx)
581
+ refs = ""
582
+ if arts:
583
+ refs = "\n\n" + "\n".join([f"**ماده {a['article_id']}** (شباهت: {a['similarity']:.2f})\n{a['text'][:380]}..." for a in arts])
584
+ return ans, refs
585
+
586
+ def train(self, model_name: str, arch: str, files: List[gr.File], use_rag: bool, epochs: int, batch: int, lr: float,
587
+ wd: float, warmup: float, report_to: str, progress=gr.Progress(track_tqdm=True)):
588
+ progress(0.0, desc="راه‌اندازی")
589
+ self.scfg.model.model_name = model_name
590
+ self.scfg.model.architecture = arch
591
+ self.scfg.train.epochs = int(epochs)
592
+ self.scfg.train.batch_size = int(batch)
593
+ self.scfg.train.lr = float(lr)
594
+ self.scfg.train.weight_decay = float(wd)
595
+ self.scfg.train.warmup_ratio = float(warmup)
596
+ self.scfg.train.report_to = report_to
597
+
598
+ progress(0.1, desc="بارگذاری مدل/توکنایزر")
599
+ self.loader = ModelLoader(self.scfg.model).load()
600
+
601
+ paths = self._file_paths(files)
602
+ if not paths:
603
+ return "⚠️ هیچ فایل JSONL برای آموزش انتخاب نشده."
604
+
605
+ tm = TrainerManager(self.scfg, self.loader)
606
+ set_seed_all(self.scfg.train.seed)
607
+
608
+ progress(0.3, desc="آماده‌سازی دیتاست‌ها و RAG")
609
+ if arch == "seq2seq":
610
+ tm.train_seq2seq(paths, use_rag=use_rag)
611
+ else:
612
+ tm.train_causal(paths, use_rag=use_rag)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
613
 
614
+ progress(0.95, desc="ذخیرهٔ آرتیفکت‌ها")
615
+ return f"✅ آموزش کامل شد و در {self.scfg.train.output_dir} ذخیره شد."
616
 
617
+ # --- UI ---
618
+ def build_ui(self):
619
+ log_deps()
620
+ default_models = {
621
  "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
622
+ "Seq2Seq (t5-fa-base)": ("HooshvareLab/t5-fa-base", "seq2seq"),
623
+ "Seq2Seq (flan-t5-base)": ("google/flan-t5-base", "seq2seq"),
624
+ "Causal (Mistral-7B Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
625
  }
626
 
627
+ with gr.Blocks(title="ماحون مشاور حقوقی هوشمند", theme=gr.themes.Soft(primary_hue="green", secondary_hue="gray")) as app:
 
 
 
 
 
 
 
 
 
 
 
 
628
  gr.HTML("""
629
+ <div style='text-align:center;padding:18px'>
630
+ <h1 style='margin-bottom:4px'>ماحون — Ultimate Legal AI</h1>
631
+ <p style='color:#666'>RAG Seq2Seq/Causal Training Metrics</p>
632
  </div>
633
  """)
634
 
635
+ with gr.Tab("مشاوره"):
636
+ with gr.Row():
637
+ model_dd = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
638
+ gr.Markdown("**راهنما:** Seq2Seq برای پاسخ‌های ساختاریافته؛ Causal برای مکالمه طبیعی‌تر.")
639
+ with gr.Row():
640
+ use_rag = gr.Checkbox(value=True, label="RAG فعال باشد؟")
641
+ persist_dir = gr.Textbox(value=self.scfg.rag.persist_dir, label="مسیر ChromaDB")
642
+ collection = gr.Textbox(value=self.scfg.rag.collection, label="نام کالکشن")
643
+ with gr.Row():
644
+ top_k = gr.Slider(1, 15, value=self.scfg.rag.top_k, step=1, label="Top-K")
645
+ threshold = gr.Slider(0.3, 0.95, value=self.scfg.rag.similarity_threshold, step=0.01, label="آستانه شباهت")
646
+ load_btn = gr.Button("بارگذاری مدل/RAG", variant="primary")
647
+ status = gr.Textbox(label="وضعیت", interactive=False)
648
+
649
+ with gr.Accordion("ساخت ایندکس قوانین (اختیاری)", open=False):
650
+ laws_file = gr.File(label="فایل JSONL قوانین", file_types=[".jsonl"])
651
+ id_key = gr.Textbox(value="article_id", label="کلید شناسه ماده")
652
+ text_key = gr.Textbox(value="text", label="کلید متن ماده")
653
+ index_btn = gr.Button("ایندکس‌سازی قوانین")
654
+ index_status = gr.Textbox(label="وضعیت ایندکس", interactive=False)
655
+
656
+ with gr.Accordion("پارامترهای تولید", open=False):
657
+ max_new_tokens = gr.Slider(64, 1024, value=self.scfg.model.max_new_tokens, step=16, label="max_new_tokens")
658
+ temperature = gr.Slider(0.0, 1.5, value=self.scfg.model.temperature, step=0.05, label="temperature")
659
+ top_p = gr.Slider(0.1, 1.0, value=self.scfg.model.top_p, step=0.05, label="top_p")
660
+ num_beams = gr.Slider(1, 8, value=self.scfg.model.num_beams, step=1, label="num_beams (Seq2Seq)")
661
+
662
+ question = gr.Textbox(lines=3, label="سوال حقوقی")
663
+ gr.Examples(
664
+ examples=[
665
+ ["در صورت نقض قرارداد فروش، چه اقداماتی باید انجام دهم؟"],
666
+ ["آیا درج شرط عدم رقابت در قرارداد کار قانونی است؟"],
667
+ ["حق و حقوق کارگر در صورت اخراج فوری چیست؟"],
668
+ ["فرآیند طرح دعوای مطالبه مهریه چگونه است؟"],
669
+ ],
670
+ inputs=question, label="نمونه پرسش‌ها"
671
  )
672
+ ask_btn = gr.Button("پرسش", variant="primary")
673
+ answer = gr.Markdown(label="پاسخ")
674
+ refs = gr.Markdown(label="مواد قانونی مرتبط")
675
+
676
+ with gr.Tab("آموزش"):
677
+ gr.Markdown("فایل‌های JSONL با کلیدهای `input` و `output` را بارگذاری کنید.")
678
+ with gr.Row():
679
+ model_dd_train = gr.Dropdown(choices=list(default_models.keys()), value="Seq2Seq (mt5-base)", label="مدل")
680
+ use_rag_train = gr.Checkbox(value=True, label="RAG-enhanced Training")
681
+ train_files = gr.Files(label="JSONL Files", file_count="multiple", file_types=[".jsonl"])
682
+ with gr.Row():
683
+ epochs = gr.Slider(1, 8, value=self.scfg.train.epochs, step=1, label="epochs")
684
+ batch = gr.Slider(1, 16, value=self.scfg.train.batch_size, step=1, label="batch per device")
685
+ lr = gr.Number(value=self.scfg.train.lr, label="learning rate")
686
+ with gr.Row():
687
+ wd = gr.Number(value=self.scfg.train.weight_decay, label="weight decay")
688
+ warmup = gr.Slider(0.0, 0.2, value=self.scfg.train.warmup_ratio, step=0.01, label="warmup ratio")
689
+ report_to = gr.Dropdown(choices=["none","wandb"], value=self.scfg.train.report_to, label="report_to")
690
+ train_btn = gr.Button("شروع آموزش", variant="primary")
691
+ train_status = gr.Textbox(label="وضعیت آموزش", interactive=False)
692
+
693
+ # رویدادها
694
+ def _resolve(choice: str) -> Tuple[str,str]:
695
+ return default_models[choice]
696
+
697
+ load_btn.click(lambda choice, rag, pdir, coll, k, th: self.load(*_resolve(choice), rag, pdir, coll, k, th),
698
+ inputs=[model_dd, use_rag, persist_dir, collection, top_k, threshold], outputs=status)
699
+
700
+ ask_btn.click(lambda q, rag, mnt, t, p, nb: self.answer(q, rag, mnt, t, p, nb),
701
+ inputs=[question, use_rag, max_new_tokens, temperature, top_p, num_beams],
702
+ outputs=[answer, refs])
703
+
704
+ index_btn.click(lambda f, ik, tk: self.build_index(f, ik, tk),
705
+ inputs=[laws_file, id_key, text_key], outputs=index_status)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
 
707
  train_btn.click(
708
+ lambda choice, files, rag, e, b, l, _wd, _wu, _r:
709
+ self.train(*_resolve(choice), files, rag, e, b, l, _wd, _wu, _r),
710
+ inputs=[model_dd_train, train_files, use_rag_train, epochs, batch, lr, wd, warmup, report_to],
 
 
 
 
 
 
 
 
 
 
711
  outputs=train_status
712
  )
 
713
  return app
714
 
715
  # ==========================
716
+ # Entrypoint for HF Spaces
717
  # ==========================
718
+ if __name__ == "__main__":
 
 
719
  app = LegalApp()
 
 
720
  ui = app.build_ui()
721
+ ui.queue(concurrency_count=2).launch(server_name="0.0.0.0", server_port=7860)