| import torch |
| import torch.nn as nn |
| import numpy as np |
| from typing import List, Dict, Any, Optional |
| from preprocessor import preprocess_for_classification |
| import re |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| from safetensors.torch import load_file |
| from transformers import AutoConfig |
|
|
|
|
| class LSTMClassifier(nn.Module): |
| """LSTM-based Arabic text classifier.""" |
| |
| def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, num_layers=2, bidirectional=False): |
| super(LSTMClassifier, self).__init__() |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| self.bidirectional = bidirectional |
| self.lstm = nn.LSTM( |
| embedding_dim, |
| hidden_dim, |
| num_layers, |
| batch_first=True, |
| dropout=0.3, |
| bidirectional=self.bidirectional |
| ) |
| fc_input_dim = hidden_dim * 2 if self.bidirectional else hidden_dim |
| self.fc = nn.Linear(fc_input_dim, output_dim) |
| self.dropout = nn.Dropout(0.5) |
| |
| def forward(self, x): |
| embedded = self.embedding(x) |
| _, (hidden, _) = self.lstm(embedded) |
| if self.bidirectional: |
| forward_hidden = hidden[-2] |
| backward_hidden = hidden[-1] |
| combined = torch.cat((forward_hidden, backward_hidden), dim=1) |
| h = combined |
| else: |
| h = hidden[-1] |
| output = self.fc(self.dropout(h)) |
| return output |
|
|
|
|
| class ModernClassifier: |
| """Modern Arabic text classifier supporting BERT and LSTM models.""" |
| |
| def __init__(self, model_type: str, model_path: str, config_path: Optional[str] = None): |
| self.model_type = model_type.lower() |
| self.model_path = model_path |
| self.config_path = config_path |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
| self.classes = np.array(['culture', 'economy', 'international', 'local', 'religion', 'sports']) |
| |
| if self.model_type == 'bert': |
| self._load_bert_model() |
| elif self.model_type == 'lstm': |
| self._load_lstm_model() |
| else: |
| raise ValueError(f"Unsupported model type: {model_type}") |
| |
| self.model_name = f"{model_type}_classifier" |
| |
| def _load_bert_model(self): |
| """Load BERT model from safetensors.""" |
| try: |
| |
| tokenizer_options = [ |
| 'asafaya/bert-base-arabic', |
| 'aubmindlab/bert-base-arabertv02', |
| 'aubmindlab/bert-base-arabertv2' |
| ] |
| |
| self.tokenizer = None |
| for tokenizer_name in tokenizer_options: |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, local_files_only=True) |
| |
| if len(tokenizer.vocab) <= 32000: |
| self.tokenizer = tokenizer |
| print(f"Using tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})") |
| break |
| except: |
| continue |
| |
| if self.tokenizer is None: |
| |
| for tokenizer_name in tokenizer_options: |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| if len(tokenizer.vocab) <= 32000: |
| self.tokenizer = tokenizer |
| print(f"Downloaded tokenizer: {tokenizer_name} (vocab size: {len(tokenizer.vocab)})") |
| break |
| except: |
| continue |
| |
| if self.tokenizer is None: |
| raise RuntimeError("No compatible Arabic BERT tokenizer found with 32K vocabulary") |
| |
| state_dict = load_file(self.model_path) |
| embed_key = next(k for k in state_dict if 'embeddings.word_embeddings.weight' in k) |
| checkpoint_vocab_size = state_dict[embed_key].shape[0] |
| |
| |
| try: |
| config = AutoConfig.from_pretrained( |
| 'aubmindlab/bert-base-arabertv2', |
| num_labels=len(self.classes), |
| vocab_size=checkpoint_vocab_size, |
| local_files_only=True |
| ) |
| except: |
| try: |
| config = AutoConfig.from_pretrained( |
| 'aubmindlab/bert-base-arabertv2', |
| num_labels=len(self.classes), |
| vocab_size=checkpoint_vocab_size |
| ) |
| except: |
| |
| from transformers import BertConfig |
| config = BertConfig( |
| vocab_size=checkpoint_vocab_size, |
| hidden_size=768, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=3072, |
| num_labels=len(self.classes) |
| ) |
| |
| self.model = AutoModelForSequenceClassification.from_config(config) |
| self.model.resize_token_embeddings(checkpoint_vocab_size) |
| self.model.load_state_dict(state_dict, strict=False) |
| self.model.to(self.device) |
| self.model.eval() |
| except Exception as e: |
| raise RuntimeError(f"Error loading BERT model: {e}") |
| |
| def _load_lstm_model(self): |
| """Load LSTM model from .pth file.""" |
| try: |
| checkpoint = torch.load(self.model_path, map_location=self.device) |
| state_dict = checkpoint.get('model_state_dict', checkpoint) |
| vocab_size, embedding_dim = state_dict['embedding.weight'].shape |
| _, hidden_dim = state_dict['lstm.weight_hh_l0'].shape |
| layer_nums = set(int(re.match(r'lstm\.weight_ih_l(\d+)', k).group(1)) |
| for k in state_dict if re.match(r'lstm\.weight_ih_l(\d+)$', k)) |
| num_layers = len(layer_nums) |
| bidirectional = True |
| output_dim = len(self.classes) |
| self.model = LSTMClassifier(vocab_size, embedding_dim, hidden_dim, |
| output_dim, num_layers=num_layers, |
| bidirectional=bidirectional) |
| self.model.load_state_dict(state_dict, strict=False) |
| self.model.to(self.device) |
| self.model.eval() |
| self.vocab = checkpoint.get('vocab', {}) |
| except Exception as e: |
| raise RuntimeError(f"Error loading LSTM model: {e}") |
| |
| def _preprocess_text_for_bert(self, text: str) -> Dict[str, torch.Tensor]: |
| """Preprocess text for BERT model.""" |
| cleaned_text = preprocess_for_classification(text) |
| |
| inputs = self.tokenizer( |
| cleaned_text, |
| return_tensors='pt', |
| truncation=True, |
| padding=True, |
| max_length=512 |
| ) |
| |
| |
| input_ids = inputs['input_ids'] |
| max_token_id = input_ids.max().item() |
| model_vocab_size = self.model.config.vocab_size |
| |
| if max_token_id >= model_vocab_size: |
| |
| inputs['input_ids'] = torch.clamp(input_ids, 0, model_vocab_size - 1) |
| |
| return {key: value.to(self.device) for key, value in inputs.items()} |
| |
| def _preprocess_text_for_lstm(self, text: str) -> torch.Tensor: |
| """Preprocess text for LSTM model.""" |
| cleaned_text = preprocess_for_classification(text) |
| |
| tokens = cleaned_text.split() |
| |
| if hasattr(self, 'vocab') and self.vocab: |
| indices = [self.vocab.get(token, 0) for token in tokens] |
| else: |
| indices = [hash(token) % 10000 for token in tokens] |
| |
| max_length = 100 |
| if len(indices) > max_length: |
| indices = indices[:max_length] |
| else: |
| indices.extend([0] * (max_length - len(indices))) |
| |
| return torch.tensor([indices], dtype=torch.long).to(self.device) |
| |
| def predict(self, text: str) -> Dict[str, Any]: |
| """Predict class with full probability distribution and metadata.""" |
| cleaned_text = preprocess_for_classification(text) |
| |
| with torch.no_grad(): |
| if self.model_type == 'bert': |
| inputs = self._preprocess_text_for_bert(text) |
| outputs = self.model(**inputs) |
| logits = outputs.logits |
| elif self.model_type == 'lstm': |
| inputs = self._preprocess_text_for_lstm(text) |
| logits = self.model(inputs) |
| |
| probabilities = torch.softmax(logits, dim=-1).cpu().numpy() |
| |
| |
| if len(probabilities.shape) > 1: |
| probabilities = probabilities[0] |
| |
| prediction_index = int(np.argmax(probabilities)) |
| prediction = self.classes[prediction_index] |
| confidence = float(probabilities[prediction_index]) |
| |
| prob_distribution = {} |
| for i, class_label in enumerate(self.classes): |
| prob_distribution[str(class_label)] = float(probabilities[i]) |
| |
| return { |
| "prediction": str(prediction), |
| "prediction_index": prediction_index, |
| "confidence": confidence, |
| "probability_distribution": prob_distribution, |
| "cleaned_text": cleaned_text, |
| "model_used": self.model_name, |
| "prediction_metadata": { |
| "max_probability": float(np.max(probabilities)), |
| "min_probability": float(np.min(probabilities)), |
| "entropy": float(-np.sum(probabilities * np.log(probabilities + 1e-10))), |
| "num_classes": len(probabilities), |
| "model_type": self.model_type, |
| "device": str(self.device) |
| }, |
| } |
| |
| def predict_batch(self, texts: List[str]) -> List[Dict[str, Any]]: |
| """Predict classes for multiple texts using true batch processing.""" |
| if not texts: |
| return [] |
| |
| cleaned_texts = [preprocess_for_classification(text) for text in texts] |
| |
| with torch.no_grad(): |
| if self.model_type == 'bert': |
| inputs = self.tokenizer( |
| cleaned_texts, |
| return_tensors='pt', |
| truncation=True, |
| padding=True, |
| max_length=512 |
| ) |
| inputs = {key: value.to(self.device) for key, value in inputs.items()} |
| outputs = self.model(**inputs) |
| logits = outputs.logits |
| |
| elif self.model_type == 'lstm': |
| batch_indices = [] |
| max_length = 100 |
| |
| for cleaned_text in cleaned_texts: |
| tokens = cleaned_text.split() |
| if hasattr(self, 'vocab') and self.vocab: |
| indices = [self.vocab.get(token, 0) for token in tokens] |
| else: |
| indices = [hash(token) % 10000 for token in tokens] |
| |
| if len(indices) > max_length: |
| indices = indices[:max_length] |
| else: |
| indices.extend([0] * (max_length - len(indices))) |
| |
| batch_indices.append(indices) |
| |
| batch_tensor = torch.tensor(batch_indices, dtype=torch.long).to(self.device) |
| logits = self.model(batch_tensor) |
| |
| probabilities = torch.softmax(logits, dim=-1).cpu().numpy() |
| |
| results = [] |
| for i, (text, cleaned_text) in enumerate(zip(texts, cleaned_texts)): |
| probs = probabilities[i] |
| prediction_index = int(np.argmax(probs)) |
| prediction = self.classes[prediction_index] |
| confidence = float(probs[prediction_index]) |
| |
| prob_distribution = {} |
| for j, class_label in enumerate(self.classes): |
| prob_distribution[str(class_label)] = float(probs[j]) |
| |
| result = { |
| "prediction": str(prediction), |
| "prediction_index": prediction_index, |
| "confidence": confidence, |
| "probability_distribution": prob_distribution, |
| "cleaned_text": cleaned_text, |
| "model_used": self.model_name, |
| "prediction_metadata": { |
| "max_probability": float(np.max(probs)), |
| "min_probability": float(np.min(probs)), |
| "entropy": float(-np.sum(probs * np.log(probs + 1e-10))), |
| "num_classes": len(probs), |
| "model_type": self.model_type, |
| "device": str(self.device) |
| }, |
| } |
| results.append(result) |
| |
| return results |
| |
| def get_model_info(self) -> Dict[str, Any]: |
| """Get model information and capabilities.""" |
| return { |
| "model_name": self.model_name, |
| "model_type": self.model_type, |
| "model_path": self.model_path, |
| "num_classes": len(self.classes), |
| "classes": self.classes.tolist(), |
| "device": str(self.device), |
| "has_predict_proba": True, |
| "framework": "pytorch", |
| "modern_model": True |
| } |
|
|