app / ai_functions.py
lemdaddy's picture
first commit
9c400b9
raw
history blame
2.64 kB
from pydantic import BaseModel, Field
from typing import Dict, Optional, Type
from data_sources.coin_gecko import CoinGecko
from src.data_sources.cryptocompare import CryptoCompare
from langchain.tools.base import BaseTool
class CryptoCoinPrice(BaseModel):
"""Represents the prices of a coin in various currencies."""
prices: Dict[str, float] = Field(..., description="Prices in various currencies")
class CryptoCoinPriceData(BaseModel):
"""Encapsulates both CoinGecko and CryptoCompare price data."""
coingecko_price: Dict[str, CryptoCoinPrice] = Field(..., description="CoinGecko prices for various coins")
crypto_compare_price: Dict[str, CryptoCoinPrice] = Field(..., description="CryptoCompare prices for various coins")
class PriceInput(BaseModel):
coin_id: str = Field(..., description="The ID of the cryptocurrency coin to retrieve prices for")
vs_currency: str = Field("usd", description="The currency to compare against")
class CryptoCoinPriceOutput(BaseModel):
price_data: CryptoCoinPriceData
class CryptoCoinPriceTool(BaseTool):
name = "CryptoCoinPriceTool"
description = "Fetches price data for a given cryptocurrency coin from CoinGecko and CryptoCompare"
args_schema: Type[BaseModel] = PriceInput
return_direct: bool = True
def __init__(self, id: Optional[str] = None):
self.id = id
self.coingecko = CoinGecko()
self.crypto_compare = CryptoCompare()
def _run(self, coin_id: str, vs_currency: str = "usd") -> CryptoCoinPriceData:
coingecko_price_data = self.coingecko.get_coin_price(ids=[coin_id], vs_currencies=[vs_currency])
crypto_compare_price_data = self.crypto_compare.get_coin_price(ids=[coin_id], vs_currencies=[vs_currency])
coingecko_price = {}
crypto_compare_price = {}
if coin_id in coingecko_price_data:
coingecko_price[coin_id] = CryptoCoinPrice(prices=coingecko_price_data[coin_id])
else:
print(f"Warning: CoinGecko data for {coin_id} not found.")
if coin_id.upper() in crypto_compare_price_data:
crypto_compare_price[coin_id] = CryptoCoinPrice(prices=crypto_compare_price_data[coin_id.upper()])
else:
print(f"Warning: CryptoCompare data for {coin_id} not found.")
return CryptoCoinPriceData(
coingecko_price=coingecko_price,
crypto_compare_price=crypto_compare_price
)
def __call__(self, inputs: PriceInput) -> CryptoCoinPriceOutput:
price_data = self._run(inputs.coin_id, inputs.vs_currency)
return CryptoCoinPriceOutput(price_data=price_data)