APIs / app.py
hfariborzi's picture
Update app.py
2766617 verified
import os
import logging
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import openai
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
import requests
# Load environment variables from the .env file
load_dotenv()
# Set the OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
# File paths
predefined_constructs_file = "/app/predefined_constructs.txt"
full_matrix_path = "/app/no_na_matrix.csv"
# Logging configuration
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# FastAPI application
app = FastAPI()
# Root endpoint for health check
@app.get("/")
def read_root():
return {"message": "API is up and running!"}
# Request schema
class InferenceRequest(BaseModel):
query: str
# Functions for inference
def read_predefined_constructs(file_path):
with open(file_path, 'r') as file:
return file.read()
def call_openai_chat(model, messages):
url = "https://api.openai.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
}
data = {"model": model, "messages": messages}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
def clean_predefined_constructs_and_matrix(predefined_constructs_path, full_matrix_path):
with open(predefined_constructs_path, "r") as f:
predefined_constructs = [
line.split('\t')[0].strip().lower()
for line in f.readlines() if line.strip()
]
predefined_constructs = [
c.replace('-', ' ').replace('–', ' ').replace('(', '')
.replace(')', '').replace('/', ' ').replace(',', ' ')
.replace('.', ' ').strip()
for c in predefined_constructs if c != "variable"
]
full_matrix = pd.read_csv(full_matrix_path, index_col=0)
full_matrix.index = full_matrix.index.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True)
full_matrix.columns = full_matrix.columns.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True)
valid_constructs = set(predefined_constructs)
matrix_rows = set(full_matrix.index)
matrix_columns = set(full_matrix.columns)
invalid_rows = matrix_rows - valid_constructs
invalid_columns = matrix_columns - valid_constructs
full_matrix = full_matrix.drop(index=invalid_rows, errors='ignore')
full_matrix = full_matrix.drop(columns=invalid_columns, errors='ignore')
missing_constructs = [c for c in predefined_constructs if c not in full_matrix.index or c not in full_matrix.columns]
if missing_constructs:
raise ValueError(f"Missing constructs in the correlation matrix: {missing_constructs}")
return predefined_constructs, full_matrix
def construct_analyzer(user_prompt, predefined_constructs):
predefined_constructs_prompt = "\n".join(predefined_constructs)
prompt_text = (
f"Here is the user's prompt: '{user_prompt}'.\n\n"
"Your role is to identify 10 relevant constructs from the provided list. "
"Pick one variable as the dependent variable. Format output as:\n"
"var1|var2|...|var10;dependent_var\n\n"
f"{predefined_constructs_prompt}"
)
messages = [
{"role": "system", "content": "You are a construct analyzer."},
{"role": "user", "content": prompt_text},
]
constructs_with_dependent = call_openai_chat(model="gpt-4", messages=messages).strip()
constructs_list, dependent_variable = constructs_with_dependent.split(';')
constructs = constructs_list.split('|')
if dependent_variable not in constructs:
constructs.append(dependent_variable)
return constructs, dependent_variable
def regression_analyst(correlation_matrix, dependent_variable):
normalized_dependent = dependent_variable.strip().lower()
normalized_columns = [col.strip().lower() for col in correlation_matrix.columns]
if normalized_dependent not in normalized_columns:
from difflib import get_close_matches
suggestions = get_close_matches(normalized_dependent, normalized_columns, n=3, cutoff=0.6)
raise KeyError(
f"Dependent variable '{dependent_variable}' not found. Suggestions: {suggestions}"
)
original_dependent = correlation_matrix.columns[normalized_columns.index(normalized_dependent)]
independent_vars = [var for var in correlation_matrix.columns if var != original_dependent]
synthetic_data = np.random.multivariate_normal(
mean=np.zeros(len(correlation_matrix)),
cov=correlation_matrix.to_numpy(),
size=1000
)
synthetic_data = pd.DataFrame(synthetic_data, columns=correlation_matrix.columns)
X = synthetic_data[independent_vars]
y = synthetic_data[original_dependent]
model = LinearRegression()
model.fit(X, y)
beta_weights = model.coef_
return independent_vars, beta_weights
def generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights):
beta_details = "\n".join([f"{var}: {round(beta, 4)}" for var, beta in zip(independent_vars, beta_weights)])
prompt = (
f"User query: '{user_query}'\n"
f"Regression Equation: {equation}\n"
f"Variables and Beta Weights:\n{beta_details}\n\n"
"Provide actionable insights based on this analysis."
)
messages = [
{"role": "system", "content": "You are a skilled analyst interpreting regression results."},
{"role": "user", "content": prompt},
]
return call_openai_chat(model="gpt-4", messages=messages).strip()
def run_inference_pipeline(user_query, predefined_constructs, correlation_matrix):
constructs_raw, dependent_variable = construct_analyzer(user_query, predefined_constructs)
constructs_list = constructs_raw.split('|')
if dependent_variable not in constructs_list:
constructs_list.append(dependent_variable)
correlation_matrix_filtered = correlation_matrix.loc[constructs_list, constructs_list]
independent_vars, beta_weights = regression_analyst(correlation_matrix_filtered, dependent_variable)
equation = f"{dependent_variable} = " + " + ".join(
[f"{round(beta, 4)}*{var}" for beta, var in zip(beta_weights, independent_vars)]
)
inference = generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights)
return {"equation": equation, "inference": inference}
# API endpoint
@app.post("/infer")
def infer(request: InferenceRequest):
try:
predefined_constructs, cleaned_matrix = clean_predefined_constructs_and_matrix(predefined_constructs_file, full_matrix_path)
results = run_inference_pipeline(request.query, predefined_constructs, cleaned_matrix)
return results
except Exception as e:
logger.error(f"Error during inference: {e}")
raise HTTPException(status_code=500, detail=str(e))