Spaces:
Sleeping
Sleeping
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| from transformers import AdamW | |
| import pandas as pd | |
| import torch | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import ModelCheckpoint | |
| from torch.nn.utils.rnn import pad_sequence | |
| # from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler | |
| pl.seed_everything(100) | |
| MODEL_NAME='t5-base' | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| INPUT_MAX_LEN = 128 | |
| OUTPUT_MAX_LEN = 128 | |
| tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512) | |
| class T5Model(pl.LightningModule): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True) | |
| def forward(self, input_ids, attention_mask, labels=None): | |
| output = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels | |
| ) | |
| return output.loss, output.logits | |
| def training_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels= batch["target"] | |
| loss, logits = self(input_ids , attention_mask, labels) | |
| self.log("train_loss", loss, prog_bar=True, logger=True) | |
| return {'loss': loss} | |
| def validation_step(self, batch, batch_idx): | |
| input_ids = batch["input_ids"] | |
| attention_mask = batch["attention_mask"] | |
| labels= batch["target"] | |
| loss, logits = self(input_ids, attention_mask, labels) | |
| self.log("val_loss", loss, prog_bar=True, logger=True) | |
| return {'val_loss': loss} | |
| def configure_optimizers(self): | |
| return AdamW(self.parameters(), lr=0.0001) | |
| train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE) | |
| train_model.freeze() | |
| def generate_response(question): | |
| inputs_encoding = tokenizer( | |
| question, | |
| add_special_tokens=True, | |
| max_length= INPUT_MAX_LEN, | |
| padding = 'max_length', | |
| truncation='only_first', | |
| return_attention_mask=True, | |
| return_tensors="pt" | |
| ) | |
| generate_ids = train_model.model.generate( | |
| input_ids = inputs_encoding["input_ids"], | |
| attention_mask = inputs_encoding["attention_mask"], | |
| max_length = INPUT_MAX_LEN, | |
| num_beams = 4, | |
| num_return_sequences = 1, | |
| no_repeat_ngram_size=2, | |
| early_stopping=True, | |
| ) | |
| preds = [ | |
| tokenizer.decode(gen_id, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True) | |
| for gen_id in generate_ids | |
| ] | |
| return "".join(preds) | |
| import uuid | |
| import datetime | |
| import os | |
| import streamlit as st | |
| from streamlit_chat import message | |
| from pymongo.mongo_client import MongoClient | |
| from pymongo.server_api import ServerApi | |
| password=os.getenv("mongo_pass") | |
| uri = "mongodb+srv://rohank587:"+password+"@rkcluster.e3fpzja.mongodb.net/?retryWrites=true&w=majority" | |
| # Create a new client and connect to the server | |
| client = MongoClient(uri, server_api=ServerApi('1')) | |
| st.title(":red[_Sarcastic_] Chatbot") | |
| if 'generated' not in st.session_state: | |
| st.session_state['generated'] = [] | |
| if 'past' not in st.session_state: | |
| st.session_state['past'] = [] | |
| if 'messages' not in st.session_state: | |
| st.session_state['messages'] = [ | |
| {"role": "system", "content": "You are a helpful assistant."} | |
| ] | |
| # container for chat history | |
| response_container = st.container() | |
| # container for text box | |
| container = st.container() | |
| with container: | |
| with st.form(key='my_form', clear_on_submit=True): | |
| user_input = st.text_input("You:", key='input',placeholder="Disclaimer: Be careful with punctuations like , ? . ! \"") | |
| submit_button = st.form_submit_button(label='Send',use_container_width=True) | |
| col1,col2=st.columns(2) | |
| with col1: | |
| clear_button = st.button("Clear Conversation", key="clear",use_container_width=True) | |
| with col2: | |
| save_button = st.button("Save Conversation", key="save",use_container_width=True) | |
| down_id = st.text_input('Enter ID to download chat',placeholder="Message ID") | |
| if down_id: | |
| info=client['rohank']['table1'] | |
| data=info.find_one({'message_id':down_id}) | |
| down_button = st.download_button('Download chat', "\n".join(data['message']),file_name="sar_chat.txt") | |
| # reset everything | |
| if clear_button: | |
| st.session_state['generated'] = [] | |
| st.session_state['past'] = [] | |
| st.session_state['messages'] = [ | |
| {"role": "system", "content": "You are a helpful assistant."} | |
| ] | |
| if save_button and st.session_state['generated'] and st.session_state['past']: | |
| # Send a ping to confirm a successful connection | |
| try: | |
| client.admin.command('ping') | |
| st.success("Pinged your deployment. You successfully connected to MongoDB! Saved Successfully.") | |
| info=client['rohank']['table1'] | |
| chats=list([]) | |
| for i in range(len(st.session_state['generated'])): | |
| chats.append("You: "+st.session_state['past'][i]) | |
| chats.append("Bot: "+st.session_state['generated'][i]) | |
| id=uuid.uuid4() | |
| time=datetime.datetime.now() | |
| info.insert_one({"time of saving":time.strftime("%c"),"message_id":str(id),"message":chats}) | |
| st.success("Copy this id "+str(id)+" for downloading saved chat anytime anywhere and then paste it down below!") | |
| except Exception as e: | |
| st.error("Can't connect to MongoDB. Save Failed.") | |
| if submit_button and user_input: | |
| output = generate_response(user_input) | |
| st.session_state['past'].append(user_input) | |
| st.session_state['generated'].append(output) | |
| if st.session_state['generated']: | |
| with response_container: | |
| for i in range(len(st.session_state['generated'])): | |
| message(st.session_state["past"][i], is_user=True, key=str(i) + '_user') | |
| message(st.session_state["generated"][i], key=str(i)) |