yoursdvniel commited on
Commit
2e9353d
·
verified ·
1 Parent(s): 643ab3a

Added new endpoint for Income/Expenses prediction (to be updated along the way)

Browse files
Files changed (1) hide show
  1. main.py +85 -100
main.py CHANGED
@@ -4,152 +4,137 @@ import os
4
  import io
5
  from flask import Flask, request, jsonify
6
  from flask_cors import CORS, cross_origin
7
- import pandas as pd
8
  import firebase_admin
9
- from firebase_admin import credentials, firestore, auth
10
- import requests
11
- import pandas as pd
12
- from datetime import datetime
13
- import os
14
- from pandasai.llm import GoogleGemini
15
- from pandasai import SmartDataframe, SmartDatalake
16
-
17
- from pandasai.responses.response_parser import ResponseParser
18
- import matplotlib.pyplot as plt
19
- from wordcloud import WordCloud
20
- import random
21
  from langchain.prompts import PromptTemplate
22
  from langchain.chains import LLMChain
23
- from dotenv import load_dotenv
24
- import json
25
-
26
-
27
- from dotenv import load_dotenv
28
 
29
  load_dotenv()
30
 
31
-
32
-
33
-
34
  app = Flask(__name__)
35
  cors = CORS(app)
36
 
37
- class FlaskResponse(ResponseParser):
38
- def __init__(self, context) -> None:
39
- super().__init__(context)
40
-
41
- def format_dataframe(self, result):
42
- return result['value'].to_html()
43
-
44
- def format_plot(self, result):
45
- # Save the plot using savefig
46
- try:
47
-
48
- img_path = result['value']
49
-
50
-
51
- except ValueError:
52
- img_path = str(result['value'])
53
- print("value error!", img_path)
54
-
55
- print("response_class_path:", img_path)
56
- return img_path
57
-
58
- def format_other(self, result):
59
- return str(result['value'])
60
-
61
- gemini_api_key = os.environ['Gemini']
62
-
63
- @app.route("/", methods=["GET"])
64
- def home():
65
-
66
- return "Hello Qx!"
67
-
68
-
69
-
70
- llm = GoogleGemini(api_key=gemini_api_key)
71
-
72
- llm2 = ChatGoogleGenerativeAI(model='gemini-1.5-flash-001', temperature=0.1)
73
-
74
  # Initialize Firebase app
75
  if not firebase_admin._apps:
76
-
77
  cred = credentials.Certificate("quant-app-99d09-firebase-adminsdk-6prb1-37f34e1c91.json")
78
  firebase_admin.initialize_app(cred)
79
 
80
  db = firestore.client()
81
 
 
 
 
82
 
 
 
83
 
 
 
 
 
 
 
 
 
84
 
 
 
 
 
 
 
 
85
  @app.route("/predict", methods=["POST"])
86
  @cross_origin()
87
  def bot():
88
-
89
-
90
  user_id = request.json.get("user_id")
91
  user_question = request.json.get("user_question")
92
- load_dotenv()
93
-
94
- inventory_ref = db.collection("system_users").document(user_id).collection('inventory')
95
 
 
96
  tasks_ref = db.collection("system_users").document(user_id).collection('tasks')
97
-
98
  transactions_ref = db.collection("system_users").document(user_id).collection('transactions')
99
 
100
- inventory_list = []
101
- for doc in inventory_ref.stream():
102
- a = doc.to_dict()
103
- inventory_list.append(a)
104
-
105
- tasks_list = []
106
- for doc in tasks_ref.stream():
107
- a = doc.to_dict()
108
- tasks_list.append(a)
109
-
110
- transactions_list = []
111
- for doc in transactions_ref.stream():
112
- a = doc.to_dict()
113
- transactions_list.append(a)
114
-
115
  inventory_df = pd.DataFrame(inventory_list)
116
  transactions_df = pd.DataFrame(transactions_list)
117
  tasks_df = pd.DataFrame(tasks_list)
