File size: 4,366 Bytes
478dec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1249d8b
478dec6
 
 
 
 
1249d8b
478dec6
 
1249d8b
478dec6
 
1249d8b
 
 
 
478dec6
1249d8b
478dec6
1249d8b
 
478dec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56dd37e
 
 
 
 
 
 
 
478dec6
 
1249d8b
 
 
 
 
 
478dec6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os, sys

from fastapi import HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession

from externals.databases.pg_models import CVWeight
from externals.databases.pg_crud import (
    create_filter,
    get_filter_by_id,
    create_weight,
    get_weight_by_id
)
from utils.logger import get_logger

logger = get_logger("weight agentic service")


class AgenticWeightService:
    def __init__(self, db: AsyncSession, user):
        self.db = db
        self.user = user

    
    async def create_weight(self, weight: CVWeight) -> dict:
        """Return criteria_id:str"""
        
        try:
            # check weight existence
            filter = await get_filter_by_id(self.db, weight.criteria_id)
            # print("weight.weight_id:", weight.weight_id)
            # loop weight fields is none or empty list
            weighted_field_name = [] # store field name for not none and not empty list
            excluded_fields = ["_sa_instance_state","created_at", "criteria_id", "weight_id"]

            for field in filter.__dict__.items():
                # if field[1] is None or (isinstance(field[1], list) and not field[1]):
                #     continue
                if field[0] not in excluded_fields:
                    weighted_field_name.append(field[0])

            print("weighted_field_name:", weighted_field_name)
            # check current weight, apply default weight
            new_weight = CVWeight(criteria_id=weight.criteria_id,
                                  weight_id=weight.weight_id)
            all_weight = [
                            getattr(weight, field)
                            for field in weighted_field_name
                            if getattr(weight, field) is not None and field not in excluded_fields
                        ]
            print("all_weight", all_weight)
            total_weight = sum(all_weight)
            
            
            if total_weight == 0:
                def_w = 1 / len(weighted_field_name)
                for field_name in weighted_field_name:
                    setattr(new_weight, field_name, def_w)
            elif total_weight > 0 and total_weight <= 1.0: # DOCS: assuming total_weight > 0
                for field_name in weighted_field_name:
                    curr_w = getattr(weight, field_name)
                    curr_w = curr_w if curr_w is not None else 0.01 # DOCS: default weight if the given weight is none or zero
                    setattr(new_weight, field_name, curr_w)
            else:
                # normalized weight
                # total_weight = sum([getattr(new_weight, field_name) for field_name in weighted_field_name])
                # for field_name in weighted_field_name:
                #     curr_w = getattr(new_weight, field_name)
                #     normalized_w = curr_w / total_weight
                #     setattr(new_weight, field_name, normalized_w)
                raise HTTPException(
                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                    detail=f"Create weight error: the weight value must be between 0 and 1, and the total weight must be less than or equal to 1. The total weight you provided is {total_weight}"
                )

            # create weight
            created_weight = await create_weight(self.db, new_weight)
            logger.info(f"Weight created: {created_weight.weight_id}")
            return {
                "criteria_id": created_weight.criteria_id,
                "weight_id": created_weight.weight_id,
                }
        except Exception as E:
            logger.error(f"❌ create weight error: {E}")
            exc_type, exc_obj, exc_tb = sys.exc_info()
            fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
            print(exc_type, fname, exc_tb.tb_lineno)
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=f"Create weight error: {E}"
            )
    

    async def get_weight_by_weight_id(self, weight_id: str) -> CVWeight:
        "Return weight json in dict"

        try:
            weight_data = await get_weight_by_id(self.db, weight_id=weight_id)
            return weight_data
        except Exception as E:
            logger.error(f"get weight by weight id error, {E}")
            return {}