engage_quotation_chatbot / common_functions.py
jesshewyz's picture
fix: reset project
17344e8 verified
from Project import *
import pandas as pd
from io import StringIO
# from pathlib import Path
import psycopg2
from psycopg2.extras import RealDictCursor
import gradio as gr
import datetime
# Initialize project
quotation_project = Project(ProjectType.Engage)
no_active_session = "**Current Session**:None"
import json
import os
from dotenv import load_dotenv
from langtrace_python_sdk import langtrace
load_dotenv()
api_key = os.getenv("LANGTRACE_API_KEY")
if api_key is None:
raise ValueError("Environment variable 'LANGTRACE_API_KEY' is not set. Please set it in your .env file.")
langtrace.init(api_key=api_key)
def check_session_exists():
"""Check if a valid session exists"""
return quotation_project.session_id is not None
def get_db_connection():
"""Establishes and returns a new database connection."""
db_params = {
'dbname': os.getenv('DB_NAME'),
'user': os.getenv('DB_USER'),
'password': os.getenv('DB_PASSWORD'),
'host': os.getenv('DB_HOST'),
'port': os.getenv('DB_PORT')
}
conn = psycopg2.connect(**db_params)
return conn
def get_latest_components():
"""Fetches the latest project rubric for the project 'Engage'."""
try:
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
cur.execute("""
SELECT base_project_name,module,submodule,unit_type,quantity,mandays_per_unit
FROM base_project_component pc
WHERE (pc.base_project_name, pc.component_version) IN (
SELECT base_project_name, MAX(component_version)
FROM base_project_component
GROUP BY base_project_name
)
ORDER BY pc.base_project_name;
""")
component = cur.fetchall()
cur.close()
conn.close()
return component
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def get_section_name_and_rubric_list():
"""Fetches the latest project rubric for the project 'Engage'."""
try:
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
cur.execute("""
SELECT section_name, criteria, initial_question,explanation, priority, quantifiable_value
FROM base_project_rubric
WHERE LOWER(base_project_name) = LOWER('Engage')
AND rubric_version = (
SELECT MAX(rubric_version)
FROM base_project_rubric
WHERE LOWER(base_project_name) = LOWER('Engage')
)
ORDER BY section_name,
CASE priority
WHEN 'high' THEN 1
WHEN 'med' THEN 2
WHEN 'low' THEN 3
ELSE 4
END;
""")
rubric = cur.fetchall()
cur.close()
conn.close()
# Convert feedback to a list of dictionaries for JSON serialization
rubric_list = [dict(row) for row in rubric]
section_name_list = {row['section_name']: dict(row) for row in rubric}.keys()
# print(f"in get_section_name_and_rubric_list: {rubric_list}, {section_name_list}")
print(f"in get_section_name_and_rubric_list: {section_name_list}")
return section_name_list, rubric_list
except Exception as e:
return {
'status': 'error',
'message': str(e)
}
def sanitize_text(text):
"""Remove or replace special characters from text"""
# Replace single quotes with double quotes to avoid string formatting issues
text = text.replace("'", '')
# Remove or replace other problematic characters as needed
# Add more replacements here if needed
return text
def process_response(answer, history):
"""Process user responses and generate appropriate follow-up questions."""
try:
# Convert history to list if it's not already
if not isinstance(history, list):
history = []
# Sanitize the answer before processing
sanitized_answer = sanitize_text(str(answer))
# Add the user's answer to project details
quotation_project.add_project_detail(sanitized_answer)
# Update session in database if we have a session_id
if quotation_project.session_id:
try:
conn = get_db_connection()
cur = conn.cursor()
# Update project_requirement in sessions table
cur.execute("""
UPDATE sessions
SET project_requirement = %s
WHERE session_id = %s
""", (json.dumps(quotation_project.project_detail), quotation_project.session_id))
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"Error updating session: {str(e)}")
# Generate next question based on conversation stage
if len(history) == 1: # After first client information question
next_question = quotation_project.generate_client_follow_up()
elif len(history) == 2: # After client follow-up
next_question = quotation_project.generate_questions()
else: # Subsequent project requirements questions
next_question = quotation_project.generate_follow_up()
# Ensure we're adding a proper tuple to history
if isinstance(answer, str) and isinstance(next_question, str):
history.append((answer, next_question))
return history, next_question
except Exception as e:
print(f"Error in process_response: {str(e)}")
return history, "Error in generating follow up questions"
def calculate_mandays_and_costs(df):
try:
# Hacky way to handle the token counts.
# arbitrary number that is large enough to cover the chatflows.
df.loc[df['quantity'] > 80, 'mandays_per_unit'] = 0
# Calculate mandays and costs
df['calculated_mandays'] = df['quantity'] * df['mandays_per_unit']
total_mandays = df['calculated_mandays'].sum()
total_cost = 1500 * total_mandays
return df, total_mandays, total_cost
except Exception as e:
print(f"Error calculating mandays and costs: {str(e)}")
return None, None, None
def create_new_session():
"""Create a new session in the database and return the session_id"""
try:
conn = get_db_connection()
cur = conn.cursor()
# Insert new session with start time
cur.execute("""
INSERT INTO sessions (start_time)
VALUES (CURRENT_TIMESTAMP)
RETURNING session_id
""")
session_id = cur.fetchone()[0]
# Insert session_base_project record for "Engage"
cur.execute("""
INSERT INTO session_base_project (session_id, base_project_name)
VALUES (%s, 'Engage')
""", (session_id,))
conn.commit()
cur.close()
conn.close()
return session_id
except Exception as e:
print(f"Error creating new session: {str(e)}")
return None
def start_chat():
"""Initialize chat with first question and create new session"""
# Create new session and get session_id
session_id = create_new_session()
# Set the rubric and session_id for the project
section_name_list, rubric_list = get_section_name_and_rubric_list()
component_list = get_latest_components()
# Update session_id in Project instance
quotation_project.reset_project()
quotation_project.session_id = session_id
quotation_project.set_rubric(rubric_list)
quotation_project.set_rubric_section_names(section_name_list)
quotation_project.set_component_list(component_list)
initial_history = [(None, client_initial_question)]
return client_initial_question, initial_history, f"Current Session: {session_id}"
def get_project_state():
"""Get current state of quotation_project project"""
# Create status boxes
status = f"""Session ID: {quotation_project.session_id}
Rubric Loaded: {bool(quotation_project.rubric)}
Components Loaded: {bool(quotation_project.component_list)}
Requirements Loaded: {bool(quotation_project.project_detail)}"""
# Format requirements as a table if they exist
requirements_table = ""
if quotation_project.project_detail:
print(f"\n\nrequirements : {type(quotation_project.project_detail)}")
# Create markdown box for requirements
# requirements_table = "\n\n### Project Requirements\n```markdown\n"
for index,requirement in enumerate(list(quotation_project.project_detail)):
requirements_table += f"\n_____________\n"
requirements_table += f"#Requirement {index+1}\n {requirement}"
return status, requirements_table
def fetch_session(session_id):
"""Fetch session details from database and initialize project state"""
try:
# 1. Fetch session details
conn = get_db_connection()
cur = conn.cursor(cursor_factory=RealDictCursor)
cur.execute("""
SELECT project_requirement, start_time
FROM sessions
WHERE session_id = %s
""", (session_id,))
session = cur.fetchone()
cur.close()
conn.close()
print(session)
if session:
# 2. Update quotation_project with session data
quotation_project.session_id = session_id
# Set project requirements if they exist
if session['project_requirement']:
try:
# Check if the project requirement is a string
if isinstance(session['project_requirement'], str):
# Attempt to parse it as JSON
try:
requirements = json.loads(session['project_requirement'])
except json.JSONDecodeError:
# If JSON parsing fails, split the string into a list
requirements = session['project_requirement'].split('\n') # or use another delimiter if needed
else:
requirements = session['project_requirement']
# Clear existing details and set new ones
quotation_project.project_detail = []
for requirement in requirements:
quotation_project.add_project_detail(requirement.strip()) # Use strip() to remove any leading/trailing whitespace
except Exception as e:
return "", "", f"Error processing project requirements in session {session_id}: {str(e)}", no_active_session
# 3. Fetch and set rubric
section_name_list, rubric_list = get_section_name_and_rubric_list()
quotation_project.set_rubric(rubric_list)
quotation_project.set_rubric_section_names(section_name_list)
# 4. Fetch and set components
component_list = get_latest_components()
quotation_project.set_component_list(component_list)
return (*get_project_state(), f"Successfully loaded session {session_id} with all data", f"Current Session: {session_id}")
# "\n".join(rubric_list), # Return rubric list as a string
# component_list) # Ensure to extract string values
else:
return "", "", f"Session {session_id} not found",no_active_session
# return "", "", f"Session {session_id} not found", "", ""
except Exception as e:
return "", "", f"Error fetching session: {str(e)}",no_active_session
# return "", "", f"Error fetching session: {str(e)}", "", ""
def insert_quotation(csv_string, total_price, total_mandays, note=None, task_breakdown_v3 = None, tier_level=1):
"""Insert a new quotation into the database with an updated version."""
try:
conn = get_db_connection()
cur = conn.cursor()
# Fetch the current maximum version for the given session_id
cur.execute("""
SELECT COALESCE(MAX(version), 0) + 1
FROM quotations
WHERE session_id = %s
""", (quotation_project.session_id,))
result = cur.fetchone()
version = result[0] if result else 1 # Default to version 1 if no result
# Get the next version number
total_price = float(total_price) if total_price is not None else None
total_mandays = float(total_mandays) if total_mandays is not None else None
structured_details = quotation_project.generated_prd
# Convert project details to JSON string
# Append the task table here, so we know what tasks are not in the quantity table
# (context : v3 function calls it slightly differently, csv_string will be the quantity table)
# why ? lazy alter table to add new column, then create a whole new if else statement to handle this
details = f"{json.dumps(quotation_project.project_detail)} + {task_breakdown_v3}" if task_breakdown_v3 else json.dumps(quotation_project.project_detail)
# Insert new quotation
cur.execute("""
INSERT INTO quotations (session_id, version, details, quotation_csv, total_price, total_mandays,structured_details)
VALUES (%s, %s, %s, %s, %s, %s,%s)
""", (
quotation_project.session_id,
version,
details,
csv_string,
total_price,
total_mandays,
structured_details
))
conn.commit()
cur.close()
conn.close()
print("Successfully inserted quotation")
except Exception as e:
print(f"Error inserting quotation: {str(e)}")
def create_folder_and_save_csv(df, folder_name, file_name):
"""Common function to create a folder and save a DataFrame as a CSV file."""
os.makedirs(folder_name, exist_ok=True)
if df is not None:
csv_file_path = os.path.join(folder_name, file_name)
df.to_csv(csv_file_path, index=False)
return f"{file_name} saved to {csv_file_path}"
return "No data to save."
def create_folder():
session_id = quotation_project.session_id
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
folder_name = f"quotation_{session_id}_{timestamp}"
os.makedirs(folder_name, exist_ok=True)
return folder_name