118
 
119
- lake = SmartDatalake([inventory_df, transactions_df, tasks_df], config={"llm":llm, "response_parser":FlaskResponse, "enable_cache": False, "save_logs":False})
120
  response = lake.chat(user_question)
121
  print(user_question)
122
 
123
- resp = str(response)
124
-
125
- return jsonify(resp)
126
 
 
127
  @app.route("/mrec", methods=["POST"])
128
  @cross_origin()
129
  def marketing_rec():
130
  user_id = request.json.get("user_id")
131
 
132
  transactions_ref = db.collection("system_users").document(user_id).collection('transactions')
133
- transactions_list = []
134
- for doc in transactions_ref.stream():
135
- a = doc.to_dict()
136
- transactions_list.append(a)
137
 
138
  transactions_df = pd.DataFrame(transactions_list)
139
- # Set up a prompt template
140
- prompt = PromptTemplate.from_template('You are a business analyst.In the fewest words possible, write a brief analysis and some very brief marketing tips suitable for a small business with this transactions data {data_frame}')
141
-
142
- # Create a chain that utilizes both the LLM and the prompt template
143
- chain = LLMChain(llm=llm2, prompt=prompt, verbose=True)
144
- data_frame = transactions_df
145
- response = chain.invoke(input=data_frame)
146
  print(response)
147
 
148
- resp = str(response['text'])
149
-
150
- return jsonify(resp)
151
-
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  if __name__ == "__main__":
155
- app.run(debug=True,host="0.0.0.0", port=7860)
 
4
  import io
5
  from flask import Flask, request, jsonify
6
  from flask_cors import CORS, cross_origin
 
7
  import firebase_admin
8
+ from firebase_admin import credentials, firestore
9
+ from dotenv import load_dotenv
10
+ from pandasai import SmartDatalake
11
+ from pandasai.responses.response_parser import ResponseParser
 
 
 
 
 
 
 
 
12
  from langchain.prompts import PromptTemplate
13
  from langchain.chains import LLMChain
14
+ from datetime import datetime
15
+ import matplotlib.pyplot as plt
16
+ from statsmodels.tsa.holtwinters import ExponentialSmoothing
17
+ from fbprophet import Prophet
 
18
 
19
  load_dotenv()
20
 
 
 
 
21
  app = Flask(__name__)
22
  cors = CORS(app)
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Initialize Firebase app
25
  if not firebase_admin._apps:
 
26
  cred = credentials.Certificate("quant-app-99d09-firebase-adminsdk-6prb1-37f34e1c91.json")
27
  firebase_admin.initialize_app(cred)
28
 
29
  db = firestore.client()
30
 
31
+ class FlaskResponse(ResponseParser):
32
+ def __init__(self, context) -> None:
33
+ super().__init__(context)
34
 
35
+ def format_dataframe(self, result):
36
+ return result['value'].to_html()
37
 
38
+ def format_plot(self, result):
39
+ try:
40
+ img_path = result['value']
41
+ except ValueError:
42
+ img_path = str(result['value'])
43
+ print("ValueError:", img_path)
44
+ print("response_class_path:", img_path)
45
+ return img_path
46
 
47
+ def format_other(self, result):
48
+ return str(result['value'])
49
+
50
+ gemini_api_key = os.getenv('Gemini')
51
+ llm = ChatGoogleGenerativeAI(api_key=gemini_api_key, model='gemini-1.5-flash-001', temperature=0.1)
52
+
53
+ # Endpoint for handling questions to the bot using transaction data
54
  @app.route("/predict", methods=["POST"])
55
  @cross_origin()
56
  def bot():
 
 
57
  user_id = request.json.get("user_id")
58
  user_question = request.json.get("user_question")
 
 
 
59
 
60
+ inventory_ref = db.collection("system_users").document(user_id).collection('inventory')
61
  tasks_ref = db.collection("system_users").document(user_id).collection('tasks')
 
