|
|
from langchain.tools import BaseTool |
|
|
from pydantic import BaseModel, Field |
|
|
from typing import Optional, Type |
|
|
from sqlalchemy.orm import Session |
|
|
from app.db.models import MarginProfile |
|
|
|
|
|
class MarginInput(BaseModel): |
|
|
"""Input for margin calculation""" |
|
|
total_cost: float = Field(..., description="The total cost to apply margin to") |
|
|
profile_name: str = Field(..., description="The margin profile name to use") |
|
|
|
|
|
class MarginTool(BaseTool): |
|
|
"""Tool for applying margin to a total cost""" |
|
|
name: str = "apply_margin" |
|
|
description: str = "Apply a margin to a total cost based on a margin profile" |
|
|
args_schema: Type[BaseModel] = MarginInput |
|
|
db: Session |
|
|
|
|
|
def __init__(self, db: Session): |
|
|
"""Initialize with database session""" |
|
|
super().__init__() |
|
|
self.db = db |
|
|
|
|
|
def _run(self, total_cost: float, profile_name: str) -> str: |
|
|
"""Apply margin to total cost""" |
|
|
|
|
|
margin_profile = self.db.query(MarginProfile).filter(MarginProfile.profile_name == profile_name).first() |
|
|
|
|
|
if not margin_profile: |
|
|
return f"Error: No margin profile found with name '{profile_name}'. Available profiles: {', '.join([mp.profile_name for mp in self.db.query(MarginProfile).all()])}" |
|
|
|
|
|
|
|
|
final_price = total_cost * (1 + margin_profile.margin_percentage) |
|
|
margin_amount = final_price - total_cost |
|
|
margin_percentage = margin_profile.margin_percentage * 100 |
|
|
|
|
|
return f"Applied {margin_percentage:.1f}% margin to {total_cost:.2f}: Final price is {final_price:.2f} (margin amount: {margin_amount:.2f})" |
|
|
|
|
|
async def _arun(self, total_cost: float, profile_name: str) -> str: |
|
|
"""Async implementation of margin application""" |
|
|
return self._run(total_cost=total_cost, profile_name=profile_name) |