Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import openai | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.linear_model import LinearRegression | |
| import requests | |
| # Load environment variables from the .env file | |
| load_dotenv() | |
| # Set the OpenAI API key | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # File paths | |
| predefined_constructs_file = "/app/predefined_constructs.txt" | |
| full_matrix_path = "/app/no_na_matrix.csv" | |
| # Logging configuration | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| handlers=[ | |
| logging.FileHandler("app.log"), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # FastAPI application | |
| app = FastAPI() | |
| # Root endpoint for health check | |
| def read_root(): | |
| return {"message": "API is up and running!"} | |
| # Request schema | |
| class InferenceRequest(BaseModel): | |
| query: str | |
| # Functions for inference | |
| def read_predefined_constructs(file_path): | |
| with open(file_path, 'r') as file: | |
| return file.read() | |
| def call_openai_chat(model, messages): | |
| url = "https://api.openai.com/v1/chat/completions" | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {openai.api_key}", | |
| } | |
| data = {"model": model, "messages": messages} | |
| response = requests.post(url, headers=headers, json=data) | |
| response.raise_for_status() | |
| return response.json()["choices"][0]["message"]["content"] | |
| def clean_predefined_constructs_and_matrix(predefined_constructs_path, full_matrix_path): | |
| with open(predefined_constructs_path, "r") as f: | |
| predefined_constructs = [ | |
| line.split('\t')[0].strip().lower() | |
| for line in f.readlines() if line.strip() | |
| ] | |
| predefined_constructs = [ | |
| c.replace('-', ' ').replace('–', ' ').replace('(', '') | |
| .replace(')', '').replace('/', ' ').replace(',', ' ') | |
| .replace('.', ' ').strip() | |
| for c in predefined_constructs if c != "variable" | |
| ] | |
| full_matrix = pd.read_csv(full_matrix_path, index_col=0) | |
| full_matrix.index = full_matrix.index.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True) | |
| full_matrix.columns = full_matrix.columns.str.strip().str.lower().str.replace('[^a-z0-9 ]', '', regex=True) | |
| valid_constructs = set(predefined_constructs) | |
| matrix_rows = set(full_matrix.index) | |
| matrix_columns = set(full_matrix.columns) | |
| invalid_rows = matrix_rows - valid_constructs | |
| invalid_columns = matrix_columns - valid_constructs | |
| full_matrix = full_matrix.drop(index=invalid_rows, errors='ignore') | |
| full_matrix = full_matrix.drop(columns=invalid_columns, errors='ignore') | |
| missing_constructs = [c for c in predefined_constructs if c not in full_matrix.index or c not in full_matrix.columns] | |
| if missing_constructs: | |
| raise ValueError(f"Missing constructs in the correlation matrix: {missing_constructs}") | |
| return predefined_constructs, full_matrix | |
| def construct_analyzer(user_prompt, predefined_constructs): | |
| predefined_constructs_prompt = "\n".join(predefined_constructs) | |
| prompt_text = ( | |
| f"Here is the user's prompt: '{user_prompt}'.\n\n" | |
| "Your role is to identify 10 relevant constructs from the provided list. " | |
| "Pick one variable as the dependent variable. Format output as:\n" | |
| "var1|var2|...|var10;dependent_var\n\n" | |
| f"{predefined_constructs_prompt}" | |
| ) | |
| messages = [ | |
| {"role": "system", "content": "You are a construct analyzer."}, | |
| {"role": "user", "content": prompt_text}, | |
| ] | |
| constructs_with_dependent = call_openai_chat(model="gpt-4", messages=messages).strip() | |
| constructs_list, dependent_variable = constructs_with_dependent.split(';') | |
| constructs = constructs_list.split('|') | |
| if dependent_variable not in constructs: | |
| constructs.append(dependent_variable) | |
| return constructs, dependent_variable | |
| def regression_analyst(correlation_matrix, dependent_variable): | |
| normalized_dependent = dependent_variable.strip().lower() | |
| normalized_columns = [col.strip().lower() for col in correlation_matrix.columns] | |
| if normalized_dependent not in normalized_columns: | |
| from difflib import get_close_matches | |
| suggestions = get_close_matches(normalized_dependent, normalized_columns, n=3, cutoff=0.6) | |
| raise KeyError( | |
| f"Dependent variable '{dependent_variable}' not found. Suggestions: {suggestions}" | |
| ) | |
| original_dependent = correlation_matrix.columns[normalized_columns.index(normalized_dependent)] | |
| independent_vars = [var for var in correlation_matrix.columns if var != original_dependent] | |
| synthetic_data = np.random.multivariate_normal( | |
| mean=np.zeros(len(correlation_matrix)), | |
| cov=correlation_matrix.to_numpy(), | |
| size=1000 | |
| ) | |
| synthetic_data = pd.DataFrame(synthetic_data, columns=correlation_matrix.columns) | |
| X = synthetic_data[independent_vars] | |
| y = synthetic_data[original_dependent] | |
| model = LinearRegression() | |
| model.fit(X, y) | |
| beta_weights = model.coef_ | |
| return independent_vars, beta_weights | |
| def generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights): | |
| beta_details = "\n".join([f"{var}: {round(beta, 4)}" for var, beta in zip(independent_vars, beta_weights)]) | |
| prompt = ( | |
| f"User query: '{user_query}'\n" | |
| f"Regression Equation: {equation}\n" | |
| f"Variables and Beta Weights:\n{beta_details}\n\n" | |
| "Provide actionable insights based on this analysis." | |
| ) | |
| messages = [ | |
| {"role": "system", "content": "You are a skilled analyst interpreting regression results."}, | |
| {"role": "user", "content": prompt}, | |
| ] | |
| return call_openai_chat(model="gpt-4", messages=messages).strip() | |
| def run_inference_pipeline(user_query, predefined_constructs, correlation_matrix): | |
| constructs_raw, dependent_variable = construct_analyzer(user_query, predefined_constructs) | |
| constructs_list = constructs_raw.split('|') | |
| if dependent_variable not in constructs_list: | |
| constructs_list.append(dependent_variable) | |
| correlation_matrix_filtered = correlation_matrix.loc[constructs_list, constructs_list] | |
| independent_vars, beta_weights = regression_analyst(correlation_matrix_filtered, dependent_variable) | |
| equation = f"{dependent_variable} = " + " + ".join( | |
| [f"{round(beta, 4)}*{var}" for beta, var in zip(beta_weights, independent_vars)] | |
| ) | |
| inference = generate_inference(user_query, equation, independent_vars, dependent_variable, beta_weights) | |
| return {"equation": equation, "inference": inference} | |
| # API endpoint | |
| def infer(request: InferenceRequest): | |
| try: | |
| predefined_constructs, cleaned_matrix = clean_predefined_constructs_and_matrix(predefined_constructs_file, full_matrix_path) | |
| results = run_inference_pipeline(request.query, predefined_constructs, cleaned_matrix) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error during inference: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |