""" 🧠 ESG Model Integration Module Connects the trained model with the Gradio application This module provides the bridge between the trained ESG classifier and the web application interface. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Dict, List, Optional, Tuple from pathlib import Path from dataclasses import dataclass import warnings warnings.filterwarnings('ignore') @dataclass class ModelConfig: """Configuration for ESG model""" embed_dim: int = 4096 n_labels: int = 4 hidden_dim: int = 512 dropout: float = 0.1 labels: List[str] = None thresholds: Dict[str, float] = None def __post_init__(self): self.labels = ['E', 'S', 'G', 'non_ESG'] # Optimized thresholds from training self.thresholds = { 'E': 0.352, 'S': 0.456, 'G': 0.398, 'non_ESG': 0.512 } class MLPClassifier(nn.Module): """ Shallow MLP classifier matching the training architecture. Architecture: embed_dim -> 512 -> n_labels """ def __init__(self, config: ModelConfig): super().__init__() self.config = config self.net = nn.Sequential( nn.Linear(config.embed_dim, config.hidden_dim), nn.BatchNorm1d(config.hidden_dim), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.hidden_dim, config.n_labels), ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class ESGModelInference: """ Production-ready ESG model inference class. Handles embedding extraction and classification. """ def __init__( self, model_path: Optional[str] = None, embedding_model_name: str = "Qwen/Qwen3-Embedding-8B", device: str = "auto", use_fp16: bool = True, ): self.config = ModelConfig() # Set device if device == "auto": self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) self.use_fp16 = use_fp16 and self.device.type == "cuda" self.embedding_model = None self.tokenizer = None self.classifier = None self.scaler = None # Load models if path provided if model_path: self.load_models(model_path, embedding_model_name) def load_embedding_model(self, model_name: str): """Load the embedding model (Qwen3-Embedding-8B)""" try: from transformers import AutoTokenizer, AutoModel print(f"Loading embedding model: {model_name}") self.tokenizer = AutoTokenizer.from_pretrained( model_name, padding_side='left', trust_remote_code=True, ) dtype = torch.float16 if self.use_fp16 else torch.float32 self.embedding_model = AutoModel.from_pretrained( model_name, torch_dtype=dtype, trust_remote_code=True, ).to(self.device) self.embedding_model.eval() print(f"✅ Embedding model loaded on {self.device}") except Exception as e: print(f"⚠️ Could not load embedding model: {e}") self.embedding_model = None def load_classifier(self, model_path: str): """Load the trained classifier weights""" try: self.classifier = MLPClassifier(self.config).to(self.device) state_dict = torch.load(model_path, map_location=self.device) self.classifier.load_state_dict(state_dict) self.classifier.eval() print(f"✅ Classifier loaded from {model_path}") except Exception as e: print(f"⚠️ Could not load classifier: {e}") self.classifier = None def load_models(self, model_path: str, embedding_model_name: str): """Load all models""" self.load_embedding_model(embedding_model_name) self.load_classifier(model_path) @torch.no_grad() def extract_embedding(self, text: str, instruction: str = None) -> torch.Tensor: """Extract embedding for a single text""" if self.embedding_model is None or self.tokenizer is None: raise RuntimeError("Embedding model not loaded") if instruction is None: instruction = ( "Instruct: Classify the following text into ESG categories: " "Environmental, Social, Governance, or non-ESG.\nQuery: " ) full_text = instruction + text encoded = self.tokenizer( [full_text], padding=True, truncation=True, max_length=512, return_tensors='pt', ).to(self.device) outputs = self.embedding_model(**encoded) # Last token pooling (Qwen3-Embedding style) attention_mask = encoded['attention_mask'] last_hidden_states = outputs.last_hidden_state left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: embedding = last_hidden_states[:, -1] else: seq_lens = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] embedding = last_hidden_states[ torch.arange(batch_size, device=self.device), seq_lens ] # L2 normalize embedding = F.normalize(embedding, p=2, dim=1) return embedding.float().cpu() @torch.no_grad() def predict(self, embedding: torch.Tensor) -> Dict: """Run classification on embedding""" if self.classifier is None: raise RuntimeError("Classifier not loaded") embedding = embedding.to(self.device) logits = self.classifier(embedding) probs = torch.sigmoid(logits).cpu().numpy()[0] # Apply thresholds predictions = [] scores = {} for i, label in enumerate(self.config.labels): scores[label] = float(probs[i]) if probs[i] >= self.config.thresholds[label]: predictions.append(label) # Default to non_ESG if no predictions if not predictions: predictions = ['non_ESG'] return { 'scores': scores, 'predictions': predictions, 'confidence': np.mean([scores[p] for p in predictions]), } def classify(self, text: str) -> Dict: """Full pipeline: text -> embedding -> classification""" embedding = self.extract_embedding(text) return self.predict(embedding) def batch_classify(self, texts: List[str], batch_size: int = 8) -> List[Dict]: """Classify multiple texts efficiently""" results = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] for text in batch_texts: try: result = self.classify(text) except Exception as e: result = { 'scores': {l: 0.0 for l in self.config.labels}, 'predictions': ['non_ESG'], 'confidence': 0.0, 'error': str(e), } results.append(result) return results class LogisticRegressionEnsemble: """ Logistic Regression ensemble classifier (matches training approach). For use when the full embedding model isn't available. """ def __init__(self, model_dir: Optional[str] = None): self.config = ModelConfig() self.models = {} self.scaler = None if model_dir: self.load(model_dir) def load(self, model_dir: str): """Load trained logistic regression models""" import joblib model_dir = Path(model_dir) # Load scaler scaler_path = model_dir / 'scaler.joblib' if scaler_path.exists(): self.scaler = joblib.load(scaler_path) # Load per-class models for label in self.config.labels: model_path = model_dir / f'lr_{label}.joblib' if model_path.exists(): self.models[label] = joblib.load(model_path) def predict(self, embedding: np.ndarray) -> Dict: """Predict on pre-computed embedding""" if self.scaler: embedding = self.scaler.transform(embedding.reshape(1, -1)) scores = {} predictions = [] for label in self.config.labels: if label in self.models: prob = self.models[label].predict_proba(embedding)[0, 1] scores[label] = float(prob) if prob >= self.config.thresholds[label]: predictions.append(label) else: scores[label] = 0.0 if not predictions: predictions = ['non_ESG'] return { 'scores': scores, 'predictions': predictions, 'confidence': np.mean([scores[p] for p in predictions]), } # ═══════════════════════════════════════════════════════════════════════════════ # UTILITY FUNCTIONS # ═══════════════════════════════════════════════════════════════════════════════ def save_models_for_deployment( classifier: nn.Module, scaler, lr_models: Dict, output_dir: str, ): """Save all models for deployment""" import joblib output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Save PyTorch classifier torch.save( classifier.state_dict(), output_dir / 'mlp_classifier.pt' ) # Save scaler if scaler is not None: joblib.dump(scaler, output_dir / 'scaler.joblib') # Save LR models for label, model in lr_models.items(): joblib.dump(model, output_dir / f'lr_{label}.joblib') # Save config config = ModelConfig() config_dict = { 'embed_dim': config.embed_dim, 'n_labels': config.n_labels, 'hidden_dim': config.hidden_dim, 'dropout': config.dropout, 'labels': config.labels, 'thresholds': config.thresholds, } import json with open(output_dir / 'config.json', 'w') as f: json.dump(config_dict, f, indent=2) print(f"✅ Models saved to {output_dir}") if __name__ == "__main__": # Test the module print("ESG Model Integration Module") print(f"Config: {ModelConfig()}")