|
|
|
|
|
import requests |
|
|
import json |
|
|
import sys |
|
|
import psycopg2 |
|
|
from datetime import datetime |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
|
|
|
MAX_WORDS=1024 |
|
|
|
|
|
model = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
|
|
|
def last_n_words(text: str, n: int) -> str: |
|
|
words = text.split() |
|
|
return " ".join(words[-n:]) |
|
|
|
|
|
def get_chat_completion(content: str, port: int) -> str: |
|
|
url = f"http://localhost:{port}/v1/chat/completions" |
|
|
headers = {"Content-Type": "application/json"} |
|
|
data = { |
|
|
"model": "Default", |
|
|
"messages": [{"role": "user", "content": content}], |
|
|
"max_tokens": 2048, |
|
|
"temperature": 0.28 |
|
|
} |
|
|
|
|
|
response = requests.post(url, headers=headers, data=json.dumps(data)) |
|
|
|
|
|
if response.status_code == 200: |
|
|
response_json = response.json() |
|
|
return response_json.get("choices", [{}])[0].get("message", {}).get("content", "No response content") |
|
|
else: |
|
|
return f"Request failed with status code {response.status_code}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_db_connection(): |
|
|
try: |
|
|
connection = psycopg2.connect( |
|
|
dbname="d_conversations", |
|
|
user="postgres", |
|
|
password="********", |
|
|
host="pg-db", |
|
|
port="5432" |
|
|
) |
|
|
return connection |
|
|
except Exception as e: |
|
|
print(f"Error in connection to database: {e}") |
|
|
return None |
|
|
|
|
|
def save_message(db_connection, message, chatbot_id, model_id): |
|
|
|
|
|
|
|
|
embeds = model.encode(message).tolist() |
|
|
|
|
|
try: |
|
|
with db_connection.cursor() as cursor: |
|
|
query = """ |
|
|
INSERT INTO conversation (id, chatbot_id, message, create_dt, model_id, embeds) |
|
|
VALUES (nextval('conversation_seq'), %s, %s, %s, %s, %s) |
|
|
RETURNING id; |
|
|
""" |
|
|
create_dt = datetime.utcnow() |
|
|
cursor.execute(query, (chatbot_id, message, create_dt, model_id, embeds)) |
|
|
conversation_id = cursor.fetchone()[0] |
|
|
db_connection.commit() |
|
|
return conversation_id |
|
|
except Exception as e: |
|
|
db_connection.rollback() |
|
|
print(f"Error while inserting: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def get_last_n_messages(db_connection, n): |
|
|
""" |
|
|
Retrieves the last n messages from the conversation table and formats them as a chat log. |
|
|
|
|
|
:param db_connection: Database connection object |
|
|
:return: Formatted string with the last n messages |
|
|
""" |
|
|
try: |
|
|
with db_connection.cursor() as cursor: |
|
|
query = """ |
|
|
SELECT cb.name, c.message |
|
|
FROM conversation c, chatbot cb WHERE c.chatbot_id=cb.id |
|
|
ORDER BY c.id DESC |
|
|
LIMIT %s; |
|
|
""" |
|
|
cursor.execute(query, (n,)) |
|
|
rows = cursor.fetchall() |
|
|
|
|
|
if len(rows)==3: |
|
|
formatted_messages = f"<{rows[2][0]}> - {rows[2][1]}\n\n<{rows[1][0]}> - {rows[1][1]}\n\n<{rows[0][0]}> - {rows[0][1]}" |
|
|
|
|
|
if len(rows)==2: |
|
|
formatted_messages = f"<{rows[1][0]}> - {rows[1][1]}\n\n<{rows[0][0]}> - {rows[0][1]}" |
|
|
|
|
|
if len(rows)==1: |
|
|
formatted_messages = f"<{rows[0][0]}> - {rows[0][1]}" |
|
|
|
|
|
if len(rows)==0: |
|
|
formatted_messages = "\n" |
|
|
|
|
|
return formatted_messages |
|
|
except Exception as e: |
|
|
print(f"Error while retrieving messages: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_chatbot_model_port(db_connection, chatbot_id): |
|
|
""" |
|
|
Retrieves model_id and port from the chatbot table for a given chatbot ID. |
|
|
|
|
|
:param db_connection: Database connection object |
|
|
:param chatbot_id: ID of the chatbot |
|
|
:return: Tuple (model_id, port) or None if not found |
|
|
""" |
|
|
try: |
|
|
with db_connection.cursor() as cursor: |
|
|
query = "SELECT model_id, port FROM chatbot WHERE id = %s;" |
|
|
cursor.execute(query, (chatbot_id,)) |
|
|
result = cursor.fetchone() |
|
|
|
|
|
if result: |
|
|
return result |
|
|
else: |
|
|
print(f"No chatbot found with ID {chatbot_id}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"Error retrieving chatbot info: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python3 conversations.py '<your question here>'") |
|
|
sys.exit(1) |
|
|
|
|
|
conn = get_db_connection() |
|
|
if conn is None: |
|
|
print("Connection to database failed") |
|
|
exit(1) |
|
|
|
|
|
user_input = sys.argv[1] |
|
|
save_message(conn, user_input, 1, 0) |
|
|
|
|
|
|
|
|
chatbot_a = 4 |
|
|
chatbot_b = 7 |
|
|
current_chatbot = chatbot_a |
|
|
|
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
last_messages = get_last_n_messages(conn, 1) |
|
|
|
|
|
print("last messages: "+last_messages) |
|
|
|
|
|
current_chatbot = chatbot_a if current_chatbot == chatbot_b else chatbot_b |
|
|
|
|
|
print(f"asking chatbot {current_chatbot}") |
|
|
|
|
|
model_port_info = get_chatbot_model_port(conn, current_chatbot) |
|
|
if not model_port_info: |
|
|
print(f"Could not retrieve info from {current_chatbot}") |
|
|
break |
|
|
|
|
|
model_id, port = model_port_info |
|
|
print("retrieved model and port") |
|
|
|
|
|
|
|
|
if (current_chatbot == chatbot_a): |
|
|
print("ask for question") |
|
|
last_messages = last_messages + "\n\nSuggest one very short factual follow-up question that has not been answered yet or cannot be found inspired by the previous conversation and excerpts.\n" |
|
|
response = get_chat_completion(last_n_words(last_messages, MAX_WORDS), port) |
|
|
|
|
|
|
|
|
save_message(conn, response, current_chatbot, model_id) |
|
|
|
|
|
|
|
|
print("answer was:") |
|
|
print(get_last_n_messages(conn,1)) |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nConversation interrupted by human.") |
|
|
finally: |
|
|
conn.close() |
|
|
|