triflix commited on
Commit
e272177
·
verified ·
1 Parent(s): 20064bf

Create app/utils.py

Browse files
Files changed (1) hide show
  1. app/utils.py +125 -0
app/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/utils.py
2
+ import os
3
+ import json
4
+ import time
5
+ import asyncio
6
+ from pathlib import Path
7
+ from typing import Dict, Set
8
+
9
+ from fastapi import WebSocket
10
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
11
+ from cryptography.hazmat.backends import default_backend
12
+
13
+ # --- Constants ---
14
+ UPLOAD_ROOT = Path("/tmp/triflix_uploader")
15
+ # In a real app, load this securely from env/secrets management
16
+ try:
17
+ MASTER_KEY = bytes.fromhex(os.environ["TRIFLIX_MASTER_KEY"])
18
+ if len(MASTER_KEY) != 32:
19
+ raise ValueError("TRIFLIX_MASTER_KEY must be a 32-byte (64-char) hex string.")
20
+ except (KeyError, ValueError) as e:
21
+ print(f"WARNING: Invalid or missing TRIFLIX_MASTER_KEY. Using a temporary key for demo purposes. {e}")
22
+ MASTER_KEY = os.urandom(32)
23
+
24
+ # --- In-memory State (locks and WebSocket connections) ---
25
+ # These are reset if the server restarts. The on-disk state persists.
26
+ session_locks: Dict[str, asyncio.Lock] = {}
27
+ ws_connections: Dict[str, Set[WebSocket]] = {}
28
+
29
+ # --- Session Management ---
30
+ def get_session_dir(session_id: str) -> Path:
31
+ """Returns the directory for a given session."""
32
+ return UPLOAD_ROOT / session_id
33
+
34
+ def load_meta(session_id: str) -> dict:
35
+ """Loads metadata for a session from its meta.json file."""
36
+ meta_path = get_session_dir(session_id) / "meta.json"
37
+ if not meta_path.exists():
38
+ raise FileNotFoundError("Session metadata not found.")
39
+ return json.loads(meta_path.read_text())
40
+
41
+ def save_meta(session_id: str, meta: dict):
42
+ """Saves metadata for a session to its meta.json file."""
43
+ meta_path = get_session_dir(session_id) / "meta.json"
44
+ meta_path.parent.mkdir(parents=True, exist_ok=True)
45
+ meta_path.write_text(json.dumps(meta, indent=2))
46
+
47
+ async def broadcast_progress(session_id: str):
48
+ """Sends progress updates to all connected WebSocket clients for a session."""
49
+ if session_id not in ws_connections:
50
+ return
51
+
52
+ try:
53
+ meta = load_meta(session_id)
54
+ payload = {
55
+ "type": "progress",
56
+ "uploaded_bytes": meta.get("uploaded_bytes", 0),
57
+ "total_bytes": meta.get("total_bytes"),
58
+ "status": meta.get("status", "uploading"),
59
+ }
60
+ # Create a list of tasks to send messages concurrently
61
+ tasks = [ws.send_json(payload) for ws in ws_connections[session_id]]
62
+ await asyncio.gather(*tasks, return_exceptions=True) # Use gather to handle potential disconnects
63
+ except FileNotFoundError:
64
+ # If meta is gone, maybe the session was cleaned up. Nothing to broadcast.
65
+ pass
66
+
67
+ # --- Rate Limiting ---
68
+ rate_limit_store: Dict[str, list] = {}
69
+
70
+ def enforce_rate_limit(ip: str, limit: int = 10, per_seconds: int = 60) -> bool:
71
+ """Simple in-memory rate limiter. Returns True if allowed, False if denied."""
72
+ now = time.time()
73
+ timestamps = rate_limit_store.setdefault(ip, [])
74
+ # Remove timestamps older than the window
75
+ timestamps[:] = [t for t in timestamps if t > now - per_seconds]
76
+ if len(timestamps) >= limit:
77
+ return False
78
+ timestamps.append(now)
79
+ return True
80
+
81
+ # --- Streaming Encryption/Decryption ---
82
+ CHUNK_SIZE = 64 * 1024 # 64KB
83
+
84
+ def encrypt_file(source_path: Path, dest_path: Path):
85
+ """Encrypts a file using AES-GCM streaming."""
86
+ nonce = os.urandom(12)
87
+ cipher = Cipher(algorithms.AES(MASTER_KEY), modes.GCM(nonce), backend=default_backend()).encryptor()
88
+
89
+ with open(source_path, "rb") as fin, open(dest_path, "wb") as fout:
90
+ fout.write(nonce)
91
+ while True:
92
+ chunk = fin.read(CHUNK_SIZE)
93
+ if not chunk:
94
+ break
95
+ encrypted_chunk = cipher.update(chunk)
96
+ fout.write(encrypted_chunk)
97
+
98
+ fout.write(cipher.finalize())
99
+ fout.write(cipher.tag)
100
+
101
+ async def decrypt_stream_generator(encrypted_path: Path):
102
+ """A generator that yields decrypted chunks of a file."""
103
+ with open(encrypted_path, "rb") as f:
104
+ nonce = f.read(12)
105
+ # Read the rest of the file to get ciphertext and tag
106
+ f.seek(0, os.SEEK_END)
107
+ end_pos = f.tell()
108
+ f.seek(12) # Go back to after the nonce
109
+
110
+ ciphertext_with_tag = f.read()
111
+ ciphertext = ciphertext_with_tag[:-16]
112
+ tag = ciphertext_with_tag[-16:]
113
+
114
+ decryptor = Cipher(algorithms.AES(MASTER_KEY), modes.GCM(nonce, tag), backend=default_backend()).decryptor()
115
+
116
+ # We can process the ciphertext in chunks now
117
+ offset = 0
118
+ while offset < len(ciphertext):
119
+ chunk = ciphertext[offset:offset + CHUNK_SIZE]
120
+ decrypted_chunk = decryptor.update(chunk)
121
+ yield decrypted_chunk
122
+ offset += CHUNK_SIZE
123
+
124
+ # Finalize to check tag and get any remaining data
125
+ yield decryptor.finalize()