File size: 3,019 Bytes
8c3fc6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""

Handler para el Inference Endpoint del clasificador de emails

"""

import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from huggingface_hub import hf_hub_download

class EndpointHandler:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.encoder = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.load_model()
    
    def load_model(self):
        """Cargar el modelo"""
        try:
            # Cargar modelo y tokenizer
            self.model = AutoModelForSequenceClassification.from_pretrained("vertigoq3/email-classifier-bert")
            self.tokenizer = AutoTokenizer.from_pretrained("vertigoq3/email-classifier-bert")
            
            # Mover al dispositivo
            self.model.to(self.device)
            self.model.eval()
            
            # Cargar encoder
            encoder_path = hf_hub_download(
                repo_id="vertigoq3/email-classifier-bert",
                filename="label_encoder.pkl"
            )
            
            with open(encoder_path, "rb") as f:
                self.encoder = pickle.load(f)
                
        except Exception as e:
            print(f"Error al cargar modelo: {e}")
            raise
    
    def __call__(self, inputs):
        """Procesar una solicitud de inferencia"""
        try:
            if isinstance(inputs, str):
                text = inputs
            elif isinstance(inputs, dict) and "inputs" in inputs:
                text = inputs["inputs"]
            else:
                text = str(inputs)
            
            # Tokenizar
            tokenized = self.tokenizer(
                text, 
                return_tensors="pt", 
                truncation=True, 
                padding=True, 
                max_length=512
            )
            tokenized = {k: v.to(self.device) for k, v in tokenized.items()}
            
            # Clasificar
            with torch.no_grad():
                outputs = self.model(**tokenized)
                logits = outputs.logits
                probabilities = torch.softmax(logits, dim=-1)
                predicted_class_id = torch.argmax(probabilities, dim=-1).item()
                predicted_class = self.encoder.inverse_transform([predicted_class_id])[0]
                confidence = float(probabilities[0][predicted_class_id])
            
            return {
                "predicted_class": predicted_class,
                "confidence": confidence,
                "all_probabilities": {
                    self.encoder.classes_[i]: float(probabilities[0][i]) 
                    for i in range(len(self.encoder.classes_))
                }
            }
            
        except Exception as e:
            return {"error": str(e)}

# Crear instancia global
handler = EndpointHandler()