import os import logging from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from pydantic import BaseModel import openai import numpy as np import pandas as pd from sklearn.linear_model import LinearRegression import requests # Load environment variables from the .env file load_dotenv() # Set the OpenAI API key openai.api_key = os.getenv("OPENAI_API_KEY") # File paths predefined_constructs_file = "/app/predefined_constructs.txt" full_matrix_path = "/app/no_na_matrix.csv" # Logging configuration logging.basicConfig( level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler("app.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # FastAPI application app = FastAPI() # Root endpoint for health check @app.get("/") def read_root(): return {"message": "API is up and running!"} # Request schema class InferenceRequest(BaseModel): query: str # Functions for inference def read_predefined_constructs(file_path): with open(file_path, 'r') as file: return file.read() def call_openai_chat(model, messages): url = "https://api.openai.com/v1/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {openai.api_key}", } data = {"model": model, "messages": messages} response = requests.post(url, headers=headers, json=data) response.raise_for_status() return response.json()["choices"][0]["message"]["content"] def clean_predefined_constructs_and_matrix(predefined_constructs_path, full_matrix_path): with open(predefined_constructs_path, "r") as f: predefined_constructs = [ line.split('\t')[0].strip().lower() for line in f.readlines() if line.strip() ] predefined_constructs = [ c.replace('-', ' ').replace('–', ' ').replace('(', '') .replace(')', '').replace('/', ' ').replace(',', ' ') .replace('.', ' ').strip() for c in predefined_constructs if c != "variable" ] full_matrix = pd.read_csv(full_matrix_path, index_col=0) full_matrix.index = full_matrix.index.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True) full_matrix.columns = full_matrix.columns.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True) valid_constructs = set(predefined_constructs) matrix_rows = set(full_matrix.index) matrix_columns = set(full_matrix.columns) invalid_rows = matrix_rows - valid_constructs invalid_columns = matrix_columns - valid_constructs full_matrix = full_matrix.drop(index=invalid_rows, errors='ignore') full_matrix = full_matrix.drop(columns=invalid_columns, errors='ignore') missing_constructs = [c for c in predefined_constructs if c not in full_matrix.index or c not in full_matrix.columns] if missing_constructs: raise ValueError(f"Missing constructs in the correlation matrix: {missing_constructs}") return predefined_constructs, full_matrix def construct_analyzer(user_prompt, predefined_constructs): predefined_constructs_prompt = "\n".join(predefined_constructs) prompt_text = ( f"Here is the user's prompt: '{user_prompt}'.\n\n" "Your role is to identify 10 relevant constructs from the provided list. " "Pick one variable as the dependent variable. Format output as:\n" "var1|var2|...|var10;dependent_var\n\n" f"{predefined_constructs_prompt}" ) messages = [ {"role": "system", "content": "You are a construct analyzer."}, {"role": "user", "content": prompt_text}, ] constructs_with_dependent = call_openai_chat(model="gpt-4", messages=messages).strip() constructs_list, dependent_variable = constructs_with_dependent.split(';') constructs = constructs_list.split('|') if dependent_variable not in constructs: constructs.append(dependent_variable) return constructs, dependent_variable def regression_analyst(correlation_matrix, dependent_variable): normalized_dependent = dependent_variable.strip().lower() normalized_columns = [col.strip().lower() for col in correlation_matrix.columns] if normalized_dependent not in normalized_columns: from difflib import get_close_matches suggestions = get_close_matches(normalized_dependent, normalized_columns, n=3, cutoff=0.6) raise KeyError( f"Dependent variable '{dependent_variable}' not found. Suggestions: {suggestions}" ) original_dependent = correlation_matrix.columns[normalized_columns.index(normalized_dependent)] independent_vars = [var for var in correlation_matrix.columns if var != original_dependent] synthetic_data = np.random.multivariate_normal( mean=np.zeros(len(correlation_matrix)), cov=correlation_matrix.to_numpy(), size=1000 ) synthetic_data = pd.DataFrame(synthetic_data, columns=correlation_matrix.columns) X = synthetic_data[independent_vars] y = synthetic_data[original_dependent] model = LinearRegression() model.fit(X, y) beta_weights = model.coef_ return independent_vars, beta_weights def generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights): beta_details = "\n".join([f"{var}: {round(beta, 4)}" for var, beta in zip(independent_vars, beta_weights)]) prompt = ( f"User query: '{user_query}'\n" f"Regression Equation: {equation}\n" f"Variables and Beta Weights:\n{beta_details}\n\n" "Provide actionable insights based on this analysis." ) messages = [ {"role": "system", "content": "You are a skilled analyst interpreting regression results."}, {"role": "user", "content": prompt}, ] return call_openai_chat(model="gpt-4", messages=messages).strip() def run_inference_pipeline(user_query, predefined_constructs, correlation_matrix): constructs_raw, dependent_variable = construct_analyzer(user_query, predefined_constructs) constructs_list = constructs_raw.split('|') if dependent_variable not in constructs_list: constructs_list.append(dependent_variable) correlation_matrix_filtered = correlation_matrix.loc[constructs_list, constructs_list] independent_vars, beta_weights = regression_analyst(correlation_matrix_filtered, dependent_variable) equation = f"{dependent_variable} = " + " + ".join( [f"{round(beta, 4)}*{var}" for beta, var in zip(beta_weights, independent_vars)] ) inference = generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights) return {"equation": equation, "inference": inference} # API endpoint @app.post("/infer") def infer(request: InferenceRequest): try: predefined_constructs, cleaned_matrix = clean_predefined_constructs_and_matrix(predefined_constructs_file, full_matrix_path) results = run_inference_pipeline(request.query, predefined_constructs, cleaned_matrix) return results except Exception as e: logger.error(f"Error during inference: {e}") raise HTTPException(status_code=500, detail=str(e))