ABAO77 commited on
Commit
9173a54
·
verified ·
1 Parent(s): cfb3c58

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +88 -40
  2. database.py +20 -5
  3. get_destinations.py +22 -5
app.py CHANGED
@@ -1,22 +1,28 @@
1
  import os
2
  from typing import Any, Dict, List
3
  from dotenv import load_dotenv
 
4
  load_dotenv()
5
  import uvicorn
6
  from fastapi import APIRouter, FastAPI, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
  from model_predict_onnx import onnx_predictor
10
- from user_weights import (get_all_users, get_user_metadata,
11
- get_user_weights,
12
- track_question_tags,
13
- update_user_metadata,
14
- update_user_weights,
15
- update_weights_from_feedback,
16
- update_weights_from_query)
17
- from get_destinations import (get_destinations_list,get_question_vector)
 
 
 
18
  from get_default_weight import feature_names, weights_bias_vector
19
  from database import db
 
 
20
 
21
  # Define request models
22
  class WeightUpdateRequest(BaseModel):
@@ -24,13 +30,16 @@ class WeightUpdateRequest(BaseModel):
24
  new_weights: List[float]
25
  metadata: Dict[str, Any] = {}
26
 
 
27
  class FeedbackRequest(BaseModel):
28
  destination_id: int
29
  tag_id: int
30
  rating: int # 1-5 stars
31
 
 
32
  router = APIRouter(prefix="/model", tags=["Model"])
33
 
 
34
  @router.get("/get_question_tags/{question}")
35
  async def get_question_tags(question: str):
36
  # Get the prediction
@@ -41,8 +50,9 @@ async def get_question_tags(question: str):
41
  print("Predicted Tags:", predicted_tags)
42
  return {"question_tags": predicted_tags}
43
 
 
44
  @router.get("/get_destinations_list/{question_tags}/{top_k}")
45
- async def get_destinations_list_api(question_tags: str, top_k:str):
46
  # Get the prediction
47
  question_vector = get_question_vector(question_tags)
48
  destinations_list = get_destinations_list(question_vector, int(top_k))
@@ -50,6 +60,22 @@ async def get_destinations_list_api(question_tags: str, top_k:str):
50
  print("destinations_list:", destinations_list)
51
  return {"destinations_list": destinations_list}
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @router.get("/get_destinations_list_by_question/{question}/{top_k}")
54
  async def get_destinations_list_api(question: str, top_k: str):
55
  # Get the prediction
@@ -66,16 +92,17 @@ async def get_destinations_list_api(question: str, top_k: str):
66
  print("destinations_list:", destinations_list)
67
  return {"destinations_list": destinations_list}
68
 
 
69
  @router.get("/get_destinations_list_by_question/{question}/{top_k}/{user_id}")
70
  def get_destinations_list_with_user_api(question: str, top_k: str, user_id: str):
71
  """
72
  Get a list of destinations based on a question and user-specific weights.
73
-
74
  Parameters:
75
  question (str): The question to get destinations for.
76
  top_k (str): The number of destinations to return.
77
  user_id (str): The ID of the user.
78
-
79
  Returns:
80
  dict: A dictionary containing the list of destinations.
81
  """
@@ -85,13 +112,15 @@ def get_destinations_list_with_user_api(question: str, top_k: str, user_id: str)
85
  # Print the sentence and its predicted tags
86
  print("Sentence:", original_sentence)
87
  print("Predicted Tags:", question_tags)
88
-
89
  # Track the question tags for the user
90
  track_question_tags(user_id, question_tags)
91
-
92
  # Update weights based on query tags
93
- update_weights_from_query(user_id, question_tags, feature_names, weights_bias_vector)
94
-
 
 
95
  # Get the prediction
96
  question_tags_str = " ".join(question_tags)
97
  question_vector = get_question_vector(question_tags_str)