62
  transactions_ref = db.collection("system_users").document(user_id).collection('transactions')
63
 
64
+ inventory_list = [doc.to_dict() for doc in inventory_ref.stream()]
65
+ tasks_list = [doc.to_dict() for doc in tasks_ref.stream()]
66
+ transactions_list = [doc.to_dict() for doc in transactions_ref.stream()]
67
+
 
 
 
 
 
 
 
 
 
 
 
68
  inventory_df = pd.DataFrame(inventory_list)
69
  transactions_df = pd.DataFrame(transactions_list)
70
  tasks_df = pd.DataFrame(tasks_list)
71
 
72
+ lake = SmartDatalake([inventory_df, transactions_df, tasks_df], config={"llm": llm, "response_parser": FlaskResponse, "enable_cache": False, "save_logs": False})
73
  response = lake.chat(user_question)
74
  print(user_question)
75
 
76
+ return jsonify(str(response))
 
 
77
 
78
+ # Marketing recommendations endpoint
79
  @app.route("/mrec", methods=["POST"])
80
  @cross_origin()
81
  def marketing_rec():
82
  user_id = request.json.get("user_id")
83
 
84
  transactions_ref = db.collection("system_users").document(user_id).collection('transactions')
85
+ transactions_list = [doc.to_dict() for doc in transactions_ref.stream()]
 
 
 
86
 
87
  transactions_df = pd.DataFrame(transactions_list)
88
+ prompt = PromptTemplate.from_template('You are a business analyst. Write a brief analysis and marketing tips for a small business using this transactions data {data_frame}')
89
+ chain = LLMChain(llm=llm, prompt=prompt, verbose=True)
90
+
91
+ response = chain.invoke(input=transactions_df)
 
 
 
92
  print(response)
93
 
94
+ return jsonify(str(response['text']))
 
 
 
95
 
96
+ # Income/Expenses Prediction endpoint
97
+ @app.route("/predict_revenue", methods=["POST"])
98
+ @cross_origin()
99
+ def predict_revenue():
100
+ request_data = request.json
101
+ user_id = request_data.get("user_id")
102
+ interval = request_data.get("interval", 30)
103
+ transaction_type = request_data.get("transaction_type", "Income")
104
+
105
+ # Fetch transaction data based on user and transaction type
106
+ transactions_ref = db.collection("system_users").document(user_id).collection("transactions")
107
+ query = transactions_ref.where("transactionType", "==", transaction_type).stream()
108
+
109
+ data = []
110
+ for doc in query:
111
+ transaction = doc.to_dict()
112
+ data.append({"date": transaction["date"].to_date(), "amountDue": transaction["amountDue"]})
113
+
114
+ # Create DataFrame from transaction data
115
+ df = pd.DataFrame(data)
116
+ df = df.sort_values("date").set_index("date")
117
+ df = df.resample("D").sum().reset_index() # Resample daily to ensure regular intervals
118
+ df.columns = ["ds", "y"] # Rename columns for Prophet (ds: date, y: target)
119
+
120
+ # Check if there's enough data to train the model
121
+ if df.shape[0] < 10:
122
+ return jsonify({"error": "Not enough data for prediction"})
123
+
124
+ # Initialize and fit the Prophet model
125
+ model = Prophet(daily_seasonality=True, yearly_seasonality=True)
126
+ model.fit(df)
127
+
128
+ # dataframe for future predictions
129
+ future_dates = model.make_future_dataframe(periods=interval)
130
+ forecast = model.predict(future_dates)
131
+
132
+ # Extract the forecast for the requested interval
133
+ forecast_data = forecast[['ds', 'yhat']].tail(interval)
134
+ predictions = forecast_data['yhat'].tolist()
135
+
136
+ # Return predictions in JSON format
137
+ return jsonify({"predictedData": predictions})
138
 
139
  if __name__ == "__main__":
140
+ app.run(debug=True, host="0.0.0.0", port=7860)