from pydantic import BaseModel, Field from typing import Dict, Optional, Type from data_sources.coin_gecko import CoinGecko from src.data_sources.cryptocompare import CryptoCompare from langchain.tools.base import BaseTool class CryptoCoinPrice(BaseModel): """Represents the prices of a coin in various currencies.""" prices: Dict[str, float] = Field(..., description="Prices in various currencies") class CryptoCoinPriceData(BaseModel): """Encapsulates both CoinGecko and CryptoCompare price data.""" coingecko_price: Dict[str, CryptoCoinPrice] = Field(..., description="CoinGecko prices for various coins") crypto_compare_price: Dict[str, CryptoCoinPrice] = Field(..., description="CryptoCompare prices for various coins") class PriceInput(BaseModel): coin_id: str = Field(..., description="The ID of the cryptocurrency coin to retrieve prices for") vs_currency: str = Field("usd", description="The currency to compare against") class CryptoCoinPriceOutput(BaseModel): price_data: CryptoCoinPriceData class CryptoCoinPriceTool(BaseTool): name = "CryptoCoinPriceTool" description = "Fetches price data for a given cryptocurrency coin from CoinGecko and CryptoCompare" args_schema: Type[BaseModel] = PriceInput return_direct: bool = True def __init__(self, id: Optional[str] = None): self.id = id self.coingecko = CoinGecko() self.crypto_compare = CryptoCompare() def _run(self, coin_id: str, vs_currency: str = "usd") -> CryptoCoinPriceData: coingecko_price_data = self.coingecko.get_coin_price(ids=[coin_id], vs_currencies=[vs_currency]) crypto_compare_price_data = self.crypto_compare.get_coin_price(ids=[coin_id], vs_currencies=[vs_currency]) coingecko_price = {} crypto_compare_price = {} if coin_id in coingecko_price_data: coingecko_price[coin_id] = CryptoCoinPrice(prices=coingecko_price_data[coin_id]) else: print(f"Warning: CoinGecko data for {coin_id} not found.") if coin_id.upper() in crypto_compare_price_data: crypto_compare_price[coin_id] = CryptoCoinPrice(prices=crypto_compare_price_data[coin_id.upper()]) else: print(f"Warning: CryptoCompare data for {coin_id} not found.") return CryptoCoinPriceData( coingecko_price=coingecko_price, crypto_compare_price=crypto_compare_price ) def __call__(self, inputs: PriceInput) -> CryptoCoinPriceOutput: price_data = self._run(inputs.coin_id, inputs.vs_currency) return CryptoCoinPriceOutput(price_data=price_data)