Files changed (1) hide show
  1. app(8).py +1616 -0
app(8).py ADDED
@@ -0,0 +1,1616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Mahoon Legal AI — Enhanced Version
4
+ Features:
5
+ - Improved memory management and resource cleanup
6
+ - Caching system for models and embeddings
7
+ - Enhanced security and input validation
8
+ - Better error handling and logging
9
+ - Metrics and monitoring
10
+ - Thread safety improvements
11
+ - Configuration validation with Pydantic
12
+ - Comprehensive testing support
13
+ - Gradio UI with advanced features
14
+ """
15
+
16
+ from __future__ import annotations
17
+ import os
18
+ import json
19
+ import warnings
20
+ import hashlib
21
+ import threading
22
+ import time
23
+ from contextlib import contextmanager
24
+ from dataclasses import dataclass, field
25
+ from pathlib import Path
26
+ from typing import List, Dict, Optional, Tuple, Any, Union
27
+ from datetime import datetime
28
+ from functools import lru_cache
29
+ import logging
30
+
31
+ import torch
32
+ from torch.utils.data import Dataset
33
+ from sklearn.model_selection import train_test_split
34
+ from pydantic import BaseModel, validator, Field
35
+
36
+ from transformers import (
37
+ AutoTokenizer,
38
+ AutoModelForSeq2SeqLM,
39
+ AutoModelForCausalLM,
40
+ Trainer,
41
+ TrainingArguments,
42
+ EarlyStoppingCallback,
43
+ DataCollatorForSeq2Seq,
44
+ TrainerCallback
45
+ )
46
+
47
+ import chromadb
48
+ from sentence_transformers import SentenceTransformer
49
+ import gradio as gr
50
+
51
+ warnings.filterwarnings("ignore")
52
+
53
+ # Configure logging
54
+ logging.basicConfig(
55
+ level=logging.INFO,
56
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
57
+ )
58
+ logger = logging.getLogger(__name__)
59
+
60
+ # ==========================
61
+ # Enhanced Config with Validation
62
+ # ==========================
63
+ class ModelConfig(BaseModel):
64
+ model_name: str = "persiannlp/parsi-t5-base"
65
+ architecture: str = "seq2seq"
66
+ max_input_length: int = Field(default=1024, ge=64, le=4096)
67
+ max_target_length: int = Field(default=512, ge=32, le=2048)
68
+ max_new_tokens: int = Field(default=512, ge=32, le=1024)
69
+ temperature: float = Field(default=0.7, ge=0.0, le=2.0)
70
+ top_p: float = Field(default=0.9, ge=0.1, le=1.0)
71
+ num_beams: int = Field(default=4, ge=1, le=8)
72
+ use_bf16: bool = True
73
+
74
+ @validator('architecture')
75
+ def validate_architecture(cls, v):
76
+ if v not in ['seq2seq', 'causal']:
77
+ raise ValueError('architecture must be seq2seq or causal')
78
+ return v
79
+
80
+ class Config:
81
+ validate_assignment = True
82
+
83
+ class SystemConfig(BaseModel):
84
+ model: ModelConfig = Field(default_factory=ModelConfig)
85
+ embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
86
+ chroma_db_path: str = "./chroma.sqlite3"
87
+ top_k_retrieval: int = Field(default=5, ge=1, le=20)
88
+ similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
89
+ cache_dir: str = "./cache"
90
+ output_dir: str = "./mahoon_legal_model"
91
+ seed: int = 42
92
+ train_test_ratio: float = Field(default=0.1, ge=0.05, le=0.3)
93
+ batch_size: int = Field(default=2, ge=1, le=16)
94
+ grad_accum: int = Field(default=2, ge=1, le=8)
95
+ epochs: int = Field(default=2, ge=1, le=10)
96
+ lr: float = Field(default=3e-5, ge=1e-6, le=1e-3)
97
+ max_file_size_mb: int = Field(default=10, ge=1, le=100)
98
+ max_lines_per_file: int = Field(default=10000, ge=100, le=100000)
99
+ request_timeout: int = Field(default=30, ge=5, le=300)
100
+
101
+ class Config:
102
+ validate_assignment = True
103
+
104
+ # ==========================
105
+ # Metrics and Monitoring
106
+ # ==========================
107
+ @dataclass
108
+ class SystemMetrics:
109
+ requests_count: int = 0
110
+ avg_response_time: float = 0.0
111
+ error_count: int = 0
112
+ success_count: int = 0
113
+ memory_usage_mb: float = 0.0
114
+ last_updated: datetime = field(default_factory=datetime.now)
115
+ active_models: List[str] = field(default_factory=list)
116
+
117
+ class MetricsCollector:
118
+ def __init__(self):
119
+ self.metrics = SystemMetrics()
120
+ self._lock = threading.Lock()
121
+
122
+ def record_request(self, response_time: float, success: bool = True):
123
+ with self._lock:
124
+ self.metrics.requests_count += 1
125
+ if success:
126
+ self.metrics.success_count += 1
127
+ else:
128
+ self.metrics.error_count += 1
129
+
130
+ # Update average response time
131
+ total_requests = self.metrics.requests_count
132
+ old_avg = self.metrics.avg_response_time
133
+ self.metrics.avg_response_time = (old_avg * (total_requests - 1) + response_time) / total_requests
134
+ self.metrics.last_updated = datetime.now()
135
+
136
+ def update_memory_usage(self):
137
+ if torch.cuda.is_available():
138
+ memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
139
+ self.metrics.memory_usage_mb = memory_mb
140
+
141
+ def get_metrics(self) -> Dict[str, Any]:
142
+ with self._lock:
143
+ return {
144
+ "requests_total": self.metrics.requests_count,
145
+ "success_rate": self.metrics.success_count / max(self.metrics.requests_count, 1) * 100,
146
+ "avg_response_time": round(self.metrics.avg_response_time, 2),
147
+ "error_count": self.metrics.error_count,
148
+ "memory_usage_mb": round(self.metrics.memory_usage_mb, 2),
149
+ "active_models": self.metrics.active_models.copy(),
150
+ "last_updated": self.metrics.last_updated.isoformat()
151
+ }
152
+
153
+ # Global metrics instance
154
+ metrics = MetricsCollector()
155
+
156
+ # ==========================
157
+ # Enhanced Utilities
158
+ # ==========================
159
+ def set_seed_all(seed: int = 42):
160
+ import random
161
+ random.seed(seed)
162
+ torch.manual_seed(seed)
163
+ torch.cuda.manual_seed_all(seed)
164
+ logger.info(f"Set random seed to {seed}")
165
+
166
+ def validate_file_security(file_path: str, max_size_mb: int = 10, max_lines: int = 10000) -> Tuple[bool, str]:
167
+ """Enhanced file validation with security checks"""
168
+ try:
169
+ path = Path(file_path)
170
+
171
+ # Check if file exists and is readable
172
+ if not path.exists() or not path.is_file():
173
+ return False, "فایل وجود ندارد یا قابل خواندن نیست"
174
+
175
+ # Check file extension
176
+ if path.suffix.lower() != '.jsonl':
177
+ return False, "فقط فایل‌های .jsonl پذیرفته می‌شوند"
178
+
179
+ # Check file size
180
+ size_mb = path.stat().st_size / (1024 * 1024)
181
+ if size_mb > max_size_mb:
182
+ return False, f"حجم فایل نباید از {max_size_mb} مگابایت بیشتر باشد"
183
+
184
+ # Validate content structure
185
+ line_count = 0
186
+ with open(path, 'r', encoding='utf-8') as f:
187
+ for line_num, line in enumerate(f, 1):
188
+ line = line.strip()
189
+ if not line:
190
+ continue
191
+
192
+ line_count += 1
193
+ if line_count > max_lines:
194
+ return False, f"فایل نباید بیش از {max_lines} خط داشته باشد"
195
+
196
+ # Validate JSON structure
197
+ try:
198
+ data = json.loads(line)
199
+ if not isinstance(data, dict):
200
+ return False, f"خط {line_num}: فرمت JSON نامعتبر"
201
+
202
+ if 'input' not in data or 'output' not in data:
203
+ return False, f"خط {line_num}: کلیدهای 'input' و 'output' الزامی هستند"
204
+
205
+ # Check content length
206
+ if len(str(data['input'])) > 2048 or len(str(data['output'])) > 2048:
207
+ return False, f"خط {line_num}: طول محتوا بیش از حد مجاز"
208
+
209
+ except json.JSONDecodeError:
210
+ return False, f"خط {line_num}: فرمت JSON نامعتبر"
211
+
212
+ if line_count == 0:
213
+ return False, "فایل خالی است"
214
+
215
+ return True, f"فایل معتبر است ({line_count} خط)"
216
+
217
+ except Exception as e:
218
+ logger.error(f"File validation error: {e}")
219
+ return False, f"خطا در بررسی فایل: {str(e)}"
220
+
221
+ def read_jsonl_files_safe(paths: List[str], cfg: SystemConfig) -> Tuple[List[Dict], List[str]]:
222
+ """Safe JSONL file reading with validation"""
223
+ data: List[Dict] = []
224
+ errors: List[str] = []
225
+
226
+ for path in paths:
227
+ # Validate file first
228
+ is_valid, msg = validate_file_security(path, cfg.max_file_size_mb, cfg.max_lines_per_file)
229
+ if not is_valid:
230
+ errors.append(f"{Path(path).name}: {msg}")
231
+ continue
232
+
233
+ try:
234
+ with open(path, 'r', encoding='utf-8') as f:
235
+ for line_num, line in enumerate(f, 1):
236
+ line = line.strip()
237
+ if not line:
238
+ continue
239
+
240
+ try:
241
+ obj = json.loads(line)
242
+ # Sanitize input
243
+ obj['input'] = str(obj['input']).strip()
244
+ obj['output'] = str(obj['output']).strip()
245
+
246
+ if obj['input'] and obj['output']:
247
+ data.append(obj)
248
+ except json.JSONDecodeError:
249
+ errors.append(f"{Path(path).name} line {line_num}: JSON decode error")
250
+
251
+ except Exception as e:
252
+ errors.append(f"{Path(path).name}: {str(e)}")
253
+
254
+ logger.info(f"Loaded {len(data)} samples from {len(paths)} files")
255
+ return data, errors
256
+
257
+ # ==========================
258
+ # Model Cache System
259
+ # ==========================
260
+ class ModelCache:
261
+ _instances: Dict[str, Any] = {}
262
+ _lock = threading.Lock()
263
+ _access_times: Dict[str, float] = {}
264
+ _max_cache_size = 3 # Maximum models to keep in cache
265
+
266
+ @classmethod
267
+ def _generate_key(cls, model_name: str, architecture: str) -> str:
268
+ return hashlib.md5(f"{model_name}_{architecture}".encode()).hexdigest()[:16]
269
+
270
+ @classmethod
271
+ def get_model(cls, model_name: str, architecture: str, model_config: ModelConfig):
272
+ key = cls._generate_key(model_name, architecture)
273
+
274
+ with cls._lock:
275
+ if key in cls._instances:
276
+ cls._access_times[key] = time.time()
277
+ logger.info(f"Model loaded from cache: {model_name}")
278
+ return cls._instances[key]
279
+
280
+ # Cleanup old models if cache is full
281
+ if len(cls._instances) >= cls._max_cache_size:
282
+ cls._cleanup_cache()
283
+
284
+ # Load new model
285
+ try:
286
+ loader = ModelLoader(model_config)
287
+ loader.load()
288
+ cls._instances[key] = loader
289
+ cls._access_times[key] = time.time()
290
+
291
+ # Update metrics
292
+ if model_name not in metrics.metrics.active_models:
293
+ metrics.metrics.active_models.append(model_name)
294
+
295
+ logger.info(f"Model loaded and cached: {model_name}")
296
+ return loader
297
+
298
+ except Exception as e:
299
+ logger.error(f"Failed to load model {model_name}: {e}")
300
+ raise
301
+
302
+ @classmethod
303
+ def _cleanup_cache(cls):
304
+ """Remove least recently used model"""
305
+ if not cls._access_times:
306
+ return
307
+
308
+ # Find least recently used model
309
+ lru_key = min(cls._access_times.keys(), key=lambda k: cls._access_times[k])
310
+
311
+ # Clean up resources
312
+ if lru_key in cls._instances:
313
+ loader = cls._instances[lru_key]
314
+ cls._cleanup_model_resources(loader)
315
+ del cls._instances[lru_key]
316
+ del cls._access_times[lru_key]
317
+ logger.info(f"Removed model from cache: {lru_key}")
318
+
319
+ @classmethod
320
+ def _cleanup_model_resources(cls, loader):
321
+ """Clean up model resources"""
322
+ try:
323
+ if hasattr(loader, 'model') and hasattr(loader.model, 'cpu'):
324
+ loader.model.cpu()
325
+ if torch.cuda.is_available():
326
+ torch.cuda.empty_cache()
327
+ except Exception as e:
328
+ logger.warning(f"Error cleaning up model resources: {e}")
329
+
330
+ @classmethod
331
+ def clear_cache(cls):
332
+ """Clear all cached models"""
333
+ with cls._lock:
334
+ for loader in cls._instances.values():
335
+ cls._cleanup_model_resources(loader)
336
+ cls._instances.clear()
337
+ cls._access_times.clear()
338
+ metrics.metrics.active_models.clear()
339
+ logger.info("Model cache cleared")
340
+
341
+ # ==========================
342
+ # Enhanced RAG System
343
+ # ==========================
344
+ class LegalRAGSystem:
345
+ def __init__(self, cfg: SystemConfig):
346
+ self.cfg = cfg
347
+ self.embedding_model: Optional[SentenceTransformer] = None
348
+ self.client = None
349
+ self.collection = None
350
+ self._lock = threading.Lock()
351
+
352
+ @contextmanager
353
+ def _safe_operation(self, operation_name: str):
354
+ """Context manager for safe RAG operations"""
355
+ start_time = time.time()
356
+ try:
357
+ yield
358
+ except Exception as e:
359
+ logger.error(f"RAG {operation_name} failed: {e}")
360
+ metrics.record_request(time.time() - start_time, success=False)
361
+ raise
362
+ else:
363
+ metrics.record_request(time.time() - start_time, success=True)
364
+
365
+ def setup_embedding(self):
366
+ if self.embedding_model is None:
367
+ try:
368
+ self.embedding_model = SentenceTransformer(
369
+ self.cfg.embedding_model,
370
+ cache_folder=self.cfg.cache_dir
371
+ )
372
+ logger.info(f"Embedding model loaded: {self.cfg.embedding_model}")
373
+ except Exception as e:
374
+ logger.error(f"Failed to load embedding model: {e}")
375
+ raise
376
+
377
+ def load_chroma(self) -> Tuple[bool, str]:
378
+ with self._safe_operation("load_chroma"):
379
+ try:
380
+ base_path = str(Path(self.cfg.chroma_db_path).parent)
381
+ os.makedirs(base_path, exist_ok=True)
382
+
383
+ self.client = chromadb.PersistentClient(path=base_path)
384
+ try:
385
+ self.collection = self.client.get_collection("legal_articles")
386
+ count = self.collection.count()
387
+ logger.info(f"Loaded existing collection with {count} documents")
388
+ return count > 0, f"مجموعه موجود با {count} سند بارگذاری شد"
389
+ except Exception:
390
+ self.collection = self.client.create_collection(
391
+ "legal_articles",
392
+ metadata={"description": "مواد قانونی"}
393
+ )
394
+ logger.info("Created new collection")
395
+ return False, "مجموعه جدید ایجاد شد"
396
+
397
+ except Exception as e:
398
+ logger.error(f"ChromaDB initialization failed: {e}")
399
+ return False, f"خطا در بارگذاری پایگاه داده: {str(e)}"
400
+
401
+ def retrieve(self, query: str) -> List[Dict]:
402
+ if not self.collection or not query.strip():
403
+ return []
404
+
405
+ with self._safe_operation("retrieve"):
406
+ try:
407
+ # Sanitize query
408
+ query = query.strip()[:500] # Limit query length
409
+
410
+ result = self.collection.query(
411
+ query_texts=[query],
412
+ n_results=self.cfg.top_k_retrieval,
413
+ include=["documents", "metadatas", "distances"]
414
+ )
415
+
416
+ articles = []
417
+ if result['documents'] and result['documents'][0]:
418
+ for i, (doc, meta, dist) in enumerate(zip(
419
+ result['documents'][0],
420
+ result['metadatas'][0],
421
+ result['distances'][0]
422
+ )):
423
+ similarity = max(0, min(1, 1 - dist)) # Normalize similarity
424
+ if similarity >= self.cfg.similarity_threshold:
425
+ articles.append({
426
+ "article_id": meta.get("article_id", f"unknown_{i}"),
427
+ "text": str(doc)[:500], # Limit text length
428
+ "similarity": round(similarity, 3),
429
+ })
430
+
431
+ logger.info(f"Retrieved {len(articles)} relevant articles")
432
+ return articles
433
+
434
+ except Exception as e:
435
+ logger.error(f"Article retrieval failed: {e}")
436
+ return []
437
+
438
+ @staticmethod
439
+ def build_context(articles: List[Dict], limit_chars: int = 500) -> str:
440
+ if not articles:
441
+ return ""
442
+
443
+ context_parts = []
444
+ total_chars = 0
445
+
446
+ for article in articles:
447
+ text = article['text'][:limit_chars]
448
+ part = f"• ماده {article['article_id']}: {text}"
449
+
450
+ if total_chars + len(part) > limit_chars * 3: # Max total context
451
+ break
452
+
453
+ context_parts.append(part)
454
+ total_chars += len(part)
455
+
456
+ return "مواد مرتبط:\n" + "\n".join(context_parts)
457
+
458
+ # ==========================
459
+ # Enhanced Formalizer
460
+ # ==========================
461
+ class Formalizer:
462
+ def __init__(self, model_name="erfan226/persian-t5-formality-transfer", device=None):
463
+ self.model_name = model_name
464
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
465
+ self.tokenizer = None
466
+ self.model = None
467
+ self._initialized = False
468
+ self._lock = threading.Lock()
469
+
470
+ def _initialize(self):
471
+ """Lazy initialization of formalizer model"""
472
+ if self._initialized:
473
+ return
474
+
475
+ with self._lock:
476
+ if self._initialized: # Double-check pattern
477
+ return
478
+
479
+ try:
480
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
481
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(self.device)
482
+ self._initialized = True
483
+ logger.info("Formalizer model initialized")
484
+ except Exception as e:
485
+ logger.error(f"Formalizer initialization failed: {e}")
486
+ raise
487
+
488
+ def formalize(self, text: str, max_len: int = 512) -> str:
489
+ if not text or not text.strip():
490
+ return text
491
+
492
+ self._initialize()
493
+
494
+ try:
495
+ # Sanitize and limit input
496
+ text = text.strip()[:max_len]
497
+
498
+ inputs = self.tokenizer(
499
+ text,
500
+ return_tensors="pt",
501
+ truncation=True,
502
+ max_length=max_len
503
+ ).to(self.device)
504
+
505
+ with torch.no_grad():
506
+ outputs = self.model.generate(
507
+ **inputs,
508
+ max_length=max_len,
509
+ num_beams=4,
510
+ early_stopping=True
511
+ )
512
+
513
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
514
+ logger.debug(f"Formalized text: {text[:50]}... -> {result[:50]}...")
515
+ return result
516
+
517
+ except Exception as e:
518
+ logger.error(f"Text formalization failed: {e}")
519
+ return text # Return original text on error
520
+
521
+ # ==========================
522
+ # Enhanced Model Loader
523
+ # ==========================
524
+ class ModelLoader:
525
+ def __init__(self, model_config: ModelConfig):
526
+ self.cfg = model_config
527
+ self.tokenizer = None
528
+ self.model = None
529
+ self._loaded = False
530
+
531
+ def _is_persianmind(self, name: str) -> bool:
532
+ return "PersianMind" in name or "universitytehran/PersianMind" in name
533
+
534
+ @contextmanager
535
+ def _gpu_memory_context(self):
536
+ """Context manager for GPU memory management"""
537
+ initial_memory = 0
538
+ if torch.cuda.is_available():
539
+ initial_memory = torch.cuda.memory_allocated()
540
+
541
+ try:
542
+ yield
543
+ finally:
544
+ if torch.cuda.is_available():
545
+ final_memory = torch.cuda.memory_allocated()
546
+ logger.info(f"Memory change: {(final_memory - initial_memory) / 1024**2:.1f} MB")
547
+ metrics.update_memory_usage()
548
+
549
+ def load(self, prefer_quantized: bool = True):
550
+ if self._loaded:
551
+ return self
552
+
553
+ with self._gpu_memory_context():
554
+ try:
555
+ self._load_tokenizer()
556
+ self._load_model(prefer_quantized)
557
+ self._loaded = True
558
+ logger.info(f"Successfully loaded {self.cfg.model_name}")
559
+ return self
560
+
561
+ except Exception as e:
562
+ logger.error(f"Model loading failed: {e}")
563
+ self._cleanup()
564
+ raise
565
+
566
+ def _load_tokenizer(self):
567
+ """Load tokenizer with error handling"""
568
+ try:
569
+ self.tokenizer = AutoTokenizer.from_pretrained(
570
+ self.cfg.model_name,
571
+ use_fast=True,
572
+ trust_remote_code=True
573
+ )
574
+ logger.info("Tokenizer loaded successfully")
575
+ except Exception as e:
576
+ logger.error(f"Tokenizer loading failed: {e}")
577
+ raise
578
+
579
+ def _load_model(self, prefer_quantized: bool):
580
+ """Load model with quantization support"""
581
+ device_map = "auto" if torch.cuda.is_available() else None
582
+ cuda_available = torch.cuda.is_available()
583
+ dtype = torch.bfloat16 if (cuda_available and self.cfg.use_bf16) else (
584
+ torch.float16 if cuda_available else torch.float32
585
+ )
586
+
587
+ # Try quantized loading for PersianMind causal models
588
+ if (self.cfg.architecture == "causal" and
589
+ self._is_persianmind(self.cfg.model_name) and
590
+ prefer_quantized and cuda_available):
591
+
592
+ if self._try_quantized_loading(device_map, dtype):
593
+ return
594
+
595
+ # Standard loading
596
+ self._load_standard_model(device_map, dtype)
597
+
598
+ def _try_quantized_loading(self, device_map, dtype) -> bool:
599
+ """Try loading model with quantization"""
600
+ # Try 8-bit first
601
+ try:
602
+ self.model = AutoModelForCausalLM.from_pretrained(
603
+ self.cfg.model_name,
604
+ device_map=device_map,
605
+ load_in_8bit=True,
606
+ torch_dtype=dtype,
607
+ trust_remote_code=True
608
+ )
609
+ self._setup_pad_token()
610
+ logger.info("Model loaded with 8-bit quantization")
611
+ return True
612
+ except Exception as e:
613
+ logger.warning(f"8-bit loading failed: {e}")
614
+
615
+ # Try 4-bit
616
+ try:
617
+ self.model = AutoModelForCausalLM.from_pretrained(
618
+ self.cfg.model_name,
619
+ device_map=device_map,
620
+ load_in_4bit=True,
621
+ bnb_4bit_use_double_quant=True,
622
+ bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
623
+ torch_dtype=dtype,
624
+ trust_remote_code=True
625
+ )
626
+ self._setup_pad_token()
627
+ logger.info("Model loaded with 4-bit quantization")
628
+ return True
629
+ except Exception as e:
630
+ logger.warning(f"4-bit loading failed: {e}")
631
+
632
+ return False
633
+
634
+ def _load_standard_model(self, device_map, dtype):
635
+ """Load model with standard precision"""
636
+ try:
637
+ if self.cfg.architecture == "seq2seq":
638
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
639
+ self.cfg.model_name,
640
+ device_map=device_map,
641
+ torch_dtype=dtype,
642
+ trust_remote_code=True
643
+ )
644
+ elif self.cfg.architecture == "causal":
645
+ self.model = AutoModelForCausalLM.from_pretrained(
646
+ self.cfg.model_name,
647
+ device_map=device_map,
648
+ torch_dtype=dtype,
649
+ trust_remote_code=True
650
+ )
651
+ self._setup_pad_token()
652
+ else:
653
+ raise ValueError(f"Unsupported architecture: {self.cfg.architecture}")
654
+
655
+ logger.info("Model loaded with standard precision")
656
+
657
+ except Exception as e:
658
+ logger.error(f"Standard model loading failed: {e}")
659
+ raise
660
+
661
+ def _setup_pad_token(self):
662
+ """Setup pad token for causal models"""
663
+ if (self.tokenizer.pad_token is None and
664
+ hasattr(self.tokenizer, 'eos_token') and
665
+ self.tokenizer.eos_token):
666
+ self.tokenizer.pad_token = self.tokenizer.eos_token
667
+
668
+ def _cleanup(self):
669
+ """Clean up resources on failure"""
670
+ try:
671
+ if self.model and hasattr(self.model, 'cpu'):
672
+ self.model.cpu()
673
+ if torch.cuda.is_available():
674
+ torch.cuda.empty_cache()
675
+ except Exception as e:
676
+ logger.warning(f"Cleanup error: {e}")
677
+
678
+ # ==========================
679
+ # Enhanced Generator
680
+ # ==========================
681
+ class UnifiedGenerator:
682
+ def __init__(self, loader: ModelLoader):
683
+ self.loader = loader
684
+ self.tokenizer = loader.tokenizer
685
+ self.model = loader.model
686
+ self.cfg = loader.cfg
687
+
688
+ def generate(self, question: str, context: str = "") -> Tuple[str, str]:
689
+ """Generate response with comprehensive error handling"""
690
+ if not question or not question.strip():
691
+ return "لطفاً سوال خود را وارد کنید.", "EMPTY_QUERY"
692
+
693
+ if not self.model or not self.tokenizer:
694
+ return "مدل بارگذاری نشده است.", "MODEL_NOT_LOADED"
695
+
696
+ start_time = time.time()
697
+ try:
698
+ # Sanitize inputs
699
+ question = question.strip()[:self.cfg.max_input_length // 2]
700
+ context = context.strip()[:self.cfg.max_input_length // 2]
701
+
702
+ if self.cfg.architecture == "seq2seq":
703
+ result = self._generate_seq2seq(question, context)
704
+ else:
705
+ result = self._generate_causal(question, context)
706
+
707
+ response_time = time.time() - start_time
708
+ metrics.record_request(response_time, success=True)
709
+
710
+ logger.info(f"Generated response in {response_time:.2f}s")
711
+ return result, ""
712
+
713
+ except torch.cuda.OutOfMemoryError:
714
+ error_msg = "حافظه GPU کافی نیست. لطفاً پارامترها را کاهش دهید."
715
+ logger.error("CUDA out of memory error")
716
+ metrics.record_request(time.time() - start_time, success=False)
717
+ return error_msg, "CUDA_OOM"
718
+
719
+ except Exception as e:
720
+ error_msg = "خطای غیرمنتظره در تولید پاسخ رخ داد."
721
+ logger.error(f"Generation error: {e}")
722
+ metrics.record_request(time.time() - start_time, success=False)
723
+ return error_msg, str(e)
724
+
725
+ def _generate_seq2seq(self, question: str, context: str) -> str:
726
+ """Generate response using seq2seq model"""
727
+ input_text = f"{context}\nسوال: {question}" if context else f"سوال: {question}"
728
+
729
+ inputs = self.tokenizer(
730
+ input_text,
731
+ return_tensors="pt",
732
+ truncation=True,
733
+ max_length=self.cfg.max_input_length
734
+ )
735
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
736
+
737
+ with torch.no_grad():
738
+ outputs = self.model.generate(
739
+ **inputs,
740
+ max_length=self.cfg.max_target_length,
741
+ num_beams=self.cfg.num_beams,
742
+ early_stopping=True,
743
+ no_repeat_ngram_size=2,
744
+ do_sample=False
745
+ )
746
+
747
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
748
+
749
+ # Clean up response (remove input echo if present)
750
+ if input_text in response:
751
+ response = response.replace(input_text, "").strip()
752
+
753
+ return response or "پاسخی تولید نشد."
754
+
755
+ def _generate_causal(self, question: str, context: str) -> str:
756
+ """Generate response using causal model"""
757
+ prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
758
+
759
+ inputs = self.tokenizer(
760
+ prompt,
761
+ return_tensors="pt",
762
+ truncation=True,
763
+ max_length=self.cfg.max_input_length
764
+ )
765
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
766
+ input_length = inputs['input_ids'].shape[1]
767
+
768
+ with torch.no_grad():
769
+ outputs = self.model.generate(
770
+ **inputs,
771
+ max_new_tokens=self.cfg.max_new_tokens,
772
+ do_sample=True,
773
+ temperature=max(0.1, self.cfg.temperature), # Ensure min temperature
774
+ top_p=self.cfg.top_p,
775
+ pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
776
+ repetition_penalty=1.1,
777
+ no_repeat_ngram_size=3
778
+ )
779
+
780
+ # Extract only the generated part
781
+ generated_tokens = outputs[0][input_length:]
782
+ response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
783
+
784
+ # Clean up response
785
+ response = response.strip()
786
+ if not response:
787
+ return "پاسخی تولید نشد."
788
+
789
+ # Remove any remaining prompt artifacts
790
+ response = response.split("سوال:")[0].strip()
791
+
792
+ return response
793
+
794
+ # ==========================
795
+ # Enhanced Datasets
796
+ # ==========================
797
+ class Seq2SeqJSONLDataset(Dataset):
798
+ def __init__(self, data: List[Dict], tokenizer, max_input: int, max_target: int):
799
+ self.tokenizer = tokenizer
800
+ self.max_input = max_input
801
+ self.max_target = max_target
802
+
803
+ # Filter and validate data
804
+ self.items = []
805
+ for item in data:
806
+ src = str(item.get("input", "")).strip()
807
+ tgt = str(item.get("output", "")).strip()
808
+
809
+ if src and tgt and len(src) > 5 and len(tgt) > 5: # Minimum length check
810
+ self.items.append((src, tgt))
811
+
812
+ logger.info(f"Seq2Seq dataset created with {len(self.items)} samples")
813
+
814
+ def __len__(self):
815
+ return len(self.items)
816
+
817
+ def __getitem__(self, idx):
818
+ source_text, target_text = self.items[idx]
819
+
820
+ # Tokenize inputs
821
+ model_inputs = self.tokenizer(
822
+ source_text,
823
+ max_length=self.max_input,
824
+ padding="max_length",
825
+ truncation=True,
826
+ return_tensors="pt"
827
+ )
828
+
829
+ # Tokenize targets
830
+ labels = self.tokenizer(
831
+ text_target=target_text,
832
+ max_length=self.max_target,
833
+ padding="max_length",
834
+ truncation=True,
835
+ return_tensors="pt"
836
+ )
837
+
838
+ # Convert to proper format
839
+ return {
840
+ "input_ids": model_inputs["input_ids"].flatten(),
841
+ "attention_mask": model_inputs["attention_mask"].flatten(),
842
+ "labels": labels["input_ids"].flatten()
843
+ }
844
+
845
+ class CausalJSONLDataset(Dataset):
846
+ def __init__(self, data: List[Dict], tokenizer, max_length: int):
847
+ self.tokenizer = tokenizer
848
+ self.max_length = max_length
849
+
850
+ # Process data
851
+ self.items = []
852
+ for item in data:
853
+ src = str(item.get("input", "")).strip()
854
+ tgt = str(item.get("output", "")).strip()
855
+
856
+ if src and tgt and len(src) > 5 and len(tgt) > 5:
857
+ formatted_text = f"سوال: {src}\nپاسخ: {tgt}"
858
+ self.items.append(formatted_text)
859
+
860
+ logger.info(f"Causal dataset created with {len(self.items)} samples")
861
+
862
+ def __len__(self):
863
+ return len(self.items)
864
+
865
+ def __getitem__(self, idx):
866
+ text = self.items[idx]
867
+
868
+ encoding = self.tokenizer(
869
+ text,
870
+ max_length=self.max_length,
871
+ padding="max_length",
872
+ truncation=True,
873
+ return_tensors="pt"
874
+ )
875
+
876
+ input_ids = encoding["input_ids"].flatten()
877
+ attention_mask = encoding["attention_mask"].flatten()
878
+
879
+ labels = input_ids.clone()
880
+ labels[attention_mask == 0] = -100
881
+
882
+ return {
883
+ "input_ids": input_ids,
884
+ "attention_mask": attention_mask,
885
+ "labels": labels
886
+ }
887
+
888
+ # ==========================
889
+ # Enhanced Progress Callback
890
+ # ==========================
891
+ class GradioProgressCallback(TrainerCallback):
892
+ def __init__(self, progress: gr.Progress, status_textbox: gr.Textbox):
893
+ self.progress = progress
894
+ self.status_textbox = status_textbox
895
+ self.total_steps = None
896
+ self.start_time = None
897
+ self.last_update = 0
898
+
899
+ def on_train_begin(self, args, state, control, **kwargs):
900
+ self.total_steps = state.max_steps
901
+ self.start_time = time.time()
902
+ self.progress(0, desc="آموزش شروع شد 🚀")
903
+ self.status_textbox.update(value="آموزش شروع شد...")
904
+
905
+ def on_step_end(self, args, state, control, **kwargs):
906
+ if not self.total_steps or time.time() - self.last_update < 1.0: # Throttle updates
907
+ return
908
+
909
+ self.last_update = time.time()
910
+
911
+ # Calculate progress
912
+ progress_pct = min(100, int((state.global_step / self.total_steps) * 100))
913
+
914
+ # Estimate remaining time
915
+ elapsed = time.time() - self.start_time
916
+ if state.global_step > 0:
917
+ avg_time_per_step = elapsed / state.global_step
918
+ remaining_steps = self.total_steps - state.global_step
919
+ eta_seconds = avg_time_per_step * remaining_steps
920
+ eta_minutes = int(eta_seconds / 60)
921
+ eta_str = f" (تخمین باقی‌مانده: {eta_minutes} دقیقه)" if eta_minutes > 0 else ""
922
+ else:
923
+ eta_str = ""
924
+
925
+ # Update progress
926
+ self.progress(progress_pct, desc=f"آموزش: {progress_pct}%")
927
+
928
+ # Update status with more details
929
+ current_lr = state.learning_rate if hasattr(state, 'learning_rate') else args.learning_rate
930
+ status_msg = (f"Step {state.global_step}/{self.total_steps} → {progress_pct}%{eta_str}\n"
931
+ f"Learning Rate: {current_lr:.2e}")
932
+
933
+ if hasattr(state, 'log_history') and state.log_history:
934
+ last_log = state.log_history[-1]
935
+ if 'train_loss' in last_log:
936
+ status_msg += f"\nTrain Loss: {last_log['train_loss']:.4f}"
937
+ if 'eval_loss' in last_log:
938
+ status_msg += f"\nEval Loss: {last_log['eval_loss']:.4f}"
939
+
940
+ self.status_textbox.update(value=status_msg)
941
+
942
+ def on_evaluate(self, args, state, control, **kwargs):
943
+ if hasattr(state, 'log_history') and state.log_history:
944
+ last_log = state.log_history[-1]
945
+ if 'eval_loss' in last_log:
946
+ self.status_textbox.update(
947
+ value=f"ارزیابی انجام شد - Loss: {last_log['eval_loss']:.4f}"
948
+ )
949
+
950
+ def on_train_end(self, args, state, control, **kwargs):
951
+ total_time = time.time() - self.start_time
952
+ total_minutes = int(total_time / 60)
953
+
954
+ self.progress(100, desc="آموزش تکمیل شد ✅")
955
+ self.status_textbox.update(
956
+ value=f"آموزش با موفقیت تکمیل شد ✅\n"
957
+ f"زمان کل: {total_minutes} دقیقه\n"
958
+ f"کل Steps: {state.global_step}"
959
+ )
960
+
961
+ # ==========================
962
+ # Enhanced Trainer Manager
963
+ # ==========================
964
+ class TrainerManager:
965
+ def __init__(self, system_config: SystemConfig, model_loader: ModelLoader):
966
+ self.cfg = system_config
967
+ self.loader = model_loader
968
+
969
+ def train(self, train_paths: List[str], extra_callbacks: List = None) -> Tuple[bool, str]:
970
+ """Main training method with comprehensive error handling"""
971
+ if extra_callbacks is None:
972
+ extra_callbacks = []
973
+
974
+ try:
975
+ # Validate training files
976
+ data, errors = read_jsonl_files_safe(train_paths, self.cfg)
977
+
978
+ if errors:
979
+ error_msg = "خطاهای فایل:\n" + "\n".join(errors[:5]) # Show first 5 errors
980
+ return False, error_msg
981
+
982
+ if len(data) < 10:
983
+ return False, f"داده کافی نیست. حداقل 10 نمونه نیاز است (موجود: {len(data)})"
984
+
985
+ # Set random seed
986
+ set_seed_all(self.cfg.seed)
987
+
988
+ # Split data
989
+ train_data, val_data = train_test_split(
990
+ data,
991
+ test_size=self.cfg.train_test_ratio,
992
+ random_state=self.cfg.seed
993
+ )
994
+
995
+ logger.info(f"Training samples: {len(train_data)}, Validation samples: {len(val_data)}")
996
+
997
+ # Train based on architecture
998
+ if self.cfg.model.architecture == "seq2seq":
999
+ success, msg = self._train_seq2seq(train_data, val_data, extra_callbacks)
1000
+ else:
1001
+ success, msg = self._train_causal(train_data, val_data, extra_callbacks)
1002
+
1003
+ if success:
1004
+ # Save configuration
1005
+ self._save_training_config()
1006
+
1007
+ return success, msg
1008
+
1009
+ except Exception as e:
1010
+ logger.error(f"Training failed: {e}")
1011
+ return False, f"خطا در آموزش: {str(e)}"
1012
+
1013
+ def _train_seq2seq(self, train_data: List[Dict], val_data: List[Dict],
1014
+ extra_callbacks: List) -> Tuple[bool, str]:
1015
+ """Train seq2seq model"""
1016
+ try:
1017
+ # Create datasets
1018
+ train_dataset = Seq2SeqJSONLDataset(
1019
+ train_data, self.loader.tokenizer,
1020
+ self.cfg.model.max_input_length,
1021
+ self.cfg.model.max_target_length
1022
+ )
1023
+
1024
+ val_dataset = Seq2SeqJSONLDataset(
1025
+ val_data, self.loader.tokenizer,
1026
+ self.cfg.model.max_input_length,
1027
+ self.cfg.model.max_target_length
1028
+ )
1029
+
1030
+ # Data collator
1031
+ data_collator = DataCollatorForSeq2Seq(
1032
+ tokenizer=self.loader.tokenizer,
1033
+ model=self.loader.model,
1034
+ padding=True
1035
+ )
1036
+
1037
+ # Training arguments
1038
+ training_args = self._get_training_args()
1039
+ training_args.predict_with_generate = True
1040
+ training_args.generation_max_length = self.cfg.model.max_target_length
1041
+ training_args.generation_num_beams = self.cfg.model.num_beams
1042
+
1043
+ # Create trainer
1044
+ trainer = Trainer(
1045
+ model=self.loader.model,
1046
+ args=training_args,
1047
+ train_dataset=train_dataset,
1048
+ eval_dataset=val_dataset,
1049
+ data_collator=data_collator,
1050
+ tokenizer=self.loader.tokenizer,
1051
+ callbacks=self._get_callbacks(extra_callbacks)
1052
+ )
1053
+
1054
+ # Train
1055
+ trainer.train()
1056
+
1057
+ # Save model
1058
+ trainer.save_model(self.cfg.output_dir)
1059
+ self.loader.tokenizer.save_pretrained(self.cfg.output_dir)
1060
+
1061
+ return True, "مدل Seq2Seq با موفقیت آموزش داده شد"
1062
+
1063
+ except Exception as e:
1064
+ logger.error(f"Seq2Seq training failed: {e}")
1065
+ return False, f"خطا در آموزش Seq2Seq: {str(e)}"
1066
+
1067
+ def _train_causal(self, train_data: List[Dict], val_data: List[Dict],
1068
+ extra_callbacks: List) -> Tuple[bool, str]:
1069
+ """Train causal language model"""
1070
+ try:
1071
+ # Create datasets
1072
+ train_dataset = CausalJSONLDataset(
1073
+ train_data, self.loader.tokenizer,
1074
+ self.cfg.model.max_input_length
1075
+ )
1076
+
1077
+ val_dataset = CausalJSONLDataset(
1078
+ val_data, self.loader.tokenizer,
1079
+ self.cfg.model.max_input_length
1080
+ )
1081
+
1082
+ # Training arguments
1083
+ training_args = self._get_training_args()
1084
+
1085
+ # Create trainer
1086
+ trainer = Trainer(
1087
+ model=self.loader.model,
1088
+ args=training_args,
1089
+ train_dataset=train_dataset,
1090
+ eval_dataset=val_dataset,
1091
+ tokenizer=self.loader.tokenizer,
1092
+ callbacks=self._get_callbacks(extra_callbacks)
1093
+ )
1094
+
1095
+ # Train
1096
+ trainer.train()
1097
+
1098
+ # Save model
1099
+ trainer.save_model(self.cfg.output_dir)
1100
+ self.loader.tokenizer.save_pretrained(self.cfg.output_dir)
1101
+
1102
+ return True, "مدل Causal با موفقیت آموزش داده شد"
1103
+
1104
+ except Exception as e:
1105
+ logger.error(f"Causal training failed: {e}")
1106
+ return False, f"خطا در آموزش Causal: {str(e)}"
1107
+
1108
+ def _get_training_args(self) -> TrainingArguments:
1109
+ """Get training arguments with optimized settings"""
1110
+ return TrainingArguments(
1111
+ output_dir=self.cfg.output_dir,
1112
+ num_train_epochs=self.cfg.epochs,
1113
+ learning_rate=self.cfg.lr,
1114
+ per_device_train_batch_size=self.cfg.batch_size,
1115
+ per_device_eval_batch_size=self.cfg.batch_size,
1116
+ gradient_accumulation_steps=self.cfg.grad_accum,
1117
+ warmup_ratio=0.05,
1118
+ weight_decay=0.01,
1119
+ evaluation_strategy="epoch",
1120
+ eval_steps=500,
1121
+ save_strategy="epoch",
1122
+ save_total_limit=3, # Keep more checkpoints
1123
+ load_best_model_at_end=True,
1124
+ metric_for_best_model="eval_loss",
1125
+ greater_is_better=False,
1126
+ logging_steps=50,
1127
+ logging_dir=f"{self.cfg.output_dir}/logs",
1128
+ report_to="none",
1129
+ bf16=self.cfg.model.use_bf16 if torch.cuda.is_available() else False,
1130
+ fp16=(not self.cfg.model.use_bf16) if torch.cuda.is_available() else False,
1131
+ dataloader_drop_last=True,
1132
+ remove_unused_columns=False,
1133
+ gradient_checkpointing=True, # Save memory
1134
+ )
1135
+
1136
+ def _get_callbacks(self, extra_callbacks: List) -> List:
1137
+ """Get training callbacks"""
1138
+ callbacks = [
1139
+ EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
1140
+ ]
1141
+ callbacks.extend(extra_callbacks)
1142
+ return callbacks
1143
+
1144
+ def _save_training_config(self):
1145
+ """Save training configuration"""
1146
+ try:
1147
+ config_path = Path(self.cfg.output_dir) / "training_config.json"
1148
+ config_dict = self.cfg.dict()
1149
+ config_dict['training_timestamp'] = datetime.now().isoformat()
1150
+ config_dict['training_completed'] = True
1151
+
1152
+ with open(config_path, 'w', encoding='utf-8') as f:
1153
+ json.dump(config_dict, f, ensure_ascii=False, indent=2)
1154
+
1155
+ logger.info(f"Training config saved to {config_path}")
1156
+ except Exception as e:
1157
+ logger.warning(f"Failed to save training config: {e}")
1158
+
1159
+ # ==========================
1160
+ # Enhanced Legal App
1161
+ # ==========================
1162
+ class LegalApp:
1163
+ def __init__(self, system_config: Optional[SystemConfig] = None):
1164
+ self.cfg = system_config or SystemConfig()
1165
+ self.rag = LegalRAGSystem(self.cfg)
1166
+ self.formalizer: Optional[Formalizer] = None
1167
+ self._current_loader: Optional[ModelLoader] = None
1168
+ self._current_generator: Optional[UnifiedGenerator] = None
1169
+ self._lock = threading.Lock()
1170
+
1171
+ def _ensure_model(self, model_name: str, architecture: str) -> Tuple[bool, str]:
1172
+ """Ensure model is loaded with thread safety"""
1173
+ with self._lock:
1174
+ try:
1175
+ # Update config
1176
+ self.cfg.model.model_name = model_name
1177
+ self.cfg.model.architecture = architecture
1178
+
1179
+ # Get model from cache
1180
+ self._current_loader = ModelCache.get_model(model_name, architecture, self.cfg.model)
1181
+ self._current_generator = UnifiedGenerator(self._current_loader)
1182
+
1183
+ return True, f"مدل بارگذاری شد: {model_name} ({architecture})"
1184
+
1185
+ except Exception as e:
1186
+ logger.error(f"Model loading failed: {e}")
1187
+ return False, f"خطا در بارگذاری مدل: {str(e)}"
1188
+
1189
+ def _ensure_rag(self) -> Tuple[bool, str]:
1190
+ """Ensure RAG system is ready"""
1191
+ try:
1192
+ self.rag.setup_embedding()
1193
+ success, message = self.rag.load_chroma()
1194
+ return success, message
1195
+ except Exception as e:
1196
+ logger.error(f"RAG setup failed: {e}")
1197
+ return False, f"خطا در راه‌اندازی RAG: {str(e)}"
1198
+
1199
+ def _ensure_formalizer(self) -> str:
1200
+ """Ensure formalizer is ready"""
1201
+ try:
1202
+ if not self.formalizer:
1203
+ self.formalizer = Formalizer()
1204
+ return "Formalizer آماده است."
1205
+ except Exception as e:
1206
+ logger.error(f"Formalizer setup failed: {e}")
1207
+ return f"خطا در راه‌اندازی Formalizer: {str(e)}"
1208
+
1209
+ # Event handlers
1210
+ def handle_load_model(self, model_choice: str, use_rag: bool) -> str:
1211
+ """Handle model loading"""
1212
+ try:
1213
+ model_configs = self._get_model_configs()
1214
+ if model_choice not in model_configs:
1215
+ return "مدل نامعتبر انتخاب شده"
1216
+
1217
+ model_name, architecture = model_configs[model_choice]
1218
+
1219
+ # Load model
1220
+ success, model_msg = self._ensure_model(model_name, architecture)
1221
+ if not success:
1222
+ return model_msg
1223
+
1224
+ # Setup RAG if requested
1225
+ rag_msg = ""
1226
+ if use_rag:
1227
+ rag_success, rag_msg = self._ensure_rag()
1228
+ rag_msg = f"\nRAG: {rag_msg}"
1229
+ else:
1230
+ rag_msg = "\nRAG: غیر فعال"
1231
+
1232
+ return f"{model_msg}{rag_msg}"
1233
+
1234
+ except Exception as e:
1235
+ logger.error(f"Model loading handler failed: {e}")
1236
+ return f"خطا در بارگذاری: {str(e)}"
1237
+
1238
+ def handle_generate_response(self, question: str, use_rag: bool, use_formalizer: bool,
1239
+ max_new_tokens: int, temperature: float, top_p: float,
1240
+ num_beams: int) -> Tuple[str, str, str]: # response, references, metrics
1241
+ """Handle response generation"""
1242
+ if not question or not question.strip():
1243
+ return "لطفاً سوال خود را وارد کنید.", "", ""
1244
+
1245
+ if not self._current_generator:
1246
+ return "ابتدا مدل را بارگذاری کنید.", "", ""
1247
+
1248
+ start_time = time.time()
1249
+
1250
+ try:
1251
+ # Update generation parameters
1252
+ self.cfg.model.max_new_tokens = max(32, min(1024, int(max_new_tokens)))
1253
+ self.cfg.model.temperature = max(0.1, min(2.0, float(temperature)))
1254
+ self.cfg.model.top_p = max(0.1, min(1.0, float(top_p)))
1255
+ self.cfg.model.num_beams = max(1, min(8, int(num_beams)))
1256
+
1257
+ # Apply input formalization if requested
1258
+ processed_question = question
1259
+ if use_formalizer:
1260
+ formalizer_msg = self._ensure_formalizer()
1261
+ if "خطا" not in formalizer_msg and self.formalizer:
1262
+ processed_question = self.formalizer.formalize(question)
1263
+
1264
+ # Retrieve relevant articles if RAG is enabled
1265
+ articles = []
1266
+ if use_rag and self.rag.collection:
1267
+ articles = self.rag.retrieve(processed_question)
1268
+
1269
+ # Build context
1270
+ context = LegalRAGSystem.build_context(articles) if articles else ""
1271
+
1272
+ # Generate response
1273
+ response, error = self._current_generator.generate(processed_question, context)
1274
+
1275
+ # Build references section
1276
+ references = ""
1277
+ if articles:
1278
+ ref_parts = []
1279
+ for article in articles[:3]: # Limit to top 3 references
1280
+ ref_parts.append(
1281
+ f"**ماده {article['article_id']}** (شباهت: {article['similarity']:.2f})\n"
1282
+ f"{article['text'][:400]}{'...' if len(article['text']) > 400 else ''}"
1283
+ )
1284
+ references = "\n\n".join(ref_parts)
1285
+
1286
+ # Generate metrics info
1287
+ elapsed_time = time.time() - start_time
1288
+ metrics_info = f"زمان پردازش: {elapsed_time:.2f}s"
1289
+ if articles:
1290
+ metrics_info += f" | مواد یافت شده: {len(articles)}"
1291
+ if use_formalizer:
1292
+ metrics_info += " | فرمالایزر فعال"
1293
+
1294
+ return response, references, metrics_info
1295
+
1296
+ except Exception as e:
1297
+ logger.error(f"Response generation failed: {e}")
1298
+ error_time = time.time() - start_time
1299
+ metrics.record_request(error_time, success=False)
1300
+ return f"خطا در تولید پاسخ: {str(e)}", "", f"خطا پس از {error_time:.2f}s"
1301
+
1302
+ def handle_training(self, model_choice: str, uploaded_files, use_rag_training: bool,
1303
+ epochs: int, batch_size: int, learning_rate: float,
1304
+ progress: gr.Progress, status_textbox: gr.Textbox) -> str:
1305
+ """Handle model training"""
1306
+ try:
1307
+ # Validate inputs
1308
+ if not uploaded_files:
1309
+ return "لطفاً فایل‌های آموزشی را بارگذاری کنید."
1310
+
1311
+ # Get model config
1312
+ model_configs = self._get_model_configs()
1313
+ if model_choice not in model_configs:
1314
+ return "مدل نامعتبر انتخاب شده"
1315
+
1316
+ model_name, architecture = model_configs[model_choice]
1317
+
1318
+ # Load model for training
1319
+ success, msg = self._ensure_model(model_name, architecture)
1320
+ if not success:
1321
+ return f"خطا در بارگذاری مدل: {msg}"
1322
+
1323
+ # Update training config
1324
+ self.cfg.epochs = max(1, min(10, int(epochs)))
1325
+ self.cfg.batch_size = max(1, min(16, int(batch_size)))
1326
+ self.cfg.lr = max(1e-6, min(1e-3, float(learning_rate)))
1327
+
1328
+ # Setup RAG if requested
1329
+ if use_rag_training:
1330
+ rag_success, rag_msg = self._ensure_rag()
1331
+ if not rag_success:
1332
+ logger.warning(f"RAG setup failed for training: {rag_msg}")
1333
+
1334
+ # Get file paths (gr.File with type="filepath" returns list[str])
1335
+ file_paths = uploaded_files
1336
+
1337
+ if not file_paths:
1338
+ return "فایل‌های معتبر یافت نشد."
1339
+
1340
+ # Create trainer
1341
+ trainer_manager = TrainerManager(self.cfg, self._current_loader)
1342
+
1343
+ # Create progress callback
1344
+ progress_callback = GradioProgressCallback(progress, status_textbox)
1345
+
1346
+ # Start training
1347
+ success, result_msg = trainer_manager.train(file_paths, [progress_callback])
1348
+
1349
+ if success:
1350
+ # Clear model cache to force reload of trained model
1351
+ ModelCache.clear_cache()
1352
+ return f"✅ {result_msg}\nمدل در مسیر '{self.cfg.output_dir}' ذخیره شد."
1353
+ else:
1354
+ return f"❌ {result_msg}"
1355
+
1356
+ except Exception as e:
1357
+ logger.error(f"Training handler failed: {e}")
1358
+ return f"خطا در آموزش: {str(e)}"
1359
+
1360
+ def get_system_status(self) -> str:
1361
+ """Get system status information"""
1362
+ try:
1363
+ status_parts = []
1364
+
1365
+ # Model status
1366
+ if self._current_loader:
1367
+ status_parts.append(f"✅ مدل فعال: {self.cfg.model.model_name}")
1368
+ else:
1369
+ status_parts.append("❌ مدل بارگذاری نشده")
1370
+
1371
+ # RAG status
1372
+ if self.rag.collection:
1373
+ doc_count = self.rag.collection.count()
1374
+ status_parts.append(f"✅ RAG فعال ({doc_count} سند)")
1375
+ else:
1376
+ status_parts.append("❌ RAG غیر فعال")
1377
+
1378
+ # System metrics
1379
+ sys_metrics = metrics.get_metrics()
1380
+ status_parts.append(f"📊 درخواست‌ها: {sys_metrics['requests_total']}")
1381
+ status_parts.append(f"📈 نرخ موفقیت: {sys_metrics['success_rate']:.1f}%")
1382
+ status_parts.append(f"⏱️ زمان متوسط: {sys_metrics['avg_response_time']}s")
1383
+
1384
+ if torch.cuda.is_available():
1385
+ memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
1386
+ status_parts.append(f"🖥️ حافظه GPU: {memory_mb:.1f} MB")
1387
+
1388
+ return "\n".join(status_parts)
1389
+
1390
+ except Exception as e:
1391
+ return f"خطا در دریافت وضعیت: {str(e)}"
1392
+
1393
+ def _get_model_configs(self) -> Dict[str, Tuple[str, str]]:
1394
+ """Get available model configurations"""
1395
+ return {
1396
+ "Seq2Seq (parsi-t5-base)": ("persiannlp/parsi-t5-base", "seq2seq"),
1397
+ "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
1398
+ "Causal (Mistral-7B)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
1399
+ "Causal (PersianMind-v1.0)": ("universitytehran/PersianMind-v1.0", "causal"),
1400
+ }
1401
+
1402
+ def build_ui(self) -> gr.Blocks:
1403
+ """Build enhanced Gradio interface"""
1404
+ model_choices = list(self._get_model_configs().keys())
1405
+
1406
+ with gr.Blocks(
1407
+ title="ماحون — مشاور حقوقی هوشمند",
1408
+ theme=gr.themes.Soft(),
1409
+ css="""
1410
+ .status-box { font-family: 'Courier New', monospace; font-size: 12px; }
1411
+ .metrics-box { background-color: #f0f0f0; padding: 10px; border-radius: 5px; }
1412
+ """
1413
+ ) as app:
1414
+
1415
+ gr.HTML("""
1416
+ <div style='text-align: center; margin-bottom: 20px;'>
1417
+ <h1>ماحون — مشاور حقوقی هوشمند 🏛️</h1>
1418
+ <p>سیستم پیشرفته مشاوره حقوقی با قابلیت RAG، Fine-tuning و هوش مصنوعی</p>
1419
+ </div>
1420
+ """)
1421
+
1422
+ # System Status
1423
+ with gr.Accordion("وضعیت سیستم", open=False):
1424
+ system_status = gr.Markdown(
1425
+ value=self.get_system_status(),
1426
+ elem_classes=["status-box"]
1427
+ )
1428
+ refresh_status_btn = gr.Button("🔄 بروزرسانی وضعیت", size="sm")
1429
+
1430
+ with gr.Tabs() as tabs:
1431
+ # Consultation Tab
1432
+ with gr.Tab("💬 مشاوره") as advice_tab:
1433
+ with gr.Row():
1434
+ with gr.Column(scale=2):
1435
+ model_dropdown = gr.Dropdown(
1436
+ choices=model_choices,
1437
+ value=model_choices[0],
1438
+ label="انتخاب مدل",
1439
+ info="نوع مدل مورد نظر را انتخاب کنید"
1440
+ )
1441
+ with gr.Column(scale=1):
1442
+ use_rag_checkbox = gr.Checkbox(
1443
+ value=True,
1444
+ label="استفاده از RAG",
1445
+ info="بازیابی مواد قانونی مرتبط"
1446
+ )
1447
+ use_formalizer_checkbox = gr.Checkbox(
1448
+ value=False,
1449
+ label="رسمی‌سازی ورودی",
1450
+ info="تبدیل متن غیررسمی به رسمی"
1451
+ )
1452
+
1453
+ load_model_btn = gr.Button("🚀 بارگذاری مدل/RAG", variant="primary", size="lg")
1454
+ load_status = gr.Textbox(
1455
+ label="وضعیت بارگذاری",
1456
+ interactive=False,
1457
+ elem_classes=["status-box"]
1458
+ )
1459
+
1460
+ # Generation Parameters
1461
+ with gr.Accordion("⚙️ پارامترهای تولید", open=False):
1462
+ with gr.Row():
1463
+ max_new_tokens = gr.Slider(
1464
+ minimum=32, maximum=1024, value=self.cfg.model.max_new_tokens,
1465
+ step=16, label="حداکثر توکن‌های جدید"
1466
+ )
1467
+ temperature = gr.Slider(
1468
+ minimum=0.1, maximum=2.0, value=self.cfg.model.temperature,
1469
+ step=0.05, label="دما (خلاقیت)"
1470
+ )
1471
+ with gr.Row():
1472
+ top_p = gr.Slider(
1473
+ minimum=0.1, maximum=1.0, value=self.cfg.model.top_p,
1474
+ step=0.05, label="Top-p (تنوع)"
1475
+ )
1476
+ num_beams = gr.Slider(
1477
+ minimum=1, maximum=8, value=self.cfg.model.num_beams,
1478
+ step=1, label="تعداد Beam"
1479
+ )
1480
+
1481
+ # Input/Output
1482
+ with gr.Row():
1483
+ with gr.Column(scale=1):
1484
+ question_input = gr.Textbox(
1485
+ label="سوال حقوقی خود را وارد کنید",
1486
+ placeholder="مثال: شرایط فسخ قرارداد اجاره چیست؟",
1487
+ lines=3
1488
+ )
1489
+ submit_btn = gr.Button("🔍 دریافت پاسخ", variant="primary")
1490
+ with gr.Column(scale=1):
1491
+ response_output = gr.Textbox(
1492
+ label="پاسخ سیستم",
1493
+ lines=8,
1494
+ interactive=False
1495
+ )
1496
+ references_output = gr.Textbox(
1497
+ label="مراجع حقوقی مرتبط",
1498
+ lines=6,
1499
+ interactive=False
1500
+ )
1501
+ metrics_output = gr.Textbox(
1502
+ label="معیارهای عملکرد",
1503
+ lines=1,
1504
+ interactive=False,
1505
+ elem_classes=["metrics-box"]
1506
+ )
1507
+
1508
+ # Training Tab
1509
+ with gr.Tab("🎓 آموزش مدل") as training_tab:
1510
+ with gr.Row():
1511
+ with gr.Column(scale=1):
1512
+ train_model_dropdown = gr.Dropdown(
1513
+ choices=model_choices,
1514
+ value=model_choices[0],
1515
+ label="انتخاب مدل برای آموزش"
1516
+ )
1517
+ use_rag_training_checkbox = gr.Checkbox(
1518
+ value=True,
1519
+ label="استفاده از RAG در آموزش",
1520
+ info="استفاده از مواد قانونی در آموزش"
1521
+ )
1522
+ train_file_upload = gr.File(
1523
+ label="بارگذاری فایل‌های آموزشی (JSONL)",
1524
+ file_types=[".jsonl"],
1525
+ type="filepath",
1526
+ file_count="multiple"
1527
+ )
1528
+ with gr.Column(scale=1):
1529
+ with gr.Accordion("⚙️ پارامترهای آموزش", open=False):
1530
+ train_epochs = gr.Slider(
1531
+ minimum=1, maximum=10, value=self.cfg.epochs,
1532
+ step=1, label="تعداد Epoch"
1533
+ )
1534
+ train_batch_size = gr.Slider(
1535
+ minimum=1, maximum=16, value=self.cfg.batch_size,
1536
+ step=1, label="اندازه Batch"
1537
+ )
1538
+ train_lr = gr.Slider(
1539
+ minimum=1e-6, maximum=1e-3, value=self.cfg.lr,
1540
+ step=1e-5, label="نرخ یادگیری"
1541
+ )
1542
+
1543
+ train_btn = gr.Button("🎯 شروع آموزش", variant="primary")
1544
+ train_status = gr.Textbox(
1545
+ label="وضعیت آموزش",
1546
+ interactive=False,
1547
+ elem_classes=["status-box"]
1548
+ )
1549
+ train_progress = gr.Progress(label="پیشرفت آموزش")
1550
+
1551
+ # Event handlers
1552
+ load_model_btn.click(
1553
+ fn=lambda m, r: self.handle_load_model(m, r),
1554
+ inputs=[model_dropdown, use_rag_checkbox],
1555
+ outputs=load_status
1556
+ )
1557
+
1558
+ submit_btn.click(
1559
+ fn=lambda q, r, f, m, t, p, b: self.handle_generate_response(
1560
+ q, r, f, m, t, p, b
1561
+ ),
1562
+ inputs=[
1563
+ question_input,
1564
+ use_rag_checkbox,
1565
+ use_formalizer_checkbox,
1566
+ max_new_tokens,
1567
+ temperature,
1568
+ top_p,
1569
+ num_beams
1570
+ ],
1571
+ outputs=[response_output, references_output, metrics_output]
1572
+ )
1573
+
1574
+ refresh_status_btn.click(
1575
+ fn=lambda: self.get_system_status(),
1576
+ outputs=system_status
1577
+ )
1578
+
1579
+ train_btn.click(
1580
+ fn=lambda m, f, r, e, b, lr, p, s: self.handle_training(
1581
+ m, f, r, e, b, lr, p, s
1582
+ ),
1583
+ inputs=[
1584
+ train_model_dropdown,
1585
+ train_file_upload,
1586
+ use_rag_training_checkbox,
1587
+ train_epochs,
1588
+ train_batch_size,
1589
+ train_lr,
1590
+ train_progress,
1591
+ train_status
1592
+ ],
1593
+ outputs=train_status
1594
+ )
1595
+
1596
+ return app
1597
+
1598
+ # ==========================
1599
+ # Main Application
1600
+ # ==========================
1601
+ def main():
1602
+ """Main entry point for the application"""
1603
+ # Initialize system
1604
+ app = LegalApp()
1605
+
1606
+ # Build and launch UI
1607
+ ui = app.build_ui()
1608
+ ui.launch(
1609
+ server_name="0.0.0.0",
1610
+ server_port=7860,
1611
+ inbrowser=True,
1612
+ share=False
1613
+ )
1614
+
1615
+ if __name__ == "__main__":
1616
+ main()