PutInPutout commited on
Commit
8b523eb
·
verified ·
1 Parent(s): bda22f0

Upload 16 files

Browse files
mediaflow_proxy/__init__.py ADDED
File without changes
mediaflow_proxy/configs.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+
3
+
4
+ class Settings(BaseSettings):
5
+ api_password: str # The password for accessing the API endpoints.
6
+ proxy_url: str | None = None # The URL of the proxy server to route requests through.
7
+ mpd_live_stream_delay: int = 30 # The delay in seconds for live MPD streams.
8
+
9
+ class Config:
10
+ env_file = ".env"
11
+ extra = "ignore"
12
+
13
+
14
+ settings = Settings()
mediaflow_proxy/const.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SUPPORTED_RESPONSE_HEADERS = [
2
+ "accept-ranges",
3
+ "content-type",
4
+ "content-length",
5
+ "content-range",
6
+ "connection",
7
+ "transfer-encoding",
8
+ "last-modified",
9
+ "etag",
10
+ "cache-control",
11
+ "expires",
12
+ ]
13
+
14
+ SUPPORTED_REQUEST_HEADERS = [
15
+ "accept",
16
+ "accept-encoding",
17
+ "accept-language",
18
+ "connection",
19
+ "range",
20
+ "if-range",
21
+ "user-agent",
22
+ "referer",
23
+ "origin",
24
+ ]
mediaflow_proxy/drm/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+
5
+ async def create_temp_file(suffix: str, content: bytes = None, prefix: str = None) -> tempfile.NamedTemporaryFile:
6
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=prefix)
7
+ temp_file.delete_file = lambda: os.unlink(temp_file.name)
8
+ if content:
9
+ temp_file.write(content)
10
+ temp_file.close()
11
+ return temp_file
mediaflow_proxy/drm/decrypter.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import struct
3
+ import sys
4
+
5
+ from Crypto.Cipher import AES
6
+ from collections import namedtuple
7
+ import array
8
+
9
+ CENCSampleAuxiliaryDataFormat = namedtuple("CENCSampleAuxiliaryDataFormat", ["is_encrypted", "iv", "sub_samples"])
10
+
11
+
12
+ class MP4Atom:
13
+ """
14
+ Represents an MP4 atom, which is a basic unit of data in an MP4 file.
15
+ Each atom contains a header (size and type) and data.
16
+ """
17
+
18
+ __slots__ = ("atom_type", "size", "data")
19
+
20
+ def __init__(self, atom_type: bytes, size: int, data: memoryview | bytearray):
21
+ """
22
+ Initializes an MP4Atom instance.
23
+
24
+ Args:
25
+ atom_type (bytes): The type of the atom.
26
+ size (int): The size of the atom.
27
+ data (memoryview | bytearray): The data contained in the atom.
28
+ """
29
+ self.atom_type = atom_type
30
+ self.size = size
31
+ self.data = data
32
+
33
+ def __repr__(self):
34
+ return f"<MP4Atom type={self.atom_type}, size={self.size}>"
35
+
36
+ def pack(self):
37
+ """
38
+ Packs the atom into binary data.
39
+
40
+ Returns:
41
+ bytes: Packed binary data with size, type, and data.
42
+ """
43
+ return struct.pack(">I", self.size) + self.atom_type + self.data
44
+
45
+
46
+ class MP4Parser:
47
+ """
48
+ Parses MP4 data to extract atoms and their structure.
49
+ """
50
+
51
+ def __init__(self, data: memoryview):
52
+ """
53
+ Initializes an MP4Parser instance.
54
+
55
+ Args:
56
+ data (memoryview): The binary data of the MP4 file.
57
+ """
58
+ self.data = data
59
+ self.position = 0
60
+
61
+ def read_atom(self) -> MP4Atom | None:
62
+ """
63
+ Reads the next atom from the data.
64
+
65
+ Returns:
66
+ MP4Atom | None: MP4Atom object or None if no more atoms are available.
67
+ """
68
+ pos = self.position
69
+ if pos + 8 > len(self.data):
70
+ return None
71
+
72
+ size, atom_type = struct.unpack_from(">I4s", self.data, pos)
73
+ pos += 8
74
+
75
+ if size == 1:
76
+ if pos + 8 > len(self.data):
77
+ return None
78
+ size = struct.unpack_from(">Q", self.data, pos)[0]
79
+ pos += 8
80
+
81
+ if size < 8 or pos + size - 8 > len(self.data):
82
+ return None
83
+
84
+ atom_data = self.data[pos : pos + size - 8]
85
+ self.position = pos + size - 8
86
+ return MP4Atom(atom_type, size, atom_data)
87
+
88
+ def list_atoms(self) -> list[MP4Atom]:
89
+ """
90
+ Lists all atoms in the data.
91
+
92
+ Returns:
93
+ list[MP4Atom]: List of MP4Atom objects.
94
+ """
95
+ atoms = []
96
+ original_position = self.position
97
+ self.position = 0
98
+ while self.position + 8 <= len(self.data):
99
+ atom = self.read_atom()
100
+ if not atom:
101
+ break
102
+ atoms.append(atom)
103
+ self.position = original_position
104
+ return atoms
105
+
106
+ def _read_atom_at(self, pos: int, end: int) -> MP4Atom | None:
107
+ if pos + 8 > end:
108
+ return None
109
+
110
+ size, atom_type = struct.unpack_from(">I4s", self.data, pos)
111
+ pos += 8
112
+
113
+ if size == 1:
114
+ if pos + 8 > end:
115
+ return None
116
+ size = struct.unpack_from(">Q", self.data, pos)[0]
117
+ pos += 8
118
+
119
+ if size < 8 or pos + size - 8 > end:
120
+ return None
121
+
122
+ atom_data = self.data[pos : pos + size - 8]
123
+ return MP4Atom(atom_type, size, atom_data)
124
+
125
+ def print_atoms_structure(self, indent: int = 0):
126
+ """
127
+ Prints the structure of all atoms in the data.
128
+
129
+ Args:
130
+ indent (int): The indentation level for printing.
131
+ """
132
+ pos = 0
133
+ end = len(self.data)
134
+ while pos + 8 <= end:
135
+ atom = self._read_atom_at(pos, end)
136
+ if not atom:
137
+ break
138
+ self.print_single_atom_structure(atom, pos, indent)
139
+ pos += atom.size
140
+
141
+ def print_single_atom_structure(self, atom: MP4Atom, parent_position: int, indent: int):
142
+ """
143
+ Prints the structure of a single atom.
144
+
145
+ Args:
146
+ atom (MP4Atom): The atom to print.
147
+ parent_position (int): The position of the parent atom.
148
+ indent (int): The indentation level for printing.
149
+ """
150
+ try:
151
+ atom_type = atom.atom_type.decode("utf-8")
152
+ except UnicodeDecodeError:
153
+ atom_type = repr(atom.atom_type)
154
+ print(" " * indent + f"Type: {atom_type}, Size: {atom.size}")
155
+
156
+ child_pos = 0
157
+ child_end = len(atom.data)
158
+ while child_pos + 8 <= child_end:
159
+ child_atom = self._read_atom_at(parent_position + 8 + child_pos, parent_position + 8 + child_end)
160
+ if not child_atom:
161
+ break
162
+ self.print_single_atom_structure(child_atom, parent_position, indent + 2)
163
+ child_pos += child_atom.size
164
+
165
+
166
+ class MP4Decrypter:
167
+ """
168
+ Class to handle the decryption of CENC encrypted MP4 segments.
169
+
170
+ Attributes:
171
+ key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
172
+ current_key (bytes | None): Current decryption key.
173
+ trun_sample_sizes (array.array): Array of sample sizes from the 'trun' box.
174
+ current_sample_info (list): List of sample information from the 'senc' box.
175
+ encryption_overhead (int): Total size of encryption-related boxes.
176
+ """
177
+
178
+ def __init__(self, key_map: dict[bytes, bytes]):
179
+ """
180
+ Initializes the MP4Decrypter with a key map.
181
+
182
+ Args:
183
+ key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
184
+ """
185
+ self.key_map = key_map
186
+ self.current_key = None
187
+ self.trun_sample_sizes = array.array("I")
188
+ self.current_sample_info = []
189
+ self.encryption_overhead = 0
190
+
191
+ def decrypt_segment(self, combined_segment: bytes) -> bytes:
192
+ """
193
+ Decrypts a combined MP4 segment.
194
+
195
+ Args:
196
+ combined_segment (bytes): Combined initialization and media segment.
197
+
198
+ Returns:
199
+ bytes: Decrypted segment content.
200
+ """
201
+ data = memoryview(combined_segment)
202
+ parser = MP4Parser(data)
203
+ atoms = parser.list_atoms()
204
+
205
+ atom_process_order = [b"moov", b"moof", b"sidx", b"mdat"]
206
+
207
+ processed_atoms = {}
208
+ for atom_type in atom_process_order:
209
+ if atom := next((a for a in atoms if a.atom_type == atom_type), None):
210
+ processed_atoms[atom_type] = self._process_atom(atom_type, atom)
211
+
212
+ result = bytearray()
213
+ for atom in atoms:
214
+ if atom.atom_type in processed_atoms:
215
+ processed_atom = processed_atoms[atom.atom_type]
216
+ result.extend(processed_atom.pack())
217
+ else:
218
+ result.extend(atom.pack())
219
+
220
+ return bytes(result)
221
+
222
+ def _process_atom(self, atom_type: bytes, atom: MP4Atom) -> MP4Atom:
223
+ """
224
+ Processes an MP4 atom based on its type.
225
+
226
+ Args:
227
+ atom_type (bytes): Type of the atom.
228
+ atom (MP4Atom): The atom to process.
229
+
230
+ Returns:
231
+ MP4Atom: Processed atom.
232
+ """
233
+ match atom_type:
234
+ case b"moov":
235
+ return self._process_moov(atom)
236
+ case b"moof":
237
+ return self._process_moof(atom)
238
+ case b"sidx":
239
+ return self._process_sidx(atom)
240
+ case b"mdat":
241
+ return self._decrypt_mdat(atom)
242
+ case _:
243
+ return atom
244
+
245
+ def _process_moov(self, moov: MP4Atom) -> MP4Atom:
246
+ """
247
+ Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
248
+ This includes information about tracks, media data, and other movie-level metadata.
249
+
250
+ Args:
251
+ moov (MP4Atom): The 'moov' atom to process.
252
+
253
+ Returns:
254
+ MP4Atom: Processed 'moov' atom with updated track information.
255
+ """
256
+ parser = MP4Parser(moov.data)
257
+ new_moov_data = bytearray()
258
+
259
+ for atom in iter(parser.read_atom, None):
260
+ if atom.atom_type == b"trak":
261
+ new_trak = self._process_trak(atom)
262
+ new_moov_data.extend(new_trak.pack())
263
+ elif atom.atom_type != b"pssh":
264
+ # Skip PSSH boxes as they are not needed in the decrypted output
265
+ new_moov_data.extend(atom.pack())
266
+
267
+ return MP4Atom(b"moov", len(new_moov_data) + 8, new_moov_data)
268
+
269
+ def _process_moof(self, moof: MP4Atom) -> MP4Atom:
270
+ """
271
+ Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
272
+ This includes information about tracks, media data, and other movie-level metadata.
273
+
274
+ Args:
275
+ moov (MP4Atom): The 'moov' atom to process.
276
+
277
+ Returns:
278
+ MP4Atom: Processed 'moov' atom with updated track information.
279
+ """
280
+ parser = MP4Parser(moof.data)
281
+ new_moof_data = bytearray()
282
+
283
+ for atom in iter(parser.read_atom, None):
284
+ if atom.atom_type == b"traf":
285
+ new_traf = self._process_traf(atom)
286
+ new_moof_data.extend(new_traf.pack())
287
+ else:
288
+ new_moof_data.extend(atom.pack())
289
+
290
+ return MP4Atom(b"moof", len(new_moof_data) + 8, new_moof_data)
291
+
292
+ def _process_traf(self, traf: MP4Atom) -> MP4Atom:
293
+ """
294
+ Processes the 'traf' (Track Fragment) atom, which contains information about a track fragment.
295
+ This includes sample information, sample encryption data, and other track-level metadata.
296
+
297
+ Args:
298
+ traf (MP4Atom): The 'traf' atom to process.
299
+
300
+ Returns:
301
+ MP4Atom: Processed 'traf' atom with updated sample information.
302
+ """
303
+ parser = MP4Parser(traf.data)
304
+ new_traf_data = bytearray()
305
+ tfhd = None
306
+ sample_count = 0
307
+ sample_info = []
308
+
309
+ atoms = parser.list_atoms()
310
+
311
+ # calculate encryption_overhead earlier to avoid dependency on trun
312
+ self.encryption_overhead = sum(a.size for a in atoms if a.atom_type in {b"senc", b"saiz", b"saio"})
313
+
314
+ for atom in atoms:
315
+ if atom.atom_type == b"tfhd":
316
+ tfhd = atom
317
+ new_traf_data.extend(atom.pack())
318
+ elif atom.atom_type == b"trun":
319
+ sample_count = self._process_trun(atom)
320
+ new_trun = self._modify_trun(atom)
321
+ new_traf_data.extend(new_trun.pack())
322
+ elif atom.atom_type == b"senc":
323
+ # Parse senc but don't include it in the new decrypted traf data and similarly don't include saiz and saio
324
+ sample_info = self._parse_senc(atom, sample_count)
325
+ elif atom.atom_type not in {b"saiz", b"saio"}:
326
+ new_traf_data.extend(atom.pack())
327
+
328
+ if tfhd:
329
+ tfhd_track_id = struct.unpack_from(">I", tfhd.data, 4)[0]
330
+ self.current_key = self._get_key_for_track(tfhd_track_id)
331
+ self.current_sample_info = sample_info
332
+
333
+ return MP4Atom(b"traf", len(new_traf_data) + 8, new_traf_data)
334
+
335
+ def _decrypt_mdat(self, mdat: MP4Atom) -> MP4Atom:
336
+ """
337
+ Decrypts the 'mdat' (Media Data) atom, which contains the actual media data (audio, video, etc.).
338
+ The decryption is performed using the current decryption key and sample information.
339
+
340
+ Args:
341
+ mdat (MP4Atom): The 'mdat' atom to decrypt.
342
+
343
+ Returns:
344
+ MP4Atom: Decrypted 'mdat' atom with decrypted media data.
345
+ """
346
+ if not self.current_key or not self.current_sample_info:
347
+ return mdat # Return original mdat if we don't have decryption info
348
+
349
+ decrypted_samples = bytearray()
350
+ mdat_data = mdat.data
351
+ position = 0
352
+
353
+ for i, info in enumerate(self.current_sample_info):
354
+ if position >= len(mdat_data):
355
+ break # No more data to process
356
+
357
+ sample_size = self.trun_sample_sizes[i] if i < len(self.trun_sample_sizes) else len(mdat_data) - position
358
+ sample = mdat_data[position : position + sample_size]
359
+ position += sample_size
360
+ decrypted_sample = self._process_sample(sample, info, self.current_key)
361
+ decrypted_samples.extend(decrypted_sample)
362
+
363
+ return MP4Atom(b"mdat", len(decrypted_samples) + 8, decrypted_samples)
364
+
365
+ def _parse_senc(self, senc: MP4Atom, sample_count: int) -> list[CENCSampleAuxiliaryDataFormat]:
366
+ """
367
+ Parses the 'senc' (Sample Encryption) atom, which contains encryption information for samples.
368
+ This includes initialization vectors (IVs) and sub-sample encryption data.
369
+
370
+ Args:
371
+ senc (MP4Atom): The 'senc' atom to parse.
372
+ sample_count (int): The number of samples.
373
+
374
+ Returns:
375
+ list[CENCSampleAuxiliaryDataFormat]: List of sample auxiliary data formats with encryption information.
376
+ """
377
+ data = memoryview(senc.data)
378
+ version_flags = struct.unpack_from(">I", data, 0)[0]
379
+ version, flags = version_flags >> 24, version_flags & 0xFFFFFF
380
+ position = 4
381
+
382
+ if version == 0:
383
+ sample_count = struct.unpack_from(">I", data, position)[0]
384
+ position += 4
385
+
386
+ sample_info = []
387
+ for _ in range(sample_count):
388
+ if position + 8 > len(data):
389
+ break
390
+
391
+ iv = data[position : position + 8].tobytes()
392
+ position += 8
393
+
394
+ sub_samples = []
395
+ if flags & 0x000002 and position + 2 <= len(data): # Check if subsample information is present
396
+ subsample_count = struct.unpack_from(">H", data, position)[0]
397
+ position += 2
398
+
399
+ for _ in range(subsample_count):
400
+ if position + 6 <= len(data):
401
+ clear_bytes, encrypted_bytes = struct.unpack_from(">HI", data, position)
402
+ position += 6
403
+ sub_samples.append((clear_bytes, encrypted_bytes))
404
+ else:
405
+ break
406
+
407
+ sample_info.append(CENCSampleAuxiliaryDataFormat(True, iv, sub_samples))
408
+
409
+ return sample_info
410
+
411
+ def _get_key_for_track(self, track_id: int) -> bytes:
412
+ """
413
+ Retrieves the decryption key for a given track ID from the key map.
414
+
415
+ Args:
416
+ track_id (int): The track ID.
417
+
418
+ Returns:
419
+ bytes: The decryption key for the specified track ID.
420
+ """
421
+ if len(self.key_map) == 1:
422
+ return next(iter(self.key_map.values()))
423
+ key = self.key_map.get(track_id.pack(4, "big"))
424
+ if not key:
425
+ raise ValueError(f"No key found for track ID {track_id}")
426
+ return key
427
+
428
+ @staticmethod
429
+ def _process_sample(
430
+ sample: memoryview, sample_info: CENCSampleAuxiliaryDataFormat, key: bytes
431
+ ) -> memoryview | bytearray | bytes:
432
+ """
433
+ Processes and decrypts a sample using the provided sample information and decryption key.
434
+ This includes handling sub-sample encryption if present.
435
+
436
+ Args:
437
+ sample (memoryview): The sample data.
438
+ sample_info (CENCSampleAuxiliaryDataFormat): The sample auxiliary data format with encryption information.
439
+ key (bytes): The decryption key.
440
+
441
+ Returns:
442
+ memoryview | bytearray | bytes: The decrypted sample.
443
+ """
444
+ if not sample_info.is_encrypted:
445
+ return sample
446
+
447
+ # pad IV to 16 bytes
448
+ iv = sample_info.iv + b"\x00" * (16 - len(sample_info.iv))
449
+ cipher = AES.new(key, AES.MODE_CTR, initial_value=iv, nonce=b"")
450
+
451
+ if not sample_info.sub_samples:
452
+ # If there are no sub_samples, decrypt the entire sample
453
+ return cipher.decrypt(sample)
454
+
455
+ result = bytearray()
456
+ offset = 0
457
+ for clear_bytes, encrypted_bytes in sample_info.sub_samples:
458
+ result.extend(sample[offset : offset + clear_bytes])
459
+ offset += clear_bytes
460
+ result.extend(cipher.decrypt(sample[offset : offset + encrypted_bytes]))
461
+ offset += encrypted_bytes
462
+
463
+ # If there's any remaining data, treat it as encrypted
464
+ if offset < len(sample):
465
+ result.extend(cipher.decrypt(sample[offset:]))
466
+
467
+ return result
468
+
469
+ def _process_trun(self, trun: MP4Atom) -> int:
470
+ """
471
+ Processes the 'trun' (Track Fragment Run) atom, which contains information about the samples in a track fragment.
472
+ This includes sample sizes, durations, flags, and composition time offsets.
473
+
474
+ Args:
475
+ trun (MP4Atom): The 'trun' atom to process.
476
+
477
+ Returns:
478
+ int: The number of samples in the 'trun' atom.
479
+ """
480
+ trun_flags, sample_count = struct.unpack_from(">II", trun.data, 0)
481
+ data_offset = 8
482
+
483
+ if trun_flags & 0x000001:
484
+ data_offset += 4
485
+ if trun_flags & 0x000004:
486
+ data_offset += 4
487
+
488
+ self.trun_sample_sizes = array.array("I")
489
+
490
+ for _ in range(sample_count):
491
+ if trun_flags & 0x000100: # sample-duration-present flag
492
+ data_offset += 4
493
+ if trun_flags & 0x000200: # sample-size-present flag
494
+ sample_size = struct.unpack_from(">I", trun.data, data_offset)[0]
495
+ self.trun_sample_sizes.append(sample_size)
496
+ data_offset += 4
497
+ else:
498
+ self.trun_sample_sizes.append(0) # Using 0 instead of None for uniformity in the array
499
+ if trun_flags & 0x000400: # sample-flags-present flag
500
+ data_offset += 4
501
+ if trun_flags & 0x000800: # sample-composition-time-offsets-present flag
502
+ data_offset += 4
503
+
504
+ return sample_count
505
+
506
+ def _modify_trun(self, trun: MP4Atom) -> MP4Atom:
507
+ """
508
+ Modifies the 'trun' (Track Fragment Run) atom to update the data offset.
509
+ This is necessary to account for the encryption overhead.
510
+
511
+ Args:
512
+ trun (MP4Atom): The 'trun' atom to modify.
513
+
514
+ Returns:
515
+ MP4Atom: Modified 'trun' atom with updated data offset.
516
+ """
517
+ trun_data = bytearray(trun.data)
518
+ current_flags = struct.unpack_from(">I", trun_data, 0)[0] & 0xFFFFFF
519
+
520
+ # If the data-offset-present flag is set, update the data offset to account for encryption overhead
521
+ if current_flags & 0x000001:
522
+ current_data_offset = struct.unpack_from(">i", trun_data, 8)[0]
523
+ struct.pack_into(">i", trun_data, 8, current_data_offset - self.encryption_overhead)
524
+
525
+ return MP4Atom(b"trun", len(trun_data) + 8, trun_data)
526
+
527
+ def _process_sidx(self, sidx: MP4Atom) -> MP4Atom:
528
+ """
529
+ Processes the 'sidx' (Segment Index) atom, which contains indexing information for media segments.
530
+ This includes references to media segments and their durations.
531
+
532
+ Args:
533
+ sidx (MP4Atom): The 'sidx' atom to process.
534
+
535
+ Returns:
536
+ MP4Atom: Processed 'sidx' atom with updated segment references.
537
+ """
538
+ sidx_data = bytearray(sidx.data)
539
+
540
+ current_size = struct.unpack_from(">I", sidx_data, 32)[0]
541
+ reference_type = current_size >> 31
542
+ current_referenced_size = current_size & 0x7FFFFFFF
543
+
544
+ # Remove encryption overhead from referenced size
545
+ new_referenced_size = current_referenced_size - self.encryption_overhead
546
+ new_size = (reference_type << 31) | new_referenced_size
547
+ struct.pack_into(">I", sidx_data, 32, new_size)
548
+
549
+ return MP4Atom(b"sidx", len(sidx_data) + 8, sidx_data)
550
+
551
+ def _process_trak(self, trak: MP4Atom) -> MP4Atom:
552
+ """
553
+ Processes the 'trak' (Track) atom, which contains information about a single track in the movie.
554
+ This includes track header, media information, and other track-level metadata.
555
+
556
+ Args:
557
+ trak (MP4Atom): The 'trak' atom to process.
558
+
559
+ Returns:
560
+ MP4Atom: Processed 'trak' atom with updated track information.
561
+ """
562
+ parser = MP4Parser(trak.data)
563
+ new_trak_data = bytearray()
564
+
565
+ for atom in iter(parser.read_atom, None):
566
+ if atom.atom_type == b"mdia":
567
+ new_mdia = self._process_mdia(atom)
568
+ new_trak_data.extend(new_mdia.pack())
569
+ else:
570
+ new_trak_data.extend(atom.pack())
571
+
572
+ return MP4Atom(b"trak", len(new_trak_data) + 8, new_trak_data)
573
+
574
+ def _process_mdia(self, mdia: MP4Atom) -> MP4Atom:
575
+ """
576
+ Processes the 'mdia' (Media) atom, which contains media information for a track.
577
+ This includes media header, handler reference, and media information container.
578
+
579
+ Args:
580
+ mdia (MP4Atom): The 'mdia' atom to process.
581
+
582
+ Returns:
583
+ MP4Atom: Processed 'mdia' atom with updated media information.
584
+ """
585
+ parser = MP4Parser(mdia.data)
586
+ new_mdia_data = bytearray()
587
+
588
+ for atom in iter(parser.read_atom, None):
589
+ if atom.atom_type == b"minf":
590
+ new_minf = self._process_minf(atom)
591
+ new_mdia_data.extend(new_minf.pack())
592
+ else:
593
+ new_mdia_data.extend(atom.pack())
594
+
595
+ return MP4Atom(b"mdia", len(new_mdia_data) + 8, new_mdia_data)
596
+
597
+ def _process_minf(self, minf: MP4Atom) -> MP4Atom:
598
+ """
599
+ Processes the 'minf' (Media Information) atom, which contains information about the media data in a track.
600
+ This includes data information, sample table, and other media-level metadata.
601
+
602
+ Args:
603
+ minf (MP4Atom): The 'minf' atom to process.
604
+
605
+ Returns:
606
+ MP4Atom: Processed 'minf' atom with updated media information.
607
+ """
608
+ parser = MP4Parser(minf.data)
609
+ new_minf_data = bytearray()
610
+
611
+ for atom in iter(parser.read_atom, None):
612
+ if atom.atom_type == b"stbl":
613
+ new_stbl = self._process_stbl(atom)
614
+ new_minf_data.extend(new_stbl.pack())
615
+ else:
616
+ new_minf_data.extend(atom.pack())
617
+
618
+ return MP4Atom(b"minf", len(new_minf_data) + 8, new_minf_data)
619
+
620
+ def _process_stbl(self, stbl: MP4Atom) -> MP4Atom:
621
+ """
622
+ Processes the 'stbl' (Sample Table) atom, which contains information about the samples in a track.
623
+ This includes sample descriptions, sample sizes, sample times, and other sample-level metadata.
624
+
625
+ Args:
626
+ stbl (MP4Atom): The 'stbl' atom to process.
627
+
628
+ Returns:
629
+ MP4Atom: Processed 'stbl' atom with updated sample information.
630
+ """
631
+ parser = MP4Parser(stbl.data)
632
+ new_stbl_data = bytearray()
633
+
634
+ for atom in iter(parser.read_atom, None):
635
+ if atom.atom_type == b"stsd":
636
+ new_stsd = self._process_stsd(atom)
637
+ new_stbl_data.extend(new_stsd.pack())
638
+ else:
639
+ new_stbl_data.extend(atom.pack())
640
+
641
+ return MP4Atom(b"stbl", len(new_stbl_data) + 8, new_stbl_data)
642
+
643
+ def _process_stsd(self, stsd: MP4Atom) -> MP4Atom:
644
+ """
645
+ Processes the 'stsd' (Sample Description) atom, which contains descriptions of the sample entries in a track.
646
+ This includes codec information, sample entry details, and other sample description metadata.
647
+
648
+ Args:
649
+ stsd (MP4Atom): The 'stsd' atom to process.
650
+
651
+ Returns:
652
+ MP4Atom: Processed 'stsd' atom with updated sample descriptions.
653
+ """
654
+ parser = MP4Parser(stsd.data)
655
+ entry_count = struct.unpack_from(">I", parser.data, 4)[0]
656
+ new_stsd_data = bytearray(stsd.data[:8])
657
+
658
+ parser.position = 8 # Move past version_flags and entry_count
659
+
660
+ for _ in range(entry_count):
661
+ sample_entry = parser.read_atom()
662
+ if not sample_entry:
663
+ break
664
+
665
+ processed_entry = self._process_sample_entry(sample_entry)
666
+ new_stsd_data.extend(processed_entry.pack())
667
+
668
+ return MP4Atom(b"stsd", len(new_stsd_data) + 8, new_stsd_data)
669
+
670
+ def _process_sample_entry(self, entry: MP4Atom) -> MP4Atom:
671
+ """
672
+ Processes a sample entry atom, which contains information about a specific type of sample.
673
+ This includes codec-specific information and other sample entry details.
674
+
675
+ Args:
676
+ entry (MP4Atom): The sample entry atom to process.
677
+
678
+ Returns:
679
+ MP4Atom: Processed sample entry atom with updated information.
680
+ """
681
+ # Determine the size of fixed fields based on sample entry type
682
+ if entry.atom_type in {b"mp4a", b"enca"}:
683
+ fixed_size = 28 # 8 bytes for size, type and reserved, 20 bytes for fixed fields in Audio Sample Entry.
684
+ elif entry.atom_type in {b"mp4v", b"encv", b"avc1", b"hev1", b"hvc1"}:
685
+ fixed_size = 78 # 8 bytes for size, type and reserved, 70 bytes for fixed fields in Video Sample Entry.
686
+ else:
687
+ fixed_size = 16 # 8 bytes for size, type and reserved, 8 bytes for fixed fields in other Sample Entries.
688
+
689
+ new_entry_data = bytearray(entry.data[:fixed_size])
690
+ parser = MP4Parser(entry.data[fixed_size:])
691
+ codec_format = None
692
+
693
+ for atom in iter(parser.read_atom, None):
694
+ if atom.atom_type in {b"sinf", b"schi", b"tenc", b"schm"}:
695
+ if atom.atom_type == b"sinf":
696
+ codec_format = self._extract_codec_format(atom)
697
+ continue # Skip encryption-related atoms
698
+ new_entry_data.extend(atom.pack())
699
+
700
+ # Replace the atom type with the extracted codec format
701
+ new_type = codec_format if codec_format else entry.atom_type
702
+ return MP4Atom(new_type, len(new_entry_data) + 8, new_entry_data)
703
+
704
+ def _extract_codec_format(self, sinf: MP4Atom) -> bytes | None:
705
+ """
706
+ Extracts the codec format from the 'sinf' (Protection Scheme Information) atom.
707
+ This includes information about the original format of the protected content.
708
+
709
+ Args:
710
+ sinf (MP4Atom): The 'sinf' atom to extract from.
711
+
712
+ Returns:
713
+ bytes | None: The codec format or None if not found.
714
+ """
715
+ parser = MP4Parser(sinf.data)
716
+ for atom in iter(parser.read_atom, None):
717
+ if atom.atom_type == b"frma":
718
+ return atom.data
719
+ return None
720
+
721
+
722
+ def decrypt_segment(init_segment: bytes, segment_content: bytes, key_id: str, key: str) -> bytes:
723
+ """
724
+ Decrypts a CENC encrypted MP4 segment.
725
+
726
+ Args:
727
+ init_segment (bytes): Initialization segment data.
728
+ segment_content (bytes): Encrypted segment content.
729
+ key_id (str): Key ID in hexadecimal format.
730
+ key (str): Key in hexadecimal format.
731
+ """
732
+ key_map = {bytes.fromhex(key_id): bytes.fromhex(key)}
733
+ decrypter = MP4Decrypter(key_map)
734
+ decrypted_content = decrypter.decrypt_segment(init_segment + segment_content)
735
+ return decrypted_content
736
+
737
+
738
+ def cli():
739
+ """
740
+ Command line interface for decrypting a CENC encrypted MP4 segment.
741
+ """
742
+ init_segment = b""
743
+
744
+ if args.init and args.segment:
745
+ with open(args.init, "rb") as f:
746
+ init_segment = f.read()
747
+ with open(args.segment, "rb") as f:
748
+ segment_content = f.read()
749
+ elif args.combined_segment:
750
+ with open(args.combined_segment, "rb") as f:
751
+ segment_content = f.read()
752
+ else:
753
+ print("Usage: python mp4decrypt.py --help")
754
+ sys.exit(1)
755
+
756
+ try:
757
+ decrypted_segment = decrypt_segment(init_segment, segment_content, args.key_id, args.key)
758
+ print(f"Decrypted content size is {len(decrypted_segment)} bytes")
759
+ with open(args.output, "wb") as f:
760
+ f.write(decrypted_segment)
761
+ print(f"Decrypted segment written to {args.output}")
762
+ except Exception as e:
763
+ print(f"Error: {e}")
764
+ sys.exit(1)
765
+
766
+
767
+ if __name__ == "__main__":
768
+ arg_parser = argparse.ArgumentParser(description="Decrypts a MP4 init and media segment using CENC encryption.")
769
+ arg_parser.add_argument("--init", help="Path to the init segment file", required=False)
770
+ arg_parser.add_argument("--segment", help="Path to the media segment file", required=False)
771
+ arg_parser.add_argument(
772
+ "--combined_segment", help="Path to the combined init and media segment file", required=False
773
+ )
774
+ arg_parser.add_argument("--key_id", help="Key ID in hexadecimal format", required=True)
775
+ arg_parser.add_argument("--key", help="Key in hexadecimal format", required=True)
776
+ arg_parser.add_argument("--output", help="Path to the output file", required=True)
777
+ args = arg_parser.parse_args()
778
+ cli()
mediaflow_proxy/handlers.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import logging
3
+
4
+ import httpx
5
+ from fastapi import Request, Response, HTTPException
6
+ from pydantic import HttpUrl
7
+ from starlette.background import BackgroundTask
8
+
9
+ from .configs import settings
10
+ from .const import SUPPORTED_RESPONSE_HEADERS
11
+ from .mpd_processor import process_manifest, process_playlist, process_segment
12
+ from .utils.cache_utils import get_cached_mpd, get_cached_init_segment
13
+ from .utils.http_utils import (
14
+ Streamer,
15
+ DownloadError,
16
+ download_file_with_retry,
17
+ request_with_retry,
18
+ EnhancedStreamingResponse,
19
+ )
20
+ from .utils.m3u8_processor import M3U8Processor
21
+ from .utils.mpd_utils import pad_base64
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ async def handle_hls_stream_proxy(
27
+ request: Request, destination: str, headers: dict, key_url: HttpUrl = None, verify_ssl: bool = True
28
+ ):
29
+ """
30
+ Handles the HLS stream proxy request, fetching and processing the m3u8 playlist or streaming the content.
31
+
32
+ Args:
33
+ request (Request): The incoming HTTP request.
34
+ destination (str): The destination URL to fetch the content from.
35
+ headers (dict): The headers to include in the request.
36
+ key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None.
37
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
38
+
39
+ Returns:
40
+ Response: The HTTP response with the processed m3u8 playlist or streamed content.
41
+ """
42
+ client = httpx.AsyncClient(
43
+ follow_redirects=True,
44
+ timeout=httpx.Timeout(30.0),
45
+ limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
46
+ proxy=settings.proxy_url,
47
+ verify=verify_ssl,
48
+ )
49
+ streamer = Streamer(client)
50
+ try:
51
+ if destination.endswith((".m3u", ".m3u8")):
52
+ return await fetch_and_process_m3u8(streamer, destination, headers, request, key_url)
53
+
54
+ response = await streamer.head(destination, headers)
55
+ if "mpegurl" in response.headers.get("content-type", "").lower():
56
+ return await fetch_and_process_m3u8(streamer, destination, headers, request, key_url)
57
+
58
+ headers.update({"range": headers.get("range", "bytes=0-")})
59
+ # clean up the headers to only include the necessary headers and remove acl headers
60
+ response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
61
+
62
+ if transfer_encoding := response_headers.get("transfer-encoding"):
63
+ if "chunked" not in transfer_encoding:
64
+ transfer_encoding += ", chunked"
65
+ else:
66
+ transfer_encoding = "chunked"
67
+ response_headers["transfer-encoding"] = transfer_encoding
68
+
69
+ return EnhancedStreamingResponse(
70
+ streamer.stream_content(destination, headers),
71
+ status_code=response.status_code,
72
+ headers=response_headers,
73
+ background=BackgroundTask(streamer.close),
74
+ )
75
+ except httpx.HTTPStatusError as e:
76
+ await client.aclose()
77
+ logger.error(f"Upstream service error while handling request: {e}")
78
+ return Response(status_code=e.response.status_code, content=f"Upstream service error: {e}")
79
+ except DownloadError as e:
80
+ await client.aclose()
81
+ logger.error(f"Error downloading {destination}: {e}")
82
+ return Response(status_code=e.status_code, content=str(e))
83
+ except Exception as e:
84
+ await client.aclose()
85
+ logger.error(f"Internal server error while handling request: {e}")
86
+ return Response(status_code=502, content=f"Internal server error: {e}")
87
+
88
+
89
+ async def proxy_stream(method: str, video_url: str, headers: dict, verify_ssl: bool = True):
90
+ """
91
+ Proxies the stream request to the given video URL.
92
+
93
+ Args:
94
+ method (str): The HTTP method (e.g., GET, HEAD).
95
+ video_url (str): The URL of the video to stream.
96
+ headers (dict): The headers to include in the request.
97
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
98
+
99
+ Returns:
100
+ Response: The HTTP response with the streamed content.
101
+ """
102
+ return await handle_stream_request(method, video_url, headers, verify_ssl)
103
+
104
+
105
+ async def handle_stream_request(method: str, video_url: str, headers: dict, verify_ssl: bool = True):
106
+ """
107
+ Handles the stream request, fetching the content from the video URL and streaming it.
108
+
109
+ Args:
110
+ method (str): The HTTP method (e.g., GET, HEAD).
111
+ video_url (str): The URL of the video to stream.
112
+ headers (dict): The headers to include in the request.
113
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
114
+
115
+ Returns:
116
+ Response: The HTTP response with the streamed content.
117
+ """
118
+ client = httpx.AsyncClient(
119
+ follow_redirects=True,
120
+ timeout=httpx.Timeout(30.0),
121
+ limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
122
+ proxy=settings.proxy_url,
123
+ verify=verify_ssl,
124
+ )
125
+ streamer = Streamer(client)
126
+ try:
127
+ response = await streamer.head(video_url, headers)
128
+ # clean up the headers to only include the necessary headers and remove acl headers
129
+ response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
130
+ if transfer_encoding := response_headers.get("transfer-encoding"):
131
+ if "chunked" not in transfer_encoding:
132
+ transfer_encoding += ", chunked"
133
+ else:
134
+ transfer_encoding = "chunked"
135
+ response_headers["transfer-encoding"] = transfer_encoding
136
+
137
+ if method == "HEAD":
138
+ await streamer.close()
139
+ return Response(headers=response_headers, status_code=response.status_code)
140
+ else:
141
+ return EnhancedStreamingResponse(
142
+ streamer.stream_content(video_url, headers),
143
+ headers=response_headers,
144
+ status_code=response.status_code,
145
+ background=BackgroundTask(streamer.close),
146
+ )
147
+ except httpx.HTTPStatusError as e:
148
+ await client.aclose()
149
+ logger.error(f"Upstream service error while handling {method} request: {e}")
150
+ return Response(status_code=e.response.status_code, content=f"Upstream service error: {e}")
151
+ except DownloadError as e:
152
+ await client.aclose()
153
+ logger.error(f"Error downloading {video_url}: {e}")
154
+ return Response(status_code=e.status_code, content=str(e))
155
+ except Exception as e:
156
+ await client.aclose()
157
+ logger.error(f"Internal server error while handling {method} request: {e}")
158
+ return Response(status_code=502, content=f"Internal server error: {e}")
159
+
160
+
161
+ async def fetch_and_process_m3u8(
162
+ streamer: Streamer, url: str, headers: dict, request: Request, key_url: HttpUrl = None
163
+ ):
164
+ """
165
+ Fetches and processes the m3u8 playlist, converting it to an HLS playlist.
166
+
167
+ Args:
168
+ streamer (Streamer): The HTTP client to use for streaming.
169
+ url (str): The URL of the m3u8 playlist.
170
+ headers (dict): The headers to include in the request.
171
+ request (Request): The incoming HTTP request.
172
+ key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None.
173
+
174
+ Returns:
175
+ Response: The HTTP response with the processed m3u8 playlist.
176
+ """
177
+ try:
178
+ content = await streamer.get_text(url, headers)
179
+ processor = M3U8Processor(request, key_url)
180
+ processed_content = await processor.process_m3u8(content, str(streamer.response.url))
181
+ return Response(
182
+ content=processed_content,
183
+ media_type="application/vnd.apple.mpegurl",
184
+ headers={
185
+ "Content-Disposition": "inline",
186
+ "Accept-Ranges": "none",
187
+ },
188
+ )
189
+ except httpx.HTTPStatusError as e:
190
+ logger.error(f"HTTP error while fetching m3u8: {e}")
191
+ return Response(status_code=e.response.status_code, content=str(e))
192
+ except DownloadError as e:
193
+ logger.error(f"Error downloading m3u8: {url}")
194
+ return Response(status_code=502, content=str(e))
195
+ except Exception as e:
196
+ logger.exception(f"Unexpected error while processing m3u8: {e}")
197
+ return Response(status_code=502, content=str(e))
198
+ finally:
199
+ await streamer.close()
200
+
201
+
202
+ async def handle_drm_key_data(key_id, key, drm_info):
203
+ """
204
+ Handles the DRM key data, retrieving the key ID and key from the DRM info if not provided.
205
+
206
+ Args:
207
+ key_id (str): The DRM key ID.
208
+ key (str): The DRM key.
209
+ drm_info (dict): The DRM information from the MPD manifest.
210
+
211
+ Returns:
212
+ tuple: The key ID and key.
213
+ """
214
+ if drm_info and not drm_info.get("isDrmProtected"):
215
+ return None, None
216
+
217
+ if not key_id or not key:
218
+ if "keyId" in drm_info and "key" in drm_info:
219
+ key_id = drm_info["keyId"]
220
+ key = drm_info["key"]
221
+ elif "laUrl" in drm_info and "keyId" in drm_info:
222
+ raise HTTPException(status_code=400, detail="LA URL is not supported yet")
223
+ else:
224
+ raise HTTPException(
225
+ status_code=400, detail="Unable to determine key_id and key, and they were not provided"
226
+ )
227
+
228
+ return key_id, key
229
+
230
+
231
+ async def get_manifest(
232
+ request: Request, mpd_url: str, headers: dict, key_id: str = None, key: str = None, verify_ssl: bool = True
233
+ ):
234
+ """
235
+ Retrieves and processes the MPD manifest, converting it to an HLS manifest.
236
+
237
+ Args:
238
+ request (Request): The incoming HTTP request.
239
+ mpd_url (str): The URL of the MPD manifest.
240
+ headers (dict): The headers to include in the request.
241
+ key_id (str, optional): The DRM key ID. Defaults to None.
242
+ key (str, optional): The DRM key. Defaults to None.
243
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
244
+
245
+ Returns:
246
+ Response: The HTTP response with the HLS manifest.
247
+ """
248
+ try:
249
+ mpd_dict = await get_cached_mpd(
250
+ mpd_url, headers=headers, parse_drm=not key_id and not key, verify_ssl=verify_ssl
251
+ )
252
+ except DownloadError as e:
253
+ raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
254
+ drm_info = mpd_dict.get("drmInfo", {})
255
+
256
+ if drm_info and not drm_info.get("isDrmProtected"):
257
+ # For non-DRM protected MPD, we still create an HLS manifest
258
+ return await process_manifest(request, mpd_dict, None, None)
259
+
260
+ key_id, key = await handle_drm_key_data(key_id, key, drm_info)
261
+
262
+ # check if the provided key_id and key are valid
263
+ if key_id and len(key_id) != 32:
264
+ key_id = base64.urlsafe_b64decode(pad_base64(key_id)).hex()
265
+ if key and len(key) != 32:
266
+ key = base64.urlsafe_b64decode(pad_base64(key)).hex()
267
+
268
+ return await process_manifest(request, mpd_dict, key_id, key)
269
+
270
+
271
+ async def get_playlist(
272
+ request: Request,
273
+ mpd_url: str,
274
+ profile_id: str,
275
+ headers: dict,
276
+ key_id: str = None,
277
+ key: str = None,
278
+ verify_ssl: bool = True,
279
+ ):
280
+ """
281
+ Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
282
+
283
+ Args:
284
+ request (Request): The incoming HTTP request.
285
+ mpd_url (str): The URL of the MPD manifest.
286
+ profile_id (str): The profile ID to generate the playlist for.
287
+ headers (dict): The headers to include in the request.
288
+ key_id (str, optional): The DRM key ID. Defaults to None.
289
+ key (str, optional): The DRM key. Defaults to None.
290
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
291
+
292
+ Returns:
293
+ Response: The HTTP response with the HLS playlist.
294
+ """
295
+ mpd_dict = await get_cached_mpd(
296
+ mpd_url,
297
+ headers=headers,
298
+ parse_drm=not key_id and not key,
299
+ parse_segment_profile_id=profile_id,
300
+ verify_ssl=verify_ssl,
301
+ )
302
+ return await process_playlist(request, mpd_dict, profile_id)
303
+
304
+
305
+ async def get_segment(
306
+ init_url: str,
307
+ segment_url: str,
308
+ mimetype: str,
309
+ headers: dict,
310
+ key_id: str = None,
311
+ key: str = None,
312
+ verify_ssl: bool = True,
313
+ ):
314
+ """
315
+ Retrieves and processes a media segment, decrypting it if necessary.
316
+
317
+ Args:
318
+ init_url (str): The URL of the initialization segment.
319
+ segment_url (str): The URL of the media segment.
320
+ mimetype (str): The MIME type of the segment.
321
+ headers (dict): The headers to include in the request.
322
+ key_id (str, optional): The DRM key ID. Defaults to None.
323
+ key (str, optional): The DRM key. Defaults to None.
324
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
325
+
326
+ Returns:
327
+ Response: The HTTP response with the processed segment.
328
+ """
329
+ try:
330
+ init_content = await get_cached_init_segment(init_url, headers, verify_ssl)
331
+ segment_content = await download_file_with_retry(segment_url, headers, verify_ssl=verify_ssl)
332
+ except DownloadError as e:
333
+ raise HTTPException(status_code=e.status_code, detail=f"Failed to download segment: {e.message}")
334
+ return await process_segment(init_content, segment_content, mimetype, key_id, key)
335
+
336
+
337
+ async def get_public_ip():
338
+ """
339
+ Retrieves the public IP address of the MediaFlow proxy.
340
+
341
+ Returns:
342
+ Response: The HTTP response with the public IP address.
343
+ """
344
+ ip_address_data = await request_with_retry("GET", "https://api.ipify.org?format=json", {})
345
+ return ip_address_data.json()
mediaflow_proxy/main.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from importlib import resources
3
+
4
+ from fastapi import FastAPI, Depends, Security, HTTPException
5
+ from fastapi.security import APIKeyQuery, APIKeyHeader
6
+ from starlette.responses import RedirectResponse
7
+ from starlette.staticfiles import StaticFiles
8
+
9
+ from mediaflow_proxy.configs import settings
10
+ from mediaflow_proxy.routes import proxy_router
11
+
12
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
13
+ app = FastAPI()
14
+ api_password_query = APIKeyQuery(name="api_password", auto_error=False)
15
+ api_password_header = APIKeyHeader(name="api_password", auto_error=False)
16
+
17
+
18
+ async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)):
19
+ """
20
+ Verifies the API key for the request.
21
+
22
+ Args:
23
+ api_key (str): The API key to validate.
24
+ api_key_alt (str): The alternative API key to validate.
25
+
26
+ Raises:
27
+ HTTPException: If the API key is invalid.
28
+ """
29
+ if api_key == settings.api_password or api_key_alt == settings.api_password:
30
+ return
31
+
32
+ raise HTTPException(status_code=403, detail="Could not validate credentials")
33
+
34
+
35
+ @app.get("/health")
36
+ async def health_check():
37
+ return {"status": "healthy"}
38
+
39
+
40
+ @app.get("/favicon.ico")
41
+ async def get_favicon():
42
+ return RedirectResponse(url="/logo.png")
43
+
44
+
45
+ app.include_router(proxy_router, prefix="/proxy", tags=["proxy"], dependencies=[Depends(verify_api_key)])
46
+
47
+ static_path = resources.files("mediaflow_proxy").joinpath("static")
48
+ app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static")
49
+
50
+
51
+ def run():
52
+ import uvicorn
53
+
54
+ uvicorn.run(app, host="127.0.0.1", port=8888)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ run()
mediaflow_proxy/mpd_processor.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import time
4
+ from datetime import datetime, timezone, timedelta
5
+
6
+ from fastapi import Request, Response, HTTPException
7
+
8
+ from mediaflow_proxy.configs import settings
9
+ from mediaflow_proxy.drm.decrypter import decrypt_segment
10
+ from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ async def process_manifest(request: Request, mpd_dict: dict, key_id: str = None, key: str = None) -> Response:
16
+ """
17
+ Processes the MPD manifest and converts it to an HLS manifest.
18
+
19
+ Args:
20
+ request (Request): The incoming HTTP request.
21
+ mpd_dict (dict): The MPD manifest data.
22
+ key_id (str, optional): The DRM key ID. Defaults to None.
23
+ key (str, optional): The DRM key. Defaults to None.
24
+
25
+ Returns:
26
+ Response: The HLS manifest as an HTTP response.
27
+ """
28
+ hls_content = build_hls(mpd_dict, request, key_id, key)
29
+ return Response(content=hls_content, media_type="application/vnd.apple.mpegurl")
30
+
31
+
32
+ async def process_playlist(request: Request, mpd_dict: dict, profile_id: str) -> Response:
33
+ """
34
+ Processes the MPD manifest and converts it to an HLS playlist for a specific profile.
35
+
36
+ Args:
37
+ request (Request): The incoming HTTP request.
38
+ mpd_dict (dict): The MPD manifest data.
39
+ profile_id (str): The profile ID to generate the playlist for.
40
+
41
+ Returns:
42
+ Response: The HLS playlist as an HTTP response.
43
+
44
+ Raises:
45
+ HTTPException: If the profile is not found in the MPD manifest.
46
+ """
47
+ matching_profiles = [p for p in mpd_dict["profiles"] if p["id"] == profile_id]
48
+ if not matching_profiles:
49
+ raise HTTPException(status_code=404, detail="Profile not found")
50
+
51
+ hls_content = build_hls_playlist(mpd_dict, matching_profiles, request)
52
+ return Response(content=hls_content, media_type="application/vnd.apple.mpegurl")
53
+
54
+
55
+ async def process_segment(
56
+ init_content: bytes,
57
+ segment_content: bytes,
58
+ mimetype: str,
59
+ key_id: str = None,
60
+ key: str = None,
61
+ ) -> Response:
62
+ """
63
+ Processes and decrypts a media segment.
64
+
65
+ Args:
66
+ init_content (bytes): The initialization segment content.
67
+ segment_content (bytes): The media segment content.
68
+ mimetype (str): The MIME type of the segment.
69
+ key_id (str, optional): The DRM key ID. Defaults to None.
70
+ key (str, optional): The DRM key. Defaults to None.
71
+
72
+ Returns:
73
+ Response: The decrypted segment as an HTTP response.
74
+ """
75
+ if key_id and key:
76
+ # For DRM protected content
77
+ now = time.time()
78
+ decrypted_content = decrypt_segment(init_content, segment_content, key_id, key)
79
+ logger.info(f"Decryption of {mimetype} segment took {time.time() - now:.4f} seconds")
80
+ else:
81
+ # For non-DRM protected content, we just concatenate init and segment content
82
+ decrypted_content = init_content + segment_content
83
+
84
+ return Response(content=decrypted_content, media_type=mimetype)
85
+
86
+
87
+ def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str:
88
+ """
89
+ Builds an HLS manifest from the MPD manifest.
90
+
91
+ Args:
92
+ mpd_dict (dict): The MPD manifest data.
93
+ request (Request): The incoming HTTP request.
94
+ key_id (str, optional): The DRM key ID. Defaults to None.
95
+ key (str, optional): The DRM key. Defaults to None.
96
+
97
+ Returns:
98
+ str: The HLS manifest as a string.
99
+ """
100
+ hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
101
+ query_params = dict(request.query_params)
102
+
103
+ video_profiles = {}
104
+ audio_profiles = {}
105
+
106
+ # Get the base URL for the playlist_endpoint endpoint
107
+ proxy_url = request.url_for("playlist_endpoint")
108
+ proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
109
+
110
+ for profile in mpd_dict["profiles"]:
111
+ query_params.update({"profile_id": profile["id"], "key_id": key_id or "", "key": key or ""})
112
+ playlist_url = encode_mediaflow_proxy_url(
113
+ proxy_url,
114
+ query_params=query_params,
115
+ )
116
+
117
+ if "video" in profile["mimeType"]:
118
+ video_profiles[profile["id"]] = (profile, playlist_url)
119
+ elif "audio" in profile["mimeType"]:
120
+ audio_profiles[profile["id"]] = (profile, playlist_url)
121
+
122
+ # Add audio streams
123
+ for i, (profile, playlist_url) in enumerate(audio_profiles.values()):
124
+ is_default = "YES" if i == 0 else "NO" # Set the first audio track as default
125
+ hls.append(
126
+ f'#EXT-X-MEDIA:TYPE=AUDIO,GROUP-ID="audio",NAME="{profile["id"]}",DEFAULT={is_default},AUTOSELECT={is_default},LANGUAGE="{profile.get("lang", "und")}",URI="{playlist_url}"'
127
+ )
128
+
129
+ # Add video streams
130
+ for profile, playlist_url in video_profiles.values():
131
+ hls.append(
132
+ f'#EXT-X-STREAM-INF:BANDWIDTH={profile["bandwidth"]},RESOLUTION={profile["width"]}x{profile["height"]},CODECS="{profile["codecs"]}",FRAME-RATE={profile["frameRate"]},AUDIO="audio"'
133
+ )
134
+ hls.append(playlist_url)
135
+
136
+ return "\n".join(hls)
137
+
138
+
139
+ def build_hls_playlist(mpd_dict: dict, profiles: list[dict], request: Request) -> str:
140
+ """
141
+ Builds an HLS playlist from the MPD manifest for specific profiles.
142
+
143
+ Args:
144
+ mpd_dict (dict): The MPD manifest data.
145
+ profiles (list[dict]): The profiles to include in the playlist.
146
+ request (Request): The incoming HTTP request.
147
+
148
+ Returns:
149
+ str: The HLS playlist as a string.
150
+ """
151
+ hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
152
+
153
+ added_segments = 0
154
+ current_time = datetime.now(timezone.utc)
155
+ live_stream_delay = timedelta(seconds=settings.mpd_live_stream_delay)
156
+ target_end_time = current_time - live_stream_delay
157
+
158
+ proxy_url = request.url_for("segment_endpoint")
159
+ proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
160
+
161
+ for index, profile in enumerate(profiles):
162
+ segments = profile["segments"]
163
+ if not segments:
164
+ logger.warning(f"No segments found for profile {profile['id']}")
165
+ continue
166
+
167
+ # Add headers for only the first profile
168
+ if index == 0:
169
+ sequence = segments[0]["number"]
170
+ extinf_values = [f["extinf"] for f in segments if "extinf" in f]
171
+ target_duration = math.ceil(max(extinf_values)) if extinf_values else 3
172
+ hls.extend(
173
+ [
174
+ f"#EXT-X-TARGETDURATION:{target_duration}",
175
+ f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
176
+ ]
177
+ )
178
+ if mpd_dict["isLive"]:
179
+ hls.append("#EXT-X-PLAYLIST-TYPE:EVENT")
180
+ else:
181
+ hls.append("#EXT-X-PLAYLIST-TYPE:VOD")
182
+
183
+ init_url = profile["initUrl"]
184
+
185
+ query_params = dict(request.query_params)
186
+ query_params.pop("profile_id", None)
187
+ query_params.pop("d", None)
188
+
189
+ for segment in segments:
190
+ if mpd_dict["isLive"]:
191
+ if segment["end_time"] > target_end_time:
192
+ continue
193
+ hls.append(f"#EXT-X-PROGRAM-DATE-TIME:{segment['program_date_time']}")
194
+ hls.append(f'#EXTINF:{segment["extinf"]:.3f},')
195
+ query_params.update(
196
+ {"init_url": init_url, "segment_url": segment["media"], "mime_type": profile["mimeType"]}
197
+ )
198
+ hls.append(
199
+ encode_mediaflow_proxy_url(
200
+ proxy_url,
201
+ query_params=query_params,
202
+ )
203
+ )
204
+ added_segments += 1
205
+
206
+ if not mpd_dict["isLive"]:
207
+ hls.append("#EXT-X-ENDLIST")
208
+
209
+ logger.info(f"Added {added_segments} segments to HLS playlist")
210
+ return "\n".join(hls)
mediaflow_proxy/routes.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import Request, Depends, APIRouter
2
+ from pydantic import HttpUrl
3
+
4
+ from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip
5
+ from .utils.http_utils import get_proxy_headers
6
+
7
+ proxy_router = APIRouter()
8
+
9
+
10
+ @proxy_router.head("/hls")
11
+ @proxy_router.get("/hls")
12
+ async def hls_stream_proxy(
13
+ request: Request,
14
+ d: HttpUrl,
15
+ headers: dict = Depends(get_proxy_headers),
16
+ key_url: HttpUrl | None = None,
17
+ verify_ssl: bool = False,
18
+ ):
19
+ """
20
+ Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
21
+
22
+ Args:
23
+ request (Request): The incoming HTTP request.
24
+ d (HttpUrl): The destination URL to fetch the content from.
25
+ key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)
26
+ headers (dict): The headers to include in the request.
27
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
28
+
29
+ Returns:
30
+ Response: The HTTP response with the processed m3u8 playlist or streamed content.
31
+ """
32
+ destination = str(d)
33
+ return await handle_hls_stream_proxy(request, destination, headers, key_url, verify_ssl)
34
+
35
+
36
+ @proxy_router.head("/stream")
37
+ @proxy_router.get("/stream")
38
+ async def proxy_stream_endpoint(
39
+ request: Request, d: HttpUrl, headers: dict = Depends(get_proxy_headers), verify_ssl: bool = False
40
+ ):
41
+ """
42
+ Proxies stream requests to the given video URL.
43
+
44
+ Args:
45
+ request (Request): The incoming HTTP request.
46
+ d (HttpUrl): The URL of the video to stream.
47
+ headers (dict): The headers to include in the request.
48
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
49
+
50
+ Returns:
51
+ Response: The HTTP response with the streamed content.
52
+ """
53
+ headers.update({"range": headers.get("range", "bytes=0-")})
54
+ return await proxy_stream(request.method, str(d), headers, verify_ssl)
55
+
56
+
57
+ @proxy_router.get("/mpd/manifest")
58
+ async def manifest_endpoint(
59
+ request: Request,
60
+ d: HttpUrl,
61
+ headers: dict = Depends(get_proxy_headers),
62
+ key_id: str = None,
63
+ key: str = None,
64
+ verify_ssl: bool = False,
65
+ ):
66
+ """
67
+ Retrieves and processes the MPD manifest, converting it to an HLS manifest.
68
+
69
+ Args:
70
+ request (Request): The incoming HTTP request.
71
+ d (HttpUrl): The URL of the MPD manifest.
72
+ headers (dict): The headers to include in the request.
73
+ key_id (str, optional): The DRM key ID. Defaults to None.
74
+ key (str, optional): The DRM key. Defaults to None.
75
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
76
+
77
+ Returns:
78
+ Response: The HTTP response with the HLS manifest.
79
+ """
80
+ return await get_manifest(request, str(d), headers, key_id, key, verify_ssl)
81
+
82
+
83
+ @proxy_router.get("/mpd/playlist")
84
+ async def playlist_endpoint(
85
+ request: Request,
86
+ d: HttpUrl,
87
+ profile_id: str,
88
+ headers: dict = Depends(get_proxy_headers),
89
+ key_id: str = None,
90
+ key: str = None,
91
+ verify_ssl: bool = False,
92
+ ):
93
+ """
94
+ Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
95
+
96
+ Args:
97
+ request (Request): The incoming HTTP request.
98
+ d (HttpUrl): The URL of the MPD manifest.
99
+ profile_id (str): The profile ID to generate the playlist for.
100
+ headers (dict): The headers to include in the request.
101
+ key_id (str, optional): The DRM key ID. Defaults to None.
102
+ key (str, optional): The DRM key. Defaults to None.
103
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
104
+
105
+ Returns:
106
+ Response: The HTTP response with the HLS playlist.
107
+ """
108
+ return await get_playlist(request, str(d), profile_id, headers, key_id, key, verify_ssl)
109
+
110
+
111
+ @proxy_router.get("/mpd/segment")
112
+ async def segment_endpoint(
113
+ init_url: HttpUrl,
114
+ segment_url: HttpUrl,
115
+ mime_type: str,
116
+ headers: dict = Depends(get_proxy_headers),
117
+ key_id: str = None,
118
+ key: str = None,
119
+ verify_ssl: bool = False,
120
+ ):
121
+ """
122
+ Retrieves and processes a media segment, decrypting it if necessary.
123
+
124
+ Args:
125
+ init_url (HttpUrl): The URL of the initialization segment.
126
+ segment_url (HttpUrl): The URL of the media segment.
127
+ mime_type (str): The MIME type of the segment.
128
+ headers (dict): The headers to include in the request.
129
+ key_id (str, optional): The DRM key ID. Defaults to None.
130
+ key (str, optional): The DRM key. Defaults to None.
131
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
132
+
133
+ Returns:
134
+ Response: The HTTP response with the processed segment.
135
+ """
136
+ return await get_segment(str(init_url), str(segment_url), mime_type, headers, key_id, key, verify_ssl)
137
+
138
+
139
+ @proxy_router.get("/ip")
140
+ async def get_mediaflow_proxy_public_ip():
141
+ """
142
+ Retrieves the public IP address of the MediaFlow proxy server.
143
+
144
+ Returns:
145
+ Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
146
+ """
147
+ return await get_public_ip()
mediaflow_proxy/static/index.html ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>MediaFlow Proxy</title>
7
+ <link rel="icon" href="/logo.png" type="image/x-icon">
8
+ <style>
9
+ body {
10
+ font-family: Arial, sans-serif;
11
+ line-height: 1.6;
12
+ color: #333;
13
+ max-width: 800px;
14
+ margin: 0 auto;
15
+ padding: 20px;
16
+ background-color: #f9f9f9;
17
+ }
18
+
19
+ header {
20
+ background-color: #90aacc;
21
+ color: #fff;
22
+ padding: 10px 0;
23
+ text-align: center;
24
+ }
25
+
26
+ header img {
27
+ width: 200px;
28
+ height: 200px;
29
+ vertical-align: middle;
30
+ border-radius: 15px;
31
+ }
32
+
33
+ header h1 {
34
+ display: inline;
35
+ margin-left: 20px;
36
+ font-size: 36px;
37
+ }
38
+
39
+ .feature {
40
+ background-color: #f4f4f4;
41
+ border-left: 4px solid #3498db;
42
+ padding: 10px;
43
+ margin-bottom: 10px;
44
+ }
45
+
46
+ a {
47
+ color: #3498db;
48
+ }
49
+ </style>
50
+ </head>
51
+ <body>
52
+ <header>
53
+ <img src="/logo.png" alt="MediaFlow Proxy Logo">
54
+ <h1>MediaFlow Proxy</h1>
55
+ </header>
56
+ <p>A high-performance proxy server for streaming media, supporting HTTP(S), HLS, and MPEG-DASH with real-time DRM decryption.</p>
57
+
58
+ <h2>Key Features</h2>
59
+ <div class="feature">Convert MPEG-DASH streams (DRM-protected and non-protected) to HLS</div>
60
+ <div class="feature">Support for Clear Key DRM-protected MPD DASH streams</div>
61
+ <div class="feature">Handle both live and video-on-demand (VOD) DASH streams</div>
62
+ <div class="feature">Proxy HTTP/HTTPS links with custom headers</div>
63
+ <div class="feature">Proxy and modify HLS (M3U8) streams in real-time with custom headers and key URL modifications for bypassing some sneaky restrictions.</div>
64
+ <div class="feature">Protect against unauthorized access and network bandwidth abuses</div>
65
+
66
+ <h2>Getting Started</h2>
67
+ <p>Visit the <a href="https://github.com/mhdzumair/mediaflow-proxy">GitHub repository</a> for installation instructions and documentation.</p>
68
+
69
+ <h2>Premium Hosted Service</h2>
70
+ <p>For a hassle-free experience, check out <a href="https://store.elfhosted.com/product/mediaflow-proxy">premium hosted service on ElfHosted</a>.</p>
71
+
72
+ <h2>API Documentation</h2>
73
+ <p>Explore the <a href="/docs">Swagger UI</a> for comprehensive details about the API endpoints and their usage.</p>
74
+
75
+ </body>
76
+ </html>
mediaflow_proxy/static/logo.png ADDED
mediaflow_proxy/utils/__init__.py ADDED
File without changes
mediaflow_proxy/utils/cache_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+
4
+ from cachetools import TTLCache
5
+
6
+ from .http_utils import download_file_with_retry
7
+ from .mpd_utils import parse_mpd, parse_mpd_dict
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # cache dictionary
12
+ mpd_cache = TTLCache(maxsize=100, ttl=300) # 5 minutes default TTL
13
+ init_segment_cache = TTLCache(maxsize=100, ttl=3600) # 1 hour default TTL
14
+
15
+
16
+ async def get_cached_mpd(
17
+ mpd_url: str, headers: dict, parse_drm: bool, parse_segment_profile_id: str | None = None, verify_ssl: bool = True
18
+ ) -> dict:
19
+ """
20
+ Retrieves and caches the MPD manifest, parsing it if not already cached.
21
+
22
+ Args:
23
+ mpd_url (str): The URL of the MPD manifest.
24
+ headers (dict): The headers to include in the request.
25
+ parse_drm (bool): Whether to parse DRM information.
26
+ parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
27
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
28
+
29
+ Returns:
30
+ dict: The parsed MPD manifest data.
31
+ """
32
+ current_time = datetime.datetime.now(datetime.UTC)
33
+ if mpd_url in mpd_cache and mpd_cache[mpd_url]["expires"] > current_time:
34
+ logger.info(f"Using cached MPD for {mpd_url}")
35
+ return parse_mpd_dict(mpd_cache[mpd_url]["mpd"], mpd_url, parse_drm, parse_segment_profile_id)
36
+
37
+ mpd_dict = parse_mpd(await download_file_with_retry(mpd_url, headers, verify_ssl=verify_ssl))
38
+ parsed_mpd_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
39
+ current_time = datetime.datetime.now(datetime.UTC)
40
+ expiration_time = current_time + datetime.timedelta(seconds=parsed_mpd_dict.get("minimumUpdatePeriod", 300))
41
+ mpd_cache[mpd_url] = {"mpd": mpd_dict, "expires": expiration_time}
42
+ return parsed_mpd_dict
43
+
44
+
45
+ async def get_cached_init_segment(init_url: str, headers: dict, verify_ssl: bool = True) -> bytes:
46
+ """
47
+ Retrieves and caches the initialization segment.
48
+
49
+ Args:
50
+ init_url (str): The URL of the initialization segment.
51
+ headers (dict): The headers to include in the request.
52
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
53
+
54
+ Returns:
55
+ bytes: The initialization segment content.
56
+ """
57
+ if init_url not in init_segment_cache:
58
+ init_content = await download_file_with_retry(init_url, headers, verify_ssl=verify_ssl)
59
+ init_segment_cache[init_url] = init_content
60
+ return init_segment_cache[init_url]
mediaflow_proxy/utils/http_utils.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import typing
3
+ from functools import partial
4
+ from urllib import parse
5
+
6
+ import anyio
7
+ import httpx
8
+ import tenacity
9
+ from fastapi import Response
10
+ from starlette.background import BackgroundTask
11
+ from starlette.concurrency import iterate_in_threadpool
12
+ from starlette.requests import Request
13
+ from starlette.types import Receive, Send, Scope
14
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
15
+
16
+ from mediaflow_proxy.configs import settings
17
+ from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class DownloadError(Exception):
23
+ def __init__(self, status_code, message):
24
+ self.status_code = status_code
25
+ self.message = message
26
+ super().__init__(message)
27
+
28
+
29
+ @retry(
30
+ stop=stop_after_attempt(3),
31
+ wait=wait_exponential(multiplier=1, min=4, max=10),
32
+ retry=retry_if_exception_type(DownloadError),
33
+ )
34
+ async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
35
+ """
36
+ Fetches a URL with retry logic.
37
+
38
+ Args:
39
+ client (httpx.AsyncClient): The HTTP client to use for the request.
40
+ method (str): The HTTP method to use (e.g., GET, POST).
41
+ url (str): The URL to fetch.
42
+ headers (dict): The headers to include in the request.
43
+ follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
44
+ **kwargs: Additional arguments to pass to the request.
45
+
46
+ Returns:
47
+ httpx.Response: The HTTP response.
48
+
49
+ Raises:
50
+ DownloadError: If the request fails after retries.
51
+ """
52
+ try:
53
+ response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
54
+ response.raise_for_status()
55
+ return response
56
+ except httpx.TimeoutException:
57
+ logger.warning(f"Timeout while downloading {url}")
58
+ raise DownloadError(409, f"Timeout while downloading {url}")
59
+ except httpx.HTTPStatusError as e:
60
+ logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
61
+ # if e.response.status_code == 404:
62
+ # logger.error(f"Segment Resource not found: {url}")
63
+ # raise e
64
+ raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
65
+ except Exception as e:
66
+ logger.error(f"Error downloading {url}: {e}")
67
+ raise
68
+
69
+
70
+ class Streamer:
71
+ def __init__(self, client):
72
+ """
73
+ Initializes the Streamer with an HTTP client.
74
+
75
+ Args:
76
+ client (httpx.AsyncClient): The HTTP client to use for streaming.
77
+ """
78
+ self.client = client
79
+ self.response = None
80
+
81
+ async def stream_content(self, url: str, headers: dict):
82
+ """
83
+ Streams content from a URL.
84
+
85
+ Args:
86
+ url (str): The URL to stream content from.
87
+ headers (dict): The headers to include in the request.
88
+
89
+ Yields:
90
+ bytes: Chunks of the streamed content.
91
+ """
92
+ async with self.client.stream("GET", url, headers=headers, follow_redirects=True) as self.response:
93
+ self.response.raise_for_status()
94
+ async for chunk in self.response.aiter_raw():
95
+ yield chunk
96
+
97
+ async def head(self, url: str, headers: dict):
98
+ """
99
+ Sends a HEAD request to a URL.
100
+
101
+ Args:
102
+ url (str): The URL to send the HEAD request to.
103
+ headers (dict): The headers to include in the request.
104
+
105
+ Returns:
106
+ httpx.Response: The HTTP response.
107
+ """
108
+ try:
109
+ self.response = await fetch_with_retry(self.client, "HEAD", url, headers)
110
+ except tenacity.RetryError as e:
111
+ raise e.last_attempt.result()
112
+ return self.response
113
+
114
+ async def get_text(self, url: str, headers: dict):
115
+ """
116
+ Sends a GET request to a URL and returns the response text.
117
+
118
+ Args:
119
+ url (str): The URL to send the GET request to.
120
+ headers (dict): The headers to include in the request.
121
+
122
+ Returns:
123
+ str: The response text.
124
+ """
125
+ try:
126
+ self.response = await fetch_with_retry(self.client, "GET", url, headers)
127
+ except tenacity.RetryError as e:
128
+ raise e.last_attempt.result()
129
+ return self.response.text
130
+
131
+ async def close(self):
132
+ """
133
+ Closes the HTTP client and response.
134
+ """
135
+ if self.response:
136
+ await self.response.aclose()
137
+ await self.client.aclose()
138
+
139
+
140
+ async def download_file_with_retry(url: str, headers: dict, timeout: float = 10.0, verify_ssl: bool = True):
141
+ """
142
+ Downloads a file with retry logic.
143
+
144
+ Args:
145
+ url (str): The URL of the file to download.
146
+ headers (dict): The headers to include in the request.
147
+ timeout (float, optional): The request timeout. Defaults to 10.0.
148
+ verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
149
+
150
+ Returns:
151
+ bytes: The downloaded file content.
152
+
153
+ Raises:
154
+ DownloadError: If the download fails after retries.
155
+ """
156
+ async with httpx.AsyncClient(
157
+ follow_redirects=True, timeout=timeout, proxy=settings.proxy_url, verify=verify_ssl
158
+ ) as client:
159
+ try:
160
+ response = await fetch_with_retry(client, "GET", url, headers)
161
+ return response.content
162
+ except DownloadError as e:
163
+ logger.error(f"Failed to download file: {e}")
164
+ raise e
165
+ except tenacity.RetryError as e:
166
+ raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
167
+
168
+
169
+ async def request_with_retry(method: str, url: str, headers: dict, timeout: float = 10.0, **kwargs):
170
+ """
171
+ Sends an HTTP request with retry logic.
172
+
173
+ Args:
174
+ method (str): The HTTP method to use (e.g., GET, POST).
175
+ url (str): The URL to send the request to.
176
+ headers (dict): The headers to include in the request.
177
+ timeout (float, optional): The request timeout. Defaults to 10.0.
178
+ **kwargs: Additional arguments to pass to the request.
179
+
180
+ Returns:
181
+ httpx.Response: The HTTP response.
182
+
183
+ Raises:
184
+ DownloadError: If the request fails after retries.
185
+ """
186
+ async with httpx.AsyncClient(follow_redirects=True, timeout=timeout, proxy=settings.proxy_url) as client:
187
+ try:
188
+ response = await fetch_with_retry(client, method, url, headers, **kwargs)
189
+ return response
190
+ except DownloadError as e:
191
+ logger.error(f"Failed to download file: {e}")
192
+ raise
193
+
194
+
195
+ def encode_mediaflow_proxy_url(
196
+ mediaflow_proxy_url: str,
197
+ endpoint: str | None = None,
198
+ destination_url: str | None = None,
199
+ query_params: dict | None = None,
200
+ request_headers: dict | None = None,
201
+ ) -> str:
202
+ """
203
+ Encodes a MediaFlow proxy URL with query parameters and headers.
204
+
205
+ Args:
206
+ mediaflow_proxy_url (str): The base MediaFlow proxy URL.
207
+ endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
208
+ destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
209
+ query_params (dict, optional): Additional query parameters to include. Defaults to None.
210
+ request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
211
+
212
+ Returns:
213
+ str: The encoded MediaFlow proxy URL.
214
+ """
215
+ query_params = query_params or {}
216
+ if destination_url is not None:
217
+ query_params["d"] = destination_url
218
+
219
+ # Add headers if provided
220
+ if request_headers:
221
+ query_params.update(
222
+ {key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
223
+ )
224
+ # Encode the query parameters
225
+ encoded_params = parse.urlencode(query_params, quote_via=parse.quote)
226
+
227
+ # Construct the full URL
228
+ if endpoint is None:
229
+ return f"{mediaflow_proxy_url}?{encoded_params}"
230
+
231
+ base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
232
+ return f"{base_url}?{encoded_params}"
233
+
234
+
235
+ def get_original_scheme(request: Request) -> str:
236
+ """
237
+ Determines the original scheme (http or https) of the request.
238
+
239
+ Args:
240
+ request (Request): The incoming HTTP request.
241
+
242
+ Returns:
243
+ str: The original scheme ('http' or 'https')
244
+ """
245
+ # Check the X-Forwarded-Proto header first
246
+ forwarded_proto = request.headers.get("X-Forwarded-Proto")
247
+ if forwarded_proto:
248
+ return forwarded_proto
249
+
250
+ # Check if the request is secure
251
+ if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on":
252
+ return "https"
253
+
254
+ # Check for other common headers that might indicate HTTPS
255
+ if (
256
+ request.headers.get("X-Forwarded-Ssl") == "on"
257
+ or request.headers.get("X-Forwarded-Protocol") == "https"
258
+ or request.headers.get("X-Url-Scheme") == "https"
259
+ ):
260
+ return "https"
261
+
262
+ # Default to http if no indicators of https are found
263
+ return "http"
264
+
265
+
266
+ def get_proxy_headers(request: Request) -> dict:
267
+ """
268
+ Extracts proxy headers from the request query parameters.
269
+
270
+ Args:
271
+ request (Request): The incoming HTTP request.
272
+
273
+ Returns:
274
+ dict: A dictionary of proxy headers.
275
+ """
276
+ request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
277
+ request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
278
+ return request_headers
279
+
280
+
281
+ class EnhancedStreamingResponse(Response):
282
+ body_iterator: typing.AsyncIterable[typing.Any]
283
+
284
+ def __init__(
285
+ self,
286
+ content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]],
287
+ status_code: int = 200,
288
+ headers: typing.Optional[typing.Mapping[str, str]] = None,
289
+ media_type: typing.Optional[str] = None,
290
+ background: typing.Optional[BackgroundTask] = None,
291
+ ) -> None:
292
+ if isinstance(content, typing.AsyncIterable):
293
+ self.body_iterator = content
294
+ else:
295
+ self.body_iterator = iterate_in_threadpool(content)
296
+ self.status_code = status_code
297
+ self.media_type = self.media_type if media_type is None else media_type
298
+ self.background = background
299
+ self.init_headers(headers)
300
+
301
+ @staticmethod
302
+ async def listen_for_disconnect(receive: Receive) -> None:
303
+ try:
304
+ while True:
305
+ message = await receive()
306
+ if message["type"] == "http.disconnect":
307
+ logger.debug("Client disconnected")
308
+ break
309
+ except Exception as e:
310
+ logger.error(f"Error in listen_for_disconnect: {str(e)}")
311
+
312
+ async def stream_response(self, send: Send) -> None:
313
+ try:
314
+ await send(
315
+ {
316
+ "type": "http.response.start",
317
+ "status": self.status_code,
318
+ "headers": self.raw_headers,
319
+ }
320
+ )
321
+ async for chunk in self.body_iterator:
322
+ if not isinstance(chunk, (bytes, memoryview)):
323
+ chunk = chunk.encode(self.charset)
324
+ try:
325
+ await send({"type": "http.response.body", "body": chunk, "more_body": True})
326
+ except (ConnectionResetError, anyio.BrokenResourceError):
327
+ logger.info("Client disconnected during streaming")
328
+ return
329
+
330
+ await send({"type": "http.response.body", "body": b"", "more_body": False})
331
+ except Exception as e:
332
+ logger.error(f"Error in stream_response: {str(e)}")
333
+
334
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
335
+ async with anyio.create_task_group() as task_group:
336
+
337
+ async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
338
+ try:
339
+ await func()
340
+ except ExceptionGroup as e:
341
+ if not any(isinstance(exc, anyio.get_cancelled_exc_class()) for exc in e.exceptions):
342
+ logger.exception("Error in streaming task")
343
+ raise
344
+ except Exception as e:
345
+ if not isinstance(e, anyio.get_cancelled_exc_class()):
346
+ logger.exception("Error in streaming task")
347
+ raise
348
+ finally:
349
+ task_group.cancel_scope.cancel()
350
+
351
+ task_group.start_soon(wrap, partial(self.stream_response, send))
352
+ await wrap(partial(self.listen_for_disconnect, receive))
353
+
354
+ if self.background is not None:
355
+ await self.background()
mediaflow_proxy/utils/m3u8_processor.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from urllib import parse
3
+
4
+ from pydantic import HttpUrl
5
+
6
+ from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
7
+
8
+
9
+ class M3U8Processor:
10
+ def __init__(self, request, key_url: HttpUrl = None):
11
+ """
12
+ Initializes the M3U8Processor with the request and URL prefix.
13
+
14
+ Args:
15
+ request (Request): The incoming HTTP request.
16
+ key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
17
+ """
18
+ self.request = request
19
+ self.key_url = key_url
20
+ self.mediaflow_proxy_url = str(request.url_for("hls_stream_proxy").replace(scheme=get_original_scheme(request)))
21
+
22
+ async def process_m3u8(self, content: str, base_url: str) -> str:
23
+ """
24
+ Processes the m3u8 content, proxying URLs and handling key lines.
25
+
26
+ Args:
27
+ content (str): The m3u8 content to process.
28
+ base_url (str): The base URL to resolve relative URLs.
29
+
30
+ Returns:
31
+ str: The processed m3u8 content.
32
+ """
33
+ lines = content.splitlines()
34
+ processed_lines = []
35
+ for line in lines:
36
+ if "URI=" in line:
37
+ processed_lines.append(await self.process_key_line(line, base_url))
38
+ elif not line.startswith("#") and line.strip():
39
+ processed_lines.append(await self.proxy_url(line, base_url))
40
+ else:
41
+ processed_lines.append(line)
42
+ return "\n".join(processed_lines)
43
+
44
+ async def process_key_line(self, line: str, base_url: str) -> str:
45
+ """
46
+ Processes a key line in the m3u8 content, proxying the URI.
47
+
48
+ Args:
49
+ line (str): The key line to process.
50
+ base_url (str): The base URL to resolve relative URLs.
51
+
52
+ Returns:
53
+ str: The processed key line.
54
+ """
55
+ uri_match = re.search(r'URI="([^"]+)"', line)
56
+ if uri_match:
57
+ original_uri = uri_match.group(1)
58
+ uri = parse.urlparse(original_uri)
59
+ if self.key_url:
60
+ uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.host)
61
+ new_uri = await self.proxy_url(uri.geturl(), base_url)
62
+ line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
63
+ return line
64
+
65
+ async def proxy_url(self, url: str, base_url: str) -> str:
66
+ """
67
+ Proxies a URL, encoding it with the MediaFlow proxy URL.
68
+
69
+ Args:
70
+ url (str): The URL to proxy.
71
+ base_url (str): The base URL to resolve relative URLs.
72
+
73
+ Returns:
74
+ str: The proxied URL.
75
+ """
76
+ full_url = parse.urljoin(base_url, url)
77
+
78
+ return encode_mediaflow_proxy_url(
79
+ self.mediaflow_proxy_url,
80
+ "",
81
+ full_url,
82
+ query_params=dict(self.request.query_params),
83
+ )
mediaflow_proxy/utils/mpd_utils.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import re
4
+ from datetime import datetime, timedelta, timezone
5
+ from typing import List, Dict
6
+ from urllib.parse import urljoin
7
+
8
+ import xmltodict
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def parse_mpd(mpd_content: str | bytes) -> dict:
14
+ """
15
+ Parses the MPD content into a dictionary.
16
+
17
+ Args:
18
+ mpd_content (str | bytes): The MPD content to parse.
19
+
20
+ Returns:
21
+ dict: The parsed MPD content as a dictionary.
22
+ """
23
+ return xmltodict.parse(mpd_content)
24
+
25
+
26
+ def parse_mpd_dict(
27
+ mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: str | None = None
28
+ ) -> dict:
29
+ """
30
+ Parses the MPD dictionary and extracts relevant information.
31
+
32
+ Args:
33
+ mpd_dict (dict): The MPD content as a dictionary.
34
+ mpd_url (str): The URL of the MPD manifest.
35
+ parse_drm (bool, optional): Whether to parse DRM information. Defaults to True.
36
+ parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
37
+
38
+ Returns:
39
+ dict: The parsed MPD information including profiles and DRM info.
40
+
41
+ This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data.
42
+ It handles both live and static MPD manifests.
43
+ """
44
+ profiles = []
45
+ parsed_dict = {}
46
+ source = "/".join(mpd_url.split("/")[:-1])
47
+
48
+ is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
49
+ parsed_dict["isLive"] = is_live
50
+
51
+ media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration")
52
+
53
+ # Parse additional MPD attributes for live streams
54
+ if is_live:
55
+ parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S"))
56
+ parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M"))
57
+ parsed_dict["availabilityStartTime"] = datetime.fromisoformat(
58
+ mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00")
59
+ )
60
+ parsed_dict["publishTime"] = datetime.fromisoformat(
61
+ mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00")
62
+ )
63
+
64
+ periods = mpd_dict["MPD"]["Period"]
65
+ periods = periods if isinstance(periods, list) else [periods]
66
+
67
+ for period in periods:
68
+ parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
69
+ for adaptation in period["AdaptationSet"]:
70
+ representations = adaptation["Representation"]
71
+ representations = representations if isinstance(representations, list) else [representations]
72
+
73
+ for representation in representations:
74
+ profile = parse_representation(
75
+ parsed_dict,
76
+ representation,
77
+ adaptation,
78
+ source,
79
+ media_presentation_duration,
80
+ parse_segment_profile_id,
81
+ )
82
+ if profile:
83
+ profiles.append(profile)
84
+ parsed_dict["profiles"] = profiles
85
+
86
+ if parse_drm:
87
+ drm_info = extract_drm_info(periods, mpd_url)
88
+ else:
89
+ drm_info = {}
90
+ parsed_dict["drmInfo"] = drm_info
91
+
92
+ return parsed_dict
93
+
94
+
95
+ def pad_base64(encoded_key_id):
96
+ """
97
+ Pads a base64 encoded key ID to make its length a multiple of 4.
98
+
99
+ Args:
100
+ encoded_key_id (str): The base64 encoded key ID.
101
+
102
+ Returns:
103
+ str: The padded base64 encoded key ID.
104
+ """
105
+ return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4)
106
+
107
+
108
+ def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict:
109
+ """
110
+ Extracts DRM information from the MPD periods.
111
+
112
+ Args:
113
+ periods (List[Dict]): The list of periods in the MPD.
114
+ mpd_url (str): The URL of the MPD manifest.
115
+
116
+ Returns:
117
+ Dict: The extracted DRM information.
118
+
119
+ This function processes the ContentProtection elements in the MPD to extract DRM system information,
120
+ such as ClearKey, Widevine, and PlayReady.
121
+ """
122
+ drm_info = {"isDrmProtected": False}
123
+
124
+ for period in periods:
125
+ adaptation_sets: list[dict] | dict = period.get("AdaptationSet", [])
126
+ if not isinstance(adaptation_sets, list):
127
+ adaptation_sets = [adaptation_sets]
128
+
129
+ for adaptation_set in adaptation_sets:
130
+ # Check ContentProtection in AdaptationSet
131
+ process_content_protection(adaptation_set.get("ContentProtection", []), drm_info)
132
+
133
+ # Check ContentProtection inside each Representation
134
+ representations: list[dict] | dict = adaptation_set.get("Representation", [])
135
+ if not isinstance(representations, list):
136
+ representations = [representations]
137
+
138
+ for representation in representations:
139
+ process_content_protection(representation.get("ContentProtection", []), drm_info)
140
+
141
+ # If we have a license acquisition URL, make sure it's absolute
142
+ if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")):
143
+ drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"])
144
+
145
+ return drm_info
146
+
147
+
148
+ def process_content_protection(content_protection: list[dict] | dict, drm_info: dict):
149
+ """
150
+ Processes the ContentProtection elements to extract DRM information.
151
+
152
+ Args:
153
+ content_protection (list[dict] | dict): The ContentProtection elements.
154
+ drm_info (dict): The dictionary to store DRM information.
155
+
156
+ This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements.
157
+ """
158
+ if not isinstance(content_protection, list):
159
+ content_protection = [content_protection]
160
+
161
+ for protection in content_protection:
162
+ drm_info["isDrmProtected"] = True
163
+ scheme_id_uri = protection.get("@schemeIdUri", "").lower()
164
+
165
+ if "clearkey" in scheme_id_uri:
166
+ drm_info["drmSystem"] = "clearkey"
167
+ if "clearkey:Laurl" in protection:
168
+ la_url = protection["clearkey:Laurl"].get("#text")
169
+ if la_url and "laUrl" not in drm_info:
170
+ drm_info["laUrl"] = la_url
171
+
172
+ elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri:
173
+ drm_info["drmSystem"] = "widevine"
174
+ pssh = protection.get("cenc:pssh", {}).get("#text")
175
+ if pssh:
176
+ drm_info["pssh"] = pssh
177
+
178
+ elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri:
179
+ drm_info["drmSystem"] = "playready"
180
+
181
+ if "@cenc:default_KID" in protection:
182
+ key_id = protection["@cenc:default_KID"].replace("-", "")
183
+ if "keyId" not in drm_info:
184
+ drm_info["keyId"] = key_id
185
+
186
+ if "ms:laurl" in protection:
187
+ la_url = protection["ms:laurl"].get("@licenseUrl")
188
+ if la_url and "laUrl" not in drm_info:
189
+ drm_info["laUrl"] = la_url
190
+
191
+ return drm_info
192
+
193
+
194
+ def parse_representation(
195
+ parsed_dict: dict,
196
+ representation: dict,
197
+ adaptation: dict,
198
+ source: str,
199
+ media_presentation_duration: str,
200
+ parse_segment_profile_id: str | None,
201
+ ) -> dict | None:
202
+ """
203
+ Parses a representation and extracts profile information.
204
+
205
+ Args:
206
+ parsed_dict (dict): The parsed MPD data.
207
+ representation (dict): The representation data.
208
+ adaptation (dict): The adaptation set data.
209
+ source (str): The source URL.
210
+ media_presentation_duration (str): The media presentation duration.
211
+ parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
212
+
213
+ Returns:
214
+ dict | None: The parsed profile information or None if not applicable.
215
+ """
216
+ mime_type = _get_key(adaptation, representation, "@mimeType") or (
217
+ "video/mp4" if "avc" in representation["@codecs"] else "audio/mp4"
218
+ )
219
+ if "video" not in mime_type and "audio" not in mime_type:
220
+ return None
221
+
222
+ profile = {
223
+ "id": representation.get("@id") or adaptation.get("@id"),
224
+ "mimeType": mime_type,
225
+ "lang": representation.get("@lang") or adaptation.get("@lang"),
226
+ "codecs": representation.get("@codecs") or adaptation.get("@codecs"),
227
+ "bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
228
+ "startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
229
+ "mediaPresentationDuration": media_presentation_duration,
230
+ }
231
+
232
+ if "audio" in profile["mimeType"]:
233
+ profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate")
234
+ profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2")
235
+ else:
236
+ profile["width"] = int(representation["@width"])
237
+ profile["height"] = int(representation["@height"])
238
+ frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001"
239
+ frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1"
240
+ profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3)
241
+ profile["sar"] = representation.get("@sar", "1:1")
242
+
243
+ if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
244
+ return profile
245
+
246
+ item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
247
+ if item:
248
+ profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
249
+ else:
250
+ profile["segments"] = parse_segment_base(representation, source)
251
+
252
+ return profile
253
+
254
+
255
+ def _get_key(adaptation: dict, representation: dict, key: str) -> str | None:
256
+ """
257
+ Retrieves a key from the representation or adaptation set.
258
+
259
+ Args:
260
+ adaptation (dict): The adaptation set data.
261
+ representation (dict): The representation data.
262
+ key (str): The key to retrieve.
263
+
264
+ Returns:
265
+ str | None: The value of the key or None if not found.
266
+ """
267
+ return representation.get(key, adaptation.get(key, None))
268
+
269
+
270
+ def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
271
+ """
272
+ Parses a segment template and extracts segment information.
273
+
274
+ Args:
275
+ parsed_dict (dict): The parsed MPD data.
276
+ item (dict): The segment template data.
277
+ profile (dict): The profile information.
278
+ source (str): The source URL.
279
+
280
+ Returns:
281
+ List[Dict]: The list of parsed segments.
282
+ """
283
+ segments = []
284
+ timescale = int(item.get("@timescale", 1))
285
+
286
+ # Initialization
287
+ if "@initialization" in item:
288
+ media = item["@initialization"]
289
+ media = media.replace("$RepresentationID$", profile["id"])
290
+ media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
291
+ if not media.startswith("http"):
292
+ media = f"{source}/{media}"
293
+ profile["initUrl"] = media
294
+
295
+ # Segments
296
+ if "SegmentTimeline" in item:
297
+ segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
298
+ elif "@duration" in item:
299
+ segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
300
+
301
+ return segments
302
+
303
+
304
+ def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
305
+ """
306
+ Parses a segment timeline and extracts segment information.
307
+
308
+ Args:
309
+ parsed_dict (dict): The parsed MPD data.
310
+ item (dict): The segment timeline data.
311
+ profile (dict): The profile information.
312
+ source (str): The source URL.
313
+ timescale (int): The timescale for the segments.
314
+
315
+ Returns:
316
+ List[Dict]: The list of parsed segments.
317
+ """
318
+ timelines = item["SegmentTimeline"]["S"]
319
+ timelines = timelines if isinstance(timelines, list) else [timelines]
320
+ period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
321
+ presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
322
+ start_number = int(item.get("@startNumber", 1))
323
+
324
+ segments = [
325
+ create_segment_data(timeline, item, profile, source, timescale)
326
+ for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
327
+ ]
328
+ return segments
329
+
330
+
331
+ def preprocess_timeline(
332
+ timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
333
+ ) -> List[Dict]:
334
+ """
335
+ Preprocesses the segment timeline data.
336
+
337
+ Args:
338
+ timelines (List[Dict]): The list of timeline segments.
339
+ start_number (int): The starting segment number.
340
+ period_start (datetime): The start time of the period.
341
+ presentation_time_offset (int): The presentation time offset.
342
+ timescale (int): The timescale for the segments.
343
+
344
+ Returns:
345
+ List[Dict]: The list of preprocessed timeline segments.
346
+ """
347
+ processed_data = []
348
+ current_time = 0
349
+ for timeline in timelines:
350
+ repeat = int(timeline.get("@r", 0))
351
+ duration = int(timeline["@d"])
352
+ start_time = int(timeline.get("@t", current_time))
353
+
354
+ for _ in range(repeat + 1):
355
+ segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
356
+ segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
357
+ processed_data.append(
358
+ {
359
+ "number": start_number,
360
+ "start_time": segment_start_time,
361
+ "end_time": segment_end_time,
362
+ "duration": duration,
363
+ "time": start_time,
364
+ }
365
+ )
366
+ start_time += duration
367
+ start_number += 1
368
+
369
+ current_time = start_time
370
+
371
+ return processed_data
372
+
373
+
374
+ def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
375
+ """
376
+ Parses segment duration and extracts segment information.
377
+ This is used for static or live MPD manifests.
378
+
379
+ Args:
380
+ parsed_dict (dict): The parsed MPD data.
381
+ item (dict): The segment duration data.
382
+ profile (dict): The profile information.
383
+ source (str): The source URL.
384
+ timescale (int): The timescale for the segments.
385
+
386
+ Returns:
387
+ List[Dict]: The list of parsed segments.
388
+ """
389
+ duration = int(item["@duration"])
390
+ start_number = int(item.get("@startNumber", 1))
391
+ segment_duration_sec = duration / timescale
392
+
393
+ if parsed_dict["isLive"]:
394
+ segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
395
+ else:
396
+ segments = generate_vod_segments(profile, duration, timescale, start_number)
397
+
398
+ return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
399
+
400
+
401
+ def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
402
+ """
403
+ Generates live segments based on the segment duration and start number.
404
+ This is used for live MPD manifests.
405
+
406
+ Args:
407
+ parsed_dict (dict): The parsed MPD data.
408
+ segment_duration_sec (float): The segment duration in seconds.
409
+ start_number (int): The starting segment number.
410
+
411
+ Returns:
412
+ List[Dict]: The list of generated live segments.
413
+ """
414
+ time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60))
415
+ segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec)
416
+ current_time = datetime.now(tz=timezone.utc)
417
+ earliest_segment_number = max(
418
+ start_number
419
+ + math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec)
420
+ - segment_count,
421
+ start_number,
422
+ )
423
+
424
+ return [
425
+ {
426
+ "number": number,
427
+ "start_time": parsed_dict["availabilityStartTime"]
428
+ + timedelta(seconds=(number - start_number) * segment_duration_sec),
429
+ "duration": segment_duration_sec,
430
+ }
431
+ for number in range(earliest_segment_number, earliest_segment_number + segment_count)
432
+ ]
433
+
434
+
435
+ def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
436
+ """
437
+ Generates VOD segments based on the segment duration and start number.
438
+ This is used for static MPD manifests.
439
+
440
+ Args:
441
+ profile (dict): The profile information.
442
+ duration (int): The segment duration.
443
+ timescale (int): The timescale for the segments.
444
+ start_number (int): The starting segment number.
445
+
446
+ Returns:
447
+ List[Dict]: The list of generated VOD segments.
448
+ """
449
+ total_duration = profile.get("mediaPresentationDuration") or 0
450
+ if isinstance(total_duration, str):
451
+ total_duration = parse_duration(total_duration)
452
+ segment_count = math.ceil(total_duration * timescale / duration)
453
+
454
+ return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
455
+
456
+
457
+ def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: int | None = None) -> Dict:
458
+ """
459
+ Creates segment data based on the segment information. This includes the segment URL and metadata.
460
+
461
+ Args:
462
+ segment (Dict): The segment information.
463
+ item (dict): The segment template data.
464
+ profile (dict): The profile information.
465
+ source (str): The source URL.
466
+ timescale (int, optional): The timescale for the segments. Defaults to None.
467
+
468
+ Returns:
469
+ Dict: The created segment data.
470
+ """
471
+ media_template = item["@media"]
472
+ media = media_template.replace("$RepresentationID$", profile["id"])
473
+ media = media.replace("$Number%04d$", f"{segment['number']:04d}")
474
+ media = media.replace("$Number$", str(segment["number"]))
475
+ media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
476
+
477
+ if "time" in segment and timescale is not None:
478
+ media = media.replace("$Time$", str(int(segment["time"] * timescale)))
479
+
480
+ if not media.startswith("http"):
481
+ media = f"{source}/{media}"
482
+
483
+ segment_data = {
484
+ "type": "segment",
485
+ "media": media,
486
+ "number": segment["number"],
487
+ }
488
+
489
+ if "start_time" in segment and "end_time" in segment:
490
+ segment_data.update(
491
+ {
492
+ "start_time": segment["start_time"],
493
+ "end_time": segment["end_time"],
494
+ "extinf": (segment["end_time"] - segment["start_time"]).total_seconds(),
495
+ "program_date_time": segment["start_time"].isoformat() + "Z",
496
+ }
497
+ )
498
+ elif "start_time" in segment and "duration" in segment:
499
+ duration = segment["duration"]
500
+ segment_data.update(
501
+ {
502
+ "start_time": segment["start_time"],
503
+ "end_time": segment["start_time"] + timedelta(seconds=duration),
504
+ "extinf": duration,
505
+ "program_date_time": segment["start_time"].isoformat() + "Z",
506
+ }
507
+ )
508
+ elif "duration" in segment:
509
+ segment_data["extinf"] = segment["duration"]
510
+
511
+ return segment_data
512
+
513
+
514
+ def parse_segment_base(representation: dict, source: str) -> List[Dict]:
515
+ """
516
+ Parses segment base information and extracts segment data. This is used for single-segment representations.
517
+
518
+ Args:
519
+ representation (dict): The representation data.
520
+ source (str): The source URL.
521
+
522
+ Returns:
523
+ List[Dict]: The list of parsed segments.
524
+ """
525
+ segment = representation["SegmentBase"]
526
+ start, end = map(int, segment["@indexRange"].split("-"))
527
+ if "Initialization" in segment:
528
+ start, _ = map(int, segment["Initialization"]["@range"].split("-"))
529
+
530
+ return [
531
+ {
532
+ "type": "segment",
533
+ "range": f"{start}-{end}",
534
+ "media": f"{source}/{representation['BaseURL']}",
535
+ }
536
+ ]
537
+
538
+
539
+ def parse_duration(duration_str: str) -> float:
540
+ """
541
+ Parses a duration ISO 8601 string into seconds.
542
+
543
+ Args:
544
+ duration_str (str): The duration string to parse.
545
+
546
+ Returns:
547
+ float: The parsed duration in seconds.
548
+ """
549
+ pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?")
550
+ match = pattern.match(duration_str)
551
+ if not match:
552
+ raise ValueError(f"Invalid duration format: {duration_str}")
553
+
554
+ years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()]
555
+ return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds