|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from typing import Optional, List |
|
|
import time |
|
|
from datetime import datetime, timezone |
|
|
import os |
|
|
import warnings |
|
|
from huggingface_hub import hf_hub_download |
|
|
from contextlib import asynccontextmanager |
|
|
import uvicorn |
|
|
from dotenv import load_dotenv |
|
|
import shutil |
|
|
import joblib |
|
|
from pathlib import Path |
|
|
from transformers import BertTokenizer, BertModel |
|
|
from utils.model_classes import MHSA_GRU, MultiHeadSelfAttention |
|
|
|
|
|
load_dotenv() |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
API_VERSION = "1.0.0" |
|
|
MODEL_VERSION = "MHSA-GRU-Transformer-v1.0" |
|
|
|
|
|
|
|
|
MODEL_REPO = { |
|
|
"repo_id": "camlas/toxicity", |
|
|
"files": { |
|
|
"classifier": "mhsa_gru_classifier.pth", |
|
|
"scaler": "scaler.pkl", |
|
|
"config": "config.json", |
|
|
"model_weights": "model.safetensors", |
|
|
"vocab": "vocab.txt", |
|
|
"tokenizer_config": "tokenizer_config.json", |
|
|
"special_tokens_map": "special_tokens_map.json" |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
classifier = None |
|
|
scaler = None |
|
|
transformer_model = None |
|
|
transformer_tokenizer = None |
|
|
EMBEDDING_TYPE = "Bert" |
|
|
MODEL_NAME = "ProtBERT" |
|
|
|
|
|
|
|
|
|
|
|
class SequenceRequest(BaseModel): |
|
|
sequence: str |
|
|
|
|
|
|
|
|
class BatchSequenceRequest(BaseModel): |
|
|
sequences: List[str] |
|
|
|
|
|
|
|
|
class PredictionResponse(BaseModel): |
|
|
status_code: int |
|
|
status: str |
|
|
success: bool |
|
|
data: Optional[dict] = None |
|
|
error: Optional[str] = None |
|
|
error_code: Optional[str] = None |
|
|
timestamp: str |
|
|
api_version: str |
|
|
processing_time_ms: float |
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel): |
|
|
status_code: int |
|
|
status: str |
|
|
service: str |
|
|
api_version: str |
|
|
model_version: str |
|
|
models_loaded: bool |
|
|
models_loaded_count: int |
|
|
total_models_required: int |
|
|
model_sources: dict |
|
|
repository_info: dict |
|
|
device: str |
|
|
timestamp: str |
|
|
|
|
|
|
|
|
|
|
|
def create_kmers(sequence, k=6): |
|
|
"""Convert DNA sequence to k-mer tokens (for DNABERT)""" |
|
|
kmers = [] |
|
|
for i in range(len(sequence) - k + 1): |
|
|
kmer = sequence[i:i+k] |
|
|
kmers.append(kmer) |
|
|
return ' '.join(kmers) |
|
|
|
|
|
|
|
|
def ensure_models_directory(): |
|
|
models_dir = "models" |
|
|
if not os.path.exists(models_dir): |
|
|
os.makedirs(models_dir) |
|
|
print(f"✅ Created {models_dir} directory") |
|
|
return models_dir |
|
|
|
|
|
|
|
|
def download_model_from_hub(model_name: str) -> Optional[str]: |
|
|
"""Download individual model files from HuggingFace Hub""" |
|
|
try: |
|
|
if model_name not in MODEL_REPO["files"]: |
|
|
raise ValueError(f"Unknown model: {model_name}") |
|
|
|
|
|
filename = MODEL_REPO["files"][model_name] |
|
|
repo_id = MODEL_REPO["repo_id"] |
|
|
models_dir = ensure_models_directory() |
|
|
local_path = os.path.join(models_dir, filename) |
|
|
|
|
|
if os.path.exists(local_path): |
|
|
print(f"✅ Found {model_name} in local models directory: {local_path}") |
|
|
return local_path |
|
|
|
|
|
print(f"📥 Downloading {model_name} ({filename}) from {repo_id}...") |
|
|
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
|
|
|
if not token: |
|
|
print("⚠️ Warning: No HF token found. This may fail for private repositories.") |
|
|
|
|
|
temp_model_path = hf_hub_download( |
|
|
repo_id=repo_id, |
|
|
filename=filename, |
|
|
repo_type="model", |
|
|
token=token |
|
|
) |
|
|
|
|
|
shutil.copy2(temp_model_path, local_path) |
|
|
print(f"✅ {model_name} downloaded and stored!") |
|
|
return local_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error downloading {model_name}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def extract_features_from_sequence(sequence: str): |
|
|
"""Extract features from sequence using ProtBERT""" |
|
|
global transformer_model, transformer_tokenizer |
|
|
|
|
|
if transformer_model is None or transformer_tokenizer is None: |
|
|
raise ValueError("ProtBERT model not loaded") |
|
|
|
|
|
|
|
|
|
|
|
processed_seq = ' '.join(list(sequence.upper())) |
|
|
|
|
|
|
|
|
inputs = transformer_tokenizer( |
|
|
processed_seq, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=512 |
|
|
) |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = transformer_model(**inputs) |
|
|
|
|
|
cls_embeddings = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
return cls_embeddings.cpu().numpy() |
|
|
|
|
|
|
|
|
def load_all_models(): |
|
|
"""Load all models from HuggingFace Hub""" |
|
|
global classifier, scaler, transformer_model, transformer_tokenizer |
|
|
|
|
|
models_dir = ensure_models_directory() |
|
|
models_loaded = { |
|
|
"classifier": False, |
|
|
"scaler": False, |
|
|
"transformer_model": False, |
|
|
"transformer_tokenizer": False |
|
|
} |
|
|
|
|
|
print(f"🚀 Loading models from {MODEL_REPO['repo_id']}...") |
|
|
print("=" * 60) |
|
|
|
|
|
try: |
|
|
|
|
|
print("📥 Downloading ProtBERT model files...") |
|
|
|
|
|
files_to_download = ["config", "model_weights", "vocab", |
|
|
"tokenizer_config", "special_tokens_map"] |
|
|
|
|
|
for file_key in files_to_download: |
|
|
download_model_from_hub(file_key) |
|
|
|
|
|
|
|
|
print("🔄 Loading ProtBERT tokenizer...") |
|
|
try: |
|
|
transformer_tokenizer = BertTokenizer.from_pretrained( |
|
|
models_dir, |
|
|
do_lower_case=False, |
|
|
local_files_only=True |
|
|
) |
|
|
models_loaded["transformer_tokenizer"] = True |
|
|
print("✅ ProtBERT tokenizer loaded!") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading tokenizer: {e}") |
|
|
|
|
|
print("🔄 Trying to load tokenizer directly from HuggingFace...") |
|
|
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
transformer_tokenizer = BertTokenizer.from_pretrained( |
|
|
MODEL_REPO["repo_id"], |
|
|
do_lower_case=False, |
|
|
token=token |
|
|
) |
|
|
models_loaded["transformer_tokenizer"] = True |
|
|
print("✅ ProtBERT tokenizer loaded from HuggingFace!") |
|
|
|
|
|
|
|
|
print("🔄 Loading ProtBERT model...") |
|
|
try: |
|
|
transformer_model = BertModel.from_pretrained( |
|
|
models_dir, |
|
|
local_files_only=True |
|
|
) |
|
|
models_loaded["transformer_model"] = True |
|
|
print("✅ ProtBERT model loaded!") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading model: {e}") |
|
|
|
|
|
print("🔄 Trying to load model directly from HuggingFace...") |
|
|
token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") |
|
|
transformer_model = BertModel.from_pretrained( |
|
|
MODEL_REPO["repo_id"], |
|
|
token=token |
|
|
) |
|
|
models_loaded["transformer_model"] = True |
|
|
print("✅ ProtBERT model loaded from HuggingFace!") |
|
|
|
|
|
transformer_model.to(device) |
|
|
transformer_model.eval() |
|
|
|
|
|
|
|
|
print("🔄 Loading classifier (MHSA-GRU)...") |
|
|
clf_path = os.path.join(models_dir, MODEL_REPO["files"]["classifier"]) |
|
|
|
|
|
if not os.path.exists(clf_path): |
|
|
print("📥 Classifier not found locally, downloading...") |
|
|
clf_path = download_model_from_hub("classifier") |
|
|
|
|
|
if clf_path and os.path.exists(clf_path): |
|
|
checkpoint = torch.load(clf_path, map_location=device, weights_only=False) |
|
|
|
|
|
|
|
|
if 'input_dim' in checkpoint: |
|
|
input_dim = checkpoint['input_dim'] |
|
|
else: |
|
|
|
|
|
input_dim = 1024 |
|
|
|
|
|
classifier = MHSA_GRU(input_dim, hidden_dim=256) |
|
|
|
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
classifier.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
classifier.load_state_dict(checkpoint) |
|
|
|
|
|
classifier.to(device) |
|
|
classifier.eval() |
|
|
models_loaded["classifier"] = True |
|
|
print(f"✅ Classifier loaded! (input_dim: {input_dim})") |
|
|
|
|
|
|
|
|
print("🔄 Loading feature scaler...") |
|
|
scaler_path = os.path.join(models_dir, MODEL_REPO["files"]["scaler"]) |
|
|
|
|
|
if not os.path.exists(scaler_path): |
|
|
print("📥 Scaler not found locally, downloading...") |
|
|
scaler_path = download_model_from_hub("scaler") |
|
|
|
|
|
if scaler_path and os.path.exists(scaler_path): |
|
|
scaler = joblib.load(scaler_path) |
|
|
models_loaded["scaler"] = True |
|
|
print("✅ Scaler loaded!") |
|
|
|
|
|
loaded_count = sum(models_loaded.values()) |
|
|
total_count = len(models_loaded) |
|
|
|
|
|
print(f"\n📊 Model Loading Summary:") |
|
|
print(f" • Successfully loaded: {loaded_count}/{total_count}") |
|
|
print(f" • Repository: {MODEL_REPO['repo_id']}") |
|
|
print(f" • Embedding Model: {MODEL_NAME}") |
|
|
print(f" • Device: {device}") |
|
|
|
|
|
critical_models = ["classifier", "scaler", "transformer_model", "transformer_tokenizer"] |
|
|
critical_loaded = all(models_loaded[m] for m in critical_models) |
|
|
|
|
|
if critical_loaded: |
|
|
print("🎉 All critical models loaded successfully!") |
|
|
return True |
|
|
else: |
|
|
print("⚠️ Some critical models failed to load") |
|
|
print(f" Models status: {models_loaded}") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error loading models: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
print("🚀 Starting Toxicity Prediction API...") |
|
|
success = load_all_models() |
|
|
if not success: |
|
|
print("⚠️ Warning: Not all models loaded successfully") |
|
|
yield |
|
|
|
|
|
print("🔄 Shutting down API...") |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Toxicity Prediction API", |
|
|
description="API for toxicity prediction using MHSA-GRU with Transformer embeddings", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"message": "Toxicity Prediction API", |
|
|
"version": API_VERSION, |
|
|
"endpoints": { |
|
|
"/predict": "POST - Predict toxicity for a single sequence", |
|
|
"/predict/batch": "POST - Predict toxicity for multiple sequences", |
|
|
"/example": "GET - Try the API with a hardcoded example sequence", |
|
|
"/health": "GET - Check API health and model status" |
|
|
}, |
|
|
"example_usage": { |
|
|
"single": { |
|
|
"method": "POST", |
|
|
"url": "/predict", |
|
|
"body": {"sequence": "MKTAYIAKQRQISFVKSHFSRQLE"} |
|
|
}, |
|
|
"batch": { |
|
|
"method": "POST", |
|
|
"url": "/predict/batch", |
|
|
"body": { |
|
|
"sequences": [ |
|
|
"MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES", |
|
|
"MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK" |
|
|
] |
|
|
} |
|
|
}, |
|
|
"example": { |
|
|
"method": "GET", |
|
|
"url": "/example", |
|
|
"description": "No input needed - just call this endpoint" |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse) |
|
|
async def predict(request: SequenceRequest): |
|
|
start_time = time.time() |
|
|
timestamp = datetime.now(timezone.utc).isoformat() |
|
|
|
|
|
try: |
|
|
if not request.sequence or len(request.sequence) == 0: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail={ |
|
|
"status_code": 400, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": "No sequence provided", |
|
|
"error_code": "MISSING_SEQUENCE", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": round((time.time() - start_time) * 1000, 2) |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if classifier is None or scaler is None or transformer_model is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail={ |
|
|
"status_code": 503, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": "Models not loaded properly", |
|
|
"error_code": "MODEL_NOT_LOADED", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": round((time.time() - start_time) * 1000, 2) |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
sequence = request.sequence.upper().strip() |
|
|
if len(sequence) < 10: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail={ |
|
|
"status_code": 400, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": "Sequence too short (minimum 10 characters)", |
|
|
"error_code": "SEQUENCE_TOO_SHORT", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": round((time.time() - start_time) * 1000, 2) |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
features = extract_features_from_sequence(sequence) |
|
|
|
|
|
|
|
|
scaled_features = scaler.transform(features) |
|
|
|
|
|
|
|
|
features_tensor = torch.FloatTensor(scaled_features).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
probability = classifier(features_tensor).cpu().numpy()[0, 0] |
|
|
|
|
|
|
|
|
prediction_class = 1 if probability > 0.5 else 0 |
|
|
predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" |
|
|
confidence = float(abs(probability - 0.5) * 2) |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
confidence_level = "high" |
|
|
elif confidence > 0.6: |
|
|
confidence_level = "medium" |
|
|
else: |
|
|
confidence_level = "low" |
|
|
|
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
|
|
|
return PredictionResponse( |
|
|
status_code=200, |
|
|
status="success", |
|
|
success=True, |
|
|
data={ |
|
|
"sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, |
|
|
"sequence_length": len(sequence), |
|
|
"prediction": { |
|
|
"predicted_class": predicted_label, |
|
|
"confidence": confidence, |
|
|
"confidence_level": confidence_level, |
|
|
"toxicity_score": float(probability), |
|
|
"non_toxicity_score": float(1 - probability) |
|
|
}, |
|
|
"metadata": { |
|
|
"embedding_model": MODEL_NAME, |
|
|
"embedding_type": EMBEDDING_TYPE, |
|
|
"model_version": MODEL_VERSION, |
|
|
"device": str(device) |
|
|
} |
|
|
}, |
|
|
timestamp=timestamp, |
|
|
api_version=API_VERSION, |
|
|
processing_time_ms=processing_time |
|
|
) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail={ |
|
|
"status_code": 500, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": f"Internal server error: {str(e)}", |
|
|
"error_code": "INTERNAL_ERROR", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": processing_time |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/predict/batch", response_model=PredictionResponse) |
|
|
async def predict_batch(request: BatchSequenceRequest): |
|
|
""" |
|
|
Predict toxicity for multiple sequences at once. |
|
|
|
|
|
Example request body: |
|
|
{ |
|
|
"sequences": [ |
|
|
"MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES", |
|
|
"MFGLPQQEVSEEEKRAHQEQTEKTLKQAAYVAAFLWVSPMIWHLVKKQWK" |
|
|
] |
|
|
} |
|
|
""" |
|
|
start_time = time.time() |
|
|
timestamp = datetime.now(timezone.utc).isoformat() |
|
|
|
|
|
try: |
|
|
if not request.sequences or len(request.sequences) == 0: |
|
|
raise HTTPException( |
|
|
status_code=400, |
|
|
detail={ |
|
|
"status_code": 400, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": "No sequences provided", |
|
|
"error_code": "MISSING_SEQUENCES", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": round((time.time() - start_time) * 1000, 2) |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if classifier is None or scaler is None or transformer_model is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail={ |
|
|
"status_code": 503, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": "Models not loaded properly", |
|
|
"error_code": "MODEL_NOT_LOADED", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": round((time.time() - start_time) * 1000, 2) |
|
|
} |
|
|
) |
|
|
|
|
|
results = [] |
|
|
|
|
|
for idx, seq in enumerate(request.sequences, 1): |
|
|
try: |
|
|
sequence = seq.upper().strip() |
|
|
|
|
|
|
|
|
if len(sequence) < 10: |
|
|
results.append({ |
|
|
"sequence_index": idx, |
|
|
"sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, |
|
|
"sequence_length": len(sequence), |
|
|
"error": "Sequence too short (minimum 10 characters)", |
|
|
"predicted_class": None, |
|
|
"toxicity_score": None, |
|
|
"confidence": None |
|
|
}) |
|
|
continue |
|
|
|
|
|
|
|
|
features = extract_features_from_sequence(sequence) |
|
|
scaled_features = scaler.transform(features) |
|
|
features_tensor = torch.FloatTensor(scaled_features).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
probability = classifier(features_tensor).cpu().numpy()[0, 0] |
|
|
|
|
|
prediction_class = 1 if probability > 0.5 else 0 |
|
|
predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" |
|
|
confidence = float(abs(probability - 0.5) * 2) |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
confidence_level = "high" |
|
|
elif confidence > 0.6: |
|
|
confidence_level = "medium" |
|
|
else: |
|
|
confidence_level = "low" |
|
|
|
|
|
results.append({ |
|
|
"sequence_index": idx, |
|
|
"sequence": sequence[:100] + "..." if len(sequence) > 100 else sequence, |
|
|
"sequence_length": len(sequence), |
|
|
"predicted_class": predicted_label, |
|
|
"toxicity_score": float(probability), |
|
|
"non_toxicity_score": float(1 - probability), |
|
|
"confidence": confidence, |
|
|
"confidence_level": confidence_level, |
|
|
"error": None |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
results.append({ |
|
|
"sequence_index": idx, |
|
|
"sequence": seq[:100] + "..." if len(seq) > 100 else seq, |
|
|
"sequence_length": len(seq), |
|
|
"error": f"Error processing sequence: {str(e)}", |
|
|
"predicted_class": None, |
|
|
"toxicity_score": None, |
|
|
"confidence": None |
|
|
}) |
|
|
|
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
|
|
|
|
|
|
successful_predictions = sum(1 for r in results if r.get("predicted_class") is not None) |
|
|
|
|
|
return PredictionResponse( |
|
|
status_code=200, |
|
|
status="success", |
|
|
success=True, |
|
|
data={ |
|
|
"total_sequences": len(request.sequences), |
|
|
"successful_predictions": successful_predictions, |
|
|
"failed_predictions": len(request.sequences) - successful_predictions, |
|
|
"results": results, |
|
|
"metadata": { |
|
|
"embedding_model": MODEL_NAME, |
|
|
"embedding_type": EMBEDDING_TYPE, |
|
|
"model_version": MODEL_VERSION, |
|
|
"device": str(device) |
|
|
} |
|
|
}, |
|
|
timestamp=timestamp, |
|
|
api_version=API_VERSION, |
|
|
processing_time_ms=processing_time |
|
|
) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail={ |
|
|
"status_code": 500, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": f"Internal server error: {str(e)}", |
|
|
"error_code": "INTERNAL_ERROR", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": processing_time |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/example", response_model=PredictionResponse) |
|
|
async def predict_example(): |
|
|
""" |
|
|
Predict using a hardcoded example protein sequence. |
|
|
No input required - just call this endpoint to see how the API works. |
|
|
|
|
|
Example sequence: MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES |
|
|
""" |
|
|
start_time = time.time() |
|
|
timestamp = datetime.now(timezone.utc).isoformat() |
|
|
|
|
|
|
|
|
EXAMPLE_SEQUENCE = "MLLPATMSDKPDMAEIEKFDKSKLKKTETQEKNPLPSKETIEQEKQAGES" |
|
|
|
|
|
try: |
|
|
|
|
|
if classifier is None or scaler is None or transformer_model is None: |
|
|
raise HTTPException( |
|
|
status_code=503, |
|
|
detail={ |
|
|
"status_code": 503, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": "Models not loaded properly", |
|
|
"error_code": "MODEL_NOT_LOADED", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": round((time.time() - start_time) * 1000, 2) |
|
|
} |
|
|
) |
|
|
|
|
|
sequence = EXAMPLE_SEQUENCE.upper().strip() |
|
|
|
|
|
|
|
|
features = extract_features_from_sequence(sequence) |
|
|
|
|
|
|
|
|
scaled_features = scaler.transform(features) |
|
|
|
|
|
|
|
|
features_tensor = torch.FloatTensor(scaled_features).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
probability = classifier(features_tensor).cpu().numpy()[0, 0] |
|
|
|
|
|
|
|
|
prediction_class = 1 if probability > 0.5 else 0 |
|
|
predicted_label = "Toxic" if prediction_class == 1 else "Non-Toxic" |
|
|
confidence = float(abs(probability - 0.5) * 2) |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
confidence_level = "high" |
|
|
elif confidence > 0.6: |
|
|
confidence_level = "medium" |
|
|
else: |
|
|
confidence_level = "low" |
|
|
|
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
|
|
|
return PredictionResponse( |
|
|
status_code=200, |
|
|
status="success", |
|
|
success=True, |
|
|
data={ |
|
|
"note": "This is an example prediction using a hardcoded sequence", |
|
|
"sequence": sequence, |
|
|
"sequence_length": len(sequence), |
|
|
"prediction": { |
|
|
"predicted_class": predicted_label, |
|
|
"confidence": confidence, |
|
|
"confidence_level": confidence_level, |
|
|
"toxicity_score": float(probability), |
|
|
"non_toxicity_score": float(1 - probability) |
|
|
}, |
|
|
"metadata": { |
|
|
"embedding_model": MODEL_NAME, |
|
|
"embedding_type": EMBEDDING_TYPE, |
|
|
"model_version": MODEL_VERSION, |
|
|
"device": str(device), |
|
|
"source": "hardcoded_example" |
|
|
} |
|
|
}, |
|
|
timestamp=timestamp, |
|
|
api_version=API_VERSION, |
|
|
processing_time_ms=processing_time |
|
|
) |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
processing_time = round((time.time() - start_time) * 1000, 2) |
|
|
raise HTTPException( |
|
|
status_code=500, |
|
|
detail={ |
|
|
"status_code": 500, |
|
|
"status": "error", |
|
|
"success": False, |
|
|
"error": f"Internal server error: {str(e)}", |
|
|
"error_code": "INTERNAL_ERROR", |
|
|
"timestamp": timestamp, |
|
|
"api_version": API_VERSION, |
|
|
"processing_time_ms": processing_time |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
models_loaded = all([ |
|
|
classifier is not None, |
|
|
scaler is not None, |
|
|
transformer_model is not None, |
|
|
transformer_tokenizer is not None |
|
|
]) |
|
|
|
|
|
model_sources = { |
|
|
"classifier": { |
|
|
"loaded": classifier is not None, |
|
|
"source": "huggingface_hub", |
|
|
"repository": MODEL_REPO["repo_id"] |
|
|
}, |
|
|
"scaler": { |
|
|
"loaded": scaler is not None, |
|
|
"source": "huggingface_hub", |
|
|
"repository": MODEL_REPO["repo_id"] |
|
|
}, |
|
|
"transformer_model": { |
|
|
"loaded": transformer_model is not None, |
|
|
"model_name": MODEL_NAME, |
|
|
"source": "huggingface_hub", |
|
|
"repository": MODEL_REPO["repo_id"] |
|
|
} |
|
|
} |
|
|
|
|
|
repository_info = { |
|
|
"repository_id": MODEL_REPO["repo_id"], |
|
|
"embedding_type": EMBEDDING_TYPE, |
|
|
"model_name": MODEL_NAME, |
|
|
"total_models": len(MODEL_REPO["files"]) |
|
|
} |
|
|
|
|
|
return HealthResponse( |
|
|
status_code=200 if models_loaded else 503, |
|
|
status="healthy" if models_loaded else "unhealthy", |
|
|
service="Toxicity Prediction API", |
|
|
api_version=API_VERSION, |
|
|
model_version=MODEL_VERSION, |
|
|
models_loaded=models_loaded, |
|
|
models_loaded_count=sum(1 for source in model_sources.values() if source["loaded"]), |
|
|
total_models_required=3, |
|
|
model_sources=model_sources, |
|
|
repository_info=repository_info, |
|
|
device=str(device), |
|
|
timestamp=datetime.now(timezone.utc).isoformat() |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |