PERI / app.py
DavidHospinal
Fix Pydantic warnings: add protected_namespaces config
826299c
"""
馃殌 PERI BERT Classifier - FastAPI Backend para HuggingFace Space
API REST para clasificaci贸n de reflexiones 茅ticas sobre IA usando BERT fine-tuneado.
Soporta predicci贸n con MC Dropout para uncertainty quantification.
Endpoints:
- POST /predict - Clasificar una reflexi贸n
- POST /predict-batch - Clasificar m煤ltiples reflexiones
- GET /health - Health check
- GET /info - Informaci贸n del modelo
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from pathlib import Path
import time
import logging
# Configurar logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# CONFIGURACI脫N
# ============================================================================
# Mapeo de arquetipos
ARCHETYPE_LABELS = {
0: "TECNOCRATA_OPTIMIZADOR",
1: "HUMANISTA_CRITICO",
2: "PRAGMATICO_EQUILIBRADO",
3: "VISIONARIO_ADAPTATIVO",
4: "ESCEPTICO_CONSERVADOR",
}
ARCHETYPE_NAMES = {
0: "Tecn贸crata Optimizador",
1: "Humanista Cr铆tico",
2: "Pragm谩tico Equilibrado",
3: "Visionario Adaptativo",
4: "Esc茅ptico Conservador",
}
ARCHETYPE_DESCRIPTIONS = {
0: "Conf铆a en la eficiencia y objetividad de los sistemas automatizados",
1: "Prioriza el bienestar humano y cuestiona activamente los sesgos tecnol贸gicos",
2: "Busca balance entre innovaci贸n tecnol贸gica y consideraciones humanas",
3: "Abraza la transformaci贸n tecnol贸gica con enfoque adaptativo y progresista",
4: "Mantiene una postura cautelosa y cr铆tica hacia la adopci贸n de IA",
}
# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512
MC_SAMPLES = 10 # N煤mero de muestras para MC Dropout
# ============================================================================
# MODELOS PYDANTIC
# ============================================================================
class ReflectionInput(BaseModel):
"""Input para clasificaci贸n individual"""
text: str = Field(..., min_length=100, max_length=5000, description="Reflexi贸n 茅tica sobre IA")
use_mc_dropout: bool = Field(default=True, description="Usar MC Dropout para uncertainty")
class BatchReflectionInput(BaseModel):
"""Input para clasificaci贸n en batch"""
texts: List[str] = Field(..., max_items=50, description="Lista de reflexiones (m谩x 50)")
use_mc_dropout: bool = Field(default=True, description="Usar MC Dropout para uncertainty")
class ArchetypeResult(BaseModel):
"""Resultado de clasificaci贸n"""
id: str
name: str
description: str
class PredictionResponse(BaseModel):
"""Respuesta de predicci贸n individual"""
archetype: ArchetypeResult
confidence: float = Field(..., ge=0.0, le=1.0, description="Confianza de la predicci贸n")
uncertainty: Optional[float] = Field(None, ge=0.0, description="Incertidumbre (MC Dropout)")
top3_predictions: List[Dict[str, Any]] = Field(..., description="Top 3 predicciones")
inference_time_ms: float = Field(..., description="Tiempo de inferencia en milisegundos")
method: str = Field(default="bert", description="M茅todo de clasificaci贸n")
class BatchPredictionResponse(BaseModel):
"""Respuesta de predicci贸n en batch"""
predictions: List[PredictionResponse]
total_inference_time_ms: float
class HealthResponse(BaseModel):
"""Health check response"""
model_config = {"protected_namespaces": ()}
status: str
model_loaded: bool
device: str
timestamp: float
class InfoResponse(BaseModel):
"""Informaci贸n del modelo"""
model_config = {"protected_namespaces": ()}
model_name: str
num_classes: int
max_length: int
device: str
mc_dropout_samples: int
archetypes: List[Dict[str, str]]
# ============================================================================
# CARGA DEL MODELO
# ============================================================================
class BERTClassifier:
"""Wrapper para el modelo BERT con MC Dropout"""
def __init__(self, model_path: str):
logger.info(f"Cargando modelo desde {model_path}...")
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(model_path)
self.model.to(DEVICE)
self.model.eval()
logger.info(f"Modelo cargado exitosamente en {DEVICE}")
def predict(
self,
text: str,
use_mc_dropout: bool = True
) -> Dict[str, Any]:
"""
Realizar predicci贸n con o sin MC Dropout
Returns:
dict con keys: predicted_class, confidence, uncertainty, all_probabilities
"""
start_time = time.time()
# Tokenizar
encoding = self.tokenizer(
text,
max_length=MAX_LENGTH,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = encoding["input_ids"].to(DEVICE)
attention_mask = encoding["attention_mask"].to(DEVICE)
if use_mc_dropout:
# MC Dropout: m煤ltiples predicciones con dropout activado
self.model.train() # Activar dropout
all_probs = []
with torch.no_grad():
for _ in range(MC_SAMPLES):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs.logits
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
all_probs.append(probs)
# Calcular estad铆sticas
all_probs = np.array(all_probs) # (MC_SAMPLES, num_classes)
mean_probs = np.mean(all_probs, axis=0)
predicted_class = int(np.argmax(mean_probs))
confidence = float(mean_probs[predicted_class])
# Calcular incertidumbre (entrop铆a)
epsilon = 1e-10
uncertainty = float(-np.sum(mean_probs * np.log(mean_probs + epsilon)))
self.model.eval() # Volver a modo evaluaci贸n
else:
# Predicci贸n est谩ndar sin MC Dropout
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs.logits
probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
mean_probs = probs
predicted_class = int(np.argmax(probs))
confidence = float(probs[predicted_class])
uncertainty = None
all_probs = probs.reshape(1, -1)
# Top 3 predicciones
top3_indices = np.argsort(mean_probs)[-3:][::-1]
top3 = [
{
"archetype_id": ARCHETYPE_LABELS[int(idx)],
"archetype_name": ARCHETYPE_NAMES[int(idx)],
"probability": float(mean_probs[idx])
}
for idx in top3_indices
]
inference_time = (time.time() - start_time) * 1000 # ms
return {
"predicted_class": predicted_class,
"confidence": confidence,
"uncertainty": uncertainty,
"top3": top3,
"inference_time_ms": inference_time,
"all_probabilities": mean_probs.tolist()
}
# Inicializar modelo global
classifier: Optional[BERTClassifier] = None
def load_model():
"""Cargar modelo al iniciar la aplicaci贸n"""
global classifier
# En HuggingFace Space, el modelo estar谩 en ./model/
# Localmente, usar path relativo
model_paths = [
Path("./model"), # HF Space
Path("../../../models/peri-bert/best_model"), # Local
]
model_path = None
for path in model_paths:
if path.exists():
model_path = str(path)
break
if model_path is None:
logger.error("No se encontr贸 el modelo. Aseg煤rate de subirlo a HuggingFace Space.")
raise RuntimeError("Model not found")
classifier = BERTClassifier(model_path)
# ============================================================================
# FASTAPI APP
# ============================================================================
app = FastAPI(
title="PERI BERT Classifier API",
description="API REST para clasificaci贸n de arquetipos 茅ticos en reflexiones sobre IA",
version="1.0.0",
docs_url="/", # Swagger UI en la ra铆z
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # En producci贸n, especificar dominios permitidos
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Cargar modelo al iniciar"""
load_model()
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
return HealthResponse(
status="healthy",
model_loaded=classifier is not None,
device=DEVICE,
timestamp=time.time()
)
@app.get("/info", response_model=InfoResponse)
async def model_info():
"""Informaci贸n del modelo"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
archetypes = [
{
"id": ARCHETYPE_LABELS[i],
"name": ARCHETYPE_NAMES[i],
"description": ARCHETYPE_DESCRIPTIONS[i]
}
for i in range(5)
]
return InfoResponse(
model_name="bert-base-multilingual-cased (fine-tuned)",
num_classes=5,
max_length=MAX_LENGTH,
device=DEVICE,
mc_dropout_samples=MC_SAMPLES,
archetypes=archetypes
)
@app.post("/predict", response_model=PredictionResponse)
async def predict(input_data: ReflectionInput):
"""
Clasificar una reflexi贸n individual
Args:
input_data: Reflexi贸n y configuraci贸n
Returns:
Predicci贸n con arquetipo, confianza y m茅tricas
"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
result = classifier.predict(
text=input_data.text,
use_mc_dropout=input_data.use_mc_dropout
)
archetype_result = ArchetypeResult(
id=ARCHETYPE_LABELS[result["predicted_class"]],
name=ARCHETYPE_NAMES[result["predicted_class"]],
description=ARCHETYPE_DESCRIPTIONS[result["predicted_class"]]
)
return PredictionResponse(
archetype=archetype_result,
confidence=result["confidence"],
uncertainty=result["uncertainty"],
top3_predictions=result["top3"],
inference_time_ms=result["inference_time_ms"],
method="bert-mc-dropout" if input_data.use_mc_dropout else "bert"
)
except Exception as e:
logger.error(f"Error en predicci贸n: {str(e)}")
raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")
@app.post("/predict-batch", response_model=BatchPredictionResponse)
async def predict_batch(input_data: BatchReflectionInput):
"""
Clasificar m煤ltiples reflexiones en batch
Args:
input_data: Lista de reflexiones
Returns:
Lista de predicciones
"""
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded")
if len(input_data.texts) == 0:
raise HTTPException(status_code=400, detail="Empty texts list")
start_time = time.time()
predictions = []
try:
for text in input_data.texts:
if len(text) < 100:
continue # Skip textos muy cortos
result = classifier.predict(
text=text,
use_mc_dropout=input_data.use_mc_dropout
)
archetype_result = ArchetypeResult(
id=ARCHETYPE_LABELS[result["predicted_class"]],
name=ARCHETYPE_NAMES[result["predicted_class"]],
description=ARCHETYPE_DESCRIPTIONS[result["predicted_class"]]
)
predictions.append(
PredictionResponse(
archetype=archetype_result,
confidence=result["confidence"],
uncertainty=result["uncertainty"],
top3_predictions=result["top3"],
inference_time_ms=result["inference_time_ms"],
method="bert-mc-dropout" if input_data.use_mc_dropout else "bert"
)
)
total_time = (time.time() - start_time) * 1000
return BatchPredictionResponse(
predictions=predictions,
total_inference_time_ms=total_time
)
except Exception as e:
logger.error(f"Error en batch prediction: {str(e)}")
raise HTTPException(status_code=500, detail=f"Batch prediction error: {str(e)}")
# ============================================================================
# MAIN (para testing local)
# ============================================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860, # Puerto est谩ndar de HuggingFace Spaces
reload=True
)