@@ -100,31 +129,34 @@ def get_destinations_list_with_user_api(question: str, top_k: str, user_id: str)
100
  print("destinations_list:", destinations_list)
101
  return {"destinations_list": destinations_list}
102
 
 
103
  @router.get("/users")
104
  def get_users():
105
  """
106
  Get a list of all users.
107
-
108
  Returns:
109
  dict: A dictionary containing the list of users.
110
  """
111
  users = get_all_users()
112
  return {"users": users}
113
 
 
114
  @router.get("/users/{user_id}")
115
  def get_user(user_id: str):
116
  """
117
  Get the metadata for a user.
118
-
119
  Parameters:
120
  user_id (str): The ID of the user.
121
-
122
  Returns:
123
  dict: A dictionary containing the user's metadata.
124
  """
125
  metadata = get_user_metadata(user_id)
126
  return {"metadata": metadata}
127
 
 
128
  @router.get("/users/{user_id}/weights")
129
  def get_user_weights_api(user_id: str):
130
  """
@@ -141,86 +173,100 @@ def get_user_weights_api(user_id: str):
141
  weights_list = weights.tolist() if weights is not None else None
142
  return {"user_id": user_id, "weights": weights_list}
143
 
 
144
  @router.post("/users/{user_id}/weights")
145
  def update_user_weights_api(user_id: str, request: WeightUpdateRequest):
146
  """
147
  Update the weights for a user.
148
-
149
  Parameters:
150
  user_id (str): The ID of the user.
151
  request (WeightUpdateRequest): The request containing the tag indices, new weights, and metadata.
152
-
153
  Returns:
154
  dict: A dictionary indicating whether the update was successful.
155
  """
156
  # Validate the request
157
  if len(request.tag_indices) != len(request.new_weights):
158
- raise HTTPException(status_code=400, detail="Tag indices and new weights must have the same length")
159
-
 
 
 
160
  # Update the weights
161
- success = update_user_weights(user_id, request.tag_indices, request.new_weights, weights_bias_vector)
162
-
 
 
163
  # Update the metadata
164
  if success and request.metadata:
165
  update_user_metadata(user_id, request.metadata)
166
-
167
  return {"success": success}
168
 
 
169
  @router.post("/users/{user_id}/feedback")
170
  def record_user_feedback(user_id: str, request: FeedbackRequest):
171
  """
172
  Record user feedback on a specific tag for a specific destination.
173
-
174
  Parameters:
175
  user_id (str): The ID of the user.
176
  request (FeedbackRequest): The request containing the destination ID, tag ID, and rating.
177
-
178
  Returns:
179
  dict: A dictionary indicating whether the feedback was recorded successfully.
180
  """
181
  # Validate the request
182
  if request.rating < 1 or request.rating > 5:
183
  raise HTTPException(status_code=400, detail="Rating must be between 1 and 5")
184
-
185
  # Update weights based on feedback
186
  success = update_weights_from_feedback(
187
- user_id,
188
- request.destination_id,
189
- request.tag_id,
190
- request.rating,
191
- weights_bias_vector
192
  )
193
-
194
  return {"success": success}
195
 
 
196
  @router.get("/tags")
197
  def get_tags():
198
  """
199
  Get a list of all tags.
200
-
201
  Returns:
202
  dict: A dictionary containing the list of tags.
203
  """
204
- return {"tags": feature_names.tolist()}
 
205
 
206
  app = FastAPI(docs_url="/")
207
  app.add_middleware(
208
  CORSMiddleware,
209
- allow_origins=['*'],
210
  allow_credentials=True,
211
- allow_methods=['*'],
212
- allow_headers=['*'],
213
- expose_headers=['*',]
 
 
214
  )
215
 
216
  app.include_router(router)
217
 
 
218
  @app.on_event("startup")
219
  def startup_event():
220
  """
221
  Connect to the database when the API starts.
