Almaatla commited on
Commit
489b626
·
verified ·
1 Parent(s): 662cd58

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +21 -29
app/main.py CHANGED
@@ -1,10 +1,11 @@
1
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Header, HTTPException
2
  from fastapi.responses import HTMLResponse, JSONResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.staticfiles import StaticFiles
 
 
5
  from uuid import uuid4
6
  from datetime import datetime, timezone, timedelta
7
- from typing import Optional, Dict, Any, List
8
  import os
9
  import asyncio
10
 
@@ -53,7 +54,7 @@ async def create_post(
53
  ):
54
  client_ip = get_client_ip(request)
55
 
56
- # 1. IP block check
57
  if is_ip_blocked(client_ip):
58
  logger.warning(
59
  f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /api/v1/posts"
@@ -63,7 +64,7 @@ async def create_post(
63
  content={"error": "Too many failed authentication attempts"},
64
  )
65
 
66
- # 2. API key validation
67
  if not x_api_key or x_api_key != POSTER_KEY:
68
  record_auth_failure(client_ip)
69
  logger.warning(
@@ -71,13 +72,12 @@ async def create_post(
71
  )
72
  raise HTTPException(status_code=401, detail="Invalid API key")
73
 
74
- # 3. Validate content length (pydantic already enforces max_length, but do a clear error)
75
  if len(payload.content) > 1000:
76
  raise HTTPException(
77
  status_code=413, detail="Content exceeds 1000 character limit"
78
  )
79
 
80
- # 4. Create message
81
  now = datetime.now(timezone.utc)
82
  msg = Message(
83
  id=str(uuid4()),
@@ -88,25 +88,25 @@ async def create_post(
88
  metadata=payload.metadata or {},
89
  )
90
 
91
- # Insert + prune by age
92
  async with storage_lock:
93
  messages.append(msg)
 
 
94
  cutoff = now - timedelta(hours=48)
95
- # Age-based pruning (after maxlen already applied)
96
  tmp = [m for m in messages if m.timestamp >= cutoff]
97
  messages.clear()
98
  for m in tmp:
99
  messages.append(m)
100
 
101
- # Broadcast to readers
102
- payload_out = NewPostPayload(type="new_post", message=msg).dict()
103
- stale_connections = []
104
  for ws in connected_readers:
105
  try:
106
  await ws.send_json(payload_out)
107
  except Exception:
108
- stale_connections.append(ws)
109
- for ws in stale_connections:
110
  connected_readers.discard(ws)
111
 
112
  logger.info(
@@ -126,10 +126,9 @@ async def create_post(
126
 
127
  @app.websocket("/ws")
128
  async def websocket_endpoint(websocket: WebSocket):
129
- # We need to manually pull query params, including api_key
130
  client_ip = websocket.client.host if websocket.client else "unknown"
131
 
132
- # 1. IP block check
133
  if is_ip_blocked(client_ip):
134
  logger.warning(
135
  f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /ws"
@@ -137,16 +136,16 @@ async def websocket_endpoint(websocket: WebSocket):
137
  await websocket.close(code=1008)
138
  return
139
 
 
140
  api_key = websocket.query_params.get("api_key")
141
  if not api_key or api_key != READER_KEY:
142
  record_auth_failure(client_ip)
143
  logger.warning(
144
  f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_reader_key - Endpoint: /ws"
145
  )
146
- # Send error payload once; then close.
147
  await websocket.accept()
148
  await websocket.send_json(
149
- ErrorPayload(type="error", error="Invalid API key").dict()
150
  )
151
  await websocket.close(code=1008)
152
  return
@@ -157,28 +156,21 @@ async def websocket_endpoint(websocket: WebSocket):
157
  )
158
  record_auth_success(client_ip)
159
 
160
- # Register connection
161
  async with storage_lock:
162
  connected_readers.add(websocket)
163
- # Send history
164
- history_payload = HistoryPayload(
165
- type="history",
166
- messages=list(messages),
167
- )
168
- await websocket.send_json(history_payload.dict())
169
 
170
  try:
171
  while True:
172
- # Keep the connection alive and support ping/pong at transport layer
173
- # We don't expect client messages; just read and ignore or timeout.
174
  try:
175
  await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
176
  except asyncio.TimeoutError:
177
- # Send ping as a small text; client can ignore it.
178
  await websocket.send_json({"type": "ping"})
179
  except WebSocketDisconnect:
180
  pass
181
  finally:
182
  async with storage_lock:
183
- if websocket in connected_readers:
184
- connected_readers.discard(websocket)
 
1
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Header, HTTPException
2
  from fastapi.responses import HTMLResponse, JSONResponse
 
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.encoders import jsonable_encoder
5
+
6
  from uuid import uuid4
7
  from datetime import datetime, timezone, timedelta
8
+ from typing import Optional
9
  import os
10
  import asyncio
11
 
 
54
  ):
55
  client_ip = get_client_ip(request)
56
 
57
+ # 1) IP block check
58
  if is_ip_blocked(client_ip):
59
  logger.warning(
60
  f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /api/v1/posts"
 
64
  content={"error": "Too many failed authentication attempts"},
65
  )
66
 
67
+ # 2) API key validation
68
  if not x_api_key or x_api_key != POSTER_KEY:
69
  record_auth_failure(client_ip)
70
  logger.warning(
 
72
  )
73
  raise HTTPException(status_code=401, detail="Invalid API key")
74
 
75
+ # 3) Defensive length check (pydantic also enforces this)
76
  if len(payload.content) > 1000:
77
  raise HTTPException(
78
  status_code=413, detail="Content exceeds 1000 character limit"
79
  )
80
 
 
81
  now = datetime.now(timezone.utc)
82
  msg = Message(
83
  id=str(uuid4()),
 
88
  metadata=payload.metadata or {},
89
  )
90
 
 
91
  async with storage_lock:
92
  messages.append(msg)
93
+
94
+ # Retention: expire >48h (50-message maxlen still takes precedence)
95
  cutoff = now - timedelta(hours=48)
 
96
  tmp = [m for m in messages if m.timestamp >= cutoff]
97
  messages.clear()
98
  for m in tmp:
99
  messages.append(m)
100
 
101
+ # Broadcast to all connected readers (JSON-safe)
102
+ payload_out = jsonable_encoder(NewPostPayload(type="new_post", message=msg))
103
+ stale = []
104
  for ws in connected_readers:
105
  try:
106
  await ws.send_json(payload_out)
107
  except Exception:
108
+ stale.append(ws)
109
+ for ws in stale:
110
  connected_readers.discard(ws)
111
 
112
  logger.info(
 
126
 
127
  @app.websocket("/ws")
128
  async def websocket_endpoint(websocket: WebSocket):
 
129
  client_ip = websocket.client.host if websocket.client else "unknown"
130
 
131
+ # 1) IP block check
132
  if is_ip_blocked(client_ip):
133
  logger.warning(
134
  f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /ws"
 
136
  await websocket.close(code=1008)
137
  return
138
 
139
+ # 2) API key validation (query param)
140
  api_key = websocket.query_params.get("api_key")
141
  if not api_key or api_key != READER_KEY:
142
  record_auth_failure(client_ip)
143
  logger.warning(
144
  f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_reader_key - Endpoint: /ws"
145
  )
 
146
  await websocket.accept()
147
  await websocket.send_json(
148
+ jsonable_encoder(ErrorPayload(type="error", error="Invalid API key"))
149
  )
150
  await websocket.close(code=1008)
151
  return
 
156
  )
157
  record_auth_success(client_ip)
158
 
159
+ # Register + send initial history (JSON-safe)
160
  async with storage_lock:
161
  connected_readers.add(websocket)
162
+ history_payload = HistoryPayload(type="history", messages=list(messages))
163
+ await websocket.send_json(jsonable_encoder(history_payload))
 
 
 
 
164
 
165
  try:
166
  while True:
167
+ # Keepalive: wait for any client message; if none, send ping
 
168
  try:
169
  await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
170
  except asyncio.TimeoutError:
 
171
  await websocket.send_json({"type": "ping"})
172
  except WebSocketDisconnect:
173
  pass
174
  finally:
175
  async with storage_lock:
176
+ connected_readers.discard(websocket)