Spaces:
Sleeping
Sleeping
File size: 4,366 Bytes
478dec6 1249d8b 478dec6 1249d8b 478dec6 1249d8b 478dec6 1249d8b 478dec6 1249d8b 478dec6 1249d8b 478dec6 56dd37e 478dec6 1249d8b 478dec6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import os, sys
from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from externals.databases.pg_models import CVWeight
from externals.databases.pg_crud import (
create_filter,
get_filter_by_id,
create_weight,
get_weight_by_id
)
from utils.logger import get_logger
logger = get_logger("weight agentic service")
class AgenticWeightService:
def __init__(self, db: AsyncSession, user):
self.db = db
self.user = user
async def create_weight(self, weight: CVWeight) -> dict:
"""Return criteria_id:str"""
try:
# check weight existence
filter = await get_filter_by_id(self.db, weight.criteria_id)
# print("weight.weight_id:", weight.weight_id)
# loop weight fields is none or empty list
weighted_field_name = [] # store field name for not none and not empty list
excluded_fields = ["_sa_instance_state","created_at", "criteria_id", "weight_id"]
for field in filter.__dict__.items():
# if field[1] is None or (isinstance(field[1], list) and not field[1]):
# continue
if field[0] not in excluded_fields:
weighted_field_name.append(field[0])
print("weighted_field_name:", weighted_field_name)
# check current weight, apply default weight
new_weight = CVWeight(criteria_id=weight.criteria_id,
weight_id=weight.weight_id)
all_weight = [
getattr(weight, field)
for field in weighted_field_name
if getattr(weight, field) is not None and field not in excluded_fields
]
print("all_weight", all_weight)
total_weight = sum(all_weight)
if total_weight == 0:
def_w = 1 / len(weighted_field_name)
for field_name in weighted_field_name:
setattr(new_weight, field_name, def_w)
elif total_weight > 0 and total_weight <= 1.0: # DOCS: assuming total_weight > 0
for field_name in weighted_field_name:
curr_w = getattr(weight, field_name)
curr_w = curr_w if curr_w is not None else 0.01 # DOCS: default weight if the given weight is none or zero
setattr(new_weight, field_name, curr_w)
else:
# normalized weight
# total_weight = sum([getattr(new_weight, field_name) for field_name in weighted_field_name])
# for field_name in weighted_field_name:
# curr_w = getattr(new_weight, field_name)
# normalized_w = curr_w / total_weight
# setattr(new_weight, field_name, normalized_w)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Create weight error: the weight value must be between 0 and 1, and the total weight must be less than or equal to 1. The total weight you provided is {total_weight}"
)
# create weight
created_weight = await create_weight(self.db, new_weight)
logger.info(f"Weight created: {created_weight.weight_id}")
return {
"criteria_id": created_weight.criteria_id,
"weight_id": created_weight.weight_id,
}
except Exception as E:
logger.error(f"❌ create weight error: {E}")
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
print(exc_type, fname, exc_tb.tb_lineno)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Create weight error: {E}"
)
async def get_weight_by_weight_id(self, weight_id: str) -> CVWeight:
"Return weight json in dict"
try:
weight_data = await get_weight_by_id(self.db, weight_id=weight_id)
return weight_data
except Exception as E:
logger.error(f"get weight by weight id error, {E}")
return {} |