| from pydantic import BaseModel, Field | |
| from typing import Dict, Optional, Type | |
| from data_sources.coin_gecko import CoinGecko | |
| from src.data_sources.cryptocompare import CryptoCompare | |
| from langchain.tools.base import BaseTool | |
| class CryptoCoinPrice(BaseModel): | |
| """Represents the prices of a coin in various currencies.""" | |
| prices: Dict[str, float] = Field(..., description="Prices in various currencies") | |
| class CryptoCoinPriceData(BaseModel): | |
| """Encapsulates both CoinGecko and CryptoCompare price data.""" | |
| coingecko_price: Dict[str, CryptoCoinPrice] = Field(..., description="CoinGecko prices for various coins") | |
| crypto_compare_price: Dict[str, CryptoCoinPrice] = Field(..., description="CryptoCompare prices for various coins") | |
| class PriceInput(BaseModel): | |
| coin_id: str = Field(..., description="The ID of the cryptocurrency coin to retrieve prices for") | |
| vs_currency: str = Field("usd", description="The currency to compare against") | |
| class CryptoCoinPriceOutput(BaseModel): | |
| price_data: CryptoCoinPriceData | |
| class CryptoCoinPriceTool(BaseTool): | |
| name = "CryptoCoinPriceTool" | |
| description = "Fetches price data for a given cryptocurrency coin from CoinGecko and CryptoCompare" | |
| args_schema: Type[BaseModel] = PriceInput | |
| return_direct: bool = True | |
| def __init__(self, id: Optional[str] = None): | |
| self.id = id | |
| self.coingecko = CoinGecko() | |
| self.crypto_compare = CryptoCompare() | |
| def _run(self, coin_id: str, vs_currency: str = "usd") -> CryptoCoinPriceData: | |
| coingecko_price_data = self.coingecko.get_coin_price(ids=[coin_id], vs_currencies=[vs_currency]) | |
| crypto_compare_price_data = self.crypto_compare.get_coin_price(ids=[coin_id], vs_currencies=[vs_currency]) | |
| coingecko_price = {} | |
| crypto_compare_price = {} | |
| if coin_id in coingecko_price_data: | |
| coingecko_price[coin_id] = CryptoCoinPrice(prices=coingecko_price_data[coin_id]) | |
| else: | |
| print(f"Warning: CoinGecko data for {coin_id} not found.") | |
| if coin_id.upper() in crypto_compare_price_data: | |
| crypto_compare_price[coin_id] = CryptoCoinPrice(prices=crypto_compare_price_data[coin_id.upper()]) | |
| else: | |
| print(f"Warning: CryptoCompare data for {coin_id} not found.") | |
| return CryptoCoinPriceData( | |
| coingecko_price=coingecko_price, | |
| crypto_compare_price=crypto_compare_price | |
| ) | |
| def __call__(self, inputs: PriceInput) -> CryptoCoinPriceOutput: | |
| price_data = self._run(inputs.coin_id, inputs.vs_currency) | |
| return CryptoCoinPriceOutput(price_data=price_data) | |