ABAO77's picture
Upload app.py
878a9aa verified
import os
from typing import Any, Dict, List
from dotenv import load_dotenv
load_dotenv()
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from model_predict_onnx import onnx_predictor
from user_weights import (get_all_users, get_user_metadata,
get_user_weights,
track_question_tags,
update_user_metadata,
update_user_weights,
update_weights_from_feedback,
update_weights_from_query)
from get_destinations import (get_destinations_list,get_question_vector,get_recent_tags)
from get_default_weight import feature_names, weights_bias_vector
from database import db
# Define request models
class WeightUpdateRequest(BaseModel):
tag_indices: List[int]
new_weights: List[float]
metadata: Dict[str, Any] = {}
class FeedbackRequest(BaseModel):
destination_id: int
tag_id: int
rating: int # 1-5 stars
router = APIRouter(prefix="/model", tags=["Model"])
@router.get("/get_question_tags/{question}")
async def get_question_tags(question: str):
# Get the prediction
original_sentence, predicted_tags = onnx_predictor.predict(question)
# Print the sentence and its predicted tags
print("Sentence:", original_sentence)
print("Predicted Tags:", predicted_tags)
return {"question_tags": predicted_tags}
@router.get("/get_destinations_list/{question_tags}/{top_k}")
async def get_destinations_list_api(question_tags: str, top_k:str):
# Get the prediction
question_vector = get_question_vector(question_tags)
destinations_list = get_destinations_list(question_vector, int(top_k))
print("destinations_list:", destinations_list)
return {"destinations_list": destinations_list}
@router.get("/get_recommendation_destinations/{user_id}/{top_k}")
async def get_recommendation_destinations(user_id: str, top_k:str):
# Get the prediction
recent_tags = get_recent_tags(user_id)
question_tags = " ".join(recent_tags)
question_vector = get_question_vector(question_tags)
destinations_list = get_destinations_list(question_vector, int(top_k))
destination_ids = db.get_destination_ids(destinations_list)
print("destinations_list:", destinations_list)
return {"destination_ids": destination_ids,"destinations_list":destinations_list,"recent_tags": recent_tags}
@router.get("/get_destinations_list_by_question/{question}/{top_k}")
async def get_destinations_list_api(question: str, top_k: str):
# Get the prediction
original_sentence, question_tags = onnx_predictor.predict(question)
# Print the sentence and its predicted tags
print("Sentence:", original_sentence)
print("Predicted Tags:", question_tags)
# Get the prediction
question_tags = " ".join(question_tags)
question_vector = get_question_vector(question_tags)
destinations_list = get_destinations_list(question_vector, int(top_k))
print("destinations_list:", destinations_list)
return {"destinations_list": destinations_list}
@router.get("/get_destinations_list_by_question/{question}/{top_k}/{user_id}")
def get_destinations_list_with_user_api(question: str, top_k: str, user_id: str):
"""
Get a list of destinations based on a question and user-specific weights.
Parameters:
question (str): The question to get destinations for.
top_k (str): The number of destinations to return.
user_id (str): The ID of the user.
Returns:
dict: A dictionary containing the list of destinations.
"""
# Get the prediction
original_sentence, question_tags = onnx_predictor.predict(question)
# Print the sentence and its predicted tags
print("Sentence:", original_sentence)
print("Predicted Tags:", question_tags)
# Track the question tags for the user
track_question_tags(user_id, question_tags)
# Update weights based on query tags
update_weights_from_query(user_id, question_tags, feature_names, weights_bias_vector)
# Get the prediction
question_tags_str = " ".join(question_tags)
question_vector = get_question_vector(question_tags_str)
destinations_list = get_destinations_list(question_vector, int(top_k), user_id)
print("destinations_list:", destinations_list)
return {"destinations_list": destinations_list}
@router.get("/users")
def get_users():
"""
Get a list of all users.
Returns:
dict: A dictionary containing the list of users.
"""
users = get_all_users()
return {"users": users}
@router.get("/users/{user_id}")
def get_user(user_id: str):
"""
Get the metadata for a user.
Parameters:
user_id (str): The ID of the user.
Returns:
dict: A dictionary containing the user's metadata.
"""
metadata = get_user_metadata(user_id)
return {"metadata": metadata}
@router.get("/users/{user_id}/weights")
def get_user_weights_api(user_id: str):
"""
Get the weights for a user.
Parameters:
user_id (str): The ID of the user.
Returns:
dict: A dictionary containing the user's weights.
"""
weights = get_user_weights(user_id, weights_bias_vector)
# Convert numpy array to list for JSON serialization
weights_list = weights.tolist() if weights is not None else None
return {"user_id": user_id, "weights": weights_list}
@router.post("/users/{user_id}/weights")
def update_user_weights_api(user_id: str, request: WeightUpdateRequest):
"""
Update the weights for a user.
Parameters:
user_id (str): The ID of the user.
request (WeightUpdateRequest): The request containing the tag indices, new weights, and metadata.
Returns:
dict: A dictionary indicating whether the update was successful.
"""
# Validate the request
if len(request.tag_indices) != len(request.new_weights):
raise HTTPException(status_code=400, detail="Tag indices and new weights must have the same length")
# Update the weights
success = update_user_weights(user_id, request.tag_indices, request.new_weights, weights_bias_vector)
# Update the metadata
if success and request.metadata:
update_user_metadata(user_id, request.metadata)
return {"success": success}
@router.post("/users/{user_id}/feedback")
def record_user_feedback(user_id: str, request: FeedbackRequest):
"""
Record user feedback on a specific tag for a specific destination.
Parameters:
user_id (str): The ID of the user.
request (FeedbackRequest): The request containing the destination ID, tag ID, and rating.
Returns:
dict: A dictionary indicating whether the feedback was recorded successfully.
"""
# Validate the request
if request.rating < 1 or request.rating > 5:
raise HTTPException(status_code=400, detail="Rating must be between 1 and 5")
# Update weights based on feedback
success = update_weights_from_feedback(
user_id,
request.destination_id,
request.tag_id,
request.rating,
weights_bias_vector
)
return {"success": success}
@router.get("/tags")
def get_tags():
"""
Get a list of all tags.
Returns:
dict: A dictionary containing the list of tags.
"""
return {"tags": feature_names.tolist()}
app = FastAPI(docs_url="/")
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
expose_headers=['*',]
)
app.include_router(router)
@app.on_event("startup")
def startup_event():
"""
Connect to the database when the API starts.
"""
db.connect()
@app.on_event("shutdown")
def shutdown_event():
"""
Close the database connection when the API shuts down.
"""
db.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7880)))