|
|
"""
|
|
|
馃殌 PERI BERT Classifier - FastAPI Backend para HuggingFace Space
|
|
|
|
|
|
API REST para clasificaci贸n de reflexiones 茅ticas sobre IA usando BERT fine-tuneado.
|
|
|
Soporta predicci贸n con MC Dropout para uncertainty quantification.
|
|
|
|
|
|
Endpoints:
|
|
|
- POST /predict - Clasificar una reflexi贸n
|
|
|
- POST /predict-batch - Clasificar m煤ltiples reflexiones
|
|
|
- GET /health - Health check
|
|
|
- GET /info - Informaci贸n del modelo
|
|
|
"""
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
from pydantic import BaseModel, Field
|
|
|
from typing import List, Optional, Dict, Any
|
|
|
import torch
|
|
|
from transformers import BertTokenizer, BertForSequenceClassification
|
|
|
import numpy as np
|
|
|
from pathlib import Path
|
|
|
import time
|
|
|
import logging
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ARCHETYPE_LABELS = {
|
|
|
0: "TECNOCRATA_OPTIMIZADOR",
|
|
|
1: "HUMANISTA_CRITICO",
|
|
|
2: "PRAGMATICO_EQUILIBRADO",
|
|
|
3: "VISIONARIO_ADAPTATIVO",
|
|
|
4: "ESCEPTICO_CONSERVADOR",
|
|
|
}
|
|
|
|
|
|
ARCHETYPE_NAMES = {
|
|
|
0: "Tecn贸crata Optimizador",
|
|
|
1: "Humanista Cr铆tico",
|
|
|
2: "Pragm谩tico Equilibrado",
|
|
|
3: "Visionario Adaptativo",
|
|
|
4: "Esc茅ptico Conservador",
|
|
|
}
|
|
|
|
|
|
ARCHETYPE_DESCRIPTIONS = {
|
|
|
0: "Conf铆a en la eficiencia y objetividad de los sistemas automatizados",
|
|
|
1: "Prioriza el bienestar humano y cuestiona activamente los sesgos tecnol贸gicos",
|
|
|
2: "Busca balance entre innovaci贸n tecnol贸gica y consideraciones humanas",
|
|
|
3: "Abraza la transformaci贸n tecnol贸gica con enfoque adaptativo y progresista",
|
|
|
4: "Mantiene una postura cautelosa y cr铆tica hacia la adopci贸n de IA",
|
|
|
}
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
MAX_LENGTH = 512
|
|
|
MC_SAMPLES = 10
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReflectionInput(BaseModel):
|
|
|
"""Input para clasificaci贸n individual"""
|
|
|
text: str = Field(..., min_length=100, max_length=5000, description="Reflexi贸n 茅tica sobre IA")
|
|
|
use_mc_dropout: bool = Field(default=True, description="Usar MC Dropout para uncertainty")
|
|
|
|
|
|
|
|
|
class BatchReflectionInput(BaseModel):
|
|
|
"""Input para clasificaci贸n en batch"""
|
|
|
texts: List[str] = Field(..., max_items=50, description="Lista de reflexiones (m谩x 50)")
|
|
|
use_mc_dropout: bool = Field(default=True, description="Usar MC Dropout para uncertainty")
|
|
|
|
|
|
|
|
|
class ArchetypeResult(BaseModel):
|
|
|
"""Resultado de clasificaci贸n"""
|
|
|
id: str
|
|
|
name: str
|
|
|
description: str
|
|
|
|
|
|
|
|
|
class PredictionResponse(BaseModel):
|
|
|
"""Respuesta de predicci贸n individual"""
|
|
|
archetype: ArchetypeResult
|
|
|
confidence: float = Field(..., ge=0.0, le=1.0, description="Confianza de la predicci贸n")
|
|
|
uncertainty: Optional[float] = Field(None, ge=0.0, description="Incertidumbre (MC Dropout)")
|
|
|
top3_predictions: List[Dict[str, Any]] = Field(..., description="Top 3 predicciones")
|
|
|
inference_time_ms: float = Field(..., description="Tiempo de inferencia en milisegundos")
|
|
|
method: str = Field(default="bert", description="M茅todo de clasificaci贸n")
|
|
|
|
|
|
|
|
|
class BatchPredictionResponse(BaseModel):
|
|
|
"""Respuesta de predicci贸n en batch"""
|
|
|
predictions: List[PredictionResponse]
|
|
|
total_inference_time_ms: float
|
|
|
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
|
"""Health check response"""
|
|
|
model_config = {"protected_namespaces": ()}
|
|
|
|
|
|
status: str
|
|
|
model_loaded: bool
|
|
|
device: str
|
|
|
timestamp: float
|
|
|
|
|
|
|
|
|
class InfoResponse(BaseModel):
|
|
|
"""Informaci贸n del modelo"""
|
|
|
model_config = {"protected_namespaces": ()}
|
|
|
|
|
|
model_name: str
|
|
|
num_classes: int
|
|
|
max_length: int
|
|
|
device: str
|
|
|
mc_dropout_samples: int
|
|
|
archetypes: List[Dict[str, str]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BERTClassifier:
|
|
|
"""Wrapper para el modelo BERT con MC Dropout"""
|
|
|
|
|
|
def __init__(self, model_path: str):
|
|
|
logger.info(f"Cargando modelo desde {model_path}...")
|
|
|
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
|
|
self.model = BertForSequenceClassification.from_pretrained(model_path)
|
|
|
self.model.to(DEVICE)
|
|
|
self.model.eval()
|
|
|
logger.info(f"Modelo cargado exitosamente en {DEVICE}")
|
|
|
|
|
|
def predict(
|
|
|
self,
|
|
|
text: str,
|
|
|
use_mc_dropout: bool = True
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Realizar predicci贸n con o sin MC Dropout
|
|
|
|
|
|
Returns:
|
|
|
dict con keys: predicted_class, confidence, uncertainty, all_probabilities
|
|
|
"""
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
encoding = self.tokenizer(
|
|
|
text,
|
|
|
max_length=MAX_LENGTH,
|
|
|
padding="max_length",
|
|
|
truncation=True,
|
|
|
return_tensors="pt"
|
|
|
)
|
|
|
|
|
|
input_ids = encoding["input_ids"].to(DEVICE)
|
|
|
attention_mask = encoding["attention_mask"].to(DEVICE)
|
|
|
|
|
|
if use_mc_dropout:
|
|
|
|
|
|
self.model.train()
|
|
|
all_probs = []
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(MC_SAMPLES):
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
logits = outputs.logits
|
|
|
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
|
|
|
all_probs.append(probs)
|
|
|
|
|
|
|
|
|
all_probs = np.array(all_probs)
|
|
|
mean_probs = np.mean(all_probs, axis=0)
|
|
|
predicted_class = int(np.argmax(mean_probs))
|
|
|
confidence = float(mean_probs[predicted_class])
|
|
|
|
|
|
|
|
|
epsilon = 1e-10
|
|
|
uncertainty = float(-np.sum(mean_probs * np.log(mean_probs + epsilon)))
|
|
|
|
|
|
self.model.eval()
|
|
|
|
|
|
else:
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
|
|
|
logits = outputs.logits
|
|
|
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
|
|
|
mean_probs = probs
|
|
|
predicted_class = int(np.argmax(probs))
|
|
|
confidence = float(probs[predicted_class])
|
|
|
uncertainty = None
|
|
|
all_probs = probs.reshape(1, -1)
|
|
|
|
|
|
|
|
|
top3_indices = np.argsort(mean_probs)[-3:][::-1]
|
|
|
top3 = [
|
|
|
{
|
|
|
"archetype_id": ARCHETYPE_LABELS[int(idx)],
|
|
|
"archetype_name": ARCHETYPE_NAMES[int(idx)],
|
|
|
"probability": float(mean_probs[idx])
|
|
|
}
|
|
|
for idx in top3_indices
|
|
|
]
|
|
|
|
|
|
inference_time = (time.time() - start_time) * 1000
|
|
|
|
|
|
return {
|
|
|
"predicted_class": predicted_class,
|
|
|
"confidence": confidence,
|
|
|
"uncertainty": uncertainty,
|
|
|
"top3": top3,
|
|
|
"inference_time_ms": inference_time,
|
|
|
"all_probabilities": mean_probs.tolist()
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
classifier: Optional[BERTClassifier] = None
|
|
|
|
|
|
|
|
|
def load_model():
|
|
|
"""Cargar modelo al iniciar la aplicaci贸n"""
|
|
|
global classifier
|
|
|
|
|
|
|
|
|
|
|
|
model_paths = [
|
|
|
Path("./model"),
|
|
|
Path("../../../models/peri-bert/best_model"),
|
|
|
]
|
|
|
|
|
|
model_path = None
|
|
|
for path in model_paths:
|
|
|
if path.exists():
|
|
|
model_path = str(path)
|
|
|
break
|
|
|
|
|
|
if model_path is None:
|
|
|
logger.error("No se encontr贸 el modelo. Aseg煤rate de subirlo a HuggingFace Space.")
|
|
|
raise RuntimeError("Model not found")
|
|
|
|
|
|
classifier = BERTClassifier(model_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
title="PERI BERT Classifier API",
|
|
|
description="API REST para clasificaci贸n de arquetipos 茅ticos en reflexiones sobre IA",
|
|
|
version="1.0.0",
|
|
|
docs_url="/",
|
|
|
)
|
|
|
|
|
|
|
|
|
app.add_middleware(
|
|
|
CORSMiddleware,
|
|
|
allow_origins=["*"],
|
|
|
allow_credentials=True,
|
|
|
allow_methods=["*"],
|
|
|
allow_headers=["*"],
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.on_event("startup")
|
|
|
async def startup_event():
|
|
|
"""Cargar modelo al iniciar"""
|
|
|
load_model()
|
|
|
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
|
async def health_check():
|
|
|
"""Health check endpoint"""
|
|
|
return HealthResponse(
|
|
|
status="healthy",
|
|
|
model_loaded=classifier is not None,
|
|
|
device=DEVICE,
|
|
|
timestamp=time.time()
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.get("/info", response_model=InfoResponse)
|
|
|
async def model_info():
|
|
|
"""Informaci贸n del modelo"""
|
|
|
if classifier is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
archetypes = [
|
|
|
{
|
|
|
"id": ARCHETYPE_LABELS[i],
|
|
|
"name": ARCHETYPE_NAMES[i],
|
|
|
"description": ARCHETYPE_DESCRIPTIONS[i]
|
|
|
}
|
|
|
for i in range(5)
|
|
|
]
|
|
|
|
|
|
return InfoResponse(
|
|
|
model_name="bert-base-multilingual-cased (fine-tuned)",
|
|
|
num_classes=5,
|
|
|
max_length=MAX_LENGTH,
|
|
|
device=DEVICE,
|
|
|
mc_dropout_samples=MC_SAMPLES,
|
|
|
archetypes=archetypes
|
|
|
)
|
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=PredictionResponse)
|
|
|
async def predict(input_data: ReflectionInput):
|
|
|
"""
|
|
|
Clasificar una reflexi贸n individual
|
|
|
|
|
|
Args:
|
|
|
input_data: Reflexi贸n y configuraci贸n
|
|
|
|
|
|
Returns:
|
|
|
Predicci贸n con arquetipo, confianza y m茅tricas
|
|
|
"""
|
|
|
if classifier is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
try:
|
|
|
result = classifier.predict(
|
|
|
text=input_data.text,
|
|
|
use_mc_dropout=input_data.use_mc_dropout
|
|
|
)
|
|
|
|
|
|
archetype_result = ArchetypeResult(
|
|
|
id=ARCHETYPE_LABELS[result["predicted_class"]],
|
|
|
name=ARCHETYPE_NAMES[result["predicted_class"]],
|
|
|
description=ARCHETYPE_DESCRIPTIONS[result["predicted_class"]]
|
|
|
)
|
|
|
|
|
|
return PredictionResponse(
|
|
|
archetype=archetype_result,
|
|
|
confidence=result["confidence"],
|
|
|
uncertainty=result["uncertainty"],
|
|
|
top3_predictions=result["top3"],
|
|
|
inference_time_ms=result["inference_time_ms"],
|
|
|
method="bert-mc-dropout" if input_data.use_mc_dropout else "bert"
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error en predicci贸n: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
|
|
|
|
|
|
|
|
|
@app.post("/predict-batch", response_model=BatchPredictionResponse)
|
|
|
async def predict_batch(input_data: BatchReflectionInput):
|
|
|
"""
|
|
|
Clasificar m煤ltiples reflexiones en batch
|
|
|
|
|
|
Args:
|
|
|
input_data: Lista de reflexiones
|
|
|
|
|
|
Returns:
|
|
|
Lista de predicciones
|
|
|
"""
|
|
|
if classifier is None:
|
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
|
|
if len(input_data.texts) == 0:
|
|
|
raise HTTPException(status_code=400, detail="Empty texts list")
|
|
|
|
|
|
start_time = time.time()
|
|
|
predictions = []
|
|
|
|
|
|
try:
|
|
|
for text in input_data.texts:
|
|
|
if len(text) < 100:
|
|
|
continue
|
|
|
|
|
|
result = classifier.predict(
|
|
|
text=text,
|
|
|
use_mc_dropout=input_data.use_mc_dropout
|
|
|
)
|
|
|
|
|
|
archetype_result = ArchetypeResult(
|
|
|
id=ARCHETYPE_LABELS[result["predicted_class"]],
|
|
|
name=ARCHETYPE_NAMES[result["predicted_class"]],
|
|
|
description=ARCHETYPE_DESCRIPTIONS[result["predicted_class"]]
|
|
|
)
|
|
|
|
|
|
predictions.append(
|
|
|
PredictionResponse(
|
|
|
archetype=archetype_result,
|
|
|
confidence=result["confidence"],
|
|
|
uncertainty=result["uncertainty"],
|
|
|
top3_predictions=result["top3"],
|
|
|
inference_time_ms=result["inference_time_ms"],
|
|
|
method="bert-mc-dropout" if input_data.use_mc_dropout else "bert"
|
|
|
)
|
|
|
)
|
|
|
|
|
|
total_time = (time.time() - start_time) * 1000
|
|
|
|
|
|
return BatchPredictionResponse(
|
|
|
predictions=predictions,
|
|
|
total_inference_time_ms=total_time
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error en batch prediction: {str(e)}")
|
|
|
raise HTTPException(status_code=500, detail=f"Batch prediction error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
import uvicorn
|
|
|
|
|
|
uvicorn.run(
|
|
|
"app:app",
|
|
|
host="0.0.0.0",
|
|
|
port=7860,
|
|
|
reload=True
|
|
|
)
|
|
|
|