222
  """
223
  db.connect()
 
 
224
 
225
  @app.on_event("shutdown")
226
  def shutdown_event():
@@ -228,6 +274,8 @@ def shutdown_event():
228
  Close the database connection when the API shuts down.
229
  """
230
  db.close()
 
 
231
 
232
  if __name__ == "__main__":
233
  uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7880)))
 
1
  import os
2
  from typing import Any, Dict, List
3
  from dotenv import load_dotenv
4
+
5
  load_dotenv()
6
  import uvicorn
7
  from fastapi import APIRouter, FastAPI, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  from model_predict_onnx import onnx_predictor
11
+ from user_weights import (
12
+ get_all_users,
13
+ get_user_metadata,
14
+ get_user_weights,
15
+ track_question_tags,
16
+ update_user_metadata,
17
+ update_user_weights,
18
+ update_weights_from_feedback,
19
+ update_weights_from_query,
20
+ )
21
+ from get_destinations import get_destinations_list, get_question_vector, get_recent_tags
22
  from get_default_weight import feature_names, weights_bias_vector
23
  from database import db
24
+ from loguru import logger
25
+
26
 
27
  # Define request models
28
  class WeightUpdateRequest(BaseModel):
 
30
  new_weights: List[float]
31
  metadata: Dict[str, Any] = {}
32
 
33
+
34
  class FeedbackRequest(BaseModel):
35
  destination_id: int
36
  tag_id: int
37
  rating: int # 1-5 stars
38
 
39
+
40
  router = APIRouter(prefix="/model", tags=["Model"])
41
 
42
+
43
  @router.get("/get_question_tags/{question}")
44
  async def get_question_tags(question: str):
45
  # Get the prediction
 
50
  print("Predicted Tags:", predicted_tags)
51
  return {"question_tags": predicted_tags}
52
 
53
+
54
  @router.get("/get_destinations_list/{question_tags}/{top_k}")
55
+ async def get_destinations_list_api(question_tags: str, top_k: str):
56
  # Get the prediction
57
  question_vector = get_question_vector(question_tags)
58
  destinations_list = get_destinations_list(question_vector, int(top_k))
 
60
  print("destinations_list:", destinations_list)
61
  return {"destinations_list": destinations_list}
62
 
63
+
64
+ @router.get("/get_recommendation_destinations/{user_id}/{top_k}")
65
+ async def get_recommendation_destinations(user_id: str, top_k: str):
66
+ # Get the prediction
67
+ recent_tags = get_recent_tags(user_id)
68
+ question_tags = " ".join(recent_tags)
69
+ question_vector = get_question_vector(question_tags)
70
+ destinations_list = get_destinations_list(question_vector, int(top_k), user_id)
71
+ destination_ids = db.get_destination_ids(destinations_list)
72
+ return {
73
+ "destination_ids": destination_ids,
74
+ "destinations_list:": destinations_list,
75
+ "recent_tags": recent_tags,
76
+ }
77
+
78
+
79
  @router.get("/get_destinations_list_by_question/{question}/{top_k}")
80
  async def get_destinations_list_api(question: str, top_k: str):
81
  # Get the prediction
 
92
  print("destinations_list:", destinations_list)
93
  return {"destinations_list": destinations_list}
94
 
95
+
96
  @router.get("/get_destinations_list_by_question/{question}/{top_k}/{user_id}")
97
  def get_destinations_list_with_user_api(question: str, top_k: str, user_id: str):
98
  """
99
  Get a list of destinations based on a question and user-specific weights.
100
+
101
  Parameters:
102
  question (str): The question to get destinations for.
103
  top_k (str): The number of destinations to return.
104
  user_id (str): The ID of the user.
105
+
106
  Returns:
107
  dict: A dictionary containing the list of destinations.
108
  """
 
112
  # Print the sentence and its predicted tags
113
  print("Sentence:", original_sentence)
114
  print("Predicted Tags:", question_tags)
115
+
116
  # Track the question tags for the user
117
  track_question_tags(user_id, question_tags)
