bechir09's picture
Upload folder using huggingface_hub
4d1bb75 verified
"""
🧠 ESG Model Integration Module
Connects the trained model with the Gradio application
This module provides the bridge between the trained ESG classifier
and the web application interface.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Optional, Tuple
from pathlib import Path
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')
@dataclass
class ModelConfig:
"""Configuration for ESG model"""
embed_dim: int = 4096
n_labels: int = 4
hidden_dim: int = 512
dropout: float = 0.1
labels: List[str] = None
thresholds: Dict[str, float] = None
def __post_init__(self):
self.labels = ['E', 'S', 'G', 'non_ESG']
# Optimized thresholds from training
self.thresholds = {
'E': 0.352,
'S': 0.456,
'G': 0.398,
'non_ESG': 0.512
}
class MLPClassifier(nn.Module):
"""
Shallow MLP classifier matching the training architecture.
Architecture: embed_dim -> 512 -> n_labels
"""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.net = nn.Sequential(
nn.Linear(config.embed_dim, config.hidden_dim),
nn.BatchNorm1d(config.hidden_dim),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_dim, config.n_labels),
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class ESGModelInference:
"""
Production-ready ESG model inference class.
Handles embedding extraction and classification.
"""
def __init__(
self,
model_path: Optional[str] = None,
embedding_model_name: str = "Qwen/Qwen3-Embedding-8B",
device: str = "auto",
use_fp16: bool = True,
):
self.config = ModelConfig()
# Set device
if device == "auto":
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
self.use_fp16 = use_fp16 and self.device.type == "cuda"
self.embedding_model = None
self.tokenizer = None
self.classifier = None
self.scaler = None
# Load models if path provided
if model_path:
self.load_models(model_path, embedding_model_name)
def load_embedding_model(self, model_name: str):
"""Load the embedding model (Qwen3-Embedding-8B)"""
try:
from transformers import AutoTokenizer, AutoModel
print(f"Loading embedding model: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
padding_side='left',
trust_remote_code=True,
)
dtype = torch.float16 if self.use_fp16 else torch.float32
self.embedding_model = AutoModel.from_pretrained(
model_name,
torch_dtype=dtype,
trust_remote_code=True,
).to(self.device)
self.embedding_model.eval()
print(f"βœ… Embedding model loaded on {self.device}")
except Exception as e:
print(f"⚠️ Could not load embedding model: {e}")
self.embedding_model = None
def load_classifier(self, model_path: str):
"""Load the trained classifier weights"""
try:
self.classifier = MLPClassifier(self.config).to(self.device)
state_dict = torch.load(model_path, map_location=self.device)
self.classifier.load_state_dict(state_dict)
self.classifier.eval()
print(f"βœ… Classifier loaded from {model_path}")
except Exception as e:
print(f"⚠️ Could not load classifier: {e}")
self.classifier = None
def load_models(self, model_path: str, embedding_model_name: str):
"""Load all models"""
self.load_embedding_model(embedding_model_name)
self.load_classifier(model_path)
@torch.no_grad()
def extract_embedding(self, text: str, instruction: str = None) -> torch.Tensor:
"""Extract embedding for a single text"""
if self.embedding_model is None or self.tokenizer is None:
raise RuntimeError("Embedding model not loaded")
if instruction is None:
instruction = (
"Instruct: Classify the following text into ESG categories: "
"Environmental, Social, Governance, or non-ESG.\nQuery: "
)
full_text = instruction + text
encoded = self.tokenizer(
[full_text],
padding=True,
truncation=True,
max_length=512,
return_tensors='pt',
).to(self.device)
outputs = self.embedding_model(**encoded)
# Last token pooling (Qwen3-Embedding style)
attention_mask = encoded['attention_mask']
last_hidden_states = outputs.last_hidden_state
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
embedding = last_hidden_states[:, -1]
else:
seq_lens = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
embedding = last_hidden_states[
torch.arange(batch_size, device=self.device), seq_lens
]
# L2 normalize
embedding = F.normalize(embedding, p=2, dim=1)
return embedding.float().cpu()
@torch.no_grad()
def predict(self, embedding: torch.Tensor) -> Dict:
"""Run classification on embedding"""
if self.classifier is None:
raise RuntimeError("Classifier not loaded")
embedding = embedding.to(self.device)
logits = self.classifier(embedding)
probs = torch.sigmoid(logits).cpu().numpy()[0]
# Apply thresholds
predictions = []
scores = {}
for i, label in enumerate(self.config.labels):
scores[label] = float(probs[i])
if probs[i] >= self.config.thresholds[label]:
predictions.append(label)
# Default to non_ESG if no predictions
if not predictions:
predictions = ['non_ESG']
return {
'scores': scores,
'predictions': predictions,
'confidence': np.mean([scores[p] for p in predictions]),
}
def classify(self, text: str) -> Dict:
"""Full pipeline: text -> embedding -> classification"""
embedding = self.extract_embedding(text)
return self.predict(embedding)
def batch_classify(self, texts: List[str], batch_size: int = 8) -> List[Dict]:
"""Classify multiple texts efficiently"""
results = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
for text in batch_texts:
try:
result = self.classify(text)
except Exception as e:
result = {
'scores': {l: 0.0 for l in self.config.labels},
'predictions': ['non_ESG'],
'confidence': 0.0,
'error': str(e),
}
results.append(result)
return results
class LogisticRegressionEnsemble:
"""
Logistic Regression ensemble classifier (matches training approach).
For use when the full embedding model isn't available.
"""
def __init__(self, model_dir: Optional[str] = None):
self.config = ModelConfig()
self.models = {}
self.scaler = None
if model_dir:
self.load(model_dir)
def load(self, model_dir: str):
"""Load trained logistic regression models"""
import joblib
model_dir = Path(model_dir)
# Load scaler
scaler_path = model_dir / 'scaler.joblib'
if scaler_path.exists():
self.scaler = joblib.load(scaler_path)
# Load per-class models
for label in self.config.labels:
model_path = model_dir / f'lr_{label}.joblib'
if model_path.exists():
self.models[label] = joblib.load(model_path)
def predict(self, embedding: np.ndarray) -> Dict:
"""Predict on pre-computed embedding"""
if self.scaler:
embedding = self.scaler.transform(embedding.reshape(1, -1))
scores = {}
predictions = []
for label in self.config.labels:
if label in self.models:
prob = self.models[label].predict_proba(embedding)[0, 1]
scores[label] = float(prob)
if prob >= self.config.thresholds[label]:
predictions.append(label)
else:
scores[label] = 0.0
if not predictions:
predictions = ['non_ESG']
return {
'scores': scores,
'predictions': predictions,
'confidence': np.mean([scores[p] for p in predictions]),
}
# ═══════════════════════════════════════════════════════════════════════════════
# UTILITY FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════════
def save_models_for_deployment(
classifier: nn.Module,
scaler,
lr_models: Dict,
output_dir: str,
):
"""Save all models for deployment"""
import joblib
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Save PyTorch classifier
torch.save(
classifier.state_dict(),
output_dir / 'mlp_classifier.pt'
)
# Save scaler
if scaler is not None:
joblib.dump(scaler, output_dir / 'scaler.joblib')
# Save LR models
for label, model in lr_models.items():
joblib.dump(model, output_dir / f'lr_{label}.joblib')
# Save config
config = ModelConfig()
config_dict = {
'embed_dim': config.embed_dim,
'n_labels': config.n_labels,
'hidden_dim': config.hidden_dim,
'dropout': config.dropout,
'labels': config.labels,
'thresholds': config.thresholds,
}
import json
with open(output_dir / 'config.json', 'w') as f:
json.dump(config_dict, f, indent=2)
print(f"βœ… Models saved to {output_dir}")
if __name__ == "__main__":
# Test the module
print("ESG Model Integration Module")
print(f"Config: {ModelConfig()}")