Spaces:
Runtime error
Runtime error
| from flask import Flask, render_template, request, jsonify | |
| import os | |
| os.environ['TRANSFORMERS_CACHE'] = '/code/cache/' | |
| #os.environ['SENTENCE_TRANSFORMERS_HOME'] = './.cache' | |
| #from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| #import numpy as np | |
| from transformers import AdamW | |
| #import pandas as pd | |
| import torch | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from torch.nn.utils.rnn import pad_sequence | |
| MODEL_NAME='t5-base' | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| INPUT_MAX_LEN = 512 | |
| OUTPUT_MAX_LEN = 512 | |
| #tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium") | |
| #model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium") | |
| tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512) | |
| app = Flask(__name__) | |
| app.jinja_env.auto_reload = True | |
| app.config['TEMPLATES_AUTO_RELOAD'] = True | |
| def index(): | |
| return render_template('chat.html') | |
| def chat(): | |
| msg = request.form["msg"] | |
| input = msg | |
| return get_Chat_response(input) | |
| class T5Model(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True) | |
| def forward(self, input_ids, attention_mask, labels=None): | |
| output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels | |
| ) | |
| return output.loss, output.logits | |
| def training_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels= batch["target"] | |
| loss, logits = self(input_ids , attention_mask, labels) | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| return {'loss': loss} | |
| def validation_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels= batch["target"] | |
| loss, logits = self(input_ids, attention_mask, labels) | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| return {'val_loss': loss} | |
| def configure_optimizers(self): | |
| return AdamW(self.parameters(), lr=0.0001) | |
| train_model = T5Model.load_from_checkpoint('best-model-version.ckpt',map_location=DEVICE) | |
| train_model.freeze() | |
| def get_Chat_response(question): | |
| inputs_encoding = tokenizer( | |
| question, | |
| add_special_tokens=True, | |
| max_length= INPUT_MAX_LEN, | |
| padding = 'max_length', | |
| truncation='only_first', | |
| return_attention_mask=True, | |
| return_tensors="pt" | |
| ) | |
| generate_ids = train_model.model.generate( | |
| input_ids = inputs_encoding["input_ids"], | |
| attention_mask = inputs_encoding["attention_mask"], | |
| max_length = INPUT_MAX_LEN, | |
| num_beams = 4, | |
| num_return_sequences = 1, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True, | |
| ) | |
| preds = [ | |
| tokenizer.decode(gen_id, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True) | |
| for gen_id in generate_ids | |
| ] | |
| return "".join(preds) | |
| #def get_Chat_response(text): | |
| # | |
| # # Let's chat for 5 lines | |
| # for step in range(5): | |
| # # encode the new user input, add the eos_token and return a tensor in Pytorch | |
| # new_user_input_ids = tokenizer.encode(str(text) + tokenizer.eos_token, return_tensors='pt') | |
| # | |
| # # append the new user input tokens to the chat history | |
| # bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids | |
| # | |
| # # generated a response while limiting the total chat history to 1000 tokens, | |
| # chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) | |
| # | |
| # # pretty print last ouput tokens from bot | |
| # return tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) | |
| if __name__ == '__main__': | |
| app.run(debug=True) |