118
+
119
  # Update weights based on query tags
120
+ update_weights_from_query(
121
+ user_id, question_tags, feature_names, weights_bias_vector
122
+ )
123
+
124
  # Get the prediction
125
  question_tags_str = " ".join(question_tags)
126
  question_vector = get_question_vector(question_tags_str)
 
129
  print("destinations_list:", destinations_list)
130
  return {"destinations_list": destinations_list}
131
 
132
+
133
  @router.get("/users")
134
  def get_users():
135
  """
136
  Get a list of all users.
137
+
138
  Returns:
139
  dict: A dictionary containing the list of users.
140
  """
141
  users = get_all_users()
142
  return {"users": users}
143
 
144
+
145
  @router.get("/users/{user_id}")
146
  def get_user(user_id: str):
147
  """
148
  Get the metadata for a user.
149
+
150
  Parameters:
151
  user_id (str): The ID of the user.
152
+
153
  Returns:
154
  dict: A dictionary containing the user's metadata.
155
  """
156
  metadata = get_user_metadata(user_id)
157
  return {"metadata": metadata}
158
 
159
+
160
  @router.get("/users/{user_id}/weights")
161
  def get_user_weights_api(user_id: str):
162
  """
 
173
  weights_list = weights.tolist() if weights is not None else None
174
  return {"user_id": user_id, "weights": weights_list}
175
 
176
+
177
  @router.post("/users/{user_id}/weights")
178
  def update_user_weights_api(user_id: str, request: WeightUpdateRequest):
179
  """
180
  Update the weights for a user.
181
+
182
  Parameters:
183
  user_id (str): The ID of the user.
184
  request (WeightUpdateRequest): The request containing the tag indices, new weights, and metadata.
185
+
186
  Returns:
187
  dict: A dictionary indicating whether the update was successful.
188
  """
189
  # Validate the request
190
  if len(request.tag_indices) != len(request.new_weights):
191
+ raise HTTPException(
192
+ status_code=400,
193
+ detail="Tag indices and new weights must have the same length",
194
+ )
195
+
196
  # Update the weights
197
+ success = update_user_weights(
198
+ user_id, request.tag_indices, request.new_weights, weights_bias_vector
199
+ )
200
+
201
  # Update the metadata
202
  if success and request.metadata:
203
  update_user_metadata(user_id, request.metadata)
204
+
205
  return {"success": success}
206
 
207
+
208
  @router.post("/users/{user_id}/feedback")
209
  def record_user_feedback(user_id: str, request: FeedbackRequest):
210
  """
211
  Record user feedback on a specific tag for a specific destination.
212
+
213
  Parameters:
214
  user_id (str): The ID of the user.
215
  request (FeedbackRequest): The request containing the destination ID, tag ID, and rating.
216
+
217
  Returns:
218
  dict: A dictionary indicating whether the feedback was recorded successfully.
219
  """
220
  # Validate the request
221
  if request.rating < 1 or request.rating > 5:
222
  raise HTTPException(status_code=400, detail="Rating must be between 1 and 5")
223
+
224
  # Update weights based on feedback
225
  success = update_weights_from_feedback(
226
+ user_id,
227
+ request.destination_id,
228
+ request.tag_id,
229
+ request.rating,
230
+ weights_bias_vector,
231
  )
232
+
233
  return {"success": success}
234
 
235
+
236
  @router.get("/tags")
237
  def get_tags():
238
  """
239
  Get a list of all tags.
240
+
241
  Returns:
242
  dict: A dictionary containing the list of tags.
243
  """
244
+ return {"tags": [tag.upper() for tag in feature_names.tolist()]}
245
+
246
 
247
  app = FastAPI(docs_url="/")
248
  app.add_middleware(
249
  CORSMiddleware,
250
+ allow_origins=["*"],
251
  allow_credentials=True,
252
+ allow_methods=["*"],
253
+ allow_headers=["*"],
254
+ expose_headers=[
255
+ "*",
256
+ ],
257
  )
258
 
259
  app.include_router(router)
260
 
261
+
262
  @app.on_event("startup")
263
  def startup_event():
264
  """
