File size: 14,078 Bytes
f64d280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826299c
 
f64d280
 
 
 
 
 
 
 
826299c
 
f64d280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""

馃殌 PERI BERT Classifier - FastAPI Backend para HuggingFace Space



API REST para clasificaci贸n de reflexiones 茅ticas sobre IA usando BERT fine-tuneado.

Soporta predicci贸n con MC Dropout para uncertainty quantification.



Endpoints:

- POST /predict - Clasificar una reflexi贸n

- POST /predict-batch - Clasificar m煤ltiples reflexiones

- GET /health - Health check

- GET /info - Informaci贸n del modelo

"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from pathlib import Path
import time
import logging

# Configurar logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ============================================================================
# CONFIGURACI脫N
# ============================================================================

# Mapeo de arquetipos
ARCHETYPE_LABELS = {
    0: "TECNOCRATA_OPTIMIZADOR",
    1: "HUMANISTA_CRITICO",
    2: "PRAGMATICO_EQUILIBRADO",
    3: "VISIONARIO_ADAPTATIVO",
    4: "ESCEPTICO_CONSERVADOR",
}

ARCHETYPE_NAMES = {
    0: "Tecn贸crata Optimizador",
    1: "Humanista Cr铆tico",
    2: "Pragm谩tico Equilibrado",
    3: "Visionario Adaptativo",
    4: "Esc茅ptico Conservador",
}

ARCHETYPE_DESCRIPTIONS = {
    0: "Conf铆a en la eficiencia y objetividad de los sistemas automatizados",
    1: "Prioriza el bienestar humano y cuestiona activamente los sesgos tecnol贸gicos",
    2: "Busca balance entre innovaci贸n tecnol贸gica y consideraciones humanas",
    3: "Abraza la transformaci贸n tecnol贸gica con enfoque adaptativo y progresista",
    4: "Mantiene una postura cautelosa y cr铆tica hacia la adopci贸n de IA",
}

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LENGTH = 512
MC_SAMPLES = 10  # N煤mero de muestras para MC Dropout

# ============================================================================
# MODELOS PYDANTIC
# ============================================================================

class ReflectionInput(BaseModel):
    """Input para clasificaci贸n individual"""
    text: str = Field(..., min_length=100, max_length=5000, description="Reflexi贸n 茅tica sobre IA")
    use_mc_dropout: bool = Field(default=True, description="Usar MC Dropout para uncertainty")


class BatchReflectionInput(BaseModel):
    """Input para clasificaci贸n en batch"""
    texts: List[str] = Field(..., max_items=50, description="Lista de reflexiones (m谩x 50)")
    use_mc_dropout: bool = Field(default=True, description="Usar MC Dropout para uncertainty")


class ArchetypeResult(BaseModel):
    """Resultado de clasificaci贸n"""
    id: str
    name: str
    description: str


class PredictionResponse(BaseModel):
    """Respuesta de predicci贸n individual"""
    archetype: ArchetypeResult
    confidence: float = Field(..., ge=0.0, le=1.0, description="Confianza de la predicci贸n")
    uncertainty: Optional[float] = Field(None, ge=0.0, description="Incertidumbre (MC Dropout)")
    top3_predictions: List[Dict[str, Any]] = Field(..., description="Top 3 predicciones")
    inference_time_ms: float = Field(..., description="Tiempo de inferencia en milisegundos")
    method: str = Field(default="bert", description="M茅todo de clasificaci贸n")


class BatchPredictionResponse(BaseModel):
    """Respuesta de predicci贸n en batch"""
    predictions: List[PredictionResponse]
    total_inference_time_ms: float


class HealthResponse(BaseModel):
    """Health check response"""
    model_config = {"protected_namespaces": ()}

    status: str
    model_loaded: bool
    device: str
    timestamp: float


class InfoResponse(BaseModel):
    """Informaci贸n del modelo"""
    model_config = {"protected_namespaces": ()}

    model_name: str
    num_classes: int
    max_length: int
    device: str
    mc_dropout_samples: int
    archetypes: List[Dict[str, str]]


