Spaces:
Runtime error
Runtime error
File size: 7,233 Bytes
5e5dfdc 2766617 5e5dfdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import logging
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import openai
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
import requests
# Load environment variables from the .env file
load_dotenv()
# Set the OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
# File paths
predefined_constructs_file = "/app/predefined_constructs.txt"
full_matrix_path = "/app/no_na_matrix.csv"
# Logging configuration
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# FastAPI application
app = FastAPI()
# Root endpoint for health check
@app.get("/")
def read_root():
return {"message": "API is up and running!"}
# Request schema
class InferenceRequest(BaseModel):
query: str
# Functions for inference
def read_predefined_constructs(file_path):
with open(file_path, 'r') as file:
return file.read()
def call_openai_chat(model, messages):
url = "https://api.openai.com/v1/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
}
data = {"model": model, "messages": messages}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]
def clean_predefined_constructs_and_matrix(predefined_constructs_path, full_matrix_path):
with open(predefined_constructs_path, "r") as f:
predefined_constructs = [
line.split('\t')[0].strip().lower()
for line in f.readlines() if line.strip()
]
predefined_constructs = [
c.replace('-', ' ').replace('–', ' ').replace('(', '')
.replace(')', '').replace('/', ' ').replace(',', ' ')
.replace('.', ' ').strip()
for c in predefined_constructs if c != "variable"
]
full_matrix = pd.read_csv(full_matrix_path, index_col=0)
full_matrix.index = full_matrix.index.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True)
full_matrix.columns = full_matrix.columns.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True)
valid_constructs = set(predefined_constructs)
matrix_rows = set(full_matrix.index)
matrix_columns = set(full_matrix.columns)
invalid_rows = matrix_rows - valid_constructs
invalid_columns = matrix_columns - valid_constructs
full_matrix = full_matrix.drop(index=invalid_rows, errors='ignore')
full_matrix = full_matrix.drop(columns=invalid_columns, errors='ignore')
missing_constructs = [c for c in predefined_constructs if c not in full_matrix.index or c not in full_matrix.columns]
if missing_constructs:
raise ValueError(f"Missing constructs in the correlation matrix: {missing_constructs}")
return predefined_constructs, full_matrix
def construct_analyzer(user_prompt, predefined_constructs):
predefined_constructs_prompt = "\n".join(predefined_constructs)
prompt_text = (
f"Here is the user's prompt: '{user_prompt}'.\n\n"
"Your role is to identify 10 relevant constructs from the provided list. "
"Pick one variable as the dependent variable. Format output as:\n"
"var1|var2|...|var10;dependent_var\n\n"
f"{predefined_constructs_prompt}"
)
messages = [
{"role": "system", "content": "You are a construct analyzer."},
{"role": "user", "content": prompt_text},
]
constructs_with_dependent = call_openai_chat(model="gpt-4", messages=messages).strip()
constructs_list, dependent_variable = constructs_with_dependent.split(';')
constructs = constructs_list.split('|')
if dependent_variable not in constructs:
constructs.append(dependent_variable)
return constructs, dependent_variable
def regression_analyst(correlation_matrix, dependent_variable):
normalized_dependent = dependent_variable.strip().lower()
normalized_columns = [col.strip().lower() for col in correlation_matrix.columns]
if normalized_dependent not in normalized_columns:
from difflib import get_close_matches
suggestions = get_close_matches(normalized_dependent, normalized_columns, n=3, cutoff=0.6)
raise KeyError(
f"Dependent variable '{dependent_variable}' not found. Suggestions: {suggestions}"
)
original_dependent = correlation_matrix.columns[normalized_columns.index(normalized_dependent)]
independent_vars = [var for var in correlation_matrix.columns if var != original_dependent]
synthetic_data = np.random.multivariate_normal(
mean=np.zeros(len(correlation_matrix)),
cov=correlation_matrix.to_numpy(),
size=1000
)
synthetic_data = pd.DataFrame(synthetic_data, columns=correlation_matrix.columns)
X = synthetic_data[independent_vars]
y = synthetic_data[original_dependent]
model = LinearRegression()
model.fit(X, y)
beta_weights = model.coef_
return independent_vars, beta_weights
def generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights):
beta_details = "\n".join([f"{var}: {round(beta, 4)}" for var, beta in zip(independent_vars, beta_weights)])
prompt = (
f"User query: '{user_query}'\n"
f"Regression Equation: {equation}\n"
f"Variables and Beta Weights:\n{beta_details}\n\n"
"Provide actionable insights based on this analysis."
)
messages = [
{"role": "system", "content": "You are a skilled analyst interpreting regression results."},
{"role": "user", "content": prompt},
]
return call_openai_chat(model="gpt-4", messages=messages).strip()
def run_inference_pipeline(user_query, predefined_constructs, correlation_matrix):
constructs_raw, dependent_variable = construct_analyzer(user_query, predefined_constructs)
constructs_list = constructs_raw.split('|')
if dependent_variable not in constructs_list:
constructs_list.append(dependent_variable)
correlation_matrix_filtered = correlation_matrix.loc[constructs_list, constructs_list]
independent_vars, beta_weights = regression_analyst(correlation_matrix_filtered, dependent_variable)
equation = f"{dependent_variable} = " + " + ".join(
[f"{round(beta, 4)}*{var}" for beta, var in zip(beta_weights, independent_vars)]
)
inference = generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights)
return {"equation": equation, "inference": inference}
# API endpoint
@app.post("/infer")
def infer(request: InferenceRequest):
try:
predefined_constructs, cleaned_matrix = clean_predefined_constructs_and_matrix(predefined_constructs_file, full_matrix_path)
results = run_inference_pipeline(request.query, predefined_constructs, cleaned_matrix)
return results
except Exception as e:
logger.error(f"Error during inference: {e}")
raise HTTPException(status_code=500, detail=str(e))
|