Sadeep Sachintha commited on
Commit ·
fb1fe87
1
Parent(s): a0b8ee5
feat: implement database schema, async session management, and bot command handlers for FlyRates
Browse files- bot/handlers.py +121 -4
- core/config.py +11 -0
- db/models.py +4 -4
- db/session.py +5 -0
bot/handlers.py
CHANGED
|
@@ -28,10 +28,7 @@ async def cmd_start(message: types.Message):
|
|
| 28 |
f"Welcome to FlyRates! 🌍💸\n\n"
|
| 29 |
f"I can track real-time exchange rates for you.\n"
|
| 30 |
f"Supported Currencies: {currencies_str}\n\n"
|
| 31 |
-
f"
|
| 32 |
-
f"/current <base> <target> - Get live rate (e.g., /current USD EUR)\n"
|
| 33 |
-
f"/subscribe <base> <target> <daily/hourly> - Get automated updates\n"
|
| 34 |
-
f"/threshold <base> <target> << or >> <value> - Get alerts (e.g., /threshold USD EUR < 0.90)"
|
| 35 |
)
|
| 36 |
|
| 37 |
@router.message(Command("current"))
|
|
@@ -110,3 +107,123 @@ async def cmd_threshold(message: types.Message):
|
|
| 110 |
await session.commit()
|
| 111 |
|
| 112 |
await message.answer(f"✅ Alert set! I will notify you when 1 {base} {condition} {value} {target}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
f"Welcome to FlyRates! 🌍💸\n\n"
|
| 29 |
f"I can track real-time exchange rates for you.\n"
|
| 30 |
f"Supported Currencies: {currencies_str}\n\n"
|
| 31 |
+
f"Use /help to see all available commands!"
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
@router.message(Command("current"))
|
|
|
|
| 107 |
await session.commit()
|
| 108 |
|
| 109 |
await message.answer(f"✅ Alert set! I will notify you when 1 {base} {condition} {value} {target}.")
|
| 110 |
+
|
| 111 |
+
@router.message(Command("mysubs"))
|
| 112 |
+
async def cmd_mysubs(message: types.Message):
|
| 113 |
+
"""Lists all active subscriptions and thresholds for the user."""
|
| 114 |
+
from db.session import async_session
|
| 115 |
+
chat_id = message.chat.id
|
| 116 |
+
|
| 117 |
+
async with async_session() as session:
|
| 118 |
+
subs_result = await session.execute(select(Subscription).where(Subscription.chat_id == chat_id))
|
| 119 |
+
subs = subs_result.scalars().all()
|
| 120 |
+
|
| 121 |
+
thresholds_result = await session.execute(select(Threshold).where(Threshold.chat_id == chat_id))
|
| 122 |
+
thresholds = thresholds_result.scalars().all()
|
| 123 |
+
|
| 124 |
+
response = "📋 **Your Active Settings**\n\n"
|
| 125 |
+
|
| 126 |
+
if subs:
|
| 127 |
+
response += "🔔 **Subscriptions:**\n"
|
| 128 |
+
for sub in subs:
|
| 129 |
+
response += f"• {sub.base_currency} to {sub.target_currency} ({sub.frequency})\n"
|
| 130 |
+
else:
|
| 131 |
+
response += "🔔 **Subscriptions:** None\n"
|
| 132 |
+
|
| 133 |
+
response += "\n"
|
| 134 |
+
|
| 135 |
+
if thresholds:
|
| 136 |
+
response += "🚨 **Thresholds:**\n"
|
| 137 |
+
for th in thresholds:
|
| 138 |
+
status = " (Active)" if th.is_active else " (Triggered/Inactive)"
|
| 139 |
+
response += f"• {th.base_currency} to {th.target_currency} {th.condition} {th.target_value}{status}\n"
|
| 140 |
+
else:
|
| 141 |
+
response += "🚨 **Thresholds:** None\n"
|
| 142 |
+
|
| 143 |
+
await message.answer(response, parse_mode="Markdown")
|
| 144 |
+
|
| 145 |
+
@router.message(Command("unsubscribe"))
|
| 146 |
+
async def cmd_unsubscribe(message: types.Message):
|
| 147 |
+
"""Removes a subscription."""
|
| 148 |
+
args = message.text.split()[1:]
|
| 149 |
+
if len(args) != 2:
|
| 150 |
+
await message.answer("Usage: /unsubscribe <base> <target>\nExample: /unsubscribe USD EUR")
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
base, target = args[0].upper(), args[1].upper()
|
| 154 |
+
chat_id = message.chat.id
|
| 155 |
+
|
| 156 |
+
from db.session import async_session
|
| 157 |
+
async with async_session() as session:
|
| 158 |
+
result = await session.execute(select(Subscription).where(
|
| 159 |
+
Subscription.chat_id == chat_id,
|
| 160 |
+
Subscription.base_currency == base,
|
| 161 |
+
Subscription.target_currency == target
|
| 162 |
+
))
|
| 163 |
+
subs = result.scalars().all()
|
| 164 |
+
|
| 165 |
+
if not subs:
|
| 166 |
+
await message.answer(f"❌ You don't have a subscription for {base} to {target}.")
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
for sub in subs:
|
| 170 |
+
await session.delete(sub)
|
| 171 |
+
await session.commit()
|
| 172 |
+
|
| 173 |
+
await message.answer(f"✅ Unsubscribed from {base} to {target} updates.")
|
| 174 |
+
|
| 175 |
+
@router.message(Command("delthreshold"))
|
| 176 |
+
async def cmd_delthreshold(message: types.Message):
|
| 177 |
+
"""Removes a threshold alert."""
|
| 178 |
+
args = message.text.split()[1:]
|
| 179 |
+
if len(args) != 2:
|
| 180 |
+
await message.answer("Usage: /delthreshold <base> <target>\nExample: /delthreshold USD EUR")
|
| 181 |
+
return
|
| 182 |
+
|
| 183 |
+
base, target = args[0].upper(), args[1].upper()
|
| 184 |
+
chat_id = message.chat.id
|
| 185 |
+
|
| 186 |
+
from db.session import async_session
|
| 187 |
+
async with async_session() as session:
|
| 188 |
+
result = await session.execute(select(Threshold).where(
|
| 189 |
+
Threshold.chat_id == chat_id,
|
| 190 |
+
Threshold.base_currency == base,
|
| 191 |
+
Threshold.target_currency == target
|
| 192 |
+
))
|
| 193 |
+
thresholds = result.scalars().all()
|
| 194 |
+
|
| 195 |
+
if not thresholds:
|
| 196 |
+
await message.answer(f"❌ You don't have a threshold alert for {base} to {target}.")
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
for th in thresholds:
|
| 200 |
+
await session.delete(th)
|
| 201 |
+
await session.commit()
|
| 202 |
+
|
| 203 |
+
await message.answer(f"✅ Deleted threshold alerts for {base} to {target}.")
|
| 204 |
+
|
| 205 |
+
@router.message(Command("help"))
|
| 206 |
+
async def cmd_help(message: types.Message):
|
| 207 |
+
"""Shows detailed usage instructions."""
|
| 208 |
+
currencies_str = ", ".join(ALLOWED_CURRENCIES)
|
| 209 |
+
help_text = (
|
| 210 |
+
"📚 **FlyRates Bot Help** 📚\n\n"
|
| 211 |
+
"Here are the commands you can use:\n\n"
|
| 212 |
+
"**Basic Commands**\n"
|
| 213 |
+
"`/start` - Get welcome message\n"
|
| 214 |
+
"`/help` - Show this help menu\n\n"
|
| 215 |
+
"**Rates & Updates**\n"
|
| 216 |
+
"`/current <base> <target>` - Get the live rate\n"
|
| 217 |
+
"`/subscribe <base> <target> [daily/hourly]` - Automate updates\n"
|
| 218 |
+
"`/threshold <base> <target> <condition> <value>` - Get custom alerts (<, >, <=, >=)\n\n"
|
| 219 |
+
"**Management**\n"
|
| 220 |
+
"`/mysubs` - View all your active subscriptions and alerts\n"
|
| 221 |
+
"`/unsubscribe <base> <target>` - Remove a subscription\n"
|
| 222 |
+
"`/delthreshold <base> <target>` - Remove an alert\n\n"
|
| 223 |
+
f"**Supported Currencies:** {currencies_str}\n\n"
|
| 224 |
+
"**Examples:**\n"
|
| 225 |
+
"`/current USD EUR`\n"
|
| 226 |
+
"`/subscribe USD EUR hourly`\n"
|
| 227 |
+
"`/threshold USD EUR < 0.90`"
|
| 228 |
+
)
|
| 229 |
+
await message.answer(help_text, parse_mode="Markdown")
|
core/config.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
|
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
class Settings(BaseSettings):
|
|
@@ -8,6 +9,16 @@ class Settings(BaseSettings):
|
|
| 8 |
database_url: str = "sqlite+aiosqlite:///./flyrates.db"
|
| 9 |
log_level: str = "INFO"
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
| 12 |
|
| 13 |
settings = Settings()
|
|
|
|
| 1 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 2 |
+
from pydantic import field_validator
|
| 3 |
from typing import Optional
|
| 4 |
|
| 5 |
class Settings(BaseSettings):
|
|
|
|
| 9 |
database_url: str = "sqlite+aiosqlite:///./flyrates.db"
|
| 10 |
log_level: str = "INFO"
|
| 11 |
|
| 12 |
+
@field_validator("database_url", mode="before")
|
| 13 |
+
@classmethod
|
| 14 |
+
def assemble_db_connection(cls, v: str) -> str:
|
| 15 |
+
if isinstance(v, str):
|
| 16 |
+
if v.startswith("postgres://"):
|
| 17 |
+
return v.replace("postgres://", "postgresql+asyncpg://", 1)
|
| 18 |
+
if v.startswith("postgresql://") and not v.startswith("postgresql+asyncpg://"):
|
| 19 |
+
return v.replace("postgresql://", "postgresql+asyncpg://", 1)
|
| 20 |
+
return v
|
| 21 |
+
|
| 22 |
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
| 23 |
|
| 24 |
settings = Settings()
|
db/models.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, ForeignKey
|
| 2 |
from sqlalchemy.orm import declarative_base, relationship
|
| 3 |
from datetime import datetime, timezone
|
| 4 |
|
|
@@ -7,7 +7,7 @@ Base = declarative_base()
|
|
| 7 |
class User(Base):
|
| 8 |
__tablename__ = "users"
|
| 9 |
|
| 10 |
-
chat_id = Column(
|
| 11 |
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
| 12 |
|
| 13 |
subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
|
|
@@ -17,7 +17,7 @@ class Subscription(Base):
|
|
| 17 |
__tablename__ = "subscriptions"
|
| 18 |
|
| 19 |
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 20 |
-
chat_id = Column(
|
| 21 |
base_currency = Column(String(3), nullable=False)
|
| 22 |
target_currency = Column(String(3), nullable=False)
|
| 23 |
# Frequency could be 'daily', 'hourly'
|
|
@@ -30,7 +30,7 @@ class Threshold(Base):
|
|
| 30 |
__tablename__ = "thresholds"
|
| 31 |
|
| 32 |
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 33 |
-
chat_id = Column(
|
| 34 |
base_currency = Column(String(3), nullable=False)
|
| 35 |
target_currency = Column(String(3), nullable=False)
|
| 36 |
condition = Column(String(5), nullable=False) # e.g., '<', '>', '<=', '>='
|
|
|
|
| 1 |
+
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, ForeignKey, BigInteger
|
| 2 |
from sqlalchemy.orm import declarative_base, relationship
|
| 3 |
from datetime import datetime, timezone
|
| 4 |
|
|
|
|
| 7 |
class User(Base):
|
| 8 |
__tablename__ = "users"
|
| 9 |
|
| 10 |
+
chat_id = Column(BigInteger, primary_key=True, index=True)
|
| 11 |
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
| 12 |
|
| 13 |
subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
|
|
|
|
| 17 |
__tablename__ = "subscriptions"
|
| 18 |
|
| 19 |
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 20 |
+
chat_id = Column(BigInteger, ForeignKey("users.chat_id"))
|
| 21 |
base_currency = Column(String(3), nullable=False)
|
| 22 |
target_currency = Column(String(3), nullable=False)
|
| 23 |
# Frequency could be 'daily', 'hourly'
|
|
|
|
| 30 |
__tablename__ = "thresholds"
|
| 31 |
|
| 32 |
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 33 |
+
chat_id = Column(BigInteger, ForeignKey("users.chat_id"))
|
| 34 |
base_currency = Column(String(3), nullable=False)
|
| 35 |
target_currency = Column(String(3), nullable=False)
|
| 36 |
condition = Column(String(5), nullable=False) # e.g., '<', '>', '<=', '>='
|
db/session.py
CHANGED
|
@@ -5,11 +5,16 @@ from db.models import Base
|
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# Initialize the async engine
|
| 9 |
engine = create_async_engine(
|
| 10 |
settings.database_url,
|
| 11 |
echo=(settings.log_level == "DEBUG"),
|
| 12 |
future=True,
|
|
|
|
| 13 |
)
|
| 14 |
|
| 15 |
# Create an async session factory
|
|
|
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
+
connect_args = {}
|
| 9 |
+
if "supabase" in settings.database_url or "postgres" in settings.database_url:
|
| 10 |
+
connect_args["ssl"] = True
|
| 11 |
+
|
| 12 |
# Initialize the async engine
|
| 13 |
engine = create_async_engine(
|
| 14 |
settings.database_url,
|
| 15 |
echo=(settings.log_level == "DEBUG"),
|
| 16 |
future=True,
|
| 17 |
+
connect_args=connect_args
|
| 18 |
)
|
| 19 |
|
| 20 |
# Create an async session factory
|