Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from transformers import AutoModelWithLMHead, AutoTokenizer
|
|
| 7 |
# model = AutoModelWithLMHead.from_pretrained('model.py')
|
| 8 |
|
| 9 |
# -*- coding: utf-8 -*-
|
| 10 |
-
|
| 11 |
import pandas as pd
|
| 12 |
|
| 13 |
data = {'Question': ['What is the story about?',
|
|
@@ -40,7 +40,6 @@ df = pd.DataFrame(data)
|
|
| 40 |
|
| 41 |
# ! pip -q install transformers
|
| 42 |
|
| 43 |
-
from transformers import AutoModelWithLMHead, AutoTokenizer
|
| 44 |
import torch
|
| 45 |
import os
|
| 46 |
|
|
@@ -635,58 +634,58 @@ print(len(test_chatbot))
|
|
| 635 |
####################################
|
| 636 |
############Streamlit###############
|
| 637 |
|
| 638 |
-
st.set_page_config(
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
)
|
| 642 |
-
|
| 643 |
-
API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 644 |
-
#headers = {"Authorization": st.secrets['api_key']}
|
| 645 |
-
|
| 646 |
-
st.header("Hello - Welcome to COVID Doctor using DialoGPT")
|
| 647 |
-
st.markdown("[Github](https://github.com/rushic24/DialoGPT-Finetune)")
|
| 648 |
-
|
| 649 |
-
if 'generated' not in st.session_state:
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
if 'past' not in st.session_state:
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
def query(payload):
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
def get_text():
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
user_input = get_text()
|
| 676 |
-
|
| 677 |
-
if user_input:
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
if st.session_state['generated']:
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
|
|
|
| 7 |
# model = AutoModelWithLMHead.from_pretrained('model.py')
|
| 8 |
|
| 9 |
# -*- coding: utf-8 -*-
|
| 10 |
+
st.write("yoyoyo")
|
| 11 |
import pandas as pd
|
| 12 |
|
| 13 |
data = {'Question': ['What is the story about?',
|
|
|
|
| 40 |
|
| 41 |
# ! pip -q install transformers
|
| 42 |
|
|
|
|
| 43 |
import torch
|
| 44 |
import os
|
| 45 |
|
|
|
|
| 634 |
####################################
|
| 635 |
############Streamlit###############
|
| 636 |
|
| 637 |
+
# st.set_page_config(
|
| 638 |
+
# page_title="COVID Doctor using DialoGPT",
|
| 639 |
+
# page_icon=":robot:"
|
| 640 |
+
# )
|
| 641 |
+
|
| 642 |
+
# API_URL = "https://api-inference.huggingface.co/models/microsoft/DialoGPT-small"
|
| 643 |
+
# #headers = {"Authorization": st.secrets['api_key']}
|
| 644 |
+
|
| 645 |
+
# st.header("Hello - Welcome to COVID Doctor using DialoGPT")
|
| 646 |
+
# st.markdown("[Github](https://github.com/rushic24/DialoGPT-Finetune)")
|
| 647 |
+
|
| 648 |
+
# if 'generated' not in st.session_state:
|
| 649 |
+
# st.session_state['generated'] = []
|
| 650 |
+
|
| 651 |
+
# if 'past' not in st.session_state:
|
| 652 |
+
# st.session_state['past'] = []
|
| 653 |
+
|
| 654 |
+
# def query(payload):
|
| 655 |
+
# bot_input_ids = tokenizer.encode(payload["inputs"]["text"] + tokenizer.eos_token, return_tensors='pt')
|
| 656 |
+
|
| 657 |
+
# chat_history_ids = model.generate(
|
| 658 |
+
# bot_input_ids, max_length=100,
|
| 659 |
+
# pad_token_id=tokenizer.eos_token_id,
|
| 660 |
+
# no_repeat_ngram_size=3,
|
| 661 |
+
# do_sample=True,
|
| 662 |
+
# top_k=10,
|
| 663 |
+
# top_p=0.7,
|
| 664 |
+
# temperature = 0.8
|
| 665 |
+
# )
|
| 666 |
+
# output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
| 667 |
+
# return {"generated_text": output}
|
| 668 |
+
|
| 669 |
+
# def get_text():
|
| 670 |
+
# input_text = st.text_input("You: "," ", key="input")
|
| 671 |
+
# return input_text
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
# user_input = get_text()
|
| 675 |
+
|
| 676 |
+
# if user_input:
|
| 677 |
+
# output = query({
|
| 678 |
+
# "inputs": {
|
| 679 |
+
# "past_user_inputs": st.session_state.past,
|
| 680 |
+
# "generated_responses": st.session_state.generated,
|
| 681 |
+
# "text": user_input,
|
| 682 |
+
# },"parameters": {"repetition_penalty": 1.33},
|
| 683 |
+
# })
|
| 684 |
+
# st.session_state.past.append(user_input)
|
| 685 |
+
# st.session_state.generated.append(output["generated_text"])
|
| 686 |
+
|
| 687 |
+
# if st.session_state['generated']:
|
| 688 |
+
|
| 689 |
+
# for i in range(len(st.session_state['generated'])-1, -1, -1):
|
| 690 |
+
# message(st.session_state["generated"][i], key=str(i))
|
| 691 |
+
# message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
|