Sadeep Sachintha commited on
Commit
fb1fe87
·
1 Parent(s): a0b8ee5

feat: implement database schema, async session management, and bot command handlers for FlyRates

Browse files
Files changed (4) hide show
  1. bot/handlers.py +121 -4
  2. core/config.py +11 -0
  3. db/models.py +4 -4
  4. db/session.py +5 -0
bot/handlers.py CHANGED
@@ -28,10 +28,7 @@ async def cmd_start(message: types.Message):
28
  f"Welcome to FlyRates! 🌍💸\n\n"
29
  f"I can track real-time exchange rates for you.\n"
30
  f"Supported Currencies: {currencies_str}\n\n"
31
- f"Commands:\n"
32
- f"/current <base> <target> - Get live rate (e.g., /current USD EUR)\n"
33
- f"/subscribe <base> <target> <daily/hourly> - Get automated updates\n"
34
- f"/threshold <base> <target> << or >> <value> - Get alerts (e.g., /threshold USD EUR < 0.90)"
35
  )
36
 
37
  @router.message(Command("current"))
@@ -110,3 +107,123 @@ async def cmd_threshold(message: types.Message):
110
  await session.commit()
111
 
112
  await message.answer(f"✅ Alert set! I will notify you when 1 {base} {condition} {value} {target}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  f"Welcome to FlyRates! 🌍💸\n\n"
29
  f"I can track real-time exchange rates for you.\n"
30
  f"Supported Currencies: {currencies_str}\n\n"
31
+ f"Use /help to see all available commands!"
 
 
 
32
  )
33
 
34
  @router.message(Command("current"))
 
107
  await session.commit()
108
 
109
  await message.answer(f"✅ Alert set! I will notify you when 1 {base} {condition} {value} {target}.")