265
  Connect to the database when the API starts.
266
  """
267
  db.connect()
268
+ logger.info("Database connected")
269
+
270
 
271
  @app.on_event("shutdown")
272
  def shutdown_event():
 
274
  Close the database connection when the API shuts down.
275
  """
276
  db.close()
277
+ logger.info("Database closed")
278
+
279
 
280
  if __name__ == "__main__":
281
  uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7880)))
database.py CHANGED
@@ -10,10 +10,7 @@ from dotenv import load_dotenv
10
  from bson import ObjectId
11
  from loguru import logger
12
 
13
- load_dotenv()
14
-
15
- print(os.environ.get("MONGODB_URL"))
16
-
17
 
18
  class Database:
19
  def __init__(self):
@@ -27,7 +24,7 @@ class Database:
27
  try:
28
  # Connect to MongoDB
29
  self.client = MongoClient(os.environ.get("MONGODB_URL"))
30
- self.db = self.client[os.environ.get("DB_NAME", "triventure")]
31
 
32
  # Create users collection if it doesn't exist
33
  if "user" not in self.db.list_collection_names():
@@ -287,7 +284,25 @@ class Database:
287
  except Exception as e:
288
  logger.error(f"Error getting all users: {e}")
289
  return []
 
 
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  # Create a singleton instance
293
  db = Database()
 
 
10
  from bson import ObjectId
11
  from loguru import logger
12
 
13
+ load_dotenv(override=True)
 
 
 
14
 
15
  class Database:
16
  def __init__(self):
 
24
  try:
25
  # Connect to MongoDB
26
  self.client = MongoClient(os.environ.get("MONGODB_URL"))
27
+ self.db = self.client[os.environ.get("DB_NAME", "scheduling")]
28
 
29
  # Create users collection if it doesn't exist
30
  if "user" not in self.db.list_collection_names():
 
284
  except Exception as e:
285
  logger.error(f"Error getting all users: {e}")
286
  return []
287
+ def get_destination_ids(self,destination_names):
288
+ """
289
+ Get destination IDs from the database.
290
 
291
+ Parameters:
292
+ destination_names (list): A list of destination names.
293
+
294
+ Returns:
295
+ list: A list of destination IDs.
296
+ """
297
+ try:
298
+ # Get destination IDs from the database
299
+ results = self.db.destination.find({"name": {"$in": destination_names}})
300
+ destination_ids = [str(result["_id"]) for result in results]
301
+ return destination_ids
302
+ except Exception as e:
303
+ print(f"Error getting destination IDs: {e}")
304
+ return []
305
 
306
  # Create a singleton instance
307
  db = Database()
308
+
get_destinations.py CHANGED
@@ -3,19 +3,22 @@ import numpy as np
3
  from config import vectorizer
4
  from get_default_weight import destinations, weights_bias_vector
5
  from user_weights import get_user_weights
 
6
 
7
 
8
  def get_des_accumulation(question_vector, weights_bias_vector):
9
  accumulation = 0
10
  for index in range(len(weights_bias_vector)):
11
- if question_vector[index]==1:
12
  accumulation += weights_bias_vector[index]
13
-
14
  return accumulation
15
 
 
16
  def get_destinations_list(question_vector, top_k, user_id=None):
17
  des = destinations
18
  des = des[1:].reset_index(drop=True)
 
