from langchain.tools import BaseTool from pydantic import BaseModel, Field from typing import Optional, Type from sqlalchemy.orm import Session from app.db.models import MarginProfile class MarginInput(BaseModel): """Input for margin calculation""" total_cost: float = Field(..., description="The total cost to apply margin to") profile_name: str = Field(..., description="The margin profile name to use") class MarginTool(BaseTool): """Tool for applying margin to a total cost""" name: str = "apply_margin" description: str = "Apply a margin to a total cost based on a margin profile" args_schema: Type[BaseModel] = MarginInput db: Session def __init__(self, db: Session): """Initialize with database session""" super().__init__() self.db = db def _run(self, total_cost: float, profile_name: str) -> str: """Apply margin to total cost""" # Find the margin profile by name margin_profile = self.db.query(MarginProfile).filter(MarginProfile.profile_name == profile_name).first() if not margin_profile: return f"Error: No margin profile found with name '{profile_name}'. Available profiles: {', '.join([mp.profile_name for mp in self.db.query(MarginProfile).all()])}" # Calculate the final price with margin final_price = total_cost * (1 + margin_profile.margin_percentage) margin_amount = final_price - total_cost margin_percentage = margin_profile.margin_percentage * 100 return f"Applied {margin_percentage:.1f}% margin to {total_cost:.2f}: Final price is {final_price:.2f} (margin amount: {margin_amount:.2f})" async def _arun(self, total_cost: float, profile_name: str) -> str: """Async implementation of margin application""" return self._run(total_cost=total_cost, profile_name=profile_name)