110
+
111
+ @router.message(Command("mysubs"))
112
+ async def cmd_mysubs(message: types.Message):
113
+ """Lists all active subscriptions and thresholds for the user."""
114
+ from db.session import async_session
115
+ chat_id = message.chat.id
116
+
117
+ async with async_session() as session:
118
+ subs_result = await session.execute(select(Subscription).where(Subscription.chat_id == chat_id))
119
+ subs = subs_result.scalars().all()
120
+
121
+ thresholds_result = await session.execute(select(Threshold).where(Threshold.chat_id == chat_id))
122
+ thresholds = thresholds_result.scalars().all()
123
+
124
+ response = "📋 **Your Active Settings**\n\n"
125
+
126
+ if subs:
127
+ response += "🔔 **Subscriptions:**\n"
128
+ for sub in subs:
129
+ response += f"• {sub.base_currency} to {sub.target_currency} ({sub.frequency})\n"
130
+ else:
131
+ response += "🔔 **Subscriptions:** None\n"
132
+
133
+ response += "\n"
134
+
135
+ if thresholds:
136
+ response += "🚨 **Thresholds:**\n"
137
+ for th in thresholds:
138
+ status = " (Active)" if th.is_active else " (Triggered/Inactive)"
139
+ response += f"• {th.base_currency} to {th.target_currency} {th.condition} {th.target_value}{status}\n"
140
+ else:
141
+ response += "🚨 **Thresholds:** None\n"
142
+
143
+ await message.answer(response, parse_mode="Markdown")
144
+
145
+ @router.message(Command("unsubscribe"))
146
+ async def cmd_unsubscribe(message: types.Message):
147
+ """Removes a subscription."""
148
+ args = message.text.split()[1:]
149
+ if len(args) != 2:
150
+ await message.answer("Usage: /unsubscribe <base> <target>\nExample: /unsubscribe USD EUR")
151
+ return
152
+
153
+ base, target = args[0].upper(), args[1].upper()
154
+ chat_id = message.chat.id
155
+
156
+ from db.session import async_session
157
+ async with async_session() as session:
158
+ result = await session.execute(select(Subscription).where(
159
+ Subscription.chat_id == chat_id,
160
+ Subscription.base_currency == base,
161
+ Subscription.target_currency == target
162
+ ))
163
+ subs = result.scalars().all()
164
+
165
+ if not subs:
166
+ await message.answer(f"❌ You don't have a subscription for {base} to {target}.")
167
+ return
168
+
169
+ for sub in subs:
170
+ await session.delete(sub)
171
+ await session.commit()
172
+
173
+ await message.answer(f"✅ Unsubscribed from {base} to {target} updates.")
174
+
175
+ @router.message(Command("delthreshold"))
176
+ async def cmd_delthreshold(message: types.Message):
177
+ """Removes a threshold alert."""
178
+ args = message.text.split()[1:]
179
+ if len(args) != 2:
180
+ await message.answer("Usage: /delthreshold <base> <target>\nExample: /delthreshold USD EUR")
181
+ return
182
+
183
+ base, target = args[0].upper(), args[1].upper()
184
+ chat_id = message.chat.id
185
+
186
+ from db.session import async_session
187
+ async with async_session() as session:
188
+ result = await session.execute(select(Threshold).where(
189
+ Threshold.chat_id == chat_id,
190
+ Threshold.base_currency == base,
191
+ Threshold.target_currency == target
192
+ ))
193
+ thresholds = result.scalars().all()
194
+
195
+ if not thresholds:
196
+ await message.answer(f"❌ You don't have a threshold alert for {base} to {target}.")
197
+ return
198
+
199
+ for th in thresholds:
200
+ await session.delete(th)
201
+ await session.commit()
202
+
203
+ await message.answer(f"✅ Deleted threshold alerts for {base} to {target}.")
204
+
205
+ @router.message(Command("help"))
206
+ async def cmd_help(message: types.Message):
207
+ """Shows detailed usage instructions."""
208
+ currencies_str = ", ".join(ALLOWED_CURRENCIES)
209
+ help_text = (
210
+ "📚 **FlyRates Bot Help** 📚\n\n"
211
+ "Here are the commands you can use:\n\n"
212
+ "**Basic Commands**\n"
213
+ "`/start` - Get welcome message\n"
214
+ "`/help` - Show this help menu\n\n"
215
+ "**Rates & Updates**\n"
216
+ "`/current <base> <target>` - Get the live rate\n"
217
+ "`/subscribe <base> <target> [daily/hourly]` - Automate updates\n"
218
+ "`/threshold <base> <target> <condition> <value>` - Get custom alerts (<, >, <=, >=)\n\n"
219
+ "**Management**\n"
220
+ "`/mysubs` - View all your active subscriptions and alerts\n"
221
+ "`/unsubscribe <base> <target>` - Remove a subscription\n"
222
+ "`/delthreshold <base> <target>` - Remove an alert\n\n"
223
+ f"**Supported Currencies:** {currencies_str}\n\n"
224
+ "**Examples:**\n"
225
+ "`/current USD EUR`\n"
226
+ "`/subscribe USD EUR hourly`\n"
227
+ "`/threshold USD EUR < 0.90`"
228
+ )
229
+ await message.answer(help_text, parse_mode="Markdown")
core/config.py CHANGED
@@ -1,4 +1,5 @@
1
  from pydantic_settings import BaseSettings, SettingsConfigDict
 
2
  from typing import Optional
3
 
4
  class Settings(BaseSettings):
@@ -8,6 +9,16 @@ class Settings(BaseSettings):
8
  database_url: str = "sqlite+aiosqlite:///./flyrates.db"
9
  log_level: str = "INFO"
10
 
 
 
 
 
 
 
 
 
 
 
11
  model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
12
 
13
  settings = Settings()
 
1
  from pydantic_settings import BaseSettings, SettingsConfigDict
2
+ from pydantic import field_validator
3
  from typing import Optional
4
 
5
  class Settings(BaseSettings):
 
9
  database_url: str = "sqlite+aiosqlite:///./flyrates.db"
10
  log_level: str = "INFO"
11
 