19
  """
20
  This function calculates the accumulated scores for each destination based on the given question vector and weights vector.
21
  It then selects the top 5 destinations with the highest scores and returns their names.
@@ -29,24 +32,28 @@ def get_destinations_list(question_vector, top_k, user_id=None):
29
  weights_vector = weights_bias_vector
30
  if user_id is not None:
31
  weights_vector = get_user_weights(user_id, weights_bias_vector)
32
-
 
33
  accumulation_dict = {}
34
  for index in range(len(weights_vector)):
35
  accumulation = get_des_accumulation(question_vector[0], weights_vector[index])
36
  accumulation_dict[str(index)] = accumulation
37
-
38
  top_keys = sorted(accumulation_dict, key=accumulation_dict.get, reverse=True)
39
  print(f"Top keys: {top_keys}")
40
  scores = [accumulation_dict[key] for key in top_keys]
 
41
  q1_score = np.percentile(scores, 25)
 
42
  destinations_list = []
43
  for key in top_keys:
44
  if accumulation_dict[key] > q1_score:
45
  destinations_list.append(des["name"][int(key)])
46
  print(f"{des['name'][int(key)]}: {accumulation_dict[key]}")
47
-
48
  return destinations_list[:top_k]
49
 
 
50
  def get_question_vector(question_tags):
51
  """
52
  Generate a question vector based on the given list of question tags.
@@ -63,4 +70,14 @@ def get_question_vector(question_tags):
63
  """
64
  question_tags = [question_tags]
65
  question_vector = vectorizer.transform(question_tags).toarray()
 
 
66
  return question_vector
 
 
 
 
 
 
 
 
 
3
  from config import vectorizer
4
  from get_default_weight import destinations, weights_bias_vector
5
  from user_weights import get_user_weights
6
+ from database import db
7
 
8
 
9
  def get_des_accumulation(question_vector, weights_bias_vector):
10
  accumulation = 0
11
  for index in range(len(weights_bias_vector)):
12
+ if question_vector[index] == 1:
13
  accumulation += weights_bias_vector[index]
14
+
15
  return accumulation
16
 
17
+
18
  def get_destinations_list(question_vector, top_k, user_id=None):
19
  des = destinations
20
  des = des[1:].reset_index(drop=True)
21
+ # print("DES:", des)
22
  """
23
  This function calculates the accumulated scores for each destination based on the given question vector and weights vector.
24
  It then selects the top 5 destinations with the highest scores and returns their names.
 
32
  weights_vector = weights_bias_vector
33
  if user_id is not None:
34
  weights_vector = get_user_weights(user_id, weights_bias_vector)
35
+ print("weights_bias_vector:", weights_vector)
36
+
37
  accumulation_dict = {}
38
  for index in range(len(weights_vector)):
39
  accumulation = get_des_accumulation(question_vector[0], weights_vector[index])
40
  accumulation_dict[str(index)] = accumulation
41
+ print("accumulation_dict:", accumulation_dict)
42
  top_keys = sorted(accumulation_dict, key=accumulation_dict.get, reverse=True)
43
  print(f"Top keys: {top_keys}")
44
  scores = [accumulation_dict[key] for key in top_keys]
45
+ print("scores:", scores)
46
  q1_score = np.percentile(scores, 25)
47
+ print("q1_score:", q1_score)
48
  destinations_list = []
49
  for key in top_keys:
50
  if accumulation_dict[key] > q1_score:
51
  destinations_list.append(des["name"][int(key)])
52
  print(f"{des['name'][int(key)]}: {accumulation_dict[key]}")
53
+
54
  return destinations_list[:top_k]
55
 
56
+
57
  def get_question_vector(question_tags):
58
  """
59
  Generate a question vector based on the given list of question tags.
 
70
  """
71
  question_tags = [question_tags]
72
  question_vector = vectorizer.transform(question_tags).toarray()
73
+ print("question_tags:", question_tags)
74
+ print("question_vector:", question_vector)
75
  return question_vector
76
+
77
+
78
+ def get_recent_tags(user_id):
79
+ recent_tags = db.get_user_metadata(user_id).get("recent_tags", [])
80
+ if recent_tags:
81
+ return recent_tags[-1].get("tags", [])
82
+ else:
83
+ return {}