Spaces:
Runtime error
Runtime error
| """ | |
| Allows to predict the summary for a given entry text | |
| """ | |
| import torch | |
| from nltk import word_tokenize | |
| import dataloader | |
| from model import Decoder, Encoder, EncoderDecoderModel | |
| # On doit loader les données pour avoir le Vectoriser > sauvegarder "words" dans un fichiers et le loader par la suite ?? | |
| ### À CHANGER POUR N'AVOIR À LOADER QUE LE VECTORISER | |
| data1 = dataloader.Data("data/train_extract.jsonl") | |
| data2 = dataloader.Data("data/dev_extract.jsonl") | |
| train_dataset = data1.make_dataset() | |
| dev_dataset = data2.make_dataset() | |
| words = data1.get_words() | |
| vectoriser = dataloader.Vectoriser(words) | |
| word_counts = vectoriser.word_count | |
| def inferenceAPI(text: str) -> str: | |
| """ | |
| Predict the summary for an input text | |
| -------- | |
| Parameter | |
| text: str | |
| the text to sumarize | |
| Return | |
| str | |
| The summary for the input text | |
| """ | |
| text = word_tokenize(text) | |
| # On défini les paramètres d'entrée pour le modèle | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to( | |
| device | |
| ) | |
| decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to( | |
| device | |
| ) | |
| # On instancie le modèle | |
| model = EncoderDecoderModel(encoder, decoder, device) | |
| model.load_state_dict(torch.load("model/model.pt", map_location=device)) | |
| model.eval() | |
| model.to(device) | |
| # On vectorise le texte | |
| source = vectoriser.encode(text) | |
| source = source.to(device) | |
| # On fait passer le texte dans le modèle | |
| with torch.no_grad(): | |
| output = model(source).to(device) | |
| output.to(device) | |
| return vectoriser.decode(output) | |
| # if __name__ == "__main__": | |
| # # inference() | |
| # print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys.")) | |