|
|
| |
| |
|
|
| import json |
| import torch |
| import torch.nn as nn |
| from transformers.utils import is_torch_available |
|
|
| |
| def simple_tokenizer(text): |
| return text.lower().split() |
|
|
| class SimpleClassifier(nn.Module): |
| def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) |
| self.lstm = nn.LSTM(embedding_dim, hidden_dim) |
| self.fc = nn.Linear(hidden_dim, output_dim) |
|
|
| def forward(self, text): |
| embedded = self.embedding(text) |
| _, (hidden, _) = self.lstm(embedded.view(len(text), 1, -1)) |
| output = self.fc(hidden.squeeze(0)) |
| return output |
|
|
| class InferenceHandler: |
| def __init__(self): |
| self.initialized = False |
| self.word_to_idx = None |
| self.model = None |
|
|
| def initialize(self, context): |
| |
| |
| vocab_path = "vocab.json" |
| with open(vocab_path, "r") as f: |
| self.word_to_idx = json.load(f) |
|
|
| |
| config_path = "config.json" |
| with open(config_path, "r") as f: |
| config = json.load(f) |
|
|
| |
| self.model = SimpleClassifier( |
| vocab_size=config['vocab_size'], |
| embedding_dim=config['embedding_dim'], |
| hidden_dim=config['hidden_dim'], |
| output_dim=config['output_dim'] |
| ) |
|
|
| |
| model_path = "pytorch_model.bin" |
| self.model.load_state_dict(torch.load(model_path)) |
|
|
| |
| self.model.eval() |
| self.initialized = True |
|
|
| def preprocess(self, inputs): |
| |
| |
| text = inputs.get("inputs", "") |
| if not text: |
| raise ValueError("Aucun texte fourni pour l'inférence.") |
| |
| |
| tokens = simple_tokenizer(text) |
| token_indices = [self.word_to_idx.get(token, 0) for token in tokens] |
| |
| |
| input_tensor = torch.tensor(token_indices, dtype=torch.long) |
| |
| return input_tensor.view(-1, 1) |
|
|
| def inference(self, input_tensor): |
| |
| with torch.no_grad(): |
| outputs = self.model(input_tensor) |
| return outputs |
|
|
| def postprocess(self, outputs): |
| |
| prediction = torch.argmax(outputs, dim=1).item() |
| |
| labels = {0: "Animaux", 1: "Capitales"} |
| predicted_label = labels.get(prediction, "Inconnu") |
| |
| return [{"label": predicted_label, "score": outputs.softmax(dim=1)[0][prediction].item()}] |
|
|