12
+ @field_validator("database_url", mode="before")
13
+ @classmethod
14
+ def assemble_db_connection(cls, v: str) -> str:
15
+ if isinstance(v, str):
16
+ if v.startswith("postgres://"):
17
+ return v.replace("postgres://", "postgresql+asyncpg://", 1)
18
+ if v.startswith("postgresql://") and not v.startswith("postgresql+asyncpg://"):
19
+ return v.replace("postgresql://", "postgresql+asyncpg://", 1)
20
+ return v
21
+
22
  model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
23
 
24
  settings = Settings()
db/models.py CHANGED
@@ -1,4 +1,4 @@
1
- from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, ForeignKey
2
  from sqlalchemy.orm import declarative_base, relationship
3
  from datetime import datetime, timezone
4
 
@@ -7,7 +7,7 @@ Base = declarative_base()
7
  class User(Base):
8
  __tablename__ = "users"
9
 
10
- chat_id = Column(Integer, primary_key=True, index=True)
11
  created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
12
 
13
  subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
@@ -17,7 +17,7 @@ class Subscription(Base):
17
  __tablename__ = "subscriptions"
18
 
19
  id = Column(Integer, primary_key=True, autoincrement=True)
20
- chat_id = Column(Integer, ForeignKey("users.chat_id"))
21
  base_currency = Column(String(3), nullable=False)
22
  target_currency = Column(String(3), nullable=False)
23
  # Frequency could be 'daily', 'hourly'
@@ -30,7 +30,7 @@ class Threshold(Base):
30
  __tablename__ = "thresholds"
31
 
32
  id = Column(Integer, primary_key=True, autoincrement=True)
33
- chat_id = Column(Integer, ForeignKey("users.chat_id"))
34
  base_currency = Column(String(3), nullable=False)
35
  target_currency = Column(String(3), nullable=False)
36
  condition = Column(String(5), nullable=False) # e.g., '<', '>', '<=', '>='
 
1
+ from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, ForeignKey, BigInteger
2
  from sqlalchemy.orm import declarative_base, relationship
3
  from datetime import datetime, timezone
4
 
 
7
  class User(Base):
8
  __tablename__ = "users"
9
 
10
+ chat_id = Column(BigInteger, primary_key=True, index=True)
11
  created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
12
 
13
  subscriptions = relationship("Subscription", back_populates="user", cascade="all, delete-orphan")
 
17
  __tablename__ = "subscriptions"
18
 
19
  id = Column(Integer, primary_key=True, autoincrement=True)
20
+ chat_id = Column(BigInteger, ForeignKey("users.chat_id"))
21
  base_currency = Column(String(3), nullable=False)
22
  target_currency = Column(String(3), nullable=False)
23
  # Frequency could be 'daily', 'hourly'
 
30
  __tablename__ = "thresholds"
31
 
32
  id = Column(Integer, primary_key=True, autoincrement=True)
33
+ chat_id = Column(BigInteger, ForeignKey("users.chat_id"))
34
  base_currency = Column(String(3), nullable=False)
35
  target_currency = Column(String(3), nullable=False)
36
  condition = Column(String(5), nullable=False) # e.g., '<', '>', '<=', '>='
db/session.py CHANGED
@@ -5,11 +5,16 @@ from db.models import Base
5
 
6
  logger = logging.getLogger(__name__)
7
 
 
 
 
 
8
  # Initialize the async engine
9
  engine = create_async_engine(
10
  settings.database_url,
11
  echo=(settings.log_level == "DEBUG"),
12
  future=True,
 
13
  )
14
 
15
  # Create an async session factory
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
8
+ connect_args = {}
9
+ if "supabase" in settings.database_url or "postgres" in settings.database_url:
10
+ connect_args["ssl"] = True
11
+
12
  # Initialize the async engine
13
  engine = create_async_engine(
14
  settings.database_url,
15
  echo=(settings.log_level == "DEBUG"),
16
  future=True,
17
+ connect_args=connect_args
18
  )
19
 
20
  # Create an async session factory