| from langchain.tools import BaseTool |
| from pydantic import BaseModel, Field |
| from typing import Optional, Type |
| from sqlalchemy.orm import Session |
| from app.db.models import Material |
|
|
| class MaterialCostInput(BaseModel): |
| """Input for material cost calculation""" |
| material_name: str = Field(..., description="The name of the material") |
| quantity: float = Field(..., description="Quantity of material needed") |
|
|
| class MaterialCostTool(BaseTool): |
| """Tool for calculating material cost based on material name and quantity""" |
| name: str = "material_cost" |
| description: str = "Calculate material cost by multiplying the unit price by the quantity for a specific material" |
| args_schema: Type[BaseModel] = MaterialCostInput |
| db: Session |
|
|
| def __init__(self, db: Session): |
| """Initialize with database session""" |
| super().__init__() |
| self.db = db |
|
|
| def _run(self, material_name: str, quantity: float) -> str: |
| """Calculate material cost""" |
| |
| material = self.db.query(Material).filter(Material.name == material_name).first() |
| |
| if not material: |
| return f"Error: No material found with name '{material_name}'. Available materials: {', '.join([m.name for m in self.db.query(Material).all()])}" |
| |
| |
| cost = material.unit_price * quantity |
| |
| return f"Material cost for {quantity} {material.unit} of {material_name} at {material.unit_price} {material.currency}/{material.unit}: {cost:.2f} {material.currency}" |
|
|
| async def _arun(self, material_name: str, quantity: float) -> str: |
| """Async implementation of material cost calculation""" |
| return self._run(material_name=material_name, quantity=quantity) |