| import openai |
| import streamlit as st |
| from streamlit_chat import message |
| from transformers import pipeline |
| summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum") |
| sentiment_task = pipeline("sentiment-analysis", model='cardiffnlp/twitter-roberta-base-sentiment-latest', tokenizer='cardiffnlp/twitter-roberta-base-sentiment-latest') |
| openai.api_key = st.secrets["openai_api_key"] |
|
|
| from math import log |
|
|
| completion = openai.Completion() |
|
|
|
|
| start_prompt = '[Instruction] Act as a friendly, compasionate, insightful, and empathetic AI therapist named Joy. Joy listens and offers advices. End the conversation when the patient wishes to.' |
| start_message = 'I am Joy, your AI therapist. How are you feeling today?' |
| start_sequence = "\nJoy:" |
| restart_sequence = "\n\nPatient:" |
| |
| def ask(question: str, chat_log: str, model='text-davinci-003', temp=0.9) -> (str, str): |
|
|
| prompt = f'{chat_log}{restart_sequence} {question}{start_sequence}' |
|
|
| response = completion.create( |
| prompt = prompt, |
| model = model, |
| stop = ["Patient:",'Joy:'], |
| temperature = temp, |
| frequency_penalty = 0.9, |
| presence_penalty = 1, |
| top_p =1, |
| best_of=1, |
| max_tokens=170 |
| ) |
| |
| answer = response.choices[0].text.strip() |
| log = f'{restart_sequence}{question}{start_sequence}{answer}' |
| return str(answer), str(log) |
|
|
| def clean_chat_log(chat_log): |
| chat_log = ' '.join(chat_log) |
| |
| first_newline = chat_log.find('\n') |
| chat_log = chat_log[first_newline:] |
| |
| chat_log = chat_log.replace('\n', ' ') |
| return chat_log |
|
|
| def summarize(chat_log): |
| chat_log = clean_chat_log(chat_log) |
| summary = summarizer(chat_log, max_length=150, do_sample=False)[0]['summary_text'] |
| return summary |
|
|
| def analyze_sentiment(chat_log): |
| |
|
|
| |
|
|
| |
|
|
|
|
| chat_log = clean_chat_log(chat_log) |
| sentiment = sentiment_task(chat_log) |
| return sentiment |
|
|
| def remove_backslashN(chat_log: list) -> list: |
| chat_log = [i.replace('\n', ' ') for i in chat_log] |
| return chat_log |
|
|
|
|
|
|
| def main(): |
| st.title("Chat with Joy - the AI therapist!") |
| col1, col2 = st.columns(2) |
| temp = col1.slider("Bot-Creativeness", 0.0, 1.0, 0.9, 0.1) |
| model = col2.selectbox("Model", ["text-davinci-003", "text-curie-001", "curie:ft-personal-2023-02-03-17-06-53"]) |
|
|
| if 'generated' not in st.session_state: |
| st.session_state['generated'] = [start_message] |
| |
| if 'past' not in st.session_state: |
| st.session_state['past'] = [] |
|
|
| if 'summary' not in st.session_state: |
| st.session_state['summary'] = [] |
|
|
| if 'chat_log' not in st.session_state: |
| st.session_state['chat_log'] = [start_prompt+start_sequence+start_message] |
| |
|
|
| if len(st.session_state['generated']) > 2: |
| if st.button("Clear and summerize", key='clear'): |
| chat_log = clean_chat_log(st.session_state['chat_log']) |
| summary = summarizer(chat_log, max_length=100, min_length=30, do_sample=False) |
| st.write(summary) |
| user_sentiment = st.session_state['past'] |
| user_sentiment = remove_backslashN(user_sentiment) |
| st.write(sentiment_task(user_sentiment)) |
| st.session_state['generated'] = [start_message] |
| st.session_state['past'] = [] |
| st.session_state['chat_log'] = [start_prompt+start_sequence+start_message] |
| st.session_state['summary'] = [] |
|
|
| user_input=st.text_input("You:",key='input') |
|
|
| if user_input: |
| output, chat_log = ask(user_input, st.session_state['chat_log'], model=model, temp=temp) |
| st.session_state['chat_log'].append(chat_log) |
| st.session_state['past'].append(user_input) |
| st.session_state['generated'].append(output) |
| print(model) |
| print(temp) |
| print(st.session_state['chat_log']) |
| if st.session_state['generated']: |
| for i in range(len(st.session_state['generated'])-1, -1, -1): |
| if i < len(st.session_state['past']): |
| message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') |
| message(st.session_state["generated"][i], key=str(i)) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| main() |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| |