| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from typing import Optional, List |
| import time |
| from datetime import datetime, timezone |
| import os |
| import warnings |
| from huggingface_hub import hf_hub_download |
| from contextlib import asynccontextmanager |
| import uvicorn |
| from dotenv import load_dotenv |
| import shutil |
| import joblib |
| from pathlib import Path |
| from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel, DistilBertTokenizer, DistilBertModel |
|
|
| load_dotenv() |
| warnings.filterwarnings('ignore') |
|
|
| |
| class MultiHeadSelfAttention(nn.Module): |
| """Multi-Head Self-Attention mechanism""" |
| def __init__(self, embed_dim, num_heads, dropout=0.3): |
| super(MultiHeadSelfAttention, self).__init__() |
| self.attention = nn.MultiheadAttention( |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| self.layer_norm = nn.LayerNorm(embed_dim) |
| self.dropout = nn.Dropout(dropout) |
| |
| def forward(self, x): |
| attn_output, _ = self.attention(x, x, x) |
| x = self.layer_norm(x + self.dropout(attn_output)) |
| return x |
|
|
|
|
| class MHSA_GRU(nn.Module): |
| """Multi-Head Self-Attention with GRU model""" |
| def __init__(self, input_dim, hidden_dim=256, num_heads=8, num_gru_layers=2, dropout=0.3): |
| super(MHSA_GRU, self).__init__() |
| |
| self.input_dim = input_dim |
| self.hidden_dim = hidden_dim |
| |
| self.input_projection = nn.Linear(input_dim, hidden_dim) |
| self.mhsa1 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) |
| self.mhsa2 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) |
| |
| self.gru = nn.GRU( |
| input_size=hidden_dim, |
| hidden_size=hidden_dim, |
| num_layers=num_gru_layers, |
| batch_first=True, |
| dropout=dropout if num_gru_layers > 1 else 0, |
| bidirectional=False |
| ) |
| |
| self.mhsa3 = MultiHeadSelfAttention(hidden_dim, num_heads, dropout) |
| self.dropout = nn.Dropout(dropout) |
| |
| self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2) |
| self.fc2 = nn.Linear(hidden_dim // 2, hidden_dim // 4) |
| self.fc3 = nn.Linear(hidden_dim // 4, 1) |
| |
| self.bn1 = nn.BatchNorm1d(hidden_dim // 2) |
| self.bn2 = nn.BatchNorm1d(hidden_dim // 4) |
| |
| def forward(self, x): |
| batch_size = x.size(0) |
| x = self.input_projection(x) |
| x = x.unsqueeze(1) |
| |
| x = self.mhsa1(x) |
| x = self.mhsa2(x) |
| gru_out, hidden = self.gru(x) |
| x = self.mhsa3(gru_out) |
| x = x[:, -1, :] |
| |
| x = self.dropout(x) |
| x = torch.relu(self.bn1(self.fc1(x))) |
| x = self.dropout(x) |
| x = torch.relu(self.bn2(self.fc2(x))) |
| x = self.dropout(x) |
| x = self.fc3(x) |
| |
| return torch.sigmoid(x) |
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| API_VERSION = "1.0.0" |
| MODEL_VERSION = "MHSA-GRU-Transformer-v1.0" |
|
|
| |
| MODEL_REPO = { |
| "repo_id": "camlas/toxicity", |
| "files": { |
| "classifier": "mhsa_gru_classifier.pth", |
| "scaler": "scaler.pkl", |
| "config": "config.json", |
| "model_weights": "model.safetensors", |
| "vocab": "vocab.txt", |
| "tokenizer_config": "tokenizer_config.json", |
| "special_tokens_map": "special_tokens_map.json" |
| } |
| } |
|
|
| |
| classifier = None |
| scaler = None |
| transformer_model = None |
| transformer_tokenizer = None |
| EMBEDDING_TYPE = "Bert" |
| MODEL_NAME = "ProtBERT" |
|
|
|
|
| |
| class SequenceRequest(BaseModel): |
| sequence: str |
|
|
|
|
| class BatchSequenceRequest(BaseModel): |
| sequences: List[str] |
|
|
|
|
| class PredictionResponse(BaseModel): |
| status_code: int |
| status: str |
| success: bool |
| data: Optional[dict] = None |
| error: Optional[str] = None |
| error_code: Optional[str] = None |
| timestamp: str |
| api_version: str |
| processing_time_ms: float |
|
|
|
|
| class HealthResponse(BaseModel): |
| status_code: int |
| status: str |
| service: str |
| api_version: str |
| model_version: str |
| models_loaded: bool |
| models_loaded_count: int |
| total_models_required: int |
| model_sources: dict |
| repository_info: dict |
| device: str |
| timestamp: str |
|
|
|
|
| |
| def create_kmers(sequence, k=6): |
| """Convert DNA sequence to k-mer tokens (for DNABERT)""" |
| kmers = [] |
| for i in range(len(sequence) - k + 1): |
| kmer = sequence[i:i+k] |
| kmers.append(kmer) |
| return ' '.join(kmers) |
|
|
|
|
| def ensure_models_directory(): |
| models_dir = "models" |
| if not os.path.exists(models_dir): |
| os.makedirs(models_dir) |
| print(f"✅ Created {models_dir} directory") |
| return models_dir |
|
|
|
|
| def download_model_from_hub(model_name: str) -> Optional[str]: |
| """Download individual model files from HuggingFace Hub""" |
| try: |
| if model_name not in MODEL_REPO["files"]: |
| raise ValueError(f"Unknown model: {model_name}") |
|
|
| filename = MODEL_REPO["files"][model_name] |
| repo_id = MODEL_REPO["repo_id"] |
| models_dir = ensure_models_directory() |
| local_path = os.path.join(models_dir, filename) |
|
|
| if os.path.exists(local_path): |
| print(f"✅ Found {model_name} in local models directory: {local_path}") |
| return local_path |
|
|
| print(f"📥 Downloading {model_name} ({filename}) from {repo_id}...") |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
| if not token: |
| print("⚠️ Warning: No HF token found. This may fail for private repositories.") |
| |
| temp_model_path = hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| repo_type="model", |
| token=token |
| ) |
|
|
| shutil.copy2(temp_model_path, local_path) |
| print(f"✅ {model_name} downloaded and stored!") |
| return local_path |
|
|
| except Exception as e: |
| print(f"❌ Error downloading {model_name}: {e}") |
| return None |
|
|
|
|
| def extract_features_from_sequence(sequence: str): |
| """Extract features from sequence using ProtBERT""" |
| global transformer_model, transformer_tokenizer |
| |
| if transformer_model is None or transformer_tokenizer is None: |
| raise ValueError("ProtBERT model not loaded") |
| |
| |
| |
| processed_seq = ' '.join(list(sequence.upper())) |
| |
| |
| inputs = transformer_tokenizer( |
| processed_seq, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=512 |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = transformer_model(**inputs) |
| |
| cls_embeddings = outputs.last_hidden_state[:, 0, :] |
| |
| return cls_embeddings.cpu().numpy() |
|
|
|
|
| def load_all_models(): |
| """Load all models from HuggingFace Hub""" |
| global classifier, scaler, transformer_model, transformer_tokenizer |
| |
| models_dir = ensure_models_directory() |
| models_loaded = { |
| "classifier": False, |
| "scaler": False, |
| "transformer_model": False, |
| "transformer_tokenizer": False |
| } |
|
|
| print(f"🚀 Loading models from {MODEL_REPO['repo_id']}...") |
| print("=" * 60) |
|
|
| try: |
| |
| print("📥 Downloading ProtBERT model files...") |
| |
| files_to_download = ["config", "model_weights", "vocab", |
| "tokenizer_config", "special_tokens_map"] |
| |
| for file_key in files_to_download: |
| download_model_from_hub(file_key) |
| |
| |
| print("🔄 Loading ProtBERT tokenizer...") |
| try: |
| transformer_tokenizer = BertTokenizer.from_pretrained( |
| models_dir, |
| do_lower_case=False, |
| local_files_only=True |
| ) |
| models_loaded["transformer_tokenizer"] = True |
| print("✅ ProtBERT tokenizer loaded!") |
| except Exception as e: |
| print(f"❌ Error loading tokenizer: {e}") |
| |
| print("🔄 Trying to load tokenizer directly from HuggingFace...") |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
| transformer_tokenizer = BertTokenizer.from_pretrained( |
| MODEL_REPO["repo_id"], |
| do_lower_case=False, |
| token=token |
| ) |
| models_loaded["transformer_tokenizer"] = True |
| print("✅ ProtBERT tokenizer loaded from HuggingFace!") |
| |
| |
| print("🔄 Loading ProtBERT model...") |
| try: |
| transformer_model = BertModel.from_pretrained( |
| models_dir, |
| local_files_only=True |
| ) |
| models_loaded["transformer_model"] = True |
| print("✅ ProtBERT model loaded!") |
| except Exception as e: |
| print(f"❌ Error loading model: {e}") |
| |
| print("🔄 Trying to load model directly from HuggingFace...") |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
| transformer_model = BertModel.from_pretrained( |
| MODEL_REPO["repo_id"], |
| token=token |
| ) |
| models_loaded["transformer_model"] = True |
| print("✅ ProtBERT model loaded from HuggingFace!") |
| |
| transformer_model.to(device) |
| transformer_model.eval() |
|
|
| |
| print("🔄 Loading classifier (MHSA-GRU)...") |
| clf_path = os.path.join(models_dir, MODEL_REPO["files"]["classifier"]) |
| |
| if not os.path.exists(clf_path): |
| print("📥 Classifier not found locally, downloading...") |
| clf_path = download_model_from_hub("classifier") |
| |
| if clf_path and os.path.exists(clf_path): |
| checkpoint = torch.load(clf_path, map_location=device, weights_only=False) |
| |
| |
| if 'input_dim' in checkpoint: |
| input_dim = checkpoint['input_dim'] |
| else: |
| |
| input_dim = 1024 |
| |
| classifier = MHSA_GRU(input_dim, hidden_dim=256) |
| |
| |
| if 'model_state_dict' in checkpoint: |
| classifier.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| classifier.load_state_dict(checkpoint) |
| |
| classifier.to(device) |
| classifier.eval() |
| models_loaded["classifier"] = True |
| print(f"✅ Classifier loaded! (input_dim: {input_dim})") |
|
|
| |
| print("🔄 Loading feature scaler...") |
| scaler_path = os.path.join(models_dir, MODEL_REPO["files"]["scaler"]) |
| |
| if not os.path.exists(scaler_path): |
| print("📥 Scaler not found locally, downloading...") |
| scaler_path = download_model_from_hub("scaler") |
| |
| if scaler_path and os.path.exists(scaler_path): |
| scaler = joblib.load(scaler_path) |
| models_loaded["scaler"] = True |
| print("✅ Scaler loaded!") |
|
|
| loaded_count = sum(models_loaded.values()) |
| total_count = len(models_loaded) |
|
|
| print(f"\n📊 Model Loading Summary:") |
| print(f" • Successfully loaded: {loaded_count}/{total_count}") |
| print(f" • Repository: {MODEL_REPO['repo_id']}") |
| print(f" • Embedding Model: {MODEL_NAME}") |
| print(f" • Device: {device}") |
| |
| critical_models = ["classifier", "scaler", "transformer_model", "transformer_tokenizer"] |
| critical_loaded = all(models_loaded[m] for m in critical_models) |
|
|
| if critical_loaded: |
| print("🎉 All critical models loaded successfully!") |
| return True |
| else: |
| print("⚠️ Some critical models failed to load") |
| print(f" Models status: {models_loaded}") |
| return False |
|
|
| except Exception as e: |
| print(f"❌ Error loading models: {e}") |
| import traceback |
| traceback.print_exc() |
| return False |
|
|
|
|
| |
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| |
| print("🚀 Starting Toxicity Prediction API...") |
| success = load_all_models() |
| if not success: |
| print("⚠️ Warning: Not all models loaded successfully") |
| yield |
| |
| print("🔄 Shutting down API...") |
|
|
|
|
| app = FastAPI( |
| title="Toxicity Prediction API", |
| description="API for toxicity prediction using MHSA-GRU with Transformer embeddings", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "message": "Toxicity Prediction API", |
| "version": API_VERSION, |
| "endpoints": { |
| "/predict": "POST - Predict toxicity for a single sequence", |
| "/predict/batch": "POST - Predict toxicity for multiple sequences", |
| "/health": "GET - Check API health and model status" |
| } |
| } |
|
|
|
|
| @app.post("/predict", response_model=PredictionResponse) |
| async def predict(request: SequenceRequest): |
| start_time = time.time() |
| timestamp = datetime.now(timezone.utc).isoformat() |
|
|
| try: |
| if not request.sequence or len(request.sequence) == 0: |
| raise HTTPException( |
| status_code=400, |
| detail={ |
| "status_code": 400, |
| "status": "error", |
| "success": False, |
| "error": "No sequence provided", |
| "error_code": "MISSING_SEQUENCE", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": round((time.time() - start_time) * 1000, 2) |
| } |
| ) |
|
|
| |
| if classifier is None or scaler is None or transformer_model is None: |
| raise HTTPException( |
| status_code=503, |
| detail={ |
| "status_code": 503, |
| "status": "error", |
| "success": False, |
| "error": "Models not loaded properly", |
| "error_code": "MODEL_NOT_LOADED", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": round((time.time() - start_time) * 1000, 2) |
| } |
| ) |
|
|
| |
| sequence = request.sequence.upper().strip() |
| if len(sequence) < 10: |
| raise HTTPException( |
| status_code=400, |
| detail={ |
| "status_code": 400, |
| "status": "error", |
| "success": False, |
| "error": "Sequence too short (minimum 10 characters)", |
| "error_code": "SEQUENCE_TOO_SHORT", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": round((time.time() - start_time) * 1000, 2) |
| } |
| ) |
|
|
| |
| features = extract_features_from_sequence(sequence) |
|
|
| |
| scaled_features = scaler.transform(features) |
|
|
| |
| features_tensor = torch.FloatTensor(scaled_features).to(device) |
| |
| with torch.no_grad(): |
| probability = classifier(features_tensor).cpu().numpy()[0, 0] |
|
|
| |
| prediction_class = 1 if probability > 0.5 else 0 |
| predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" |
| confidence = float(abs(probability - 0.5) * 2) |
|
|
| |
| if confidence > 0.8: |
| confidence_level = "high" |
| elif confidence > 0.6: |
| confidence_level = "medium" |
| else: |
| confidence_level = "low" |
|
|
| processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
| return PredictionResponse( |
| status_code=200, |
| status="success", |
| success=True, |
| data={ |
| "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, |
| "sequence_length": len(sequence), |
| "prediction": { |
| "predicted_class": predicted_label, |
| "confidence": confidence, |
| "confidence_level": confidence_level, |
| "toxicity_score": float(probability), |
| "non_toxicity_score": float(1 - probability) |
| }, |
| "metadata": { |
| "embedding_model": MODEL_NAME, |
| "embedding_type": EMBEDDING_TYPE, |
| "model_version": MODEL_VERSION, |
| "device": str(device) |
| } |
| }, |
| timestamp=timestamp, |
| api_version=API_VERSION, |
| processing_time_ms=processing_time |
| ) |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| processing_time = round((time.time() - start_time) * 1000, 2) |
| raise HTTPException( |
| status_code=500, |
| detail={ |
| "status_code": 500, |
| "status": "error", |
| "success": False, |
| "error": f"Internal server error: {str(e)}", |
| "error_code": "INTERNAL_ERROR", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": processing_time |
| } |
| ) |
|
|
|
|
| @app.post("/predict/batch", response_model=PredictionResponse) |
| async def predict_batch(request: BatchSequenceRequest): |
| start_time = time.time() |
| timestamp = datetime.now(timezone.utc).isoformat() |
|
|
| try: |
| if not request.sequences or len(request.sequences) == 0: |
| raise HTTPException( |
| status_code=400, |
| detail={ |
| "status_code": 400, |
| "status": "error", |
| "success": False, |
| "error": "No sequences provided", |
| "error_code": "MISSING_SEQUENCES", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": round((time.time() - start_time) * 1000, 2) |
| } |
| ) |
|
|
| |
| if classifier is None or scaler is None or transformer_model is None: |
| raise HTTPException( |
| status_code=503, |
| detail={ |
| "status_code": 503, |
| "status": "error", |
| "success": False, |
| "error": "Models not loaded properly", |
| "error_code": "MODEL_NOT_LOADED", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": round((time.time() - start_time) * 1000, 2) |
| } |
| ) |
|
|
| results = [] |
| |
| for seq in request.sequences: |
| sequence = seq.upper().strip() |
| |
| |
| features = extract_features_from_sequence(sequence) |
| scaled_features = scaler.transform(features) |
| features_tensor = torch.FloatTensor(scaled_features).to(device) |
| |
| with torch.no_grad(): |
| probability = classifier(features_tensor).cpu().numpy()[0, 0] |
| |
| prediction_class = 1 if probability > 0.5 else 0 |
| predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" |
| confidence = float(abs(probability - 0.5) * 2) |
| |
| results.append({ |
| "sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, |
| "sequence_length": len(sequence), |
| "predicted_class": predicted_label, |
| "toxicity_score": float(probability), |
| "confidence": confidence |
| }) |
|
|
| processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
| return PredictionResponse( |
| status_code=200, |
| status="success", |
| success=True, |
| data={ |
| "total_sequences": len(request.sequences), |
| "results": results, |
| "metadata": { |
| "embedding_model": MODEL_NAME, |
| "embedding_type": EMBEDDING_TYPE, |
| "model_version": MODEL_VERSION, |
| "device": str(device) |
| } |
| }, |
| timestamp=timestamp, |
| api_version=API_VERSION, |
| processing_time_ms=processing_time |
| ) |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| processing_time = round((time.time() - start_time) * 1000, 2) |
| raise HTTPException( |
| status_code=500, |
| detail={ |
| "status_code": 500, |
| "status": "error", |
| "success": False, |
| "error": f"Internal server error: {str(e)}", |
| "error_code": "INTERNAL_ERROR", |
| "timestamp": timestamp, |
| "api_version": API_VERSION, |
| "processing_time_ms": processing_time |
| } |
| ) |
|
|
|
|
| @app.get("/health", response_model=HealthResponse) |
| async def health_check(): |
| models_loaded = all([ |
| classifier is not None, |
| scaler is not None, |
| transformer_model is not None, |
| transformer_tokenizer is not None |
| ]) |
|
|
| model_sources = { |
| "classifier": { |
| "loaded": classifier is not None, |
| "source": "huggingface_hub", |
| "repository": MODEL_REPO["repo_id"] |
| }, |
| "scaler": { |
| "loaded": scaler is not None, |
| "source": "huggingface_hub", |
| "repository": MODEL_REPO["repo_id"] |
| }, |
| "transformer_model": { |
| "loaded": transformer_model is not None, |
| "model_name": MODEL_NAME, |
| "source": "huggingface_hub", |
| "repository": MODEL_REPO["repo_id"] |
| } |
| } |
|
|
| repository_info = { |
| "repository_id": MODEL_REPO["repo_id"], |
| "embedding_type": EMBEDDING_TYPE, |
| "model_name": MODEL_NAME, |
| "total_models": len(MODEL_REPO["files"]) |
| } |
|
|
| return HealthResponse( |
| status_code=200 if models_loaded else 503, |
| status="healthy" if models_loaded else "unhealthy", |
| service="Toxicity Prediction API", |
| api_version=API_VERSION, |
| model_version=MODEL_VERSION, |
| models_loaded=models_loaded, |
| models_loaded_count=sum(1 for source in model_sources.values() if source["loaded"]), |
| total_models_required=4, |
| model_sources=model_sources, |
| repository_info=repository_info, |
| device=str(device), |
| timestamp=datetime.now(timezone.utc).isoformat() |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |