Spaces:
Sleeping
Sleeping
| import os, sys | |
| from fastapi import HTTPException, status | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from externals.databases.pg_models import CVWeight | |
| from externals.databases.pg_crud import ( | |
| create_filter, | |
| get_filter_by_id, | |
| create_weight, | |
| get_weight_by_id | |
| ) | |
| from utils.logger import get_logger | |
| logger = get_logger("weight agentic service") | |
| class AgenticWeightService: | |
| def __init__(self, db: AsyncSession, user): | |
| self.db = db | |
| self.user = user | |
| async def create_weight(self, weight: CVWeight) -> dict: | |
| """Return criteria_id:str""" | |
| try: | |
| # check weight existence | |
| filter = await get_filter_by_id(self.db, weight.criteria_id) | |
| # print("weight.weight_id:", weight.weight_id) | |
| # loop weight fields is none or empty list | |
| weighted_field_name = [] # store field name for not none and not empty list | |
| excluded_fields = ["_sa_instance_state","created_at", "criteria_id", "weight_id"] | |
| for field in filter.__dict__.items(): | |
| # if field[1] is None or (isinstance(field[1], list) and not field[1]): | |
| # continue | |
| if field[0] not in excluded_fields: | |
| weighted_field_name.append(field[0]) | |
| print("weighted_field_name:", weighted_field_name) | |
| # check current weight, apply default weight | |
| new_weight = CVWeight(criteria_id=weight.criteria_id, | |
| weight_id=weight.weight_id) | |
| all_weight = [ | |
| getattr(weight, field) | |
| for field in weighted_field_name | |
| if getattr(weight, field) is not None and field not in excluded_fields | |
| ] | |
| print("all_weight", all_weight) | |
| total_weight = sum(all_weight) | |
| if total_weight == 0: | |
| def_w = 1 / len(weighted_field_name) | |
| for field_name in weighted_field_name: | |
| setattr(new_weight, field_name, def_w) | |
| elif total_weight > 0 and total_weight <= 1.0: # DOCS: assuming total_weight > 0 | |
| for field_name in weighted_field_name: | |
| curr_w = getattr(weight, field_name) | |
| curr_w = curr_w if curr_w is not None else 0.01 # DOCS: default weight if the given weight is none or zero | |
| setattr(new_weight, field_name, curr_w) | |
| else: | |
| # normalized weight | |
| # total_weight = sum([getattr(new_weight, field_name) for field_name in weighted_field_name]) | |
| # for field_name in weighted_field_name: | |
| # curr_w = getattr(new_weight, field_name) | |
| # normalized_w = curr_w / total_weight | |
| # setattr(new_weight, field_name, normalized_w) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Create weight error: the weight value must be between 0 and 1, and the total weight must be less than or equal to 1. The total weight you provided is {total_weight}" | |
| ) | |
| # create weight | |
| created_weight = await create_weight(self.db, new_weight) | |
| logger.info(f"Weight created: {created_weight.weight_id}") | |
| return { | |
| "criteria_id": created_weight.criteria_id, | |
| "weight_id": created_weight.weight_id, | |
| } | |
| except Exception as E: | |
| logger.error(f"❌ create weight error: {E}") | |
| exc_type, exc_obj, exc_tb = sys.exc_info() | |
| fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] | |
| print(exc_type, fname, exc_tb.tb_lineno) | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"Create weight error: {E}" | |
| ) | |
| async def get_weight_by_weight_id(self, weight_id: str) -> CVWeight: | |
| "Return weight json in dict" | |
| try: | |
| weight_data = await get_weight_by_id(self.db, weight_id=weight_id) | |
| return weight_data | |
| except Exception as E: | |
| logger.error(f"get weight by weight id error, {E}") | |
| return {} |