Spaces:
Sleeping
Sleeping
credit
Browse files- app.py +9 -267
- auth_utils.py +72 -0
- dependencies.py +145 -0
- models.py +52 -1
- pytest.ini +3 -0
- requirements.txt +3 -0
- routers/auth.py +249 -0
- routers/blink.py +193 -0
- routers/general.py +29 -0
- schemas.py +15 -0
- tests/conftest.py +62 -0
- tests/test_integration.py +91 -0
app.py
CHANGED
|
@@ -6,20 +6,12 @@ decrypting it, and storing in SQLite database.
|
|
| 6 |
"""
|
| 7 |
import logging
|
| 8 |
from contextlib import asynccontextmanager
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
import ipaddress
|
| 12 |
-
import httpx
|
| 13 |
-
from fastapi import FastAPI, Query, Request, Depends, HTTPException, status
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
-
from fastapi.responses import JSONResponse
|
| 16 |
-
from fastapi.staticfiles import StaticFiles
|
| 17 |
-
from sqlalchemy.ext.asyncio import AsyncSession
|
| 18 |
-
from sqlalchemy import select, func
|
| 19 |
|
| 20 |
-
from database import
|
| 21 |
-
from
|
| 22 |
-
from encryption import decrypt_data, decrypt_multiple_blocks
|
| 23 |
|
| 24 |
# Configure logging
|
| 25 |
logging.basicConfig(
|
|
@@ -28,43 +20,6 @@ logging.basicConfig(
|
|
| 28 |
)
|
| 29 |
logger = logging.getLogger(__name__)
|
| 30 |
|
| 31 |
-
# User ID length constant
|
| 32 |
-
USER_ID_LENGTH = 20
|
| 33 |
-
|
| 34 |
-
# Geolocation API settings
|
| 35 |
-
GEOLOCATION_API_URL = "http://ip-api.com/json/{ip}?fields=status,country,regionName"
|
| 36 |
-
GEOLOCATION_TIMEOUT = 2.0 # seconds
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 40 |
-
"""
|
| 41 |
-
Get country and region for an IP address using ip-api.com.
|
| 42 |
-
|
| 43 |
-
Args:
|
| 44 |
-
ip_address: IPv4 or IPv6 address
|
| 45 |
-
|
| 46 |
-
Returns:
|
| 47 |
-
Tuple of (country, region) or (None, None) if lookup fails
|
| 48 |
-
"""
|
| 49 |
-
if not ip_address:
|
| 50 |
-
return None, None
|
| 51 |
-
|
| 52 |
-
# Skip geolocation for localhost/private IPs
|
| 53 |
-
if ip_address in ("127.0.0.1", "::1", "localhost") or ip_address.startswith(("192.168.", "10.", "172.")):
|
| 54 |
-
return None, None
|
| 55 |
-
|
| 56 |
-
try:
|
| 57 |
-
async with httpx.AsyncClient(timeout=GEOLOCATION_TIMEOUT) as client:
|
| 58 |
-
response = await client.get(GEOLOCATION_API_URL.format(ip=ip_address))
|
| 59 |
-
if response.status_code == 200:
|
| 60 |
-
data = response.json()
|
| 61 |
-
if data.get("status") == "success":
|
| 62 |
-
return data.get("country"), data.get("regionName")
|
| 63 |
-
except Exception as e:
|
| 64 |
-
logger.warning(f"Geolocation lookup failed for {ip_address}: {e}")
|
| 65 |
-
|
| 66 |
-
return None, None
|
| 67 |
-
|
| 68 |
|
| 69 |
@asynccontextmanager
|
| 70 |
async def lifespan(app: FastAPI):
|
|
@@ -96,223 +51,10 @@ app.add_middleware(
|
|
| 96 |
allow_headers=["*"],
|
| 97 |
)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
Health check endpoint.
|
| 104 |
-
|
| 105 |
-
Returns:
|
| 106 |
-
Health status of the application
|
| 107 |
-
"""
|
| 108 |
-
return {"status": "healthy", "service": "url-blink-api"}
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
@app.get("/", response_class=HTMLResponse)
|
| 112 |
-
async def root():
|
| 113 |
-
"""
|
| 114 |
-
Serve the main HTML page.
|
| 115 |
-
"""
|
| 116 |
-
import os
|
| 117 |
-
template_path = os.path.join(os.path.dirname(__file__), "templates", "index.html")
|
| 118 |
-
with open(template_path, "r") as f:
|
| 119 |
-
return HTMLResponse(content=f.read())
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
@app.get("/api/data")
|
| 123 |
-
async def get_data(
|
| 124 |
-
page: int = Query(1, ge=1, description="Page number"),
|
| 125 |
-
limit: int = Query(100, ge=1, le=500, description="Items per page"),
|
| 126 |
-
db: AsyncSession = Depends(get_db)
|
| 127 |
-
):
|
| 128 |
-
"""
|
| 129 |
-
Get paginated blink data.
|
| 130 |
-
|
| 131 |
-
Args:
|
| 132 |
-
page: Page number (1-indexed)
|
| 133 |
-
limit: Number of items per page
|
| 134 |
-
db: Database session
|
| 135 |
-
|
| 136 |
-
Returns:
|
| 137 |
-
Paginated list of blink data records
|
| 138 |
-
"""
|
| 139 |
-
try:
|
| 140 |
-
offset = (page - 1) * limit
|
| 141 |
-
|
| 142 |
-
# Get total count
|
| 143 |
-
total_result = await db.execute(select(func.count(BlinkData.id)))
|
| 144 |
-
total = total_result.scalar() or 0
|
| 145 |
-
|
| 146 |
-
# Get unique users count
|
| 147 |
-
unique_result = await db.execute(select(func.count(func.distinct(BlinkData.user_id))))
|
| 148 |
-
unique_users = unique_result.scalar() or 0
|
| 149 |
-
|
| 150 |
-
# Get paginated items
|
| 151 |
-
query = select(BlinkData).order_by(BlinkData.id.desc()).offset(offset).limit(limit)
|
| 152 |
-
result = await db.execute(query)
|
| 153 |
-
items = result.scalars().all()
|
| 154 |
-
|
| 155 |
-
return {
|
| 156 |
-
"items": [
|
| 157 |
-
{
|
| 158 |
-
"id": item.id,
|
| 159 |
-
"user_id": item.user_id,
|
| 160 |
-
"refer_url": item.refer_url,
|
| 161 |
-
"ip_address": item.ip_address,
|
| 162 |
-
"ipv4_address": item.ipv4_address,
|
| 163 |
-
"ipv6_address": item.ipv6_address,
|
| 164 |
-
"country": item.country,
|
| 165 |
-
"region": item.region,
|
| 166 |
-
"json_data": item.json_data,
|
| 167 |
-
"created_at": item.created_at.isoformat() if item.created_at else None
|
| 168 |
-
}
|
| 169 |
-
for item in items
|
| 170 |
-
],
|
| 171 |
-
"total": total,
|
| 172 |
-
"unique_users": unique_users,
|
| 173 |
-
"page": page,
|
| 174 |
-
"limit": limit
|
| 175 |
-
}
|
| 176 |
-
except Exception as e:
|
| 177 |
-
logger.error(f"Error fetching data: {e}")
|
| 178 |
-
raise HTTPException(
|
| 179 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 180 |
-
detail="Error fetching data"
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
@app.get("/blink")
|
| 185 |
-
async def blink(
|
| 186 |
-
request: Request,
|
| 187 |
-
userid: str = Query(..., description="User ID (20 chars) + encrypted data"),
|
| 188 |
-
db: AsyncSession = Depends(get_db)
|
| 189 |
-
):
|
| 190 |
-
"""
|
| 191 |
-
Process blink request with encrypted user data.
|
| 192 |
-
|
| 193 |
-
The userid parameter format:
|
| 194 |
-
- First 20 characters: User ID
|
| 195 |
-
- Remaining characters: Base64 encoded encrypted data
|
| 196 |
-
|
| 197 |
-
Args:
|
| 198 |
-
request: FastAPI request object
|
| 199 |
-
userid: Combined user ID and encrypted data
|
| 200 |
-
db: Database session
|
| 201 |
-
|
| 202 |
-
Returns:
|
| 203 |
-
Success response with processing status
|
| 204 |
-
"""
|
| 205 |
-
try:
|
| 206 |
-
# Validate minimum length
|
| 207 |
-
if len(userid) < USER_ID_LENGTH:
|
| 208 |
-
raise HTTPException(
|
| 209 |
-
status_code=status.HTTP_400_BAD_REQUEST,
|
| 210 |
-
detail=f"Parameter 'userid' must be at least {USER_ID_LENGTH} characters"
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
# Extract user_id (first 20 characters)
|
| 214 |
-
user_id = userid[:USER_ID_LENGTH]
|
| 215 |
-
|
| 216 |
-
# Extract encrypted data (remaining characters)
|
| 217 |
-
encrypted_data = userid[USER_ID_LENGTH:]
|
| 218 |
-
|
| 219 |
-
if not encrypted_data:
|
| 220 |
-
logger.warning(f"No encrypted data received for user: {user_id}")
|
| 221 |
-
# Still store the record with empty json_data
|
| 222 |
-
decrypted_results = []
|
| 223 |
-
else:
|
| 224 |
-
# Try to decrypt - might be single or multiple blocks
|
| 225 |
-
try:
|
| 226 |
-
decrypted_results = decrypt_multiple_blocks(encrypted_data)
|
| 227 |
-
except Exception as e:
|
| 228 |
-
logger.error(f"Decryption failed for user {user_id}: {e}")
|
| 229 |
-
# Store with error information
|
| 230 |
-
decrypted_results = [{"error": str(e), "raw_encrypted": encrypted_data[:100]}]
|
| 231 |
-
|
| 232 |
-
# Get referer URL from headers (full URL, not just origin)
|
| 233 |
-
refer_url = request.headers.get("referer")
|
| 234 |
-
|
| 235 |
-
# Get client IP address
|
| 236 |
-
# Check X-Forwarded-For header first (for proxies/load balancers)
|
| 237 |
-
forwarded_for = request.headers.get("x-forwarded-for")
|
| 238 |
-
if forwarded_for:
|
| 239 |
-
# X-Forwarded-For can contain multiple IPs, take the first one
|
| 240 |
-
ip_address = forwarded_for.split(",")[0].strip()
|
| 241 |
-
else:
|
| 242 |
-
# Fall back to direct client IP
|
| 243 |
-
ip_address = request.client.host if request.client else None
|
| 244 |
-
|
| 245 |
-
# Get geolocation from IP address
|
| 246 |
-
country, region = await get_geolocation(ip_address)
|
| 247 |
-
|
| 248 |
-
# Determine IPv4 vs IPv6
|
| 249 |
-
ipv4_address = None
|
| 250 |
-
ipv6_address = None
|
| 251 |
-
|
| 252 |
-
if ip_address:
|
| 253 |
-
try:
|
| 254 |
-
ip_obj = ipaddress.ip_address(ip_address)
|
| 255 |
-
if isinstance(ip_obj, ipaddress.IPv4Address):
|
| 256 |
-
ipv4_address = ip_address
|
| 257 |
-
elif isinstance(ip_obj, ipaddress.IPv6Address):
|
| 258 |
-
ipv6_address = ip_address
|
| 259 |
-
except ValueError:
|
| 260 |
-
# Invalid IP address format, just keep it in ip_address
|
| 261 |
-
pass
|
| 262 |
-
|
| 263 |
-
# Store each decrypted result as separate records
|
| 264 |
-
records_created = 0
|
| 265 |
-
for json_data in decrypted_results:
|
| 266 |
-
blink_record = BlinkData(
|
| 267 |
-
user_id=user_id,
|
| 268 |
-
refer_url=refer_url,
|
| 269 |
-
ip_address=ip_address,
|
| 270 |
-
ipv4_address=ipv4_address,
|
| 271 |
-
ipv6_address=ipv6_address,
|
| 272 |
-
country=country,
|
| 273 |
-
region=region,
|
| 274 |
-
json_data=json_data
|
| 275 |
-
)
|
| 276 |
-
db.add(blink_record)
|
| 277 |
-
records_created += 1
|
| 278 |
-
|
| 279 |
-
# If no results but we have encrypted data, store a record with the raw data reference
|
| 280 |
-
if not decrypted_results and encrypted_data:
|
| 281 |
-
blink_record = BlinkData(
|
| 282 |
-
user_id=user_id,
|
| 283 |
-
refer_url=refer_url,
|
| 284 |
-
ip_address=ip_address,
|
| 285 |
-
ipv4_address=ipv4_address,
|
| 286 |
-
ipv6_address=ipv6_address,
|
| 287 |
-
country=country,
|
| 288 |
-
region=region,
|
| 289 |
-
json_data={"encrypted_length": len(encrypted_data)}
|
| 290 |
-
)
|
| 291 |
-
db.add(blink_record)
|
| 292 |
-
records_created = 1
|
| 293 |
-
|
| 294 |
-
await db.commit()
|
| 295 |
-
|
| 296 |
-
logger.info(f"Successfully processed blink for user: {user_id}, records: {records_created}")
|
| 297 |
-
|
| 298 |
-
return JSONResponse(
|
| 299 |
-
status_code=status.HTTP_200_OK,
|
| 300 |
-
content={
|
| 301 |
-
"status": "success",
|
| 302 |
-
"user_id": user_id,
|
| 303 |
-
"records_created": records_created
|
| 304 |
-
}
|
| 305 |
-
)
|
| 306 |
-
|
| 307 |
-
except HTTPException:
|
| 308 |
-
raise
|
| 309 |
-
except Exception as e:
|
| 310 |
-
logger.error(f"Error processing blink request: {e}")
|
| 311 |
-
await db.rollback()
|
| 312 |
-
raise HTTPException(
|
| 313 |
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 314 |
-
detail="Internal server error processing request"
|
| 315 |
-
)
|
| 316 |
|
| 317 |
|
| 318 |
@app.exception_handler(Exception)
|
|
@@ -322,7 +64,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
|
| 322 |
"""
|
| 323 |
logger.error(f"Unhandled exception: {exc}")
|
| 324 |
return JSONResponse(
|
| 325 |
-
status_code=
|
| 326 |
content={"detail": "Internal server error"}
|
| 327 |
)
|
| 328 |
|
|
|
|
| 6 |
"""
|
| 7 |
import logging
|
| 8 |
from contextlib import asynccontextmanager
|
| 9 |
+
from fastapi import FastAPI, Request
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from fastapi.responses import JSONResponse
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
from database import init_db
|
| 14 |
+
from routers import auth, blink, general
|
|
|
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(
|
|
|
|
| 20 |
)
|
| 21 |
logger = logging.getLogger(__name__)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
@asynccontextmanager
|
| 25 |
async def lifespan(app: FastAPI):
|
|
|
|
| 51 |
allow_headers=["*"],
|
| 52 |
)
|
| 53 |
|
| 54 |
+
# Include Routers
|
| 55 |
+
app.include_router(general.router)
|
| 56 |
+
app.include_router(auth.router)
|
| 57 |
+
app.include_router(blink.router)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
@app.exception_handler(Exception)
|
|
|
|
| 64 |
"""
|
| 65 |
logger.error(f"Unhandled exception: {exc}")
|
| 66 |
return JSONResponse(
|
| 67 |
+
status_code=500,
|
| 68 |
content={"detail": "Internal server error"}
|
| 69 |
)
|
| 70 |
|
auth_utils.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import secrets
|
| 2 |
+
import smtplib
|
| 3 |
+
import ssl
|
| 4 |
+
from email.mime.text import MIMEText
|
| 5 |
+
from email.mime.multipart import MIMEMultipart
|
| 6 |
+
import bcrypt
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# Configure logging
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Email configuration
|
| 14 |
+
SMTP_SERVER = os.getenv("SMTP_SERVER", "127.0.0.1")
|
| 15 |
+
SMTP_PORT = int(os.getenv("SMTP_PORT", "1025"))
|
| 16 |
+
SMTP_USERNAME = os.getenv("SMTP_USERNAME", "sender@domain.com")
|
| 17 |
+
SMTP_PASSWORD = os.getenv("SMTP_PASSWORD", "yourpassword")
|
| 18 |
+
SMTP_SENDER = os.getenv("SMTP_SENDER", SMTP_USERNAME)
|
| 19 |
+
|
| 20 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 21 |
+
"""
|
| 22 |
+
Verify a password against a hash.
|
| 23 |
+
"""
|
| 24 |
+
if isinstance(hashed_password, str):
|
| 25 |
+
hashed_password = hashed_password.encode('utf-8')
|
| 26 |
+
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password)
|
| 27 |
+
|
| 28 |
+
def get_password_hash(password: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Hash a password using bcrypt.
|
| 31 |
+
"""
|
| 32 |
+
# rounds=12 as per spec
|
| 33 |
+
salt = bcrypt.gensalt(rounds=12)
|
| 34 |
+
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
| 35 |
+
return hashed.decode('utf-8')
|
| 36 |
+
|
| 37 |
+
def generate_secret_key() -> str:
|
| 38 |
+
"""
|
| 39 |
+
Generate a secure secret key starting with 'sk_'.
|
| 40 |
+
"""
|
| 41 |
+
return "sk_" + secrets.token_urlsafe(32)
|
| 42 |
+
|
| 43 |
+
def send_email(to_email: str, subject: str, body: str):
|
| 44 |
+
"""
|
| 45 |
+
Send an email using SMTP (configured for ProtonMail Bridge).
|
| 46 |
+
This function is blocking and should be run in a background task.
|
| 47 |
+
"""
|
| 48 |
+
try:
|
| 49 |
+
message = MIMEMultipart()
|
| 50 |
+
message["From"] = SMTP_SENDER
|
| 51 |
+
message["To"] = to_email
|
| 52 |
+
message["Subject"] = subject
|
| 53 |
+
|
| 54 |
+
message.attach(MIMEText(body, "plain"))
|
| 55 |
+
|
| 56 |
+
# Create a secure SSL context
|
| 57 |
+
context = ssl.create_default_context()
|
| 58 |
+
# Note: ProtonMail Bridge with "change smtp-security" to SSL usually uses self-signed certs or specific setup.
|
| 59 |
+
# If using localhost, we might need to bypass hostname check or trust the cert.
|
| 60 |
+
# For now, we'll try standard SSL context. If it fails on localhost self-signed, we might need check_hostname=False.
|
| 61 |
+
context.check_hostname = False
|
| 62 |
+
context.verify_mode = ssl.CERT_NONE
|
| 63 |
+
|
| 64 |
+
with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT, context=context) as server:
|
| 65 |
+
server.login(SMTP_USERNAME, SMTP_PASSWORD)
|
| 66 |
+
server.sendmail(SMTP_SENDER, to_email, message.as_string())
|
| 67 |
+
|
| 68 |
+
logger.info(f"Email sent successfully to {to_email}")
|
| 69 |
+
return True
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Failed to send email to {to_email}: {e}")
|
| 72 |
+
return False
|
dependencies.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from datetime import datetime, timedelta
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
import ipaddress
|
| 5 |
+
import httpx
|
| 6 |
+
from fastapi import Request, Depends, HTTPException, status
|
| 7 |
+
from sqlalchemy import select, and_
|
| 8 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 9 |
+
|
| 10 |
+
from database import get_db
|
| 11 |
+
from models import User, RateLimit
|
| 12 |
+
from auth_utils import verify_password
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
# Geolocation API settings
|
| 17 |
+
GEOLOCATION_API_URL = "http://ip-api.com/json/{ip}?fields=status,country,regionName"
|
| 18 |
+
GEOLOCATION_TIMEOUT = 2.0 # seconds
|
| 19 |
+
|
| 20 |
+
async def check_rate_limit(
|
| 21 |
+
db: AsyncSession,
|
| 22 |
+
identifier: str,
|
| 23 |
+
endpoint: str,
|
| 24 |
+
limit: int,
|
| 25 |
+
window_minutes: int
|
| 26 |
+
) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Check if request is within rate limits.
|
| 29 |
+
Returns True if allowed, False if limit exceeded.
|
| 30 |
+
"""
|
| 31 |
+
now = datetime.utcnow()
|
| 32 |
+
window_start = now - timedelta(minutes=window_minutes)
|
| 33 |
+
|
| 34 |
+
# Check existing limit
|
| 35 |
+
query = select(RateLimit).where(
|
| 36 |
+
and_(
|
| 37 |
+
RateLimit.identifier == identifier,
|
| 38 |
+
RateLimit.endpoint == endpoint,
|
| 39 |
+
RateLimit.window_start >= window_start
|
| 40 |
+
)
|
| 41 |
+
)
|
| 42 |
+
result = await db.execute(query)
|
| 43 |
+
rate_limit = result.scalar_one_or_none()
|
| 44 |
+
|
| 45 |
+
if rate_limit:
|
| 46 |
+
if rate_limit.attempts >= limit:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
# Increment attempts
|
| 50 |
+
rate_limit.attempts += 1
|
| 51 |
+
await db.commit()
|
| 52 |
+
return True
|
| 53 |
+
else:
|
| 54 |
+
# Create new rate limit record
|
| 55 |
+
new_limit = RateLimit(
|
| 56 |
+
identifier=identifier,
|
| 57 |
+
endpoint=endpoint,
|
| 58 |
+
attempts=1,
|
| 59 |
+
window_start=now,
|
| 60 |
+
expires_at=now + timedelta(minutes=window_minutes)
|
| 61 |
+
)
|
| 62 |
+
db.add(new_limit)
|
| 63 |
+
await db.commit()
|
| 64 |
+
return True
|
| 65 |
+
|
| 66 |
+
async def verify_credits(
|
| 67 |
+
req: Request,
|
| 68 |
+
db: AsyncSession = Depends(get_db)
|
| 69 |
+
) -> User:
|
| 70 |
+
"""
|
| 71 |
+
Dependency to validate secret key and deduct credits.
|
| 72 |
+
"""
|
| 73 |
+
secret_key = req.headers.get("X-Secret-Key")
|
| 74 |
+
if not secret_key:
|
| 75 |
+
raise HTTPException(
|
| 76 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 77 |
+
detail="Missing X-Secret-Key header"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Validate secret key format
|
| 81 |
+
if not secret_key.startswith("sk_"):
|
| 82 |
+
raise HTTPException(
|
| 83 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 84 |
+
detail="Invalid secret key format"
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Find user
|
| 88 |
+
query = select(User).where(User.is_active == True)
|
| 89 |
+
result = await db.execute(query)
|
| 90 |
+
users = result.scalars().all()
|
| 91 |
+
|
| 92 |
+
valid_user = None
|
| 93 |
+
for user in users:
|
| 94 |
+
if verify_password(secret_key, user.secret_key_hash):
|
| 95 |
+
valid_user = user
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
if not valid_user:
|
| 99 |
+
raise HTTPException(
|
| 100 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 101 |
+
detail="Invalid secret key"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Check credits
|
| 105 |
+
if valid_user.credits <= 0:
|
| 106 |
+
raise HTTPException(
|
| 107 |
+
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
| 108 |
+
detail="Insufficient credits"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Deduct credit
|
| 112 |
+
valid_user.credits -= 1
|
| 113 |
+
valid_user.last_used_at = datetime.utcnow()
|
| 114 |
+
await db.commit()
|
| 115 |
+
|
| 116 |
+
return valid_user
|
| 117 |
+
|
| 118 |
+
async def get_geolocation(ip_address: str) -> Tuple[Optional[str], Optional[str]]:
|
| 119 |
+
"""
|
| 120 |
+
Get country and region for an IP address using ip-api.com.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
ip_address: IPv4 or IPv6 address
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Tuple of (country, region) or (None, None) if lookup fails
|
| 127 |
+
"""
|
| 128 |
+
if not ip_address:
|
| 129 |
+
return None, None
|
| 130 |
+
|
| 131 |
+
# Skip geolocation for localhost/private IPs
|
| 132 |
+
if ip_address in ("127.0.0.1", "::1", "localhost") or ip_address.startswith(("192.168.", "10.", "172.")):
|
| 133 |
+
return None, None
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
async with httpx.AsyncClient(timeout=GEOLOCATION_TIMEOUT) as client:
|
| 137 |
+
response = await client.get(GEOLOCATION_API_URL.format(ip=ip_address))
|
| 138 |
+
if response.status_code == 200:
|
| 139 |
+
data = response.json()
|
| 140 |
+
if data.get("status") == "success":
|
| 141 |
+
return data.get("country"), data.get("regionName")
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logger.warning(f"Geolocation lookup failed for {ip_address}: {e}")
|
| 144 |
+
|
| 145 |
+
return None, None
|
models.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
SQLAlchemy models for the URL Blink application.
|
| 3 |
"""
|
| 4 |
-
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON
|
| 5 |
from sqlalchemy.sql import func
|
| 6 |
from database import Base
|
| 7 |
|
|
@@ -32,3 +32,54 @@ class BlinkData(Base):
|
|
| 32 |
|
| 33 |
def __repr__(self):
|
| 34 |
return f"<BlinkData(id={self.id}, user_id={self.user_id})>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
SQLAlchemy models for the URL Blink application.
|
| 3 |
"""
|
| 4 |
+
from sqlalchemy import Column, Integer, String, Text, DateTime, JSON, Boolean
|
| 5 |
from sqlalchemy.sql import func
|
| 6 |
from database import Base
|
| 7 |
|
|
|
|
| 32 |
|
| 33 |
def __repr__(self):
|
| 34 |
return f"<BlinkData(id={self.id}, user_id={self.user_id})>"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class User(Base):
|
| 38 |
+
"""
|
| 39 |
+
User model for credit system.
|
| 40 |
+
"""
|
| 41 |
+
__tablename__ = "users"
|
| 42 |
+
|
| 43 |
+
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
| 44 |
+
user_id = Column(String(50), unique=True, index=True, nullable=False) # Backend generated UUID
|
| 45 |
+
temp_user_id = Column(String(50), index=True, nullable=True) # From frontend
|
| 46 |
+
email = Column(String(255), unique=True, index=True, nullable=False)
|
| 47 |
+
secret_key_hash = Column(String(255), nullable=False)
|
| 48 |
+
credits = Column(Integer, default=100)
|
| 49 |
+
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
| 50 |
+
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
| 51 |
+
last_used_at = Column(DateTime(timezone=True), nullable=True)
|
| 52 |
+
is_active = Column(Boolean, default=True)
|
| 53 |
+
|
| 54 |
+
def __repr__(self):
|
| 55 |
+
return f"<User(id={self.id}, email={self.email})>"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RateLimit(Base):
|
| 59 |
+
"""
|
| 60 |
+
Rate limit tracking table.
|
| 61 |
+
"""
|
| 62 |
+
__tablename__ = "rate_limits"
|
| 63 |
+
|
| 64 |
+
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
| 65 |
+
identifier = Column(String(255), index=True, nullable=False) # IP or email
|
| 66 |
+
endpoint = Column(String(255), index=True, nullable=False)
|
| 67 |
+
attempts = Column(Integer, default=0)
|
| 68 |
+
window_start = Column(DateTime(timezone=True), nullable=False)
|
| 69 |
+
expires_at = Column(DateTime(timezone=True), nullable=False)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AuditLog(Base):
|
| 73 |
+
"""
|
| 74 |
+
Audit log for security events.
|
| 75 |
+
"""
|
| 76 |
+
__tablename__ = "audit_logs"
|
| 77 |
+
|
| 78 |
+
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
| 79 |
+
user_id = Column(String(50), nullable=True)
|
| 80 |
+
action = Column(String(50), nullable=False)
|
| 81 |
+
ip_address = Column(String(45), nullable=False)
|
| 82 |
+
user_agent = Column(String(255), nullable=True)
|
| 83 |
+
status = Column(String(20), nullable=False)
|
| 84 |
+
error_message = Column(Text, nullable=True)
|
| 85 |
+
timestamp = Column(DateTime(timezone=True), server_default=func.now())
|
pytest.ini
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
| 3 |
+
asyncio_default_fixture_loop_scope = function
|
requirements.txt
CHANGED
|
@@ -6,3 +6,6 @@ aiosqlite>=0.19.0
|
|
| 6 |
cryptography>=41.0.0
|
| 7 |
pydantic>=2.0.0
|
| 8 |
httpx>=0.25.0
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
cryptography>=41.0.0
|
| 7 |
pydantic>=2.0.0
|
| 8 |
httpx>=0.25.0
|
| 9 |
+
|
| 10 |
+
passlib[bcrypt]>=1.7.4
|
| 11 |
+
email-validator>=2.0.0
|
routers/auth.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Request, BackgroundTasks
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
+
from sqlalchemy import select
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
from database import get_db
|
| 9 |
+
from models import User, AuditLog
|
| 10 |
+
from schemas import CheckRegistrationRequest, RegisterRequest, ValidateRequest, ResetRequest
|
| 11 |
+
from auth_utils import get_password_hash, verify_password, generate_secret_key, send_email
|
| 12 |
+
from dependencies import check_rate_limit
|
| 13 |
+
|
| 14 |
+
router = APIRouter(prefix="/auth", tags=["auth"])
|
| 15 |
+
|
| 16 |
+
@router.post("/check-registration")
|
| 17 |
+
async def check_registration(
|
| 18 |
+
request: CheckRegistrationRequest,
|
| 19 |
+
req: Request,
|
| 20 |
+
db: AsyncSession = Depends(get_db)
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Check if a temporary user_id has completed registration.
|
| 24 |
+
"""
|
| 25 |
+
# Rate Limit: 10 requests per minute per IP
|
| 26 |
+
ip = req.client.host
|
| 27 |
+
if not await check_rate_limit(db, ip, "/auth/check-registration", 10, 1):
|
| 28 |
+
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests")
|
| 29 |
+
|
| 30 |
+
query = select(User).where(User.temp_user_id == request.user_id)
|
| 31 |
+
result = await db.execute(query)
|
| 32 |
+
user = result.scalar_one_or_none()
|
| 33 |
+
|
| 34 |
+
return {"is_registered": user is not None}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@router.post("/register")
|
| 38 |
+
async def register(
|
| 39 |
+
request: RegisterRequest,
|
| 40 |
+
req: Request,
|
| 41 |
+
background_tasks: BackgroundTasks,
|
| 42 |
+
db: AsyncSession = Depends(get_db)
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Register new user, generate secret key, send email.
|
| 46 |
+
"""
|
| 47 |
+
# Rate Limit: 5 registrations per hour per IP
|
| 48 |
+
ip = req.client.host
|
| 49 |
+
if not await check_rate_limit(db, ip, "/auth/register", 5, 60):
|
| 50 |
+
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many registration attempts")
|
| 51 |
+
|
| 52 |
+
# Check Email Already Registered
|
| 53 |
+
query = select(User).where(User.email == request.email)
|
| 54 |
+
result = await db.execute(query)
|
| 55 |
+
if result.scalar_one_or_none():
|
| 56 |
+
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Email already registered")
|
| 57 |
+
|
| 58 |
+
# Check temp_user_id Already Registered
|
| 59 |
+
query = select(User).where(User.temp_user_id == request.user_id)
|
| 60 |
+
result = await db.execute(query)
|
| 61 |
+
if result.scalar_one_or_none():
|
| 62 |
+
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="User already registered")
|
| 63 |
+
|
| 64 |
+
# Generate Secret Key
|
| 65 |
+
secret_key = generate_secret_key()
|
| 66 |
+
secret_key_hash = get_password_hash(secret_key)
|
| 67 |
+
backend_user_id = "usr_" + str(uuid.uuid4())
|
| 68 |
+
|
| 69 |
+
# Create User
|
| 70 |
+
new_user = User(
|
| 71 |
+
user_id=backend_user_id,
|
| 72 |
+
temp_user_id=request.user_id,
|
| 73 |
+
email=request.email,
|
| 74 |
+
secret_key_hash=secret_key_hash,
|
| 75 |
+
credits=100
|
| 76 |
+
)
|
| 77 |
+
db.add(new_user)
|
| 78 |
+
|
| 79 |
+
# Log Audit
|
| 80 |
+
audit_log = AuditLog(
|
| 81 |
+
user_id=backend_user_id,
|
| 82 |
+
action="register",
|
| 83 |
+
ip_address=ip,
|
| 84 |
+
status="success"
|
| 85 |
+
)
|
| 86 |
+
db.add(audit_log)
|
| 87 |
+
|
| 88 |
+
await db.commit()
|
| 89 |
+
|
| 90 |
+
# Send Email (Async)
|
| 91 |
+
email_body = f"""Hello,
|
| 92 |
+
|
| 93 |
+
Your secret key is: {secret_key}
|
| 94 |
+
|
| 95 |
+
Please save this key securely.
|
| 96 |
+
You'll need it to access your credits.
|
| 97 |
+
|
| 98 |
+
If you lose this key, use the 'Forgot Key'
|
| 99 |
+
option with this email address.
|
| 100 |
+
|
| 101 |
+
Credits: 100
|
| 102 |
+
Valid from: {datetime.now().strftime('%Y-%m-%d')}
|
| 103 |
+
|
| 104 |
+
Do not share this key with anyone."""
|
| 105 |
+
|
| 106 |
+
background_tasks.add_task(
|
| 107 |
+
send_email,
|
| 108 |
+
request.email,
|
| 109 |
+
"Your Secret Key - Credit System",
|
| 110 |
+
email_body
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
return {"success": True, "message": "Registration successful. Check your email."}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@router.post("/validate")
|
| 117 |
+
async def validate_key(
|
| 118 |
+
request: ValidateRequest,
|
| 119 |
+
req: Request,
|
| 120 |
+
db: AsyncSession = Depends(get_db)
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Validate secret key and return user info.
|
| 124 |
+
"""
|
| 125 |
+
# Rate Limit: 20 validations per hour per IP
|
| 126 |
+
ip = req.client.host
|
| 127 |
+
if not await check_rate_limit(db, ip, "/auth/validate", 20, 60):
|
| 128 |
+
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many validation attempts")
|
| 129 |
+
|
| 130 |
+
query = select(User).where(User.is_active == True)
|
| 131 |
+
result = await db.execute(query)
|
| 132 |
+
users = result.scalars().all()
|
| 133 |
+
|
| 134 |
+
valid_user = None
|
| 135 |
+
for user in users:
|
| 136 |
+
if verify_password(request.secret_key, user.secret_key_hash):
|
| 137 |
+
valid_user = user
|
| 138 |
+
break
|
| 139 |
+
|
| 140 |
+
if valid_user:
|
| 141 |
+
# Update last_used_at
|
| 142 |
+
valid_user.last_used_at = datetime.utcnow()
|
| 143 |
+
|
| 144 |
+
# Log Audit
|
| 145 |
+
audit_log = AuditLog(
|
| 146 |
+
user_id=valid_user.user_id,
|
| 147 |
+
action="validate",
|
| 148 |
+
ip_address=ip,
|
| 149 |
+
status="success"
|
| 150 |
+
)
|
| 151 |
+
db.add(audit_log)
|
| 152 |
+
await db.commit()
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"valid": True,
|
| 156 |
+
"user_id": valid_user.user_id,
|
| 157 |
+
"credits": valid_user.credits,
|
| 158 |
+
"message": "Valid key"
|
| 159 |
+
}
|
| 160 |
+
else:
|
| 161 |
+
# Log Audit
|
| 162 |
+
audit_log = AuditLog(
|
| 163 |
+
user_id=None,
|
| 164 |
+
action="validate",
|
| 165 |
+
ip_address=ip,
|
| 166 |
+
status="failed",
|
| 167 |
+
error_message="Invalid key"
|
| 168 |
+
)
|
| 169 |
+
db.add(audit_log)
|
| 170 |
+
await db.commit()
|
| 171 |
+
|
| 172 |
+
return JSONResponse(
|
| 173 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 174 |
+
content={"valid": False, "message": "Invalid secret key"}
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@router.post("/reset")
|
| 179 |
+
async def reset_key(
|
| 180 |
+
request: ResetRequest,
|
| 181 |
+
req: Request,
|
| 182 |
+
background_tasks: BackgroundTasks,
|
| 183 |
+
db: AsyncSession = Depends(get_db)
|
| 184 |
+
):
|
| 185 |
+
"""
|
| 186 |
+
Reset/recover secret key via email.
|
| 187 |
+
"""
|
| 188 |
+
# Rate Limit: 3 reset attempts per hour per email
|
| 189 |
+
if not await check_rate_limit(db, request.email, "/auth/reset", 3, 60):
|
| 190 |
+
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many reset attempts")
|
| 191 |
+
|
| 192 |
+
query = select(User).where(User.email == request.email)
|
| 193 |
+
result = await db.execute(query)
|
| 194 |
+
user = result.scalar_one_or_none()
|
| 195 |
+
|
| 196 |
+
ip = req.client.host
|
| 197 |
+
|
| 198 |
+
if user:
|
| 199 |
+
# Generate New Secret Key
|
| 200 |
+
new_secret_key = generate_secret_key()
|
| 201 |
+
new_secret_key_hash = get_password_hash(new_secret_key)
|
| 202 |
+
|
| 203 |
+
user.secret_key_hash = new_secret_key_hash
|
| 204 |
+
user.updated_at = datetime.utcnow()
|
| 205 |
+
|
| 206 |
+
# Log Audit
|
| 207 |
+
audit_log = AuditLog(
|
| 208 |
+
user_id=user.user_id,
|
| 209 |
+
action="reset",
|
| 210 |
+
ip_address=ip,
|
| 211 |
+
status="success"
|
| 212 |
+
)
|
| 213 |
+
db.add(audit_log)
|
| 214 |
+
await db.commit()
|
| 215 |
+
|
| 216 |
+
# Send Email
|
| 217 |
+
email_body = f"""Hello,
|
| 218 |
+
|
| 219 |
+
You requested a secret key reset.
|
| 220 |
+
|
| 221 |
+
Your NEW secret key is: {new_secret_key}
|
| 222 |
+
|
| 223 |
+
Your old secret key is now invalid.
|
| 224 |
+
|
| 225 |
+
If you didn't request this, please
|
| 226 |
+
contact support immediately.
|
| 227 |
+
|
| 228 |
+
Current Credits: {user.credits}"""
|
| 229 |
+
|
| 230 |
+
background_tasks.add_task(
|
| 231 |
+
send_email,
|
| 232 |
+
request.email,
|
| 233 |
+
"Your New Secret Key",
|
| 234 |
+
email_body
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
# Log Audit (failed/not found)
|
| 238 |
+
audit_log = AuditLog(
|
| 239 |
+
user_id=None,
|
| 240 |
+
action="reset",
|
| 241 |
+
ip_address=ip,
|
| 242 |
+
status="failed",
|
| 243 |
+
error_message="Email not found"
|
| 244 |
+
)
|
| 245 |
+
db.add(audit_log)
|
| 246 |
+
await db.commit()
|
| 247 |
+
|
| 248 |
+
# Always return success
|
| 249 |
+
return {"success": True, "message": "If this email is registered, reset instructions have been sent."}
|
routers/blink.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Query, Request
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
from sqlalchemy.ext.asyncio import AsyncSession
|
| 4 |
+
from sqlalchemy import select, func
|
| 5 |
+
import ipaddress
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from database import get_db
|
| 9 |
+
from models import BlinkData
|
| 10 |
+
from encryption import decrypt_multiple_blocks
|
| 11 |
+
from dependencies import get_geolocation
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
router = APIRouter()
|
| 16 |
+
|
| 17 |
+
# User ID length constant
|
| 18 |
+
USER_ID_LENGTH = 20
|
| 19 |
+
|
| 20 |
+
@router.get("/api/data")
|
| 21 |
+
async def get_data(
|
| 22 |
+
page: int = Query(1, ge=1, description="Page number"),
|
| 23 |
+
limit: int = Query(100, ge=1, le=500, description="Items per page"),
|
| 24 |
+
db: AsyncSession = Depends(get_db)
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Get paginated blink data.
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
offset = (page - 1) * limit
|
| 31 |
+
|
| 32 |
+
# Get total count
|
| 33 |
+
total_result = await db.execute(select(func.count(BlinkData.id)))
|
| 34 |
+
total = total_result.scalar() or 0
|
| 35 |
+
|
| 36 |
+
# Get unique users count
|
| 37 |
+
unique_result = await db.execute(select(func.count(func.distinct(BlinkData.user_id))))
|
| 38 |
+
unique_users = unique_result.scalar() or 0
|
| 39 |
+
|
| 40 |
+
# Get paginated items
|
| 41 |
+
query = select(BlinkData).order_by(BlinkData.id.desc()).offset(offset).limit(limit)
|
| 42 |
+
result = await db.execute(query)
|
| 43 |
+
items = result.scalars().all()
|
| 44 |
+
|
| 45 |
+
return {
|
| 46 |
+
"items": [
|
| 47 |
+
{
|
| 48 |
+
"id": item.id,
|
| 49 |
+
"user_id": item.user_id,
|
| 50 |
+
"refer_url": item.refer_url,
|
| 51 |
+
"ip_address": item.ip_address,
|
| 52 |
+
"ipv4_address": item.ipv4_address,
|
| 53 |
+
"ipv6_address": item.ipv6_address,
|
| 54 |
+
"country": item.country,
|
| 55 |
+
"region": item.region,
|
| 56 |
+
"json_data": item.json_data,
|
| 57 |
+
"created_at": item.created_at.isoformat() if item.created_at else None
|
| 58 |
+
}
|
| 59 |
+
for item in items
|
| 60 |
+
],
|
| 61 |
+
"total": total,
|
| 62 |
+
"unique_users": unique_users,
|
| 63 |
+
"page": page,
|
| 64 |
+
"limit": limit
|
| 65 |
+
}
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Error fetching data: {e}")
|
| 68 |
+
raise HTTPException(
|
| 69 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 70 |
+
detail="Error fetching data"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@router.get("/blink")
|
| 75 |
+
async def blink(
|
| 76 |
+
request: Request,
|
| 77 |
+
userid: str = Query(..., description="User ID (20 chars) + encrypted data"),
|
| 78 |
+
db: AsyncSession = Depends(get_db)
|
| 79 |
+
):
|
| 80 |
+
"""
|
| 81 |
+
Process blink request with encrypted user data.
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
# Validate minimum length
|
| 85 |
+
if len(userid) < USER_ID_LENGTH:
|
| 86 |
+
raise HTTPException(
|
| 87 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 88 |
+
detail=f"Parameter 'userid' must be at least {USER_ID_LENGTH} characters"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Extract user_id (first 20 characters)
|
| 92 |
+
user_id = userid[:USER_ID_LENGTH]
|
| 93 |
+
|
| 94 |
+
# Extract encrypted data (remaining characters)
|
| 95 |
+
encrypted_data = userid[USER_ID_LENGTH:]
|
| 96 |
+
|
| 97 |
+
if not encrypted_data:
|
| 98 |
+
logger.warning(f"No encrypted data received for user: {user_id}")
|
| 99 |
+
# Still store the record with empty json_data
|
| 100 |
+
decrypted_results = []
|
| 101 |
+
else:
|
| 102 |
+
# Try to decrypt - might be single or multiple blocks
|
| 103 |
+
try:
|
| 104 |
+
decrypted_results = decrypt_multiple_blocks(encrypted_data)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Decryption failed for user {user_id}: {e}")
|
| 107 |
+
# Store with error information
|
| 108 |
+
decrypted_results = [{"error": str(e), "raw_encrypted": encrypted_data[:100]}]
|
| 109 |
+
|
| 110 |
+
# Get referer URL from headers (full URL, not just origin)
|
| 111 |
+
refer_url = request.headers.get("referer")
|
| 112 |
+
|
| 113 |
+
# Get client IP address
|
| 114 |
+
# Check X-Forwarded-For header first (for proxies/load balancers)
|
| 115 |
+
forwarded_for = request.headers.get("x-forwarded-for")
|
| 116 |
+
if forwarded_for:
|
| 117 |
+
# X-Forwarded-For can contain multiple IPs, take the first one
|
| 118 |
+
ip_address = forwarded_for.split(",")[0].strip()
|
| 119 |
+
else:
|
| 120 |
+
# Fall back to direct client IP
|
| 121 |
+
ip_address = request.client.host if request.client else None
|
| 122 |
+
|
| 123 |
+
# Get geolocation from IP address
|
| 124 |
+
country, region = await get_geolocation(ip_address)
|
| 125 |
+
|
| 126 |
+
# Determine IPv4 vs IPv6
|
| 127 |
+
ipv4_address = None
|
| 128 |
+
ipv6_address = None
|
| 129 |
+
|
| 130 |
+
if ip_address:
|
| 131 |
+
try:
|
| 132 |
+
ip_obj = ipaddress.ip_address(ip_address)
|
| 133 |
+
if isinstance(ip_obj, ipaddress.IPv4Address):
|
| 134 |
+
ipv4_address = ip_address
|
| 135 |
+
elif isinstance(ip_obj, ipaddress.IPv6Address):
|
| 136 |
+
ipv6_address = ip_address
|
| 137 |
+
except ValueError:
|
| 138 |
+
# Invalid IP address format, just keep it in ip_address
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
# Store each decrypted result as separate records
|
| 142 |
+
records_created = 0
|
| 143 |
+
for json_data in decrypted_results:
|
| 144 |
+
blink_record = BlinkData(
|
| 145 |
+
user_id=user_id,
|
| 146 |
+
refer_url=refer_url,
|
| 147 |
+
ip_address=ip_address,
|
| 148 |
+
ipv4_address=ipv4_address,
|
| 149 |
+
ipv6_address=ipv6_address,
|
| 150 |
+
country=country,
|
| 151 |
+
region=region,
|
| 152 |
+
json_data=json_data
|
| 153 |
+
)
|
| 154 |
+
db.add(blink_record)
|
| 155 |
+
records_created += 1
|
| 156 |
+
|
| 157 |
+
# If no results but we have encrypted data, store a record with the raw data reference
|
| 158 |
+
if not decrypted_results and encrypted_data:
|
| 159 |
+
blink_record = BlinkData(
|
| 160 |
+
user_id=user_id,
|
| 161 |
+
refer_url=refer_url,
|
| 162 |
+
ip_address=ip_address,
|
| 163 |
+
ipv4_address=ipv4_address,
|
| 164 |
+
ipv6_address=ipv6_address,
|
| 165 |
+
country=country,
|
| 166 |
+
region=region,
|
| 167 |
+
json_data={"encrypted_length": len(encrypted_data)}
|
| 168 |
+
)
|
| 169 |
+
db.add(blink_record)
|
| 170 |
+
records_created = 1
|
| 171 |
+
|
| 172 |
+
await db.commit()
|
| 173 |
+
|
| 174 |
+
logger.info(f"Successfully processed blink for user: {user_id}, records: {records_created}")
|
| 175 |
+
|
| 176 |
+
return JSONResponse(
|
| 177 |
+
status_code=status.HTTP_200_OK,
|
| 178 |
+
content={
|
| 179 |
+
"status": "success",
|
| 180 |
+
"user_id": user_id,
|
| 181 |
+
"records_created": records_created
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
except HTTPException:
|
| 186 |
+
raise
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.error(f"Error processing blink request: {e}")
|
| 189 |
+
await db.rollback()
|
| 190 |
+
raise HTTPException(
|
| 191 |
+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| 192 |
+
detail="Internal server error processing request"
|
| 193 |
+
)
|
routers/general.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from fastapi.responses import HTMLResponse
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
router = APIRouter()
|
| 6 |
+
|
| 7 |
+
@router.get("/health")
|
| 8 |
+
async def health_check():
|
| 9 |
+
"""
|
| 10 |
+
Health check endpoint.
|
| 11 |
+
"""
|
| 12 |
+
return {"status": "healthy", "service": "url-blink-api"}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.get("/", response_class=HTMLResponse)
|
| 16 |
+
async def root():
|
| 17 |
+
"""
|
| 18 |
+
Serve the main HTML page.
|
| 19 |
+
"""
|
| 20 |
+
# Note: We need to go up one level to find templates since we are in routers/
|
| 21 |
+
# But actually, app.py is in root, so templates is in root.
|
| 22 |
+
# The file path logic here needs to be robust.
|
| 23 |
+
# os.path.dirname(__file__) is routers/
|
| 24 |
+
# so we need ../templates
|
| 25 |
+
|
| 26 |
+
base_dir = os.path.dirname(os.path.dirname(__file__))
|
| 27 |
+
template_path = os.path.join(base_dir, "templates", "index.html")
|
| 28 |
+
with open(template_path, "r") as f:
|
| 29 |
+
return HTMLResponse(content=f.read())
|
schemas.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, EmailStr, Field
|
| 2 |
+
|
| 3 |
+
# Pydantic Models
|
| 4 |
+
class CheckRegistrationRequest(BaseModel):
|
| 5 |
+
user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
|
| 6 |
+
|
| 7 |
+
class RegisterRequest(BaseModel):
|
| 8 |
+
user_id: str = Field(..., min_length=1, description="Temporary user ID from frontend")
|
| 9 |
+
email: EmailStr = Field(..., description="User email address")
|
| 10 |
+
|
| 11 |
+
class ValidateRequest(BaseModel):
|
| 12 |
+
secret_key: str = Field(..., min_length=35, description="Secret key starting with sk_")
|
| 13 |
+
|
| 14 |
+
class ResetRequest(BaseModel):
|
| 15 |
+
email: EmailStr = Field(..., description="User email address")
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from fastapi.testclient import TestClient
|
| 5 |
+
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
| 6 |
+
|
| 7 |
+
# Add parent directory to path to allow importing app
|
| 8 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
| 9 |
+
|
| 10 |
+
from app import app
|
| 11 |
+
from database import get_db, Base
|
| 12 |
+
|
| 13 |
+
# Use a file-based SQLite database for testing to ensure persistence
|
| 14 |
+
TEST_DATABASE_URL = "sqlite+aiosqlite:///./test_blink_data.db"
|
| 15 |
+
|
| 16 |
+
@pytest.fixture(scope="session")
|
| 17 |
+
def test_engine():
|
| 18 |
+
engine = create_async_engine(TEST_DATABASE_URL, echo=False)
|
| 19 |
+
yield engine
|
| 20 |
+
# Cleanup after session
|
| 21 |
+
if os.path.exists("./test_blink_data.db"):
|
| 22 |
+
os.remove("./test_blink_data.db")
|
| 23 |
+
|
| 24 |
+
@pytest.fixture(scope="function")
|
| 25 |
+
async def db_session(test_engine):
|
| 26 |
+
async with test_engine.begin() as conn:
|
| 27 |
+
await conn.run_sync(Base.metadata.create_all)
|
| 28 |
+
|
| 29 |
+
async_session = async_sessionmaker(
|
| 30 |
+
test_engine,
|
| 31 |
+
class_=AsyncSession,
|
| 32 |
+
expire_on_commit=False
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
async with async_session() as session:
|
| 36 |
+
yield session
|
| 37 |
+
|
| 38 |
+
# We don't drop tables here to allow persistence if needed,
|
| 39 |
+
# but for isolation we usually want to.
|
| 40 |
+
# However, the previous test relied on persistence for rate limiting.
|
| 41 |
+
# Let's keep it simple: we clear data manually if needed or rely on fresh DB per run (session scope engine).
|
| 42 |
+
# Actually, for rate limiting test to work across requests, we need persistence.
|
| 43 |
+
# But for isolation between tests, we want cleanup.
|
| 44 |
+
# The previous test_app.py had cleanup_db fixture. Let's replicate that logic in the test file or here.
|
| 45 |
+
|
| 46 |
+
# Let's just yield session here.
|
| 47 |
+
|
| 48 |
+
@pytest.fixture(scope="function")
|
| 49 |
+
def client(test_engine):
|
| 50 |
+
async def override_get_db():
|
| 51 |
+
async_session = async_sessionmaker(
|
| 52 |
+
test_engine,
|
| 53 |
+
class_=AsyncSession,
|
| 54 |
+
expire_on_commit=False
|
| 55 |
+
)
|
| 56 |
+
async with async_session() as session:
|
| 57 |
+
yield session
|
| 58 |
+
|
| 59 |
+
app.dependency_overrides[get_db] = override_get_db
|
| 60 |
+
with TestClient(app) as c:
|
| 61 |
+
yield c
|
| 62 |
+
app.dependency_overrides.clear()
|
tests/test_integration.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from unittest.mock import patch
|
| 3 |
+
import os
|
| 4 |
+
from sqlalchemy import text
|
| 5 |
+
|
| 6 |
+
# Cleanup fixture
|
| 7 |
+
@pytest.fixture(autouse=True)
|
| 8 |
+
def cleanup_db():
|
| 9 |
+
if os.path.exists("./test_blink_data.db"):
|
| 10 |
+
# We can't easily delete the file if it's open by the engine in conftest.
|
| 11 |
+
# Instead, we should probably truncate tables.
|
| 12 |
+
pass
|
| 13 |
+
yield
|
| 14 |
+
# Cleanup logic if needed
|
| 15 |
+
|
| 16 |
+
# We need a way to clear data between tests if we want isolation.
|
| 17 |
+
# Since we are using a file-based DB shared across the session (engine),
|
| 18 |
+
# we should truncate tables.
|
| 19 |
+
|
| 20 |
+
@pytest.fixture(autouse=True)
|
| 21 |
+
async def clear_tables(db_session):
|
| 22 |
+
# Truncate all tables
|
| 23 |
+
async with db_session.begin():
|
| 24 |
+
await db_session.execute(text("DELETE FROM users"))
|
| 25 |
+
await db_session.execute(text("DELETE FROM rate_limits"))
|
| 26 |
+
await db_session.execute(text("DELETE FROM audit_logs"))
|
| 27 |
+
await db_session.execute(text("DELETE FROM blink_data"))
|
| 28 |
+
await db_session.commit()
|
| 29 |
+
|
| 30 |
+
@patch("routers.auth.send_email")
|
| 31 |
+
def test_credit_system_flow(mock_send_email, client):
|
| 32 |
+
mock_send_email.return_value = True
|
| 33 |
+
|
| 34 |
+
# 1. Register
|
| 35 |
+
response = client.post("/auth/register", json={
|
| 36 |
+
"user_id": "test-user-1",
|
| 37 |
+
"email": "test@example.com"
|
| 38 |
+
})
|
| 39 |
+
assert response.status_code == 200
|
| 40 |
+
assert response.json()["success"] == True
|
| 41 |
+
|
| 42 |
+
# 2. Check registration
|
| 43 |
+
response = client.post("/auth/check-registration", json={"user_id": "test-user-1"})
|
| 44 |
+
assert response.status_code == 200
|
| 45 |
+
assert response.json()["is_registered"] == True
|
| 46 |
+
|
| 47 |
+
# 3. Validate with mocked key (we need to know the key)
|
| 48 |
+
# Since we can't easily get the key from the hashed DB, let's mock generate_secret_key
|
| 49 |
+
with patch("routers.auth.generate_secret_key", return_value="sk_test_key_1234567890123456789012345"):
|
| 50 |
+
# Register user 2
|
| 51 |
+
client.post("/auth/register", json={
|
| 52 |
+
"user_id": "test-user-2",
|
| 53 |
+
"email": "test2@example.com"
|
| 54 |
+
})
|
| 55 |
+
|
| 56 |
+
# Validate
|
| 57 |
+
response = client.post("/auth/validate", json={"secret_key": "sk_test_key_1234567890123456789012345"})
|
| 58 |
+
assert response.status_code == 200
|
| 59 |
+
assert response.json()["valid"] == True
|
| 60 |
+
assert response.json()["credits"] == 100
|
| 61 |
+
|
| 62 |
+
def test_blink_flow(client):
|
| 63 |
+
# Test Blink Endpoint
|
| 64 |
+
# We need a valid userid format: 20 chars + encrypted data
|
| 65 |
+
user_id = "12345678901234567890"
|
| 66 |
+
encrypted_data = "some_encrypted_data_base64"
|
| 67 |
+
userid_param = user_id + encrypted_data
|
| 68 |
+
|
| 69 |
+
response = client.get(f"/blink?userid={userid_param}")
|
| 70 |
+
assert response.status_code == 200
|
| 71 |
+
data = response.json()
|
| 72 |
+
assert data["status"] == "success"
|
| 73 |
+
assert data["user_id"] == user_id
|
| 74 |
+
|
| 75 |
+
# Verify data stored (via API)
|
| 76 |
+
response = client.get("/api/data")
|
| 77 |
+
assert response.status_code == 200
|
| 78 |
+
items = response.json()["items"]
|
| 79 |
+
assert len(items) > 0
|
| 80 |
+
assert items[0]["user_id"] == user_id
|
| 81 |
+
|
| 82 |
+
@patch("routers.auth.send_email")
|
| 83 |
+
def test_rate_limiting(mock_send_email, client):
|
| 84 |
+
# 10 requests should succeed
|
| 85 |
+
for _ in range(10):
|
| 86 |
+
response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
|
| 87 |
+
assert response.status_code == 200
|
| 88 |
+
|
| 89 |
+
# 11th request should fail
|
| 90 |
+
response = client.post("/auth/check-registration", json={"user_id": "rate-limit-test"})
|
| 91 |
+
assert response.status_code == 429
|