MediatorBot / app.py
peterpull's picture
Update app.py
c8fc5ce
raw
history blame
3.5 kB
from gpt_index import GPTSimpleVectorIndex
from langchain import OpenAI
import gradio as gr
from gradio import Interface, Textbox
import os
import datetime
from datasets import load_dataset
from huggingface_hub import HfFolder
os.environ["OPENAI_API_KEY"] = os.environ['SECRET_CODE']
# Need to write to persistent dataset because cannot store temp data on spaces
DATASET_REPO_URL = "https://huggingface.co/datasets/peterpull/MediatorBot"
DATA_FILENAME = "data.txt"
INDEX_FILENAME = "index_base_89MB.json"
DATA_FILE = os.path.join("data", DATA_FILENAME)
INDEX_FILE = os.path.join("data", INDEX_FILENAME)
# we need a write access token.
HF_TOKEN = os.environ.get("HF_TOKEN")
print("HF TOKEN is none?", HF_TOKEN is None)
# Clones the distant repo to the local repo
dataset = load_dataset(DATASET_REPO_URL)
dataset_folder = HfFolder(dataset._data_files["train"][0].path).path
print(f"Dataset folder: {dataset_folder}")
print(f"Dataset files: {os.listdir(dataset_folder)}")
def generate_text() -> str:
with open(os.path.join(dataset_folder, DATA_FILENAME)) as file:
text = ""
for line in file:
row_parts = line.strip().split(",")
if len(row_parts) != 3:
continue
user, chatbot, time = row_parts
text += f"Time: {time}\nUser: {user}\nChatbot: {chatbot}\n\n"
return text if text else "No messages yet"
def store_message(chatinput: str, chatresponse: str):
if chatinput and chatresponse:
with open(os.path.join(dataset_folder, DATA_FILENAME), "a") as file:
file.write(f"{datetime.datetime.now()},{chatinput},{chatresponse}\n")
print(f"Wrote to datafile: {datetime.datetime.now()},{chatinput},{chatresponse}\n")
# Push back to hub every N-th time the function is called
if store_message.count_calls % 1 == 0:
print("Pushing back to Hugging Face model hub")
dataset.commit("Added new chat data") # Commit the changes
store_message.count_calls += 1
return generate_text()
store_message.count_calls = 1 #initiases the count at one. We want to count how many messages stored before pushing back to repo.
# gets the index file which is the context data
def get_index(index_file_path):
if os.path.exists(index_file_path):
index_size = os.path.getsize(index_file_path)
print(f"Size of {index_file_path}: {index_size} bytes") #let me know how big json file is.
return GPTSimpleVectorIndex.load_from_disk(index_file_path)
else:
print(f"Error: '{index_file_path}' does not exist.")
sys.exit()
index = get_index(INDEX_FILE)
# passes the prompt to the chatbot, queries the index, stores the output, returns the response
def chatbot(input_text, mentioned_person='Mediator John Haynes', confidence_threshold=0.5):
prompt = f"You are {mentioned_person}. Answer this: {input_text}. Only reply from the contextual data, or say you don't know. At the end of your answer ask an insightful question."
response = index.query(prompt, response_mode="default")
store_message(input_text,response)
# return the response
return response.response
with open('about.txt', 'r') as file:
about = file.read()
iface = Interface(
fn=chatbot,
inputs=Textbox("Enter your question"),
outputs="text",
title="AI Chatbot trained on J. Haynes mediation material, v0.5",
description=about)
iface.launch()