Spaces:
Runtime error
Runtime error
| import time | |
| import uuid | |
| import numpy as np | |
| import pandas as pd | |
| from io import StringIO | |
| # from pathlib import Path | |
| from enum import Enum | |
| import psycopg2 | |
| from psycopg2.extras import RealDictCursor | |
| import gradio as gr | |
| import datetime | |
| from Project import * | |
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from langtrace_python_sdk import langtrace | |
| from state import state | |
| from prompt_configs import * | |
| import asyncio | |
| import re | |
| import base64 | |
| no_active_session = "**Current Session**:None" | |
| class SLEEP_TIME(Enum): | |
| """Enum class containing sleep time constants.""" | |
| ONE_AND_HALF_SEC = 1.5 # Default sleep time in seconds | |
| TWO_SEC = 1.8 # Default sleep time in seconds | |
| THREE_SEC = 3 # Default sleep time in seconds | |
| load_dotenv() | |
| api_key = os.getenv("LANGTRACE_API_KEY") | |
| if api_key is None: | |
| raise ValueError("Environment variable 'LANGTRACE_API_KEY' is not set. Please set it in your .env file.") | |
| langtrace.init(api_key=api_key) | |
| #TODO: Since now we are merging both page and engage , we have to pull the project type in order to load the correct project | |
| def get_db_connection(): | |
| """Establishes and returns a new database connection.""" | |
| db_params = { | |
| 'dbname': os.getenv('DB_NAME'), | |
| 'user': os.getenv('DB_USER'), | |
| 'password': os.getenv('DB_PASSWORD'), | |
| 'host': os.getenv('DB_HOST'), | |
| 'port': os.getenv('DB_PORT') | |
| } | |
| conn = psycopg2.connect(**db_params) | |
| return conn | |
| def get_latest_components(): | |
| """Fetches the latest project rubric for the project 'Page'.""" | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor(cursor_factory=RealDictCursor) | |
| cur.execute(""" | |
| SELECT base_project_name,module,submodule,unit_type,quantity,mandays_per_unit | |
| FROM base_project_component pc | |
| WHERE (pc.base_project_name, pc.component_version) IN ( | |
| SELECT base_project_name, MAX(component_version) | |
| FROM base_project_component | |
| GROUP BY base_project_name | |
| ) | |
| ORDER BY pc.base_project_name; | |
| """) | |
| component = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| return component | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'message': str(e) | |
| } | |
| def get_section_name_and_rubric_list(): | |
| """Fetches the latest project rubric for the project 'Page'.""" | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor(cursor_factory=RealDictCursor) | |
| cur.execute(""" | |
| SELECT section_name, criteria, initial_question,explanation, mvp, quantifiable_value | |
| FROM base_project_rubric | |
| WHERE LOWER(base_project_name) = LOWER('Page') | |
| AND rubric_version = ( | |
| SELECT MAX(rubric_version) | |
| FROM base_project_rubric | |
| WHERE LOWER(base_project_name) = LOWER('Page') | |
| ) | |
| ORDER BY section_name, | |
| CASE mvp | |
| WHEN 'high' THEN 1 | |
| WHEN 'med' THEN 2 | |
| WHEN 'low' THEN 3 | |
| ELSE 4 | |
| END; | |
| """) | |
| rubric = cur.fetchall() | |
| cur.close() | |
| conn.close() | |
| # Convert feedback to a list of dictionaries for JSON serialization | |
| rubric_list = [dict(row) for row in rubric] | |
| section_name_list = {row['section_name']: dict(row) for row in rubric}.keys() | |
| # print(f"in get_section_name_and_rubric_list: {rubric_list}, {section_name_list}") | |
| print(f"in get_section_name_and_rubric_list: {section_name_list}") | |
| return section_name_list, rubric_list | |
| except Exception as e: | |
| return { | |
| 'status': 'error', | |
| 'message': str(e) | |
| } | |
| def update_session_project_requirements(session_id, project_detail): | |
| """Update project requirements for a session in the database. | |
| Args: | |
| session_id: ID of the session to update | |
| project_detail: Project requirements to store | |
| Returns: | |
| None | |
| Raises: | |
| Exception if database update fails | |
| """ | |
| if session_id: | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Update project_requirement in sessions table | |
| cur.execute(""" | |
| UPDATE sessions | |
| SET project_requirement = %s | |
| WHERE session_id = %s | |
| """, (json.dumps(project_detail), session_id)) | |
| conn.commit() | |
| cur.close() | |
| conn.close() | |
| except Exception as e: | |
| print(f"Error updating session: {str(e)}") | |
| raise | |
| def sanitize_text(text): | |
| """Remove or replace special characters from text""" | |
| # Replace single quotes with double quotes to avoid string formatting issues | |
| text = text.replace("'", '') | |
| # Remove or replace other problematic characters as needed | |
| # Add more replacements here if needed | |
| return text | |
| async def run_question_agent(quotation_project): | |
| configuration_output = await quotation_project.async_execute_prompt("questioning_agent", {"project_detail": quotation_project.get_project_detail()}) | |
| try: | |
| config = quotation_project._parse_json_response(configuration_output) | |
| quotation_project.config = config | |
| log_prompt(PROMPTS['component_agent'].step, | |
| PROMPTS['component_agent'].description, | |
| PROMPTS["component_agent"].prompt, | |
| configuration_output) | |
| selected_functions = config[0]["selected_functions"] | |
| print(f"Selected {len(selected_functions)} component to generate") | |
| except Exception as e: | |
| print(f"Error in analyzing configuration: {e}") | |
| if not selected_functions: | |
| # TO DO: there has to be a way to handle this | |
| print("No question generator selected.") | |
| return None | |
| # Execute only the first selected function | |
| function_name = selected_functions[0] | |
| try: | |
| next_question = await quotation_project.async_execute_prompt(function_name) | |
| log_prompt(PROMPTS[function_name].step, | |
| PROMPTS[function_name].description, | |
| PROMPTS[function_name].prompt, | |
| next_question) | |
| return f"## Project Configuration: {config[0]['configuration_type']}\n\n{next_question}" | |
| except Exception as e: | |
| print(f"Error executing {function_name}: {e}") | |
| return None | |
| def save_prompt_to_db(new_prompt: str) -> None: | |
| """ | |
| Saves a prompt to the database after base64 encoding. | |
| Args: | |
| new_prompt: The prompt text to save | |
| Raises: | |
| Exception: If database connection or insertion fails | |
| """ | |
| try: | |
| new_prompt = f"{new_prompt}" | |
| # Encode the prompt in base64 | |
| encoded_prompt = base64.b64encode(new_prompt.strip().encode('utf-8')).decode('utf-8') | |
| # Get database connection using helper function | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| insert_sql = "INSERT INTO config (config_body) VALUES (%s)" | |
| cur.execute(insert_sql, (encoded_prompt,)) | |
| conn.commit() | |
| cur.close() | |
| conn.close() | |
| except Exception as e: | |
| print(f"Error saving prompt to database: {str(e)}") | |
| raise | |
| def get_latest_prompt_from_db() -> dict: | |
| """ | |
| Retrieves the latest prompt from the database, decodes it from base64, and parses as JSON. | |
| Returns: | |
| dict: The decoded and parsed prompt configuration | |
| Raises: | |
| Exception: If database connection or retrieval fails | |
| """ | |
| try: | |
| # Get database connection using helper function | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Get the latest config entry by ID | |
| select_sql = "SELECT id, config_body FROM config ORDER BY id DESC LIMIT 1" | |
| cur.execute(select_sql) | |
| result = cur.fetchone() | |
| cur.close() | |
| conn.close() | |
| if result: | |
| config_id, encoded_prompt = result | |
| print(f"get_latest_prompt_from_db: config_id: {config_id}") | |
| # Decode from base64 | |
| decoded_prompt = base64.b64decode(encoded_prompt).decode('utf-8') | |
| # with open(f'db_config_log_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}.txt', 'w') as f: | |
| # f.write(f"Loaded configuration from database: {decoded_prompt}") | |
| # Parse as JSON | |
| # prompt_config = json.loads(decoded_prompt) | |
| return decoded_prompt | |
| else: | |
| print("No prompt configurations found in database") | |
| return {} | |
| except Exception as e: | |
| print(f"Error retrieving prompt from database: {str(e)}") | |
| raise | |
| def process_response(answer, history): | |
| """Process user responses and generate appropriate follow-up questions.""" | |
| try: | |
| # Convert history to list if it's not already | |
| if not isinstance(history, list): | |
| history = [] | |
| # Sanitize the answer before processing | |
| sanitized_answer = sanitize_text(str(answer)) | |
| # Add the user's answer to project details | |
| state.quotation_project.add_project_detail(sanitized_answer) | |
| # Call the function with project details | |
| if state.quotation_project.session_id: | |
| update_session_project_requirements( | |
| state.quotation_project.session_id, | |
| state.quotation_project.project_detail | |
| ) | |
| # Generate next question based on conversation stage | |
| if len(history) == 1: # After first client information question | |
| next_question = state.quotation_project.generate_client_follow_up() | |
| elif len(history) == 2: # After client follow-up | |
| # next_question = quotation_project.generate_questions() | |
| next_question = state.quotation_project.gather_project_input() | |
| else: # Subsequent project requirements questions | |
| # next_question = quotation_project.generate_follow_up() | |
| next_question = state.quotation_project.review_project_input() | |
| # Ensure we're adding a proper tuple to hisxtory | |
| if isinstance(answer, str) and isinstance(next_question, str): | |
| history.append((answer, next_question)) | |
| return history, next_question | |
| except Exception as e: | |
| print(f"Error in process_response: {str(e)}") | |
| return history, "Error in generating follow up questions" | |
| def log_chat_history(chat_history, filename="chat_logs.txt"): | |
| """ | |
| Logs chat history to a text file in a clean, readable format. | |
| Args: | |
| chat_history (list): List of dictionaries containing chat messages | |
| filename (str): Name of the log file (default: chat_logs.txt) | |
| """ | |
| with open(filename, 'a', encoding='utf-8') as f: | |
| f.write("Chat History\n") | |
| f.write("="*50 + "\n\n") | |
| for i, message in enumerate(chat_history, 1): | |
| # Message number and role | |
| f.write(f"Message #{i}\n") | |
| f.write(f"Role: {message['role'].upper()}\n\n") | |
| # Write metadata if it exists and is not empty | |
| if message.get('metadata') and message['metadata']: | |
| f.write("Metadata:\n") | |
| for key, value in message['metadata'].items(): | |
| f.write(f" • {key}: {value}\n") | |
| f.write("\n") | |
| # Write content | |
| f.write("Content:\n") | |
| f.write(f"{message['content']}\n") | |
| # Add separator between messages | |
| f.write("\n" + "-"*50 + "\n\n") | |
| async def clean_sample_answers(text): | |
| """Clean up sample answers in the text by removing 'Sample:' and its content. | |
| Args: | |
| text (str): Input text containing sample answers | |
| Returns: | |
| str: Cleaned text with 'Sample:' and its content replaced by 'Answer:' | |
| """ | |
| try: | |
| if not text: | |
| return text | |
| print("Starting to clean sample answers...") | |
| # Replace all '*' or '#' with 'Answer:' | |
| cleaned_text = text.replace('*', '').replace('#', '').replace('-', '') | |
| # this works | |
| # cleaned_text = re.sub( | |
| # r'(?i)\s*(?:\(|\*)?\s*Sample(?:\s*Answer)?\s*:\s*(?:"[^"]*"|[^\)\*\n]*)\s*(?:\)|\*)?', | |
| # '\nAnswer:', | |
| # cleaned_text, | |
| # flags=re.MULTILINE | |
| # ) | |
| cleaned_text = re.sub( | |
| r'(?i)\s*(?:\(|\*)?\s*Sample(?:\s*Answers?)?\s*:\s*(?:"[^"]*"|[^\)\*\n]*)\s*(?:\)|\*)?', | |
| '\nAnswer:\n', | |
| cleaned_text, | |
| flags=re.MULTILINE | |
| ) | |
| print("Sample answers cleaned successfully.") | |
| return cleaned_text | |
| # return cleaned_text | |
| except Exception as e: | |
| print(f"Error cleaning sample answers: {e}") | |
| return text | |
| # Example usage: | |
| # chat_history = [{'role': 'user', ...}, {'role': 'assistant', ...}] | |
| # log_chat_history(chat_history) | |
| # | |
| #TODO: Ensure it directs towards correct question | |
| async def async_process_response(answer, history): | |
| """Process user responses and generate appropriate follow-up questions.""" | |
| print(f"[DEBUG] Entering async_process_response") | |
| # print(f"[DEBUG] Input answer: {answer}") | |
| log_chat_history(history) | |
| # print(f"[DEBUG] Input history: {history}") | |
| try: | |
| # Convert history to list if it's not already | |
| if not isinstance(history, list): | |
| # print("[DEBUG] Converting history to list") | |
| history = [] | |
| # print("[DEBUG] Sanitizing answer") | |
| sanitized_answer = sanitize_text(str(answer)) | |
| # print("[DEBUG] Adding project detail") | |
| if sanitized_answer is not None and len(sanitized_answer) > 0: | |
| state.quotation_project.add_project_detail(sanitized_answer) | |
| start_time = time.time() | |
| project_detail_len = len(state.quotation_project.project_detail) | |
| # Determine which function and prompt config to use based on project detail length | |
| if project_detail_len == 1: | |
| function_to_run = state.quotation_project.generate_client_follow_up | |
| prompt_config = PROMPTS["generate_client_follow_up"] | |
| elif project_detail_len == 2: | |
| function_to_run = lambda: run_question_agent(state.quotation_project) | |
| prompt_config = PROMPTS["questioning_agent"] | |
| elif project_detail_len in [3, 4]: | |
| function_to_run = state.quotation_project.generate_further_follow_up_questions | |
| prompt_config = PROMPTS["generate_further_follow_up_questions"] | |
| elif project_detail_len == 5: | |
| function_to_run = state.quotation_project.generate_general_questions | |
| prompt_config = PROMPTS["generate_general_questions"] | |
| else: | |
| function_to_run = state.quotation_project.generate_further_follow_up_questions | |
| prompt_config = PROMPTS["generate_further_follow_up_questions"] | |
| # Create response object with prompt config description as title | |
| response = gr.ChatMessage( | |
| content="", | |
| metadata={ | |
| "title": f"_{prompt_config.description}_", | |
| "id": 0, | |
| "status": "pending" | |
| } | |
| ) | |
| yield response, "" | |
| # Rest of the function remains the same | |
| next_question_task = asyncio.create_task(function_to_run()) | |
| # Update project detail with new list; | |
| if state.quotation_project.session_id: | |
| update_session_project_requirements( | |
| state.quotation_project.session_id, | |
| state.quotation_project.project_detail | |
| ) | |
| accumulated_thoughts = "" | |
| thought_index = 0 | |
| while not next_question_task.done(): | |
| thought = prompt_config.thoughts[thought_index % len(prompt_config.thoughts)] | |
| thought_index += 1 | |
| # print(f"[DEBUG] Current thought: {thought}") | |
| await asyncio.sleep(SLEEP_TIME.TWO_SEC.value) | |
| accumulated_thoughts += f"- {thought}\n\n" | |
| response.content = accumulated_thoughts.strip() | |
| # print(f"[DEBUG] Yielding updated thoughts: {response.content}") | |
| yield response, "" | |
| next_question = await next_question_task | |
| print(f"[DEBUG] Next question: {next_question}") | |
| user_input_template = await clean_sample_answers(next_question) | |
| response.metadata["status"] = "done" | |
| response.metadata["duration"] = time.time() - start_time | |
| yield response, "" | |
| response_list = [ | |
| response, | |
| gr.ChatMessage(content=next_question) | |
| ] | |
| yield response_list, user_input_template | |
| except Exception as e: | |
| print(f"[DEBUG] Error in async_process_response: {str(e)}") | |
| print(f"[DEBUG] Error type: {type(e)}") | |
| yield history, "Error in generating follow up questions" | |
| #TODO: Create calculate mandays for general and mvp | |
| def calculate_mandays_and_costs(generated_results=None): | |
| try: | |
| total_mandays = 0 | |
| if generated_results: | |
| for result in generated_results: | |
| if 'result' in result: | |
| result_content = result['result'] | |
| # Handle nested dictionary structure | |
| if isinstance(result_content, dict): | |
| # Check if the result contains function-specific data | |
| function_name = result.get('function_name', '') | |
| if function_name in result_content: | |
| records = result_content[function_name] | |
| # Sum up mandays from all records | |
| for record in records: | |
| if 'mandays' in record: | |
| try: | |
| mandays = float(record['mandays']) | |
| total_mandays += mandays | |
| except (ValueError, TypeError): | |
| print(f"Invalid mandays value in record: {record['mandays']}") | |
| continue | |
| # Calculate costs based on total mandays | |
| total_cost = 1500 * total_mandays | |
| estimated_months = total_mandays/30 | |
| return (total_mandays, total_cost, estimated_months) | |
| except Exception as e: | |
| print(f"Error calculating mandays and costs: {str(e)}") | |
| return tuple([None] * 3) | |
| def calculate_mvp_mandays_and_costs(generated_mvp_results): | |
| try: | |
| total_mvp_mandays = 0 | |
| if generated_mvp_results: | |
| for result in generated_mvp_results: | |
| if 'result' in result: | |
| result_content = result['result'] | |
| for section_name, section_data in result_content.items(): | |
| if isinstance(section_data, list): | |
| for record in section_data: | |
| mandays = float(record.get('mandays', 0)) | |
| total_mvp_mandays += mandays | |
| total_mvp_cost = 1500 * total_mvp_mandays | |
| estimated_mvp_months = total_mvp_mandays / 30 | |
| return (total_mvp_mandays, total_mvp_cost, estimated_mvp_months) | |
| except Exception as e: | |
| print(f"Error calculating MVP mandays and costs: {str(e)}") | |
| return 0, 0, 0 | |
| def create_new_session(): | |
| """Create a new session in the database and return the session_id""" | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Insert new session with start time | |
| cur.execute(""" | |
| INSERT INTO sessions (start_time) | |
| VALUES (CURRENT_TIMESTAMP) | |
| RETURNING session_id | |
| """) | |
| session_id = cur.fetchone()[0] | |
| # Insert session_base_project record for "Page" | |
| cur.execute(""" | |
| INSERT INTO session_base_project (session_id, base_project_name) | |
| VALUES (%s, 'Page') | |
| """, (session_id,)) | |
| conn.commit() | |
| cur.close() | |
| conn.close() | |
| return session_id | |
| # return 161 | |
| except Exception as e: | |
| print(f"Error creating new session: {str(e)}") | |
| return None | |
| def start_chat(): | |
| """Initialize chat with first question and create new session""" | |
| # Create new session and get session_id | |
| session_id = create_new_session() | |
| # Set the rubric and session_id for the project | |
| state.quotation_project.reset_project() | |
| state.quotation_project.session_id = session_id | |
| # Get the initial question from prompts config | |
| initial_prompt = PROMPTS["client_initial_question"].prompt | |
| # Get project state and combine with session display | |
| status, requirements = get_project_state() | |
| session_display = f"Current Session: {session_id}" | |
| # Return exactly three values as expected by the Gradio interface | |
| # return initial_prompt, initial_history, f"Current Session: {session_id}" | |
| return initial_prompt,session_display | |
| def get_project_state(): | |
| """Get current state of quotation_project project""" | |
| # Create status boxes | |
| status = f"""Session ID: {state.quotation_project.session_id} | |
| Rubric Loaded: {bool(state.quotation_project.rubric)} | |
| Components Loaded: {bool(state.quotation_project.component_list)} | |
| Requirements Loaded: {bool(state.quotation_project.project_detail)}""" | |
| # Format requirements as a table if they exist | |
| requirements_table = "" | |
| if state.quotation_project.project_detail: | |
| print(f"\n\nrequirements : {type(state.quotation_project.project_detail)}") | |
| # Create markdown box for requirements | |
| # requirements_table = "\n\n### Project Requirements\n```markdown\n" | |
| for index,requirement in enumerate(list(state.quotation_project.project_detail)): | |
| requirements_table += f"\n_____________\n" | |
| requirements_table += f"#Requirement {index+1}\n {requirement}" | |
| return status, requirements_table | |
| def fetch_session(session_id): | |
| """Fetch session details from database and initialize project state""" | |
| try: | |
| # 1. Fetch session details | |
| conn = get_db_connection() | |
| cur = conn.cursor(cursor_factory=RealDictCursor) | |
| cur.execute(""" | |
| SELECT project_requirement, start_time | |
| FROM sessions | |
| WHERE session_id = %s | |
| """, (session_id,)) | |
| session = cur.fetchone() | |
| cur.close() | |
| conn.close() | |
| print(session) | |
| if session: | |
| # 2. Update quotation_project with session data | |
| state.quotation_project.session_id = session_id | |
| if session['project_requirement']: | |
| try: | |
| # Check if the project requirement is a string | |
| if isinstance(session['project_requirement'], str): | |
| # Attempt to parse it as JSON | |
| try: | |
| requirements = json.loads(session['project_requirement']) | |
| except json.JSONDecodeError: | |
| # If JSON parsing fails, split the string into a list | |
| requirements = session['project_requirement'].split('\n') # or use another delimiter if needed | |
| else: | |
| requirements = session['project_requirement'] | |
| # Clear existing details and set new ones | |
| state.quotation_project.project_detail = [] | |
| for requirement in requirements: | |
| state.quotation_project.add_project_detail(requirement.strip()) # Use strip() to remove any leading/trailing whitespace | |
| except Exception as e: | |
| return "", "", f"Error processing project requirements in session {session_id}: {str(e)}", no_active_session | |
| section_name_list, rubric_list = get_section_name_and_rubric_list() | |
| state.quotation_project.set_rubric(rubric_list) | |
| state.quotation_project.set_rubric_section_names(section_name_list) | |
| print("in fetch_session: loading config from db") | |
| load_msg = state.quotation_project.load_config_from_db() | |
| # 4. Fetch and set components | |
| component_list = get_latest_components() | |
| state.quotation_project.set_component_list(component_list) | |
| return (*get_project_state(), f"Successfully loaded session {session_id} with all data, {load_msg}", f"Current Session: {session_id}") | |
| # "\n".join(rubric_list), # Return rubric list as a string | |
| # component_list) # Ensure to extract string values | |
| else: | |
| return "", "", f"Session {session_id} not found",no_active_session | |
| # return "", "", f"Session {session_id} not found", "", "" | |
| except Exception as e: | |
| return "", "", f"Error fetching session: {str(e)}",no_active_session | |
| # return "", "", f"Error fetching session: {str(e)}", "", "" | |
| def log_prompt_execution(step_name, sub_step_name, prompt_text): | |
| """Log prompt execution to the database with a randomly generated prompt_id.""" | |
| created_at = datetime.datetime.now() | |
| session_id = state.quotation_project.session_id | |
| # Generate a random prompt_id | |
| prompt_id = str(uuid.uuid4()) | |
| # Log the prompt execution into the table "prompts" | |
| conn = None | |
| try: | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Directly insert the prompt execution without checking for duplicates | |
| cur.execute(""" | |
| INSERT INTO prompts (session_id, prompt_id, step_name, sub_step_name, prompt, created_at) | |
| VALUES (%s, %s, %s, %s, %s, %s) | |
| """, (session_id, prompt_id, step_name, sub_step_name, prompt_text, created_at)) | |
| conn.commit() | |
| return prompt_id | |
| except Exception as e: | |
| print(f"Error logging prompt execution: {str(e)}") | |
| return None | |
| finally: | |
| if conn: | |
| conn.close() | |
| def log_prompt_execution_output(prompt_id, output): | |
| """Save prompt execution output to the database""" | |
| created_at = datetime.datetime.now() | |
| output_id = str(uuid.uuid4()) | |
| try: | |
| # Establish a database connection | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Directly insert the output into the outputs table without checking for duplicates | |
| cur.execute(""" | |
| INSERT INTO outputs (output_id, prompt_id, output, created_at) | |
| VALUES (%s, %s, %s, %s) | |
| """, (output_id, prompt_id, output, created_at)) | |
| # Commit the transaction | |
| conn.commit() | |
| return output_id | |
| except Exception as e: | |
| print(f"Error logging prompt execution output: {str(e)}") | |
| return None | |
| finally: | |
| if conn: | |
| conn.close() | |
| def update_prompt_execution_output(output_id, output): | |
| """Update prompt execution output in the database based on output_id""" | |
| created_at = datetime.datetime.now() | |
| try: | |
| # Establish a database connection | |
| conn = get_db_connection() | |
| cur = conn.cursor() | |
| # Convert output to a standard Python type if it's a numpy type | |
| if isinstance(output, (np.integer, float)): | |
| output = int(output) if isinstance(output, np.integer) else float(output) | |
| # Update the existing output for the given output_id | |
| cur.execute(""" | |
| UPDATE outputs | |
| SET output = %s, created_at = %s | |
| WHERE output_id = %s | |
| """, (output, created_at, output_id)) | |
| # Commit the transaction | |
| conn.commit() | |
| print(f"Successfully updated output for {output_id}") | |
| except Exception as e: | |
| print(f"Error updating prompt execution output: {str(e)}") | |
| return None | |
| finally: | |
| if conn: | |
| conn.close() | |
| def log_prompt(prompt_name: str, prompt_description: str, prompt: str, output: Any) -> None: | |
| """Log prompt execution and its output""" | |
| try: | |
| prompt_id = log_prompt_execution(prompt_name, prompt_description, prompt) | |
| output_id = log_prompt_execution_output(prompt_id, output) | |
| print(f"Succesfully logged {prompt_name}: as {output_id}") | |
| return output_id | |
| except Exception as e: | |
| print(f"Error logging {prompt_name} generation: {str(e)}") | |