gpt4all-conversations / conversations.py
dangermouse77's picture
Upload 3 files
394d97c verified
#!/usr/bin/python3
import requests
import json
import sys
import psycopg2
from datetime import datetime
from sentence_transformers import SentenceTransformer
import numpy as np
MAX_WORDS=1024
model = SentenceTransformer("all-MiniLM-L6-v2")
def last_n_words(text: str, n: int) -> str:
words = text.split()
return " ".join(words[-n:])
def get_chat_completion(content: str, port: int) -> str:
url = f"http://localhost:{port}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": "Default",
"messages": [{"role": "user", "content": content}],
"max_tokens": 2048,
"temperature": 0.28
}
response = requests.post(url, headers=headers, data=json.dumps(data))
if response.status_code == 200:
response_json = response.json()
return response_json.get("choices", [{}])[0].get("message", {}).get("content", "No response content")
else:
return f"Request failed with status code {response.status_code}"
# database connection
def get_db_connection():
try:
connection = psycopg2.connect(
dbname="d_conversations",
user="postgres",
password="********",
host="pg-db",
port="5432"
)
return connection
except Exception as e:
print(f"Error in connection to database: {e}")
return None
def save_message(db_connection, message, chatbot_id, model_id):
# save a message in conversation table
embeds = model.encode(message).tolist()
try:
with db_connection.cursor() as cursor:
query = """
INSERT INTO conversation (id, chatbot_id, message, create_dt, model_id, embeds)
VALUES (nextval('conversation_seq'), %s, %s, %s, %s, %s)
RETURNING id;
"""
create_dt = datetime.utcnow()
cursor.execute(query, (chatbot_id, message, create_dt, model_id, embeds))
conversation_id = cursor.fetchone()[0]
db_connection.commit()
return conversation_id
except Exception as e:
db_connection.rollback()
print(f"Error while inserting: {e}")
return None
def get_last_n_messages(db_connection, n):
"""
Retrieves the last n messages from the conversation table and formats them as a chat log.
:param db_connection: Database connection object
:return: Formatted string with the last n messages
"""
try:
with db_connection.cursor() as cursor:
query = """
SELECT cb.name, c.message
FROM conversation c, chatbot cb WHERE c.chatbot_id=cb.id
ORDER BY c.id DESC
LIMIT %s;
"""
cursor.execute(query, (n,))
rows = cursor.fetchall()
if len(rows)==3:
formatted_messages = f"<{rows[2][0]}> - {rows[2][1]}\n\n<{rows[1][0]}> - {rows[1][1]}\n\n<{rows[0][0]}> - {rows[0][1]}"
if len(rows)==2:
formatted_messages = f"<{rows[1][0]}> - {rows[1][1]}\n\n<{rows[0][0]}> - {rows[0][1]}"
if len(rows)==1:
formatted_messages = f"<{rows[0][0]}> - {rows[0][1]}"
if len(rows)==0:
formatted_messages = "\n"
return formatted_messages
except Exception as e:
print(f"Error while retrieving messages: {e}")
return None
def get_chatbot_model_port(db_connection, chatbot_id):
"""
Retrieves model_id and port from the chatbot table for a given chatbot ID.
:param db_connection: Database connection object
:param chatbot_id: ID of the chatbot
:return: Tuple (model_id, port) or None if not found
"""
try:
with db_connection.cursor() as cursor:
query = "SELECT model_id, port FROM chatbot WHERE id = %s;"
cursor.execute(query, (chatbot_id,))
result = cursor.fetchone()
if result:
return result # (model_id, port)
else:
print(f"No chatbot found with ID {chatbot_id}")
return None
except Exception as e:
print(f"Error retrieving chatbot info: {e}")
return None
# Example usage
if len(sys.argv) < 2:
print("Usage: python3 conversations.py '<your question here>'")
sys.exit(1)
conn = get_db_connection()
if conn is None:
print("Connection to database failed")
exit(1)
user_input = sys.argv[1]
save_message(conn, user_input, 1, 0)
chatbot_a = 4
chatbot_b = 7
current_chatbot = chatbot_a
try:
while True:
# Recovers last message from conversation
last_messages = get_last_n_messages(conn, 1)
print("last messages: "+last_messages)
# Change chatbot
current_chatbot = chatbot_a if current_chatbot == chatbot_b else chatbot_b
print(f"asking chatbot {current_chatbot}")
# Retrieve model and port of current chatbot
model_port_info = get_chatbot_model_port(conn, current_chatbot)
if not model_port_info:
print(f"Could not retrieve info from {current_chatbot}")
break
model_id, port = model_port_info
print("retrieved model and port")
# Retrieves answer
if (current_chatbot == chatbot_a):
print("ask for question")
last_messages = last_messages + "\n\nSuggest one very short factual follow-up question that has not been answered yet or cannot be found inspired by the previous conversation and excerpts.\n"
response = get_chat_completion(last_n_words(last_messages, MAX_WORDS), port)
# Saves answer in database
save_message(conn, response, current_chatbot, model_id)
# Print last answer
print("answer was:")
print(get_last_n_messages(conn,1))
except KeyboardInterrupt:
print("\nConversation interrupted by human.")
finally:
conn.close()