# ============================================================================
# CARGA DEL MODELO
# ============================================================================

class BERTClassifier:
    """Wrapper para el modelo BERT con MC Dropout"""

    def __init__(self, model_path: str):
        logger.info(f"Cargando modelo desde {model_path}...")
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertForSequenceClassification.from_pretrained(model_path)
        self.model.to(DEVICE)
        self.model.eval()
        logger.info(f"Modelo cargado exitosamente en {DEVICE}")

    def predict(

        self,

        text: str,

        use_mc_dropout: bool = True

    ) -> Dict[str, Any]:
        """

        Realizar predicci贸n con o sin MC Dropout



        Returns:

            dict con keys: predicted_class, confidence, uncertainty, all_probabilities

        """
        start_time = time.time()

        # Tokenizar
        encoding = self.tokenizer(
            text,
            max_length=MAX_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = encoding["input_ids"].to(DEVICE)
        attention_mask = encoding["attention_mask"].to(DEVICE)

        if use_mc_dropout:
            # MC Dropout: m煤ltiples predicciones con dropout activado
            self.model.train()  # Activar dropout
            all_probs = []

            with torch.no_grad():
                for _ in range(MC_SAMPLES):
                    outputs = self.model(
                        input_ids=input_ids,
                        attention_mask=attention_mask
                    )
                    logits = outputs.logits
                    probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
                    all_probs.append(probs)

            # Calcular estad铆sticas
            all_probs = np.array(all_probs)  # (MC_SAMPLES, num_classes)
            mean_probs = np.mean(all_probs, axis=0)
            predicted_class = int(np.argmax(mean_probs))
            confidence = float(mean_probs[predicted_class])

            # Calcular incertidumbre (entrop铆a)
            epsilon = 1e-10
            uncertainty = float(-np.sum(mean_probs * np.log(mean_probs + epsilon)))

            self.model.eval()  # Volver a modo evaluaci贸n

        else:
            # Predicci贸n est谩ndar sin MC Dropout
            with torch.no_grad():
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )

            logits = outputs.logits
            probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
            mean_probs = probs
            predicted_class = int(np.argmax(probs))
            confidence = float(probs[predicted_class])
            uncertainty = None
            all_probs = probs.reshape(1, -1)

        # Top 3 predicciones
        top3_indices = np.argsort(mean_probs)[-3:][::-1]
        top3 = [
            {
                "archetype_id": ARCHETYPE_LABELS[int(idx)],
                "archetype_name": ARCHETYPE_NAMES[int(idx)],
                "probability": float(mean_probs[idx])
            }
            for idx in top3_indices
        ]

        inference_time = (time.time() - start_time) * 1000  # ms

        return {
            "predicted_class": predicted_class,
            "confidence": confidence,
            "uncertainty": uncertainty,
            "top3": top3,
            "inference_time_ms": inference_time,
            "all_probabilities": mean_probs.tolist()
        }


# Inicializar modelo global
classifier: Optional[BERTClassifier] = None


def load_model():
    """Cargar modelo al iniciar la aplicaci贸n"""
    global classifier

    # En HuggingFace Space, el modelo estar谩 en ./model/
    # Localmente, usar path relativo
    model_paths = [
        Path("./model"),  # HF Space
        Path("../../../models/peri-bert/best_model"),  # Local
    ]

    model_path = None
    for path in model_paths:
        if path.exists():
            model_path = str(path)
            break

    if model_path is None:
        logger.error("No se encontr贸 el modelo. Aseg煤rate de subirlo a HuggingFace Space.")
        raise RuntimeError("Model not found")

    classifier = BERTClassifier(model_path)


# ============================================================================
# FASTAPI APP
# ============================================================================

app = FastAPI(
    title="PERI BERT Classifier API",
    description="API REST para clasificaci贸n de arquetipos 茅ticos en reflexiones sobre IA",
    version="1.0.0",
    docs_url="/",  # Swagger UI en la ra铆z
)

# CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # En producci贸n, especificar dominios permitidos
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event("startup")
async def startup_event():
    """Cargar modelo al iniciar"""
    load_model()


@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Health check endpoint"""
    return HealthResponse(
        status="healthy",
        model_loaded=classifier is not None,
        device=DEVICE,
        timestamp=time.time()
    )


@app.get("/info", response_model=InfoResponse)
async def model_info():
    """Informaci贸n del modelo"""
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    archetypes = [
        {
            "id": ARCHETYPE_LABELS[i],
            "name": ARCHETYPE_NAMES[i],
            "description": ARCHETYPE_DESCRIPTIONS[i]
        }
        for i in range(5)
    ]

    return InfoResponse(
        model_name="bert-base-multilingual-cased (fine-tuned)",
        num_classes=5,
        max_length=MAX_LENGTH,
        device=DEVICE,
        mc_dropout_samples=MC_SAMPLES,
        archetypes=archetypes
    )


@app.post("/predict", response_model=PredictionResponse)
async def predict(input_data: ReflectionInput):
    """

    Clasificar una reflexi贸n individual



    Args:

        input_data: Reflexi贸n y configuraci贸n



    Returns:

        Predicci贸n con arquetipo, confianza y m茅tricas

    """
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    try:
        result = classifier.predict(
            text=input_data.text,
            use_mc_dropout=input_data.use_mc_dropout
        )

        archetype_result = ArchetypeResult(
            id=ARCHETYPE_LABELS[result["predicted_class"]],
            name=ARCHETYPE_NAMES[result["predicted_class"]],
            description=ARCHETYPE_DESCRIPTIONS[result["predicted_class"]]
        )

        return PredictionResponse(
            archetype=archetype_result,
            confidence=result["confidence"],
            uncertainty=result["uncertainty"],
            top3_predictions=result["top3"],
            inference_time_ms=result["inference_time_ms"],
            method="bert-mc-dropout" if input_data.use_mc_dropout else "bert"
        )

    except Exception as e:
        logger.error(f"Error en predicci贸n: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Prediction error: {str(e)}")


@app.post("/predict-batch", response_model=BatchPredictionResponse)
async def predict_batch(input_data: BatchReflectionInput):
    """

    Clasificar m煤ltiples reflexiones en batch



    Args:

        input_data: Lista de reflexiones



    Returns:

        Lista de predicciones

    """
    if classifier is None:
        raise HTTPException(status_code=503, detail="Model not loaded")

    if len(input_data.texts) == 0:
        raise HTTPException(status_code=400, detail="Empty texts list")

    start_time = time.time()
    predictions = []

    try:
        for text in input_data.texts:
            if len(text) < 100:
                continue  # Skip textos muy cortos

            result = classifier.predict(
                text=text,
                use_mc_dropout=input_data.use_mc_dropout
            )

            archetype_result = ArchetypeResult(
                id=ARCHETYPE_LABELS[result["predicted_class"]],
                name=ARCHETYPE_NAMES[result["predicted_class"]],
                description=ARCHETYPE_DESCRIPTIONS[result["predicted_class"]]
            )

            predictions.append(
                PredictionResponse(
                    archetype=archetype_result,
                    confidence=result["confidence"],
                    uncertainty=result["uncertainty"],
                    top3_predictions=result["top3"],
                    inference_time_ms=result["inference_time_ms"],
                    method="bert-mc-dropout" if input_data.use_mc_dropout else "bert"
                )
            )

        total_time = (time.time() - start_time) * 1000

        return BatchPredictionResponse(
            predictions=predictions,
            total_inference_time_ms=total_time
        )

    except Exception as e:
        logger.error(f"Error en batch prediction: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Batch prediction error: {str(e)}")


# ============================================================================
# MAIN (para testing local)
# ============================================================================

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(
        "app:app",
        host="0.0.0.0",
        port=7860,  # Puerto est谩ndar de HuggingFace Spaces
        reload=True
    )