Almaatla commited on
Commit
68269b4
·
verified ·
1 Parent(s): 507fec6

Upload 2 files

Browse files

support upload of png, md and txt

Files changed (2) hide show
  1. app/main.py +362 -176
  2. app/models.py +46 -34
app/main.py CHANGED
@@ -1,176 +1,362 @@
1
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Header, HTTPException
2
- from fastapi.responses import HTMLResponse, JSONResponse
3
- from fastapi.staticfiles import StaticFiles
4
- from fastapi.encoders import jsonable_encoder
5
-
6
- from uuid import uuid4
7
- from datetime import datetime, timezone, timedelta
8
- from typing import Optional
9
- import os
10
- import asyncio
11
-
12
- from .models import Message, PostRequest, HistoryPayload, NewPostPayload, ErrorPayload
13
- from .storage import messages, connected_readers, storage_lock
14
- from .rate_limit import (
15
- is_ip_blocked,
16
- record_auth_failure,
17
- record_auth_success,
18
- get_client_ip,
19
- )
20
- from .logger_config import logger
21
-
22
- POSTER_KEY = os.getenv("POSTER_KEY", "")
23
- READER_KEY = os.getenv("READER_KEY", "")
24
-
25
- app = FastAPI()
26
-
27
- # Static HTML (root)
28
- app.mount("/static", StaticFiles(directory="app/static"), name="static")
29
-
30
-
31
- @app.get("/", response_class=HTMLResponse)
32
- async def root():
33
- with open("app/static/index.html", "r", encoding="utf-8") as f:
34
- return f.read()
35
-
36
-
37
- @app.get("/health")
38
- async def health():
39
- async with storage_lock:
40
- message_count = len(messages)
41
- connected = len(connected_readers)
42
- return {
43
- "status": "healthy",
44
- "message_count": message_count,
45
- "connected_readers": connected,
46
- }
47
-
48
-
49
- @app.post("/api/v1/posts")
50
- async def create_post(
51
- request: Request,
52
- payload: PostRequest,
53
- x_api_key: Optional[str] = Header(default=None, alias="X-API-Key"),
54
- ):
55
- client_ip = get_client_ip(request)
56
-
57
- # 1) IP block check
58
- if is_ip_blocked(client_ip):
59
- logger.warning(
60
- f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /api/v1/posts"
61
- )
62
- return JSONResponse(
63
- status_code=429,
64
- content={"error": "Too many failed authentication attempts"},
65
- )
66
-
67
- # 2) API key validation
68
- if not x_api_key or x_api_key != POSTER_KEY:
69
- record_auth_failure(client_ip)
70
- logger.warning(
71
- f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_poster_key - Endpoint: /api/v1/posts"
72
- )
73
- raise HTTPException(status_code=401, detail="Invalid API key")
74
-
75
- # 3) Defensive length check (pydantic also enforces this)
76
- if len(payload.content) > 1000:
77
- raise HTTPException(
78
- status_code=413, detail="Content exceeds 1000 character limit"
79
- )
80
-
81
- now = datetime.now(timezone.utc)
82
- msg = Message(
83
- id=str(uuid4()),
84
- poster_id=payload.poster_id,
85
- content=payload.content,
86
- timestamp=now,
87
- category=payload.category,
88
- metadata=payload.metadata or {},
89
- )
90
-
91
- async with storage_lock:
92
- messages.append(msg)
93
-
94
- # Retention: expire >48h (50-message maxlen still takes precedence)
95
- cutoff = now - timedelta(hours=48)
96
- tmp = [m for m in messages if m.timestamp >= cutoff]
97
- messages.clear()
98
- for m in tmp:
99
- messages.append(m)
100
-
101
- # Broadcast to all connected readers (JSON-safe)
102
- payload_out = jsonable_encoder(NewPostPayload(type="new_post", message=msg))
103
- stale = []
104
- for ws in connected_readers:
105
- try:
106
- await ws.send_json(payload_out)
107
- except Exception:
108
- stale.append(ws)
109
- for ws in stale:
110
- connected_readers.discard(ws)
111
-
112
- logger.info(
113
- f"{datetime.utcnow().isoformat()} - AUTH_SUCCESS - IP: {client_ip} - UserType: poster - Action: post - PosterID: {payload.poster_id}"
114
- )
115
- record_auth_success(client_ip)
116
-
117
- return JSONResponse(
118
- status_code=201,
119
- content={
120
- "message_id": msg.id,
121
- "status": "accepted",
122
- "timestamp": msg.timestamp.isoformat(),
123
- },
124
- )
125
-
126
-
127
- @app.websocket("/ws")
128
- async def websocket_endpoint(websocket: WebSocket):
129
- client_ip = websocket.client.host if websocket.client else "unknown"
130
-
131
- # 1) IP block check
132
- if is_ip_blocked(client_ip):
133
- logger.warning(
134
- f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /ws"
135
- )
136
- await websocket.close(code=1008)
137
- return
138
-
139
- # 2) API key validation (query param)
140
- api_key = websocket.query_params.get("api_key")
141
- if not api_key or api_key != READER_KEY:
142
- record_auth_failure(client_ip)
143
- logger.warning(
144
- f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_reader_key - Endpoint: /ws"
145
- )
146
- await websocket.accept()
147
- await websocket.send_json(
148
- jsonable_encoder(ErrorPayload(type="error", error="Invalid API key"))
149
- )
150
- await websocket.close(code=1008)
151
- return
152
-
153
- await websocket.accept()
154
- logger.info(
155
- f"{datetime.utcnow().isoformat()} - AUTH_SUCCESS - IP: {client_ip} - UserType: reader - Action: connect"
156
- )
157
- record_auth_success(client_ip)
158
-
159
- # Register + send initial history (JSON-safe)
160
- async with storage_lock:
161
- connected_readers.add(websocket)
162
- history_payload = HistoryPayload(type="history", messages=list(messages))
163
- await websocket.send_json(jsonable_encoder(history_payload))
164
-
165
- try:
166
- while True:
167
- # Keepalive: wait for any client message; if none, send ping
168
- try:
169
- await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
170
- except asyncio.TimeoutError:
171
- await websocket.send_json({"type": "ping"})
172
- except WebSocketDisconnect:
173
- pass
174
- finally:
175
- async with storage_lock:
176
- connected_readers.discard(websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Header, HTTPException, File, UploadFile, Form
2
+ from fastapi.responses import HTMLResponse, JSONResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.encoders import jsonable_encoder
5
+
6
+ from uuid import uuid4
7
+ from datetime import datetime, timezone, timedelta
8
+ from typing import Optional
9
+ import os
10
+ import asyncio
11
+ import base64
12
+ import json
13
+
14
+ from .models import Message, PostRequest, HistoryPayload, NewPostPayload, ErrorPayload, FileContentPayload
15
+ from .storage import messages, connected_readers, storage_lock
16
+ from .rate_limit import (
17
+ is_ip_blocked,
18
+ record_auth_failure,
19
+ record_auth_success,
20
+ get_client_ip,
21
+ )
22
+ from .logger_config import logger
23
+
24
+ POSTER_KEY = os.getenv("POSTER_KEY", "")
25
+ READER_KEY = os.getenv("READER_KEY", "")
26
+
27
+ app = FastAPI()
28
+
29
+ # Static HTML (root)
30
+ app.mount("/static", StaticFiles(directory="app/static"), name="static")
31
+
32
+
33
+ @app.get("/", response_class=HTMLResponse)
34
+ async def root():
35
+ with open("app/static/index.html", "r", encoding="utf-8") as f:
36
+ return f.read()
37
+
38
+
39
+ @app.get("/health")
40
+ async def health():
41
+ async with storage_lock:
42
+ message_count = len(messages)
43
+ connected = len(connected_readers)
44
+ return {
45
+ "status": "healthy",
46
+ "message_count": message_count,
47
+ "connected_readers": connected,
48
+ }
49
+
50
+
51
+ @app.post("/api/v1/posts")
52
+ async def create_post(
53
+ request: Request,
54
+ payload: PostRequest,
55
+ x_api_key: Optional[str] = Header(default=None, alias="X-API-Key"),
56
+ ):
57
+ client_ip = get_client_ip(request)
58
+
59
+ # 1) IP block check
60
+ if is_ip_blocked(client_ip):
61
+ logger.warning(
62
+ f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /api/v1/posts"
63
+ )
64
+ return JSONResponse(
65
+ status_code=429,
66
+ content={"error": "Too many failed authentication attempts"},
67
+ )
68
+
69
+ # 2) API key validation
70
+ if not x_api_key or x_api_key != POSTER_KEY:
71
+ record_auth_failure(client_ip)
72
+ logger.warning(
73
+ f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_poster_key - Endpoint: /api/v1/posts"
74
+ )
75
+ raise HTTPException(status_code=401, detail="Invalid API key")
76
+
77
+ # 3) Defensive length check (pydantic also enforces this)
78
+ if len(payload.content) > 1000:
79
+ raise HTTPException(
80
+ status_code=413, detail="Content exceeds 1000 character limit"
81
+ )
82
+
83
+ now = datetime.now(timezone.utc)
84
+ msg = Message(
85
+ id=str(uuid4()),
86
+ poster_id=payload.poster_id,
87
+ content=payload.content,
88
+ timestamp=now,
89
+ category=payload.category,
90
+ metadata=payload.metadata or {},
91
+ )
92
+
93
+ async with storage_lock:
94
+ messages.append(msg)
95
+
96
+ # Retention: expire >48h (50-message maxlen still takes precedence)
97
+ cutoff = now - timedelta(hours=48)
98
+ tmp = [m for m in messages if m.timestamp >= cutoff]
99
+ messages.clear()
100
+ for m in tmp:
101
+ messages.append(m)
102
+
103
+ # Broadcast to all connected readers (JSON-safe)
104
+ payload_out = jsonable_encoder(NewPostPayload(type="new_post", message=msg))
105
+ stale = []
106
+ for ws in connected_readers:
107
+ try:
108
+ await ws.send_json(payload_out)
109
+ except Exception:
110
+ stale.append(ws)
111
+ for ws in stale:
112
+ connected_readers.discard(ws)
113
+
114
+ logger.info(
115
+ f"{datetime.utcnow().isoformat()} - AUTH_SUCCESS - IP: {client_ip} - UserType: poster - Action: post - PosterID: {payload.poster_id}"
116
+ )
117
+ record_auth_success(client_ip)
118
+
119
+ return JSONResponse(
120
+ status_code=201,
121
+ content={
122
+ "message_id": msg.id,
123
+ "status": "accepted",
124
+ "timestamp": msg.timestamp.isoformat(),
125
+ },
126
+ )
127
+
128
+
129
+ @app.post("/api/v1/upload")
130
+ async def upload_file(
131
+ request: Request,
132
+ poster_id: str = Form(...),
133
+ category: Optional[str] = Form(default=None),
134
+ file: UploadFile = File(...),
135
+ x_api_key: Optional[str] = Header(default=None, alias="X-API-Key"),
136
+ ):
137
+ client_ip = get_client_ip(request)
138
+
139
+ # 1) IP block check
140
+ if is_ip_blocked(client_ip):
141
+ logger.warning(
142
+ f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /api/v1/upload"
143
+ )
144
+ return JSONResponse(
145
+ status_code=429,
146
+ content={"error": "Too many failed authentication attempts"},
147
+ )
148
+
149
+ # 2) API key validation
150
+ if not x_api_key or x_api_key != POSTER_KEY:
151
+ record_auth_failure(client_ip)
152
+ logger.warning(
153
+ f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_poster_key - Endpoint: /api/v1/upload"
154
+ )
155
+ raise HTTPException(status_code=401, detail="Invalid API key")
156
+
157
+ # 3) Validate file extension and MIME type
158
+ filename = file.filename or "unknown"
159
+ file_ext = os.path.splitext(filename)[1].lower()
160
+
161
+ allowed_extensions = {".png", ".md", ".txt"}
162
+ if file_ext not in allowed_extensions:
163
+ raise HTTPException(status_code=400, detail="Only PNG, MD, and TXT files are allowed")
164
+
165
+ # MIME type validation
166
+ mime_type = file.content_type or ""
167
+ valid_mime_types = {
168
+ ".png": ["image/png"],
169
+ ".md": ["text/markdown", "text/plain", "application/octet-stream"],
170
+ ".txt": ["text/plain", "application/octet-stream"],
171
+ }
172
+
173
+ if mime_type and mime_type not in valid_mime_types.get(file_ext, []):
174
+ # Allow if no content-type provided, otherwise validate
175
+ pass # Be lenient with MIME types as they can vary
176
+
177
+ # 4) Read file content and validate size
178
+ content = await file.read()
179
+ file_size = len(content)
180
+
181
+ if file_ext == ".png":
182
+ if file_size > 2 * 1024 * 1024: # 2MB
183
+ raise HTTPException(status_code=413, detail="PNG file exceeds 2MB limit")
184
+ else: # .md or .txt
185
+ if file_size > 100 * 1024: # 100KB
186
+ raise HTTPException(status_code=413, detail="Text file exceeds 100KB limit")
187
+
188
+ # 5) Process file based on type
189
+ now = datetime.now(timezone.utc)
190
+ msg_id = str(uuid4())
191
+
192
+ if file_ext == ".png":
193
+ # Store PNG as base64 data URL
194
+ base64_data = base64.b64encode(content).decode("utf-8")
195
+ file_url = f"data:image/png;base64,{base64_data}"
196
+ msg = Message(
197
+ id=msg_id,
198
+ poster_id=poster_id,
199
+ content=filename,
200
+ timestamp=now,
201
+ category=category,
202
+ metadata={},
203
+ message_type="png",
204
+ file_url=file_url,
205
+ title=filename,
206
+ )
207
+ elif file_ext == ".md":
208
+ # Extract first line as title, store full content
209
+ text_content = content.decode("utf-8", errors="replace")
210
+ first_line = text_content.split("\n")[0].strip() or filename
211
+ msg = Message(
212
+ id=msg_id,
213
+ poster_id=poster_id,
214
+ content=filename,
215
+ timestamp=now,
216
+ category=category,
217
+ metadata={},
218
+ message_type="md",
219
+ title=first_line,
220
+ file_content=text_content,
221
+ )
222
+ else: # .txt
223
+ # Extract first line as title, store full content
224
+ text_content = content.decode("utf-8", errors="replace")
225
+ first_line = text_content.split("\n")[0].strip() or filename
226
+ msg = Message(
227
+ id=msg_id,
228
+ poster_id=poster_id,
229
+ content=filename,
230
+ timestamp=now,
231
+ category=category,
232
+ metadata={},
233
+ message_type="txt",
234
+ title=first_line,
235
+ file_content=text_content,
236
+ )
237
+
238
+ # 6) Store message and broadcast (same as create_post)
239
+ async with storage_lock:
240
+ messages.append(msg)
241
+
242
+ # Retention: expire >48h (50-message maxlen still takes precedence)
243
+ cutoff = now - timedelta(hours=48)
244
+ tmp = [m for m in messages if m.timestamp >= cutoff]
245
+ messages.clear()
246
+ for m in tmp:
247
+ messages.append(m)
248
+
249
+ # Broadcast to all connected readers (JSON-safe)
250
+ payload_out = jsonable_encoder(NewPostPayload(type="new_post", message=msg))
251
+ stale = []
252
+ for ws in connected_readers:
253
+ try:
254
+ await ws.send_json(payload_out)
255
+ except Exception:
256
+ stale.append(ws)
257
+ for ws in stale:
258
+ connected_readers.discard(ws)
259
+
260
+ logger.info(
261
+ f"{datetime.utcnow().isoformat()} - AUTH_SUCCESS - IP: {client_ip} - UserType: poster - Action: upload - PosterID: {poster_id}"
262
+ )
263
+ record_auth_success(client_ip)
264
+
265
+ return JSONResponse(
266
+ status_code=201,
267
+ content={
268
+ "message_id": msg.id,
269
+ "status": "accepted",
270
+ "timestamp": msg.timestamp.isoformat(),
271
+ },
272
+ )
273
+
274
+
275
+ @app.websocket("/ws")
276
+ async def websocket_endpoint(websocket: WebSocket):
277
+ client_ip = websocket.client.host if websocket.client else "unknown"
278
+
279
+ # 1) IP block check
280
+ if is_ip_blocked(client_ip):
281
+ logger.warning(
282
+ f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: blocked_ip - Endpoint: /ws"
283
+ )
284
+ await websocket.close(code=1008)
285
+ return
286
+
287
+ # 2) API key validation (query param)
288
+ api_key = websocket.query_params.get("api_key")
289
+ if not api_key or api_key != READER_KEY:
290
+ record_auth_failure(client_ip)
291
+ logger.warning(
292
+ f"{datetime.utcnow().isoformat()} - AUTH_FAILURE - IP: {client_ip} - Reason: invalid_reader_key - Endpoint: /ws"
293
+ )
294
+ await websocket.accept()
295
+ await websocket.send_json(
296
+ jsonable_encoder(ErrorPayload(type="error", error="Invalid API key"))
297
+ )
298
+ await websocket.close(code=1008)
299
+ return
300
+
301
+ await websocket.accept()
302
+ logger.info(
303
+ f"{datetime.utcnow().isoformat()} - AUTH_SUCCESS - IP: {client_ip} - UserType: reader - Action: connect"
304
+ )
305
+ record_auth_success(client_ip)
306
+
307
+ # Register + send initial history (JSON-safe)
308
+ async with storage_lock:
309
+ connected_readers.add(websocket)
310
+ history_payload = HistoryPayload(type="history", messages=list(messages))
311
+ await websocket.send_json(jsonable_encoder(history_payload))
312
+
313
+ try:
314
+ while True:
315
+ # Keepalive: wait for client message; if none, send ping
316
+ try:
317
+ data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
318
+ # Handle file content request
319
+ try:
320
+ msg = json.loads(data)
321
+ if msg.get("type") == "get_file_content":
322
+ message_id = msg.get("message_id")
323
+ async with storage_lock:
324
+ target_msg = None
325
+ for m in messages:
326
+ if m.id == message_id:
327
+ target_msg = m
328
+ break
329
+
330
+ if target_msg is None:
331
+ await websocket.send_json(
332
+ jsonable_encoder(ErrorPayload(
333
+ type="error",
334
+ error="Message not found"
335
+ ))
336
+ )
337
+ elif target_msg.message_type not in ("md", "txt"):
338
+ await websocket.send_json(
339
+ jsonable_encoder(ErrorPayload(
340
+ type="error",
341
+ error="Invalid message type"
342
+ ))
343
+ )
344
+ else:
345
+ await websocket.send_json(
346
+ jsonable_encoder(FileContentPayload(
347
+ type="file_content",
348
+ message_id=message_id,
349
+ content=target_msg.file_content or "",
350
+ content_type=target_msg.message_type
351
+ ))
352
+ )
353
+ except json.JSONDecodeError:
354
+ # Ignore non-JSON messages (keepalive behavior)
355
+ pass
356
+ except asyncio.TimeoutError:
357
+ await websocket.send_json({"type": "ping"})
358
+ except WebSocketDisconnect:
359
+ pass
360
+ finally:
361
+ async with storage_lock:
362
+ connected_readers.discard(websocket)
app/models.py CHANGED
@@ -1,34 +1,46 @@
1
- from pydantic import BaseModel, Field
2
- from datetime import datetime
3
- from typing import Optional, Dict, Any, List
4
-
5
-
6
- class PostRequest(BaseModel):
7
- poster_id: str = Field(..., min_length=1)
8
- content: str = Field(..., min_length=1, max_length=1000)
9
- category: Optional[str] = None
10
- metadata: Optional[Dict[str, Any]] = None
11
-
12
-
13
- class Message(BaseModel):
14
- id: str
15
- poster_id: str
16
- content: str
17
- timestamp: datetime
18
- category: Optional[str] = None
19
- metadata: Dict[str, Any] = {}
20
-
21
-
22
- class HistoryPayload(BaseModel):
23
- type: str = "history"
24
- messages: List[Message]
25
-
26
-
27
- class NewPostPayload(BaseModel):
28
- type: str = "new_post"
29
- message: Message
30
-
31
-
32
- class ErrorPayload(BaseModel):
33
- type: str = "error"
34
- error: str
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from datetime import datetime
3
+ from typing import Optional, Dict, Any, List, Literal
4
+
5
+
6
+ class PostRequest(BaseModel):
7
+ poster_id: str = Field(..., min_length=1)
8
+ content: str = Field(..., min_length=1, max_length=1000)
9
+ category: Optional[str] = None
10
+ metadata: Optional[Dict[str, Any]] = None
11
+
12
+
13
+ class Message(BaseModel):
14
+ id: str
15
+ poster_id: str
16
+ content: str
17
+ timestamp: datetime
18
+ category: Optional[str] = None
19
+ metadata: Dict[str, Any] = {}
20
+ # File-related fields
21
+ message_type: Literal["text", "png", "md", "txt"] = "text"
22
+ file_url: Optional[str] = None
23
+ title: Optional[str] = None
24
+ file_content: Optional[str] = None # For MD/TXT - stored but not shown in list
25
+
26
+
27
+ class HistoryPayload(BaseModel):
28
+ type: str = "history"
29
+ messages: List[Message]
30
+
31
+
32
+ class NewPostPayload(BaseModel):
33
+ type: str = "new_post"
34
+ message: Message
35
+
36
+
37
+ class ErrorPayload(BaseModel):
38
+ type: str = "error"
39
+ error: str
40
+
41
+
42
+ class FileContentPayload(BaseModel):
43
+ type: str = "file_content"
44
+ message_id: str
45
+ content: str
46
+ content_type: Literal["md", "txt"]