Spaces:
Paused
Paused
Upload 16 files
Browse files- mediaflow_proxy/__init__.py +0 -0
- mediaflow_proxy/configs.py +14 -0
- mediaflow_proxy/const.py +24 -0
- mediaflow_proxy/drm/__init__.py +11 -0
- mediaflow_proxy/drm/decrypter.py +778 -0
- mediaflow_proxy/handlers.py +345 -0
- mediaflow_proxy/main.py +58 -0
- mediaflow_proxy/mpd_processor.py +210 -0
- mediaflow_proxy/routes.py +147 -0
- mediaflow_proxy/static/index.html +76 -0
- mediaflow_proxy/static/logo.png +0 -0
- mediaflow_proxy/utils/__init__.py +0 -0
- mediaflow_proxy/utils/cache_utils.py +60 -0
- mediaflow_proxy/utils/http_utils.py +355 -0
- mediaflow_proxy/utils/m3u8_processor.py +83 -0
- mediaflow_proxy/utils/mpd_utils.py +555 -0
mediaflow_proxy/__init__.py
ADDED
|
File without changes
|
mediaflow_proxy/configs.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
api_password: str # The password for accessing the API endpoints.
|
| 6 |
+
proxy_url: str | None = None # The URL of the proxy server to route requests through.
|
| 7 |
+
mpd_live_stream_delay: int = 30 # The delay in seconds for live MPD streams.
|
| 8 |
+
|
| 9 |
+
class Config:
|
| 10 |
+
env_file = ".env"
|
| 11 |
+
extra = "ignore"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
settings = Settings()
|
mediaflow_proxy/const.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
SUPPORTED_RESPONSE_HEADERS = [
|
| 2 |
+
"accept-ranges",
|
| 3 |
+
"content-type",
|
| 4 |
+
"content-length",
|
| 5 |
+
"content-range",
|
| 6 |
+
"connection",
|
| 7 |
+
"transfer-encoding",
|
| 8 |
+
"last-modified",
|
| 9 |
+
"etag",
|
| 10 |
+
"cache-control",
|
| 11 |
+
"expires",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
SUPPORTED_REQUEST_HEADERS = [
|
| 15 |
+
"accept",
|
| 16 |
+
"accept-encoding",
|
| 17 |
+
"accept-language",
|
| 18 |
+
"connection",
|
| 19 |
+
"range",
|
| 20 |
+
"if-range",
|
| 21 |
+
"user-agent",
|
| 22 |
+
"referer",
|
| 23 |
+
"origin",
|
| 24 |
+
]
|
mediaflow_proxy/drm/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
async def create_temp_file(suffix: str, content: bytes = None, prefix: str = None) -> tempfile.NamedTemporaryFile:
|
| 6 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix, prefix=prefix)
|
| 7 |
+
temp_file.delete_file = lambda: os.unlink(temp_file.name)
|
| 8 |
+
if content:
|
| 9 |
+
temp_file.write(content)
|
| 10 |
+
temp_file.close()
|
| 11 |
+
return temp_file
|
mediaflow_proxy/drm/decrypter.py
ADDED
|
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import struct
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from Crypto.Cipher import AES
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
import array
|
| 8 |
+
|
| 9 |
+
CENCSampleAuxiliaryDataFormat = namedtuple("CENCSampleAuxiliaryDataFormat", ["is_encrypted", "iv", "sub_samples"])
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MP4Atom:
|
| 13 |
+
"""
|
| 14 |
+
Represents an MP4 atom, which is a basic unit of data in an MP4 file.
|
| 15 |
+
Each atom contains a header (size and type) and data.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
__slots__ = ("atom_type", "size", "data")
|
| 19 |
+
|
| 20 |
+
def __init__(self, atom_type: bytes, size: int, data: memoryview | bytearray):
|
| 21 |
+
"""
|
| 22 |
+
Initializes an MP4Atom instance.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
atom_type (bytes): The type of the atom.
|
| 26 |
+
size (int): The size of the atom.
|
| 27 |
+
data (memoryview | bytearray): The data contained in the atom.
|
| 28 |
+
"""
|
| 29 |
+
self.atom_type = atom_type
|
| 30 |
+
self.size = size
|
| 31 |
+
self.data = data
|
| 32 |
+
|
| 33 |
+
def __repr__(self):
|
| 34 |
+
return f"<MP4Atom type={self.atom_type}, size={self.size}>"
|
| 35 |
+
|
| 36 |
+
def pack(self):
|
| 37 |
+
"""
|
| 38 |
+
Packs the atom into binary data.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
bytes: Packed binary data with size, type, and data.
|
| 42 |
+
"""
|
| 43 |
+
return struct.pack(">I", self.size) + self.atom_type + self.data
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MP4Parser:
|
| 47 |
+
"""
|
| 48 |
+
Parses MP4 data to extract atoms and their structure.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, data: memoryview):
|
| 52 |
+
"""
|
| 53 |
+
Initializes an MP4Parser instance.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
data (memoryview): The binary data of the MP4 file.
|
| 57 |
+
"""
|
| 58 |
+
self.data = data
|
| 59 |
+
self.position = 0
|
| 60 |
+
|
| 61 |
+
def read_atom(self) -> MP4Atom | None:
|
| 62 |
+
"""
|
| 63 |
+
Reads the next atom from the data.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
MP4Atom | None: MP4Atom object or None if no more atoms are available.
|
| 67 |
+
"""
|
| 68 |
+
pos = self.position
|
| 69 |
+
if pos + 8 > len(self.data):
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
size, atom_type = struct.unpack_from(">I4s", self.data, pos)
|
| 73 |
+
pos += 8
|
| 74 |
+
|
| 75 |
+
if size == 1:
|
| 76 |
+
if pos + 8 > len(self.data):
|
| 77 |
+
return None
|
| 78 |
+
size = struct.unpack_from(">Q", self.data, pos)[0]
|
| 79 |
+
pos += 8
|
| 80 |
+
|
| 81 |
+
if size < 8 or pos + size - 8 > len(self.data):
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
atom_data = self.data[pos : pos + size - 8]
|
| 85 |
+
self.position = pos + size - 8
|
| 86 |
+
return MP4Atom(atom_type, size, atom_data)
|
| 87 |
+
|
| 88 |
+
def list_atoms(self) -> list[MP4Atom]:
|
| 89 |
+
"""
|
| 90 |
+
Lists all atoms in the data.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
list[MP4Atom]: List of MP4Atom objects.
|
| 94 |
+
"""
|
| 95 |
+
atoms = []
|
| 96 |
+
original_position = self.position
|
| 97 |
+
self.position = 0
|
| 98 |
+
while self.position + 8 <= len(self.data):
|
| 99 |
+
atom = self.read_atom()
|
| 100 |
+
if not atom:
|
| 101 |
+
break
|
| 102 |
+
atoms.append(atom)
|
| 103 |
+
self.position = original_position
|
| 104 |
+
return atoms
|
| 105 |
+
|
| 106 |
+
def _read_atom_at(self, pos: int, end: int) -> MP4Atom | None:
|
| 107 |
+
if pos + 8 > end:
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
size, atom_type = struct.unpack_from(">I4s", self.data, pos)
|
| 111 |
+
pos += 8
|
| 112 |
+
|
| 113 |
+
if size == 1:
|
| 114 |
+
if pos + 8 > end:
|
| 115 |
+
return None
|
| 116 |
+
size = struct.unpack_from(">Q", self.data, pos)[0]
|
| 117 |
+
pos += 8
|
| 118 |
+
|
| 119 |
+
if size < 8 or pos + size - 8 > end:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
atom_data = self.data[pos : pos + size - 8]
|
| 123 |
+
return MP4Atom(atom_type, size, atom_data)
|
| 124 |
+
|
| 125 |
+
def print_atoms_structure(self, indent: int = 0):
|
| 126 |
+
"""
|
| 127 |
+
Prints the structure of all atoms in the data.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
indent (int): The indentation level for printing.
|
| 131 |
+
"""
|
| 132 |
+
pos = 0
|
| 133 |
+
end = len(self.data)
|
| 134 |
+
while pos + 8 <= end:
|
| 135 |
+
atom = self._read_atom_at(pos, end)
|
| 136 |
+
if not atom:
|
| 137 |
+
break
|
| 138 |
+
self.print_single_atom_structure(atom, pos, indent)
|
| 139 |
+
pos += atom.size
|
| 140 |
+
|
| 141 |
+
def print_single_atom_structure(self, atom: MP4Atom, parent_position: int, indent: int):
|
| 142 |
+
"""
|
| 143 |
+
Prints the structure of a single atom.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
atom (MP4Atom): The atom to print.
|
| 147 |
+
parent_position (int): The position of the parent atom.
|
| 148 |
+
indent (int): The indentation level for printing.
|
| 149 |
+
"""
|
| 150 |
+
try:
|
| 151 |
+
atom_type = atom.atom_type.decode("utf-8")
|
| 152 |
+
except UnicodeDecodeError:
|
| 153 |
+
atom_type = repr(atom.atom_type)
|
| 154 |
+
print(" " * indent + f"Type: {atom_type}, Size: {atom.size}")
|
| 155 |
+
|
| 156 |
+
child_pos = 0
|
| 157 |
+
child_end = len(atom.data)
|
| 158 |
+
while child_pos + 8 <= child_end:
|
| 159 |
+
child_atom = self._read_atom_at(parent_position + 8 + child_pos, parent_position + 8 + child_end)
|
| 160 |
+
if not child_atom:
|
| 161 |
+
break
|
| 162 |
+
self.print_single_atom_structure(child_atom, parent_position, indent + 2)
|
| 163 |
+
child_pos += child_atom.size
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class MP4Decrypter:
|
| 167 |
+
"""
|
| 168 |
+
Class to handle the decryption of CENC encrypted MP4 segments.
|
| 169 |
+
|
| 170 |
+
Attributes:
|
| 171 |
+
key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
|
| 172 |
+
current_key (bytes | None): Current decryption key.
|
| 173 |
+
trun_sample_sizes (array.array): Array of sample sizes from the 'trun' box.
|
| 174 |
+
current_sample_info (list): List of sample information from the 'senc' box.
|
| 175 |
+
encryption_overhead (int): Total size of encryption-related boxes.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, key_map: dict[bytes, bytes]):
|
| 179 |
+
"""
|
| 180 |
+
Initializes the MP4Decrypter with a key map.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
key_map (dict[bytes, bytes]): Mapping of track IDs to decryption keys.
|
| 184 |
+
"""
|
| 185 |
+
self.key_map = key_map
|
| 186 |
+
self.current_key = None
|
| 187 |
+
self.trun_sample_sizes = array.array("I")
|
| 188 |
+
self.current_sample_info = []
|
| 189 |
+
self.encryption_overhead = 0
|
| 190 |
+
|
| 191 |
+
def decrypt_segment(self, combined_segment: bytes) -> bytes:
|
| 192 |
+
"""
|
| 193 |
+
Decrypts a combined MP4 segment.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
combined_segment (bytes): Combined initialization and media segment.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
bytes: Decrypted segment content.
|
| 200 |
+
"""
|
| 201 |
+
data = memoryview(combined_segment)
|
| 202 |
+
parser = MP4Parser(data)
|
| 203 |
+
atoms = parser.list_atoms()
|
| 204 |
+
|
| 205 |
+
atom_process_order = [b"moov", b"moof", b"sidx", b"mdat"]
|
| 206 |
+
|
| 207 |
+
processed_atoms = {}
|
| 208 |
+
for atom_type in atom_process_order:
|
| 209 |
+
if atom := next((a for a in atoms if a.atom_type == atom_type), None):
|
| 210 |
+
processed_atoms[atom_type] = self._process_atom(atom_type, atom)
|
| 211 |
+
|
| 212 |
+
result = bytearray()
|
| 213 |
+
for atom in atoms:
|
| 214 |
+
if atom.atom_type in processed_atoms:
|
| 215 |
+
processed_atom = processed_atoms[atom.atom_type]
|
| 216 |
+
result.extend(processed_atom.pack())
|
| 217 |
+
else:
|
| 218 |
+
result.extend(atom.pack())
|
| 219 |
+
|
| 220 |
+
return bytes(result)
|
| 221 |
+
|
| 222 |
+
def _process_atom(self, atom_type: bytes, atom: MP4Atom) -> MP4Atom:
|
| 223 |
+
"""
|
| 224 |
+
Processes an MP4 atom based on its type.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
atom_type (bytes): Type of the atom.
|
| 228 |
+
atom (MP4Atom): The atom to process.
|
| 229 |
+
|
| 230 |
+
Returns:
|
| 231 |
+
MP4Atom: Processed atom.
|
| 232 |
+
"""
|
| 233 |
+
match atom_type:
|
| 234 |
+
case b"moov":
|
| 235 |
+
return self._process_moov(atom)
|
| 236 |
+
case b"moof":
|
| 237 |
+
return self._process_moof(atom)
|
| 238 |
+
case b"sidx":
|
| 239 |
+
return self._process_sidx(atom)
|
| 240 |
+
case b"mdat":
|
| 241 |
+
return self._decrypt_mdat(atom)
|
| 242 |
+
case _:
|
| 243 |
+
return atom
|
| 244 |
+
|
| 245 |
+
def _process_moov(self, moov: MP4Atom) -> MP4Atom:
|
| 246 |
+
"""
|
| 247 |
+
Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
|
| 248 |
+
This includes information about tracks, media data, and other movie-level metadata.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
moov (MP4Atom): The 'moov' atom to process.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
MP4Atom: Processed 'moov' atom with updated track information.
|
| 255 |
+
"""
|
| 256 |
+
parser = MP4Parser(moov.data)
|
| 257 |
+
new_moov_data = bytearray()
|
| 258 |
+
|
| 259 |
+
for atom in iter(parser.read_atom, None):
|
| 260 |
+
if atom.atom_type == b"trak":
|
| 261 |
+
new_trak = self._process_trak(atom)
|
| 262 |
+
new_moov_data.extend(new_trak.pack())
|
| 263 |
+
elif atom.atom_type != b"pssh":
|
| 264 |
+
# Skip PSSH boxes as they are not needed in the decrypted output
|
| 265 |
+
new_moov_data.extend(atom.pack())
|
| 266 |
+
|
| 267 |
+
return MP4Atom(b"moov", len(new_moov_data) + 8, new_moov_data)
|
| 268 |
+
|
| 269 |
+
def _process_moof(self, moof: MP4Atom) -> MP4Atom:
|
| 270 |
+
"""
|
| 271 |
+
Processes the 'moov' (Movie) atom, which contains metadata about the entire presentation.
|
| 272 |
+
This includes information about tracks, media data, and other movie-level metadata.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
moov (MP4Atom): The 'moov' atom to process.
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
MP4Atom: Processed 'moov' atom with updated track information.
|
| 279 |
+
"""
|
| 280 |
+
parser = MP4Parser(moof.data)
|
| 281 |
+
new_moof_data = bytearray()
|
| 282 |
+
|
| 283 |
+
for atom in iter(parser.read_atom, None):
|
| 284 |
+
if atom.atom_type == b"traf":
|
| 285 |
+
new_traf = self._process_traf(atom)
|
| 286 |
+
new_moof_data.extend(new_traf.pack())
|
| 287 |
+
else:
|
| 288 |
+
new_moof_data.extend(atom.pack())
|
| 289 |
+
|
| 290 |
+
return MP4Atom(b"moof", len(new_moof_data) + 8, new_moof_data)
|
| 291 |
+
|
| 292 |
+
def _process_traf(self, traf: MP4Atom) -> MP4Atom:
|
| 293 |
+
"""
|
| 294 |
+
Processes the 'traf' (Track Fragment) atom, which contains information about a track fragment.
|
| 295 |
+
This includes sample information, sample encryption data, and other track-level metadata.
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
traf (MP4Atom): The 'traf' atom to process.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
MP4Atom: Processed 'traf' atom with updated sample information.
|
| 302 |
+
"""
|
| 303 |
+
parser = MP4Parser(traf.data)
|
| 304 |
+
new_traf_data = bytearray()
|
| 305 |
+
tfhd = None
|
| 306 |
+
sample_count = 0
|
| 307 |
+
sample_info = []
|
| 308 |
+
|
| 309 |
+
atoms = parser.list_atoms()
|
| 310 |
+
|
| 311 |
+
# calculate encryption_overhead earlier to avoid dependency on trun
|
| 312 |
+
self.encryption_overhead = sum(a.size for a in atoms if a.atom_type in {b"senc", b"saiz", b"saio"})
|
| 313 |
+
|
| 314 |
+
for atom in atoms:
|
| 315 |
+
if atom.atom_type == b"tfhd":
|
| 316 |
+
tfhd = atom
|
| 317 |
+
new_traf_data.extend(atom.pack())
|
| 318 |
+
elif atom.atom_type == b"trun":
|
| 319 |
+
sample_count = self._process_trun(atom)
|
| 320 |
+
new_trun = self._modify_trun(atom)
|
| 321 |
+
new_traf_data.extend(new_trun.pack())
|
| 322 |
+
elif atom.atom_type == b"senc":
|
| 323 |
+
# Parse senc but don't include it in the new decrypted traf data and similarly don't include saiz and saio
|
| 324 |
+
sample_info = self._parse_senc(atom, sample_count)
|
| 325 |
+
elif atom.atom_type not in {b"saiz", b"saio"}:
|
| 326 |
+
new_traf_data.extend(atom.pack())
|
| 327 |
+
|
| 328 |
+
if tfhd:
|
| 329 |
+
tfhd_track_id = struct.unpack_from(">I", tfhd.data, 4)[0]
|
| 330 |
+
self.current_key = self._get_key_for_track(tfhd_track_id)
|
| 331 |
+
self.current_sample_info = sample_info
|
| 332 |
+
|
| 333 |
+
return MP4Atom(b"traf", len(new_traf_data) + 8, new_traf_data)
|
| 334 |
+
|
| 335 |
+
def _decrypt_mdat(self, mdat: MP4Atom) -> MP4Atom:
|
| 336 |
+
"""
|
| 337 |
+
Decrypts the 'mdat' (Media Data) atom, which contains the actual media data (audio, video, etc.).
|
| 338 |
+
The decryption is performed using the current decryption key and sample information.
|
| 339 |
+
|
| 340 |
+
Args:
|
| 341 |
+
mdat (MP4Atom): The 'mdat' atom to decrypt.
|
| 342 |
+
|
| 343 |
+
Returns:
|
| 344 |
+
MP4Atom: Decrypted 'mdat' atom with decrypted media data.
|
| 345 |
+
"""
|
| 346 |
+
if not self.current_key or not self.current_sample_info:
|
| 347 |
+
return mdat # Return original mdat if we don't have decryption info
|
| 348 |
+
|
| 349 |
+
decrypted_samples = bytearray()
|
| 350 |
+
mdat_data = mdat.data
|
| 351 |
+
position = 0
|
| 352 |
+
|
| 353 |
+
for i, info in enumerate(self.current_sample_info):
|
| 354 |
+
if position >= len(mdat_data):
|
| 355 |
+
break # No more data to process
|
| 356 |
+
|
| 357 |
+
sample_size = self.trun_sample_sizes[i] if i < len(self.trun_sample_sizes) else len(mdat_data) - position
|
| 358 |
+
sample = mdat_data[position : position + sample_size]
|
| 359 |
+
position += sample_size
|
| 360 |
+
decrypted_sample = self._process_sample(sample, info, self.current_key)
|
| 361 |
+
decrypted_samples.extend(decrypted_sample)
|
| 362 |
+
|
| 363 |
+
return MP4Atom(b"mdat", len(decrypted_samples) + 8, decrypted_samples)
|
| 364 |
+
|
| 365 |
+
def _parse_senc(self, senc: MP4Atom, sample_count: int) -> list[CENCSampleAuxiliaryDataFormat]:
|
| 366 |
+
"""
|
| 367 |
+
Parses the 'senc' (Sample Encryption) atom, which contains encryption information for samples.
|
| 368 |
+
This includes initialization vectors (IVs) and sub-sample encryption data.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
senc (MP4Atom): The 'senc' atom to parse.
|
| 372 |
+
sample_count (int): The number of samples.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
list[CENCSampleAuxiliaryDataFormat]: List of sample auxiliary data formats with encryption information.
|
| 376 |
+
"""
|
| 377 |
+
data = memoryview(senc.data)
|
| 378 |
+
version_flags = struct.unpack_from(">I", data, 0)[0]
|
| 379 |
+
version, flags = version_flags >> 24, version_flags & 0xFFFFFF
|
| 380 |
+
position = 4
|
| 381 |
+
|
| 382 |
+
if version == 0:
|
| 383 |
+
sample_count = struct.unpack_from(">I", data, position)[0]
|
| 384 |
+
position += 4
|
| 385 |
+
|
| 386 |
+
sample_info = []
|
| 387 |
+
for _ in range(sample_count):
|
| 388 |
+
if position + 8 > len(data):
|
| 389 |
+
break
|
| 390 |
+
|
| 391 |
+
iv = data[position : position + 8].tobytes()
|
| 392 |
+
position += 8
|
| 393 |
+
|
| 394 |
+
sub_samples = []
|
| 395 |
+
if flags & 0x000002 and position + 2 <= len(data): # Check if subsample information is present
|
| 396 |
+
subsample_count = struct.unpack_from(">H", data, position)[0]
|
| 397 |
+
position += 2
|
| 398 |
+
|
| 399 |
+
for _ in range(subsample_count):
|
| 400 |
+
if position + 6 <= len(data):
|
| 401 |
+
clear_bytes, encrypted_bytes = struct.unpack_from(">HI", data, position)
|
| 402 |
+
position += 6
|
| 403 |
+
sub_samples.append((clear_bytes, encrypted_bytes))
|
| 404 |
+
else:
|
| 405 |
+
break
|
| 406 |
+
|
| 407 |
+
sample_info.append(CENCSampleAuxiliaryDataFormat(True, iv, sub_samples))
|
| 408 |
+
|
| 409 |
+
return sample_info
|
| 410 |
+
|
| 411 |
+
def _get_key_for_track(self, track_id: int) -> bytes:
|
| 412 |
+
"""
|
| 413 |
+
Retrieves the decryption key for a given track ID from the key map.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
track_id (int): The track ID.
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
bytes: The decryption key for the specified track ID.
|
| 420 |
+
"""
|
| 421 |
+
if len(self.key_map) == 1:
|
| 422 |
+
return next(iter(self.key_map.values()))
|
| 423 |
+
key = self.key_map.get(track_id.pack(4, "big"))
|
| 424 |
+
if not key:
|
| 425 |
+
raise ValueError(f"No key found for track ID {track_id}")
|
| 426 |
+
return key
|
| 427 |
+
|
| 428 |
+
@staticmethod
|
| 429 |
+
def _process_sample(
|
| 430 |
+
sample: memoryview, sample_info: CENCSampleAuxiliaryDataFormat, key: bytes
|
| 431 |
+
) -> memoryview | bytearray | bytes:
|
| 432 |
+
"""
|
| 433 |
+
Processes and decrypts a sample using the provided sample information and decryption key.
|
| 434 |
+
This includes handling sub-sample encryption if present.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
sample (memoryview): The sample data.
|
| 438 |
+
sample_info (CENCSampleAuxiliaryDataFormat): The sample auxiliary data format with encryption information.
|
| 439 |
+
key (bytes): The decryption key.
|
| 440 |
+
|
| 441 |
+
Returns:
|
| 442 |
+
memoryview | bytearray | bytes: The decrypted sample.
|
| 443 |
+
"""
|
| 444 |
+
if not sample_info.is_encrypted:
|
| 445 |
+
return sample
|
| 446 |
+
|
| 447 |
+
# pad IV to 16 bytes
|
| 448 |
+
iv = sample_info.iv + b"\x00" * (16 - len(sample_info.iv))
|
| 449 |
+
cipher = AES.new(key, AES.MODE_CTR, initial_value=iv, nonce=b"")
|
| 450 |
+
|
| 451 |
+
if not sample_info.sub_samples:
|
| 452 |
+
# If there are no sub_samples, decrypt the entire sample
|
| 453 |
+
return cipher.decrypt(sample)
|
| 454 |
+
|
| 455 |
+
result = bytearray()
|
| 456 |
+
offset = 0
|
| 457 |
+
for clear_bytes, encrypted_bytes in sample_info.sub_samples:
|
| 458 |
+
result.extend(sample[offset : offset + clear_bytes])
|
| 459 |
+
offset += clear_bytes
|
| 460 |
+
result.extend(cipher.decrypt(sample[offset : offset + encrypted_bytes]))
|
| 461 |
+
offset += encrypted_bytes
|
| 462 |
+
|
| 463 |
+
# If there's any remaining data, treat it as encrypted
|
| 464 |
+
if offset < len(sample):
|
| 465 |
+
result.extend(cipher.decrypt(sample[offset:]))
|
| 466 |
+
|
| 467 |
+
return result
|
| 468 |
+
|
| 469 |
+
def _process_trun(self, trun: MP4Atom) -> int:
|
| 470 |
+
"""
|
| 471 |
+
Processes the 'trun' (Track Fragment Run) atom, which contains information about the samples in a track fragment.
|
| 472 |
+
This includes sample sizes, durations, flags, and composition time offsets.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
trun (MP4Atom): The 'trun' atom to process.
|
| 476 |
+
|
| 477 |
+
Returns:
|
| 478 |
+
int: The number of samples in the 'trun' atom.
|
| 479 |
+
"""
|
| 480 |
+
trun_flags, sample_count = struct.unpack_from(">II", trun.data, 0)
|
| 481 |
+
data_offset = 8
|
| 482 |
+
|
| 483 |
+
if trun_flags & 0x000001:
|
| 484 |
+
data_offset += 4
|
| 485 |
+
if trun_flags & 0x000004:
|
| 486 |
+
data_offset += 4
|
| 487 |
+
|
| 488 |
+
self.trun_sample_sizes = array.array("I")
|
| 489 |
+
|
| 490 |
+
for _ in range(sample_count):
|
| 491 |
+
if trun_flags & 0x000100: # sample-duration-present flag
|
| 492 |
+
data_offset += 4
|
| 493 |
+
if trun_flags & 0x000200: # sample-size-present flag
|
| 494 |
+
sample_size = struct.unpack_from(">I", trun.data, data_offset)[0]
|
| 495 |
+
self.trun_sample_sizes.append(sample_size)
|
| 496 |
+
data_offset += 4
|
| 497 |
+
else:
|
| 498 |
+
self.trun_sample_sizes.append(0) # Using 0 instead of None for uniformity in the array
|
| 499 |
+
if trun_flags & 0x000400: # sample-flags-present flag
|
| 500 |
+
data_offset += 4
|
| 501 |
+
if trun_flags & 0x000800: # sample-composition-time-offsets-present flag
|
| 502 |
+
data_offset += 4
|
| 503 |
+
|
| 504 |
+
return sample_count
|
| 505 |
+
|
| 506 |
+
def _modify_trun(self, trun: MP4Atom) -> MP4Atom:
|
| 507 |
+
"""
|
| 508 |
+
Modifies the 'trun' (Track Fragment Run) atom to update the data offset.
|
| 509 |
+
This is necessary to account for the encryption overhead.
|
| 510 |
+
|
| 511 |
+
Args:
|
| 512 |
+
trun (MP4Atom): The 'trun' atom to modify.
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
MP4Atom: Modified 'trun' atom with updated data offset.
|
| 516 |
+
"""
|
| 517 |
+
trun_data = bytearray(trun.data)
|
| 518 |
+
current_flags = struct.unpack_from(">I", trun_data, 0)[0] & 0xFFFFFF
|
| 519 |
+
|
| 520 |
+
# If the data-offset-present flag is set, update the data offset to account for encryption overhead
|
| 521 |
+
if current_flags & 0x000001:
|
| 522 |
+
current_data_offset = struct.unpack_from(">i", trun_data, 8)[0]
|
| 523 |
+
struct.pack_into(">i", trun_data, 8, current_data_offset - self.encryption_overhead)
|
| 524 |
+
|
| 525 |
+
return MP4Atom(b"trun", len(trun_data) + 8, trun_data)
|
| 526 |
+
|
| 527 |
+
def _process_sidx(self, sidx: MP4Atom) -> MP4Atom:
|
| 528 |
+
"""
|
| 529 |
+
Processes the 'sidx' (Segment Index) atom, which contains indexing information for media segments.
|
| 530 |
+
This includes references to media segments and their durations.
|
| 531 |
+
|
| 532 |
+
Args:
|
| 533 |
+
sidx (MP4Atom): The 'sidx' atom to process.
|
| 534 |
+
|
| 535 |
+
Returns:
|
| 536 |
+
MP4Atom: Processed 'sidx' atom with updated segment references.
|
| 537 |
+
"""
|
| 538 |
+
sidx_data = bytearray(sidx.data)
|
| 539 |
+
|
| 540 |
+
current_size = struct.unpack_from(">I", sidx_data, 32)[0]
|
| 541 |
+
reference_type = current_size >> 31
|
| 542 |
+
current_referenced_size = current_size & 0x7FFFFFFF
|
| 543 |
+
|
| 544 |
+
# Remove encryption overhead from referenced size
|
| 545 |
+
new_referenced_size = current_referenced_size - self.encryption_overhead
|
| 546 |
+
new_size = (reference_type << 31) | new_referenced_size
|
| 547 |
+
struct.pack_into(">I", sidx_data, 32, new_size)
|
| 548 |
+
|
| 549 |
+
return MP4Atom(b"sidx", len(sidx_data) + 8, sidx_data)
|
| 550 |
+
|
| 551 |
+
def _process_trak(self, trak: MP4Atom) -> MP4Atom:
|
| 552 |
+
"""
|
| 553 |
+
Processes the 'trak' (Track) atom, which contains information about a single track in the movie.
|
| 554 |
+
This includes track header, media information, and other track-level metadata.
|
| 555 |
+
|
| 556 |
+
Args:
|
| 557 |
+
trak (MP4Atom): The 'trak' atom to process.
|
| 558 |
+
|
| 559 |
+
Returns:
|
| 560 |
+
MP4Atom: Processed 'trak' atom with updated track information.
|
| 561 |
+
"""
|
| 562 |
+
parser = MP4Parser(trak.data)
|
| 563 |
+
new_trak_data = bytearray()
|
| 564 |
+
|
| 565 |
+
for atom in iter(parser.read_atom, None):
|
| 566 |
+
if atom.atom_type == b"mdia":
|
| 567 |
+
new_mdia = self._process_mdia(atom)
|
| 568 |
+
new_trak_data.extend(new_mdia.pack())
|
| 569 |
+
else:
|
| 570 |
+
new_trak_data.extend(atom.pack())
|
| 571 |
+
|
| 572 |
+
return MP4Atom(b"trak", len(new_trak_data) + 8, new_trak_data)
|
| 573 |
+
|
| 574 |
+
def _process_mdia(self, mdia: MP4Atom) -> MP4Atom:
|
| 575 |
+
"""
|
| 576 |
+
Processes the 'mdia' (Media) atom, which contains media information for a track.
|
| 577 |
+
This includes media header, handler reference, and media information container.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
mdia (MP4Atom): The 'mdia' atom to process.
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
MP4Atom: Processed 'mdia' atom with updated media information.
|
| 584 |
+
"""
|
| 585 |
+
parser = MP4Parser(mdia.data)
|
| 586 |
+
new_mdia_data = bytearray()
|
| 587 |
+
|
| 588 |
+
for atom in iter(parser.read_atom, None):
|
| 589 |
+
if atom.atom_type == b"minf":
|
| 590 |
+
new_minf = self._process_minf(atom)
|
| 591 |
+
new_mdia_data.extend(new_minf.pack())
|
| 592 |
+
else:
|
| 593 |
+
new_mdia_data.extend(atom.pack())
|
| 594 |
+
|
| 595 |
+
return MP4Atom(b"mdia", len(new_mdia_data) + 8, new_mdia_data)
|
| 596 |
+
|
| 597 |
+
def _process_minf(self, minf: MP4Atom) -> MP4Atom:
|
| 598 |
+
"""
|
| 599 |
+
Processes the 'minf' (Media Information) atom, which contains information about the media data in a track.
|
| 600 |
+
This includes data information, sample table, and other media-level metadata.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
minf (MP4Atom): The 'minf' atom to process.
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
MP4Atom: Processed 'minf' atom with updated media information.
|
| 607 |
+
"""
|
| 608 |
+
parser = MP4Parser(minf.data)
|
| 609 |
+
new_minf_data = bytearray()
|
| 610 |
+
|
| 611 |
+
for atom in iter(parser.read_atom, None):
|
| 612 |
+
if atom.atom_type == b"stbl":
|
| 613 |
+
new_stbl = self._process_stbl(atom)
|
| 614 |
+
new_minf_data.extend(new_stbl.pack())
|
| 615 |
+
else:
|
| 616 |
+
new_minf_data.extend(atom.pack())
|
| 617 |
+
|
| 618 |
+
return MP4Atom(b"minf", len(new_minf_data) + 8, new_minf_data)
|
| 619 |
+
|
| 620 |
+
def _process_stbl(self, stbl: MP4Atom) -> MP4Atom:
|
| 621 |
+
"""
|
| 622 |
+
Processes the 'stbl' (Sample Table) atom, which contains information about the samples in a track.
|
| 623 |
+
This includes sample descriptions, sample sizes, sample times, and other sample-level metadata.
|
| 624 |
+
|
| 625 |
+
Args:
|
| 626 |
+
stbl (MP4Atom): The 'stbl' atom to process.
|
| 627 |
+
|
| 628 |
+
Returns:
|
| 629 |
+
MP4Atom: Processed 'stbl' atom with updated sample information.
|
| 630 |
+
"""
|
| 631 |
+
parser = MP4Parser(stbl.data)
|
| 632 |
+
new_stbl_data = bytearray()
|
| 633 |
+
|
| 634 |
+
for atom in iter(parser.read_atom, None):
|
| 635 |
+
if atom.atom_type == b"stsd":
|
| 636 |
+
new_stsd = self._process_stsd(atom)
|
| 637 |
+
new_stbl_data.extend(new_stsd.pack())
|
| 638 |
+
else:
|
| 639 |
+
new_stbl_data.extend(atom.pack())
|
| 640 |
+
|
| 641 |
+
return MP4Atom(b"stbl", len(new_stbl_data) + 8, new_stbl_data)
|
| 642 |
+
|
| 643 |
+
def _process_stsd(self, stsd: MP4Atom) -> MP4Atom:
|
| 644 |
+
"""
|
| 645 |
+
Processes the 'stsd' (Sample Description) atom, which contains descriptions of the sample entries in a track.
|
| 646 |
+
This includes codec information, sample entry details, and other sample description metadata.
|
| 647 |
+
|
| 648 |
+
Args:
|
| 649 |
+
stsd (MP4Atom): The 'stsd' atom to process.
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
MP4Atom: Processed 'stsd' atom with updated sample descriptions.
|
| 653 |
+
"""
|
| 654 |
+
parser = MP4Parser(stsd.data)
|
| 655 |
+
entry_count = struct.unpack_from(">I", parser.data, 4)[0]
|
| 656 |
+
new_stsd_data = bytearray(stsd.data[:8])
|
| 657 |
+
|
| 658 |
+
parser.position = 8 # Move past version_flags and entry_count
|
| 659 |
+
|
| 660 |
+
for _ in range(entry_count):
|
| 661 |
+
sample_entry = parser.read_atom()
|
| 662 |
+
if not sample_entry:
|
| 663 |
+
break
|
| 664 |
+
|
| 665 |
+
processed_entry = self._process_sample_entry(sample_entry)
|
| 666 |
+
new_stsd_data.extend(processed_entry.pack())
|
| 667 |
+
|
| 668 |
+
return MP4Atom(b"stsd", len(new_stsd_data) + 8, new_stsd_data)
|
| 669 |
+
|
| 670 |
+
def _process_sample_entry(self, entry: MP4Atom) -> MP4Atom:
|
| 671 |
+
"""
|
| 672 |
+
Processes a sample entry atom, which contains information about a specific type of sample.
|
| 673 |
+
This includes codec-specific information and other sample entry details.
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
entry (MP4Atom): The sample entry atom to process.
|
| 677 |
+
|
| 678 |
+
Returns:
|
| 679 |
+
MP4Atom: Processed sample entry atom with updated information.
|
| 680 |
+
"""
|
| 681 |
+
# Determine the size of fixed fields based on sample entry type
|
| 682 |
+
if entry.atom_type in {b"mp4a", b"enca"}:
|
| 683 |
+
fixed_size = 28 # 8 bytes for size, type and reserved, 20 bytes for fixed fields in Audio Sample Entry.
|
| 684 |
+
elif entry.atom_type in {b"mp4v", b"encv", b"avc1", b"hev1", b"hvc1"}:
|
| 685 |
+
fixed_size = 78 # 8 bytes for size, type and reserved, 70 bytes for fixed fields in Video Sample Entry.
|
| 686 |
+
else:
|
| 687 |
+
fixed_size = 16 # 8 bytes for size, type and reserved, 8 bytes for fixed fields in other Sample Entries.
|
| 688 |
+
|
| 689 |
+
new_entry_data = bytearray(entry.data[:fixed_size])
|
| 690 |
+
parser = MP4Parser(entry.data[fixed_size:])
|
| 691 |
+
codec_format = None
|
| 692 |
+
|
| 693 |
+
for atom in iter(parser.read_atom, None):
|
| 694 |
+
if atom.atom_type in {b"sinf", b"schi", b"tenc", b"schm"}:
|
| 695 |
+
if atom.atom_type == b"sinf":
|
| 696 |
+
codec_format = self._extract_codec_format(atom)
|
| 697 |
+
continue # Skip encryption-related atoms
|
| 698 |
+
new_entry_data.extend(atom.pack())
|
| 699 |
+
|
| 700 |
+
# Replace the atom type with the extracted codec format
|
| 701 |
+
new_type = codec_format if codec_format else entry.atom_type
|
| 702 |
+
return MP4Atom(new_type, len(new_entry_data) + 8, new_entry_data)
|
| 703 |
+
|
| 704 |
+
def _extract_codec_format(self, sinf: MP4Atom) -> bytes | None:
|
| 705 |
+
"""
|
| 706 |
+
Extracts the codec format from the 'sinf' (Protection Scheme Information) atom.
|
| 707 |
+
This includes information about the original format of the protected content.
|
| 708 |
+
|
| 709 |
+
Args:
|
| 710 |
+
sinf (MP4Atom): The 'sinf' atom to extract from.
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
bytes | None: The codec format or None if not found.
|
| 714 |
+
"""
|
| 715 |
+
parser = MP4Parser(sinf.data)
|
| 716 |
+
for atom in iter(parser.read_atom, None):
|
| 717 |
+
if atom.atom_type == b"frma":
|
| 718 |
+
return atom.data
|
| 719 |
+
return None
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def decrypt_segment(init_segment: bytes, segment_content: bytes, key_id: str, key: str) -> bytes:
|
| 723 |
+
"""
|
| 724 |
+
Decrypts a CENC encrypted MP4 segment.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
init_segment (bytes): Initialization segment data.
|
| 728 |
+
segment_content (bytes): Encrypted segment content.
|
| 729 |
+
key_id (str): Key ID in hexadecimal format.
|
| 730 |
+
key (str): Key in hexadecimal format.
|
| 731 |
+
"""
|
| 732 |
+
key_map = {bytes.fromhex(key_id): bytes.fromhex(key)}
|
| 733 |
+
decrypter = MP4Decrypter(key_map)
|
| 734 |
+
decrypted_content = decrypter.decrypt_segment(init_segment + segment_content)
|
| 735 |
+
return decrypted_content
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
def cli():
|
| 739 |
+
"""
|
| 740 |
+
Command line interface for decrypting a CENC encrypted MP4 segment.
|
| 741 |
+
"""
|
| 742 |
+
init_segment = b""
|
| 743 |
+
|
| 744 |
+
if args.init and args.segment:
|
| 745 |
+
with open(args.init, "rb") as f:
|
| 746 |
+
init_segment = f.read()
|
| 747 |
+
with open(args.segment, "rb") as f:
|
| 748 |
+
segment_content = f.read()
|
| 749 |
+
elif args.combined_segment:
|
| 750 |
+
with open(args.combined_segment, "rb") as f:
|
| 751 |
+
segment_content = f.read()
|
| 752 |
+
else:
|
| 753 |
+
print("Usage: python mp4decrypt.py --help")
|
| 754 |
+
sys.exit(1)
|
| 755 |
+
|
| 756 |
+
try:
|
| 757 |
+
decrypted_segment = decrypt_segment(init_segment, segment_content, args.key_id, args.key)
|
| 758 |
+
print(f"Decrypted content size is {len(decrypted_segment)} bytes")
|
| 759 |
+
with open(args.output, "wb") as f:
|
| 760 |
+
f.write(decrypted_segment)
|
| 761 |
+
print(f"Decrypted segment written to {args.output}")
|
| 762 |
+
except Exception as e:
|
| 763 |
+
print(f"Error: {e}")
|
| 764 |
+
sys.exit(1)
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
if __name__ == "__main__":
|
| 768 |
+
arg_parser = argparse.ArgumentParser(description="Decrypts a MP4 init and media segment using CENC encryption.")
|
| 769 |
+
arg_parser.add_argument("--init", help="Path to the init segment file", required=False)
|
| 770 |
+
arg_parser.add_argument("--segment", help="Path to the media segment file", required=False)
|
| 771 |
+
arg_parser.add_argument(
|
| 772 |
+
"--combined_segment", help="Path to the combined init and media segment file", required=False
|
| 773 |
+
)
|
| 774 |
+
arg_parser.add_argument("--key_id", help="Key ID in hexadecimal format", required=True)
|
| 775 |
+
arg_parser.add_argument("--key", help="Key in hexadecimal format", required=True)
|
| 776 |
+
arg_parser.add_argument("--output", help="Path to the output file", required=True)
|
| 777 |
+
args = arg_parser.parse_args()
|
| 778 |
+
cli()
|
mediaflow_proxy/handlers.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import httpx
|
| 5 |
+
from fastapi import Request, Response, HTTPException
|
| 6 |
+
from pydantic import HttpUrl
|
| 7 |
+
from starlette.background import BackgroundTask
|
| 8 |
+
|
| 9 |
+
from .configs import settings
|
| 10 |
+
from .const import SUPPORTED_RESPONSE_HEADERS
|
| 11 |
+
from .mpd_processor import process_manifest, process_playlist, process_segment
|
| 12 |
+
from .utils.cache_utils import get_cached_mpd, get_cached_init_segment
|
| 13 |
+
from .utils.http_utils import (
|
| 14 |
+
Streamer,
|
| 15 |
+
DownloadError,
|
| 16 |
+
download_file_with_retry,
|
| 17 |
+
request_with_retry,
|
| 18 |
+
EnhancedStreamingResponse,
|
| 19 |
+
)
|
| 20 |
+
from .utils.m3u8_processor import M3U8Processor
|
| 21 |
+
from .utils.mpd_utils import pad_base64
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def handle_hls_stream_proxy(
|
| 27 |
+
request: Request, destination: str, headers: dict, key_url: HttpUrl = None, verify_ssl: bool = True
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Handles the HLS stream proxy request, fetching and processing the m3u8 playlist or streaming the content.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
request (Request): The incoming HTTP request.
|
| 34 |
+
destination (str): The destination URL to fetch the content from.
|
| 35 |
+
headers (dict): The headers to include in the request.
|
| 36 |
+
key_url (str, optional): The HLS Key URL to replace the original key URL. Defaults to None.
|
| 37 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Response: The HTTP response with the processed m3u8 playlist or streamed content.
|
| 41 |
+
"""
|
| 42 |
+
client = httpx.AsyncClient(
|
| 43 |
+
follow_redirects=True,
|
| 44 |
+
timeout=httpx.Timeout(30.0),
|
| 45 |
+
limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
|
| 46 |
+
proxy=settings.proxy_url,
|
| 47 |
+
verify=verify_ssl,
|
| 48 |
+
)
|
| 49 |
+
streamer = Streamer(client)
|
| 50 |
+
try:
|
| 51 |
+
if destination.endswith((".m3u", ".m3u8")):
|
| 52 |
+
return await fetch_and_process_m3u8(streamer, destination, headers, request, key_url)
|
| 53 |
+
|
| 54 |
+
response = await streamer.head(destination, headers)
|
| 55 |
+
if "mpegurl" in response.headers.get("content-type", "").lower():
|
| 56 |
+
return await fetch_and_process_m3u8(streamer, destination, headers, request, key_url)
|
| 57 |
+
|
| 58 |
+
headers.update({"range": headers.get("range", "bytes=0-")})
|
| 59 |
+
# clean up the headers to only include the necessary headers and remove acl headers
|
| 60 |
+
response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
|
| 61 |
+
|
| 62 |
+
if transfer_encoding := response_headers.get("transfer-encoding"):
|
| 63 |
+
if "chunked" not in transfer_encoding:
|
| 64 |
+
transfer_encoding += ", chunked"
|
| 65 |
+
else:
|
| 66 |
+
transfer_encoding = "chunked"
|
| 67 |
+
response_headers["transfer-encoding"] = transfer_encoding
|
| 68 |
+
|
| 69 |
+
return EnhancedStreamingResponse(
|
| 70 |
+
streamer.stream_content(destination, headers),
|
| 71 |
+
status_code=response.status_code,
|
| 72 |
+
headers=response_headers,
|
| 73 |
+
background=BackgroundTask(streamer.close),
|
| 74 |
+
)
|
| 75 |
+
except httpx.HTTPStatusError as e:
|
| 76 |
+
await client.aclose()
|
| 77 |
+
logger.error(f"Upstream service error while handling request: {e}")
|
| 78 |
+
return Response(status_code=e.response.status_code, content=f"Upstream service error: {e}")
|
| 79 |
+
except DownloadError as e:
|
| 80 |
+
await client.aclose()
|
| 81 |
+
logger.error(f"Error downloading {destination}: {e}")
|
| 82 |
+
return Response(status_code=e.status_code, content=str(e))
|
| 83 |
+
except Exception as e:
|
| 84 |
+
await client.aclose()
|
| 85 |
+
logger.error(f"Internal server error while handling request: {e}")
|
| 86 |
+
return Response(status_code=502, content=f"Internal server error: {e}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
async def proxy_stream(method: str, video_url: str, headers: dict, verify_ssl: bool = True):
|
| 90 |
+
"""
|
| 91 |
+
Proxies the stream request to the given video URL.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
method (str): The HTTP method (e.g., GET, HEAD).
|
| 95 |
+
video_url (str): The URL of the video to stream.
|
| 96 |
+
headers (dict): The headers to include in the request.
|
| 97 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Response: The HTTP response with the streamed content.
|
| 101 |
+
"""
|
| 102 |
+
return await handle_stream_request(method, video_url, headers, verify_ssl)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
async def handle_stream_request(method: str, video_url: str, headers: dict, verify_ssl: bool = True):
|
| 106 |
+
"""
|
| 107 |
+
Handles the stream request, fetching the content from the video URL and streaming it.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
method (str): The HTTP method (e.g., GET, HEAD).
|
| 111 |
+
video_url (str): The URL of the video to stream.
|
| 112 |
+
headers (dict): The headers to include in the request.
|
| 113 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Response: The HTTP response with the streamed content.
|
| 117 |
+
"""
|
| 118 |
+
client = httpx.AsyncClient(
|
| 119 |
+
follow_redirects=True,
|
| 120 |
+
timeout=httpx.Timeout(30.0),
|
| 121 |
+
limits=httpx.Limits(max_keepalive_connections=10, max_connections=20),
|
| 122 |
+
proxy=settings.proxy_url,
|
| 123 |
+
verify=verify_ssl,
|
| 124 |
+
)
|
| 125 |
+
streamer = Streamer(client)
|
| 126 |
+
try:
|
| 127 |
+
response = await streamer.head(video_url, headers)
|
| 128 |
+
# clean up the headers to only include the necessary headers and remove acl headers
|
| 129 |
+
response_headers = {k: v for k, v in response.headers.multi_items() if k in SUPPORTED_RESPONSE_HEADERS}
|
| 130 |
+
if transfer_encoding := response_headers.get("transfer-encoding"):
|
| 131 |
+
if "chunked" not in transfer_encoding:
|
| 132 |
+
transfer_encoding += ", chunked"
|
| 133 |
+
else:
|
| 134 |
+
transfer_encoding = "chunked"
|
| 135 |
+
response_headers["transfer-encoding"] = transfer_encoding
|
| 136 |
+
|
| 137 |
+
if method == "HEAD":
|
| 138 |
+
await streamer.close()
|
| 139 |
+
return Response(headers=response_headers, status_code=response.status_code)
|
| 140 |
+
else:
|
| 141 |
+
return EnhancedStreamingResponse(
|
| 142 |
+
streamer.stream_content(video_url, headers),
|
| 143 |
+
headers=response_headers,
|
| 144 |
+
status_code=response.status_code,
|
| 145 |
+
background=BackgroundTask(streamer.close),
|
| 146 |
+
)
|
| 147 |
+
except httpx.HTTPStatusError as e:
|
| 148 |
+
await client.aclose()
|
| 149 |
+
logger.error(f"Upstream service error while handling {method} request: {e}")
|
| 150 |
+
return Response(status_code=e.response.status_code, content=f"Upstream service error: {e}")
|
| 151 |
+
except DownloadError as e:
|
| 152 |
+
await client.aclose()
|
| 153 |
+
logger.error(f"Error downloading {video_url}: {e}")
|
| 154 |
+
return Response(status_code=e.status_code, content=str(e))
|
| 155 |
+
except Exception as e:
|
| 156 |
+
await client.aclose()
|
| 157 |
+
logger.error(f"Internal server error while handling {method} request: {e}")
|
| 158 |
+
return Response(status_code=502, content=f"Internal server error: {e}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
async def fetch_and_process_m3u8(
|
| 162 |
+
streamer: Streamer, url: str, headers: dict, request: Request, key_url: HttpUrl = None
|
| 163 |
+
):
|
| 164 |
+
"""
|
| 165 |
+
Fetches and processes the m3u8 playlist, converting it to an HLS playlist.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
streamer (Streamer): The HTTP client to use for streaming.
|
| 169 |
+
url (str): The URL of the m3u8 playlist.
|
| 170 |
+
headers (dict): The headers to include in the request.
|
| 171 |
+
request (Request): The incoming HTTP request.
|
| 172 |
+
key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Response: The HTTP response with the processed m3u8 playlist.
|
| 176 |
+
"""
|
| 177 |
+
try:
|
| 178 |
+
content = await streamer.get_text(url, headers)
|
| 179 |
+
processor = M3U8Processor(request, key_url)
|
| 180 |
+
processed_content = await processor.process_m3u8(content, str(streamer.response.url))
|
| 181 |
+
return Response(
|
| 182 |
+
content=processed_content,
|
| 183 |
+
media_type="application/vnd.apple.mpegurl",
|
| 184 |
+
headers={
|
| 185 |
+
"Content-Disposition": "inline",
|
| 186 |
+
"Accept-Ranges": "none",
|
| 187 |
+
},
|
| 188 |
+
)
|
| 189 |
+
except httpx.HTTPStatusError as e:
|
| 190 |
+
logger.error(f"HTTP error while fetching m3u8: {e}")
|
| 191 |
+
return Response(status_code=e.response.status_code, content=str(e))
|
| 192 |
+
except DownloadError as e:
|
| 193 |
+
logger.error(f"Error downloading m3u8: {url}")
|
| 194 |
+
return Response(status_code=502, content=str(e))
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.exception(f"Unexpected error while processing m3u8: {e}")
|
| 197 |
+
return Response(status_code=502, content=str(e))
|
| 198 |
+
finally:
|
| 199 |
+
await streamer.close()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
async def handle_drm_key_data(key_id, key, drm_info):
|
| 203 |
+
"""
|
| 204 |
+
Handles the DRM key data, retrieving the key ID and key from the DRM info if not provided.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
key_id (str): The DRM key ID.
|
| 208 |
+
key (str): The DRM key.
|
| 209 |
+
drm_info (dict): The DRM information from the MPD manifest.
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
tuple: The key ID and key.
|
| 213 |
+
"""
|
| 214 |
+
if drm_info and not drm_info.get("isDrmProtected"):
|
| 215 |
+
return None, None
|
| 216 |
+
|
| 217 |
+
if not key_id or not key:
|
| 218 |
+
if "keyId" in drm_info and "key" in drm_info:
|
| 219 |
+
key_id = drm_info["keyId"]
|
| 220 |
+
key = drm_info["key"]
|
| 221 |
+
elif "laUrl" in drm_info and "keyId" in drm_info:
|
| 222 |
+
raise HTTPException(status_code=400, detail="LA URL is not supported yet")
|
| 223 |
+
else:
|
| 224 |
+
raise HTTPException(
|
| 225 |
+
status_code=400, detail="Unable to determine key_id and key, and they were not provided"
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return key_id, key
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
async def get_manifest(
|
| 232 |
+
request: Request, mpd_url: str, headers: dict, key_id: str = None, key: str = None, verify_ssl: bool = True
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
request (Request): The incoming HTTP request.
|
| 239 |
+
mpd_url (str): The URL of the MPD manifest.
|
| 240 |
+
headers (dict): The headers to include in the request.
|
| 241 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 242 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 243 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Response: The HTTP response with the HLS manifest.
|
| 247 |
+
"""
|
| 248 |
+
try:
|
| 249 |
+
mpd_dict = await get_cached_mpd(
|
| 250 |
+
mpd_url, headers=headers, parse_drm=not key_id and not key, verify_ssl=verify_ssl
|
| 251 |
+
)
|
| 252 |
+
except DownloadError as e:
|
| 253 |
+
raise HTTPException(status_code=e.status_code, detail=f"Failed to download MPD: {e.message}")
|
| 254 |
+
drm_info = mpd_dict.get("drmInfo", {})
|
| 255 |
+
|
| 256 |
+
if drm_info and not drm_info.get("isDrmProtected"):
|
| 257 |
+
# For non-DRM protected MPD, we still create an HLS manifest
|
| 258 |
+
return await process_manifest(request, mpd_dict, None, None)
|
| 259 |
+
|
| 260 |
+
key_id, key = await handle_drm_key_data(key_id, key, drm_info)
|
| 261 |
+
|
| 262 |
+
# check if the provided key_id and key are valid
|
| 263 |
+
if key_id and len(key_id) != 32:
|
| 264 |
+
key_id = base64.urlsafe_b64decode(pad_base64(key_id)).hex()
|
| 265 |
+
if key and len(key) != 32:
|
| 266 |
+
key = base64.urlsafe_b64decode(pad_base64(key)).hex()
|
| 267 |
+
|
| 268 |
+
return await process_manifest(request, mpd_dict, key_id, key)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
async def get_playlist(
|
| 272 |
+
request: Request,
|
| 273 |
+
mpd_url: str,
|
| 274 |
+
profile_id: str,
|
| 275 |
+
headers: dict,
|
| 276 |
+
key_id: str = None,
|
| 277 |
+
key: str = None,
|
| 278 |
+
verify_ssl: bool = True,
|
| 279 |
+
):
|
| 280 |
+
"""
|
| 281 |
+
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
request (Request): The incoming HTTP request.
|
| 285 |
+
mpd_url (str): The URL of the MPD manifest.
|
| 286 |
+
profile_id (str): The profile ID to generate the playlist for.
|
| 287 |
+
headers (dict): The headers to include in the request.
|
| 288 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 289 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 290 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 291 |
+
|
| 292 |
+
Returns:
|
| 293 |
+
Response: The HTTP response with the HLS playlist.
|
| 294 |
+
"""
|
| 295 |
+
mpd_dict = await get_cached_mpd(
|
| 296 |
+
mpd_url,
|
| 297 |
+
headers=headers,
|
| 298 |
+
parse_drm=not key_id and not key,
|
| 299 |
+
parse_segment_profile_id=profile_id,
|
| 300 |
+
verify_ssl=verify_ssl,
|
| 301 |
+
)
|
| 302 |
+
return await process_playlist(request, mpd_dict, profile_id)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
async def get_segment(
|
| 306 |
+
init_url: str,
|
| 307 |
+
segment_url: str,
|
| 308 |
+
mimetype: str,
|
| 309 |
+
headers: dict,
|
| 310 |
+
key_id: str = None,
|
| 311 |
+
key: str = None,
|
| 312 |
+
verify_ssl: bool = True,
|
| 313 |
+
):
|
| 314 |
+
"""
|
| 315 |
+
Retrieves and processes a media segment, decrypting it if necessary.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
init_url (str): The URL of the initialization segment.
|
| 319 |
+
segment_url (str): The URL of the media segment.
|
| 320 |
+
mimetype (str): The MIME type of the segment.
|
| 321 |
+
headers (dict): The headers to include in the request.
|
| 322 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 323 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 324 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
Response: The HTTP response with the processed segment.
|
| 328 |
+
"""
|
| 329 |
+
try:
|
| 330 |
+
init_content = await get_cached_init_segment(init_url, headers, verify_ssl)
|
| 331 |
+
segment_content = await download_file_with_retry(segment_url, headers, verify_ssl=verify_ssl)
|
| 332 |
+
except DownloadError as e:
|
| 333 |
+
raise HTTPException(status_code=e.status_code, detail=f"Failed to download segment: {e.message}")
|
| 334 |
+
return await process_segment(init_content, segment_content, mimetype, key_id, key)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
async def get_public_ip():
|
| 338 |
+
"""
|
| 339 |
+
Retrieves the public IP address of the MediaFlow proxy.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
Response: The HTTP response with the public IP address.
|
| 343 |
+
"""
|
| 344 |
+
ip_address_data = await request_with_retry("GET", "https://api.ipify.org?format=json", {})
|
| 345 |
+
return ip_address_data.json()
|
mediaflow_proxy/main.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from importlib import resources
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI, Depends, Security, HTTPException
|
| 5 |
+
from fastapi.security import APIKeyQuery, APIKeyHeader
|
| 6 |
+
from starlette.responses import RedirectResponse
|
| 7 |
+
from starlette.staticfiles import StaticFiles
|
| 8 |
+
|
| 9 |
+
from mediaflow_proxy.configs import settings
|
| 10 |
+
from mediaflow_proxy.routes import proxy_router
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 13 |
+
app = FastAPI()
|
| 14 |
+
api_password_query = APIKeyQuery(name="api_password", auto_error=False)
|
| 15 |
+
api_password_header = APIKeyHeader(name="api_password", auto_error=False)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
async def verify_api_key(api_key: str = Security(api_password_query), api_key_alt: str = Security(api_password_header)):
|
| 19 |
+
"""
|
| 20 |
+
Verifies the API key for the request.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
api_key (str): The API key to validate.
|
| 24 |
+
api_key_alt (str): The alternative API key to validate.
|
| 25 |
+
|
| 26 |
+
Raises:
|
| 27 |
+
HTTPException: If the API key is invalid.
|
| 28 |
+
"""
|
| 29 |
+
if api_key == settings.api_password or api_key_alt == settings.api_password:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
raise HTTPException(status_code=403, detail="Could not validate credentials")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.get("/health")
|
| 36 |
+
async def health_check():
|
| 37 |
+
return {"status": "healthy"}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@app.get("/favicon.ico")
|
| 41 |
+
async def get_favicon():
|
| 42 |
+
return RedirectResponse(url="/logo.png")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
app.include_router(proxy_router, prefix="/proxy", tags=["proxy"], dependencies=[Depends(verify_api_key)])
|
| 46 |
+
|
| 47 |
+
static_path = resources.files("mediaflow_proxy").joinpath("static")
|
| 48 |
+
app.mount("/", StaticFiles(directory=str(static_path), html=True), name="static")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def run():
|
| 52 |
+
import uvicorn
|
| 53 |
+
|
| 54 |
+
uvicorn.run(app, host="127.0.0.1", port=8888)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
run()
|
mediaflow_proxy/mpd_processor.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
from datetime import datetime, timezone, timedelta
|
| 5 |
+
|
| 6 |
+
from fastapi import Request, Response, HTTPException
|
| 7 |
+
|
| 8 |
+
from mediaflow_proxy.configs import settings
|
| 9 |
+
from mediaflow_proxy.drm.decrypter import decrypt_segment
|
| 10 |
+
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def process_manifest(request: Request, mpd_dict: dict, key_id: str = None, key: str = None) -> Response:
|
| 16 |
+
"""
|
| 17 |
+
Processes the MPD manifest and converts it to an HLS manifest.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
request (Request): The incoming HTTP request.
|
| 21 |
+
mpd_dict (dict): The MPD manifest data.
|
| 22 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 23 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Response: The HLS manifest as an HTTP response.
|
| 27 |
+
"""
|
| 28 |
+
hls_content = build_hls(mpd_dict, request, key_id, key)
|
| 29 |
+
return Response(content=hls_content, media_type="application/vnd.apple.mpegurl")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
async def process_playlist(request: Request, mpd_dict: dict, profile_id: str) -> Response:
|
| 33 |
+
"""
|
| 34 |
+
Processes the MPD manifest and converts it to an HLS playlist for a specific profile.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
request (Request): The incoming HTTP request.
|
| 38 |
+
mpd_dict (dict): The MPD manifest data.
|
| 39 |
+
profile_id (str): The profile ID to generate the playlist for.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Response: The HLS playlist as an HTTP response.
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
HTTPException: If the profile is not found in the MPD manifest.
|
| 46 |
+
"""
|
| 47 |
+
matching_profiles = [p for p in mpd_dict["profiles"] if p["id"] == profile_id]
|
| 48 |
+
if not matching_profiles:
|
| 49 |
+
raise HTTPException(status_code=404, detail="Profile not found")
|
| 50 |
+
|
| 51 |
+
hls_content = build_hls_playlist(mpd_dict, matching_profiles, request)
|
| 52 |
+
return Response(content=hls_content, media_type="application/vnd.apple.mpegurl")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
async def process_segment(
|
| 56 |
+
init_content: bytes,
|
| 57 |
+
segment_content: bytes,
|
| 58 |
+
mimetype: str,
|
| 59 |
+
key_id: str = None,
|
| 60 |
+
key: str = None,
|
| 61 |
+
) -> Response:
|
| 62 |
+
"""
|
| 63 |
+
Processes and decrypts a media segment.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
init_content (bytes): The initialization segment content.
|
| 67 |
+
segment_content (bytes): The media segment content.
|
| 68 |
+
mimetype (str): The MIME type of the segment.
|
| 69 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 70 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Response: The decrypted segment as an HTTP response.
|
| 74 |
+
"""
|
| 75 |
+
if key_id and key:
|
| 76 |
+
# For DRM protected content
|
| 77 |
+
now = time.time()
|
| 78 |
+
decrypted_content = decrypt_segment(init_content, segment_content, key_id, key)
|
| 79 |
+
logger.info(f"Decryption of {mimetype} segment took {time.time() - now:.4f} seconds")
|
| 80 |
+
else:
|
| 81 |
+
# For non-DRM protected content, we just concatenate init and segment content
|
| 82 |
+
decrypted_content = init_content + segment_content
|
| 83 |
+
|
| 84 |
+
return Response(content=decrypted_content, media_type=mimetype)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def build_hls(mpd_dict: dict, request: Request, key_id: str = None, key: str = None) -> str:
|
| 88 |
+
"""
|
| 89 |
+
Builds an HLS manifest from the MPD manifest.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
mpd_dict (dict): The MPD manifest data.
|
| 93 |
+
request (Request): The incoming HTTP request.
|
| 94 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 95 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
str: The HLS manifest as a string.
|
| 99 |
+
"""
|
| 100 |
+
hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
|
| 101 |
+
query_params = dict(request.query_params)
|
| 102 |
+
|
| 103 |
+
video_profiles = {}
|
| 104 |
+
audio_profiles = {}
|
| 105 |
+
|
| 106 |
+
# Get the base URL for the playlist_endpoint endpoint
|
| 107 |
+
proxy_url = request.url_for("playlist_endpoint")
|
| 108 |
+
proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
|
| 109 |
+
|
| 110 |
+
for profile in mpd_dict["profiles"]:
|
| 111 |
+
query_params.update({"profile_id": profile["id"], "key_id": key_id or "", "key": key or ""})
|
| 112 |
+
playlist_url = encode_mediaflow_proxy_url(
|
| 113 |
+
proxy_url,
|
| 114 |
+
query_params=query_params,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if "video" in profile["mimeType"]:
|
| 118 |
+
video_profiles[profile["id"]] = (profile, playlist_url)
|
| 119 |
+
elif "audio" in profile["mimeType"]:
|
| 120 |
+
audio_profiles[profile["id"]] = (profile, playlist_url)
|
| 121 |
+
|
| 122 |
+
# Add audio streams
|
| 123 |
+
for i, (profile, playlist_url) in enumerate(audio_profiles.values()):
|
| 124 |
+
is_default = "YES" if i == 0 else "NO" # Set the first audio track as default
|
| 125 |
+
hls.append(
|
| 126 |
+
f'#EXT-X-MEDIA:TYPE=AUDIO,GROUP-ID="audio",NAME="{profile["id"]}",DEFAULT={is_default},AUTOSELECT={is_default},LANGUAGE="{profile.get("lang", "und")}",URI="{playlist_url}"'
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Add video streams
|
| 130 |
+
for profile, playlist_url in video_profiles.values():
|
| 131 |
+
hls.append(
|
| 132 |
+
f'#EXT-X-STREAM-INF:BANDWIDTH={profile["bandwidth"]},RESOLUTION={profile["width"]}x{profile["height"]},CODECS="{profile["codecs"]}",FRAME-RATE={profile["frameRate"]},AUDIO="audio"'
|
| 133 |
+
)
|
| 134 |
+
hls.append(playlist_url)
|
| 135 |
+
|
| 136 |
+
return "\n".join(hls)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def build_hls_playlist(mpd_dict: dict, profiles: list[dict], request: Request) -> str:
|
| 140 |
+
"""
|
| 141 |
+
Builds an HLS playlist from the MPD manifest for specific profiles.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
mpd_dict (dict): The MPD manifest data.
|
| 145 |
+
profiles (list[dict]): The profiles to include in the playlist.
|
| 146 |
+
request (Request): The incoming HTTP request.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
str: The HLS playlist as a string.
|
| 150 |
+
"""
|
| 151 |
+
hls = ["#EXTM3U", "#EXT-X-VERSION:6"]
|
| 152 |
+
|
| 153 |
+
added_segments = 0
|
| 154 |
+
current_time = datetime.now(timezone.utc)
|
| 155 |
+
live_stream_delay = timedelta(seconds=settings.mpd_live_stream_delay)
|
| 156 |
+
target_end_time = current_time - live_stream_delay
|
| 157 |
+
|
| 158 |
+
proxy_url = request.url_for("segment_endpoint")
|
| 159 |
+
proxy_url = str(proxy_url.replace(scheme=get_original_scheme(request)))
|
| 160 |
+
|
| 161 |
+
for index, profile in enumerate(profiles):
|
| 162 |
+
segments = profile["segments"]
|
| 163 |
+
if not segments:
|
| 164 |
+
logger.warning(f"No segments found for profile {profile['id']}")
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
# Add headers for only the first profile
|
| 168 |
+
if index == 0:
|
| 169 |
+
sequence = segments[0]["number"]
|
| 170 |
+
extinf_values = [f["extinf"] for f in segments if "extinf" in f]
|
| 171 |
+
target_duration = math.ceil(max(extinf_values)) if extinf_values else 3
|
| 172 |
+
hls.extend(
|
| 173 |
+
[
|
| 174 |
+
f"#EXT-X-TARGETDURATION:{target_duration}",
|
| 175 |
+
f"#EXT-X-MEDIA-SEQUENCE:{sequence}",
|
| 176 |
+
]
|
| 177 |
+
)
|
| 178 |
+
if mpd_dict["isLive"]:
|
| 179 |
+
hls.append("#EXT-X-PLAYLIST-TYPE:EVENT")
|
| 180 |
+
else:
|
| 181 |
+
hls.append("#EXT-X-PLAYLIST-TYPE:VOD")
|
| 182 |
+
|
| 183 |
+
init_url = profile["initUrl"]
|
| 184 |
+
|
| 185 |
+
query_params = dict(request.query_params)
|
| 186 |
+
query_params.pop("profile_id", None)
|
| 187 |
+
query_params.pop("d", None)
|
| 188 |
+
|
| 189 |
+
for segment in segments:
|
| 190 |
+
if mpd_dict["isLive"]:
|
| 191 |
+
if segment["end_time"] > target_end_time:
|
| 192 |
+
continue
|
| 193 |
+
hls.append(f"#EXT-X-PROGRAM-DATE-TIME:{segment['program_date_time']}")
|
| 194 |
+
hls.append(f'#EXTINF:{segment["extinf"]:.3f},')
|
| 195 |
+
query_params.update(
|
| 196 |
+
{"init_url": init_url, "segment_url": segment["media"], "mime_type": profile["mimeType"]}
|
| 197 |
+
)
|
| 198 |
+
hls.append(
|
| 199 |
+
encode_mediaflow_proxy_url(
|
| 200 |
+
proxy_url,
|
| 201 |
+
query_params=query_params,
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
added_segments += 1
|
| 205 |
+
|
| 206 |
+
if not mpd_dict["isLive"]:
|
| 207 |
+
hls.append("#EXT-X-ENDLIST")
|
| 208 |
+
|
| 209 |
+
logger.info(f"Added {added_segments} segments to HLS playlist")
|
| 210 |
+
return "\n".join(hls)
|
mediaflow_proxy/routes.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import Request, Depends, APIRouter
|
| 2 |
+
from pydantic import HttpUrl
|
| 3 |
+
|
| 4 |
+
from .handlers import handle_hls_stream_proxy, proxy_stream, get_manifest, get_playlist, get_segment, get_public_ip
|
| 5 |
+
from .utils.http_utils import get_proxy_headers
|
| 6 |
+
|
| 7 |
+
proxy_router = APIRouter()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@proxy_router.head("/hls")
|
| 11 |
+
@proxy_router.get("/hls")
|
| 12 |
+
async def hls_stream_proxy(
|
| 13 |
+
request: Request,
|
| 14 |
+
d: HttpUrl,
|
| 15 |
+
headers: dict = Depends(get_proxy_headers),
|
| 16 |
+
key_url: HttpUrl | None = None,
|
| 17 |
+
verify_ssl: bool = False,
|
| 18 |
+
):
|
| 19 |
+
"""
|
| 20 |
+
Proxify HLS stream requests, fetching and processing the m3u8 playlist or streaming the content.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
request (Request): The incoming HTTP request.
|
| 24 |
+
d (HttpUrl): The destination URL to fetch the content from.
|
| 25 |
+
key_url (HttpUrl, optional): The HLS Key URL to replace the original key URL. Defaults to None. (Useful for bypassing some sneaky protection)
|
| 26 |
+
headers (dict): The headers to include in the request.
|
| 27 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Response: The HTTP response with the processed m3u8 playlist or streamed content.
|
| 31 |
+
"""
|
| 32 |
+
destination = str(d)
|
| 33 |
+
return await handle_hls_stream_proxy(request, destination, headers, key_url, verify_ssl)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@proxy_router.head("/stream")
|
| 37 |
+
@proxy_router.get("/stream")
|
| 38 |
+
async def proxy_stream_endpoint(
|
| 39 |
+
request: Request, d: HttpUrl, headers: dict = Depends(get_proxy_headers), verify_ssl: bool = False
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Proxies stream requests to the given video URL.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
request (Request): The incoming HTTP request.
|
| 46 |
+
d (HttpUrl): The URL of the video to stream.
|
| 47 |
+
headers (dict): The headers to include in the request.
|
| 48 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Response: The HTTP response with the streamed content.
|
| 52 |
+
"""
|
| 53 |
+
headers.update({"range": headers.get("range", "bytes=0-")})
|
| 54 |
+
return await proxy_stream(request.method, str(d), headers, verify_ssl)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@proxy_router.get("/mpd/manifest")
|
| 58 |
+
async def manifest_endpoint(
|
| 59 |
+
request: Request,
|
| 60 |
+
d: HttpUrl,
|
| 61 |
+
headers: dict = Depends(get_proxy_headers),
|
| 62 |
+
key_id: str = None,
|
| 63 |
+
key: str = None,
|
| 64 |
+
verify_ssl: bool = False,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
Retrieves and processes the MPD manifest, converting it to an HLS manifest.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
request (Request): The incoming HTTP request.
|
| 71 |
+
d (HttpUrl): The URL of the MPD manifest.
|
| 72 |
+
headers (dict): The headers to include in the request.
|
| 73 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 74 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 75 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Response: The HTTP response with the HLS manifest.
|
| 79 |
+
"""
|
| 80 |
+
return await get_manifest(request, str(d), headers, key_id, key, verify_ssl)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@proxy_router.get("/mpd/playlist")
|
| 84 |
+
async def playlist_endpoint(
|
| 85 |
+
request: Request,
|
| 86 |
+
d: HttpUrl,
|
| 87 |
+
profile_id: str,
|
| 88 |
+
headers: dict = Depends(get_proxy_headers),
|
| 89 |
+
key_id: str = None,
|
| 90 |
+
key: str = None,
|
| 91 |
+
verify_ssl: bool = False,
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Retrieves and processes the MPD manifest, converting it to an HLS playlist for a specific profile.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
request (Request): The incoming HTTP request.
|
| 98 |
+
d (HttpUrl): The URL of the MPD manifest.
|
| 99 |
+
profile_id (str): The profile ID to generate the playlist for.
|
| 100 |
+
headers (dict): The headers to include in the request.
|
| 101 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 102 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 103 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
Response: The HTTP response with the HLS playlist.
|
| 107 |
+
"""
|
| 108 |
+
return await get_playlist(request, str(d), profile_id, headers, key_id, key, verify_ssl)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@proxy_router.get("/mpd/segment")
|
| 112 |
+
async def segment_endpoint(
|
| 113 |
+
init_url: HttpUrl,
|
| 114 |
+
segment_url: HttpUrl,
|
| 115 |
+
mime_type: str,
|
| 116 |
+
headers: dict = Depends(get_proxy_headers),
|
| 117 |
+
key_id: str = None,
|
| 118 |
+
key: str = None,
|
| 119 |
+
verify_ssl: bool = False,
|
| 120 |
+
):
|
| 121 |
+
"""
|
| 122 |
+
Retrieves and processes a media segment, decrypting it if necessary.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
init_url (HttpUrl): The URL of the initialization segment.
|
| 126 |
+
segment_url (HttpUrl): The URL of the media segment.
|
| 127 |
+
mime_type (str): The MIME type of the segment.
|
| 128 |
+
headers (dict): The headers to include in the request.
|
| 129 |
+
key_id (str, optional): The DRM key ID. Defaults to None.
|
| 130 |
+
key (str, optional): The DRM key. Defaults to None.
|
| 131 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to False.
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Response: The HTTP response with the processed segment.
|
| 135 |
+
"""
|
| 136 |
+
return await get_segment(str(init_url), str(segment_url), mime_type, headers, key_id, key, verify_ssl)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@proxy_router.get("/ip")
|
| 140 |
+
async def get_mediaflow_proxy_public_ip():
|
| 141 |
+
"""
|
| 142 |
+
Retrieves the public IP address of the MediaFlow proxy server.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Response: The HTTP response with the public IP address in the form of a JSON object. {"ip": "xxx.xxx.xxx.xxx"}
|
| 146 |
+
"""
|
| 147 |
+
return await get_public_ip()
|
mediaflow_proxy/static/index.html
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>MediaFlow Proxy</title>
|
| 7 |
+
<link rel="icon" href="/logo.png" type="image/x-icon">
|
| 8 |
+
<style>
|
| 9 |
+
body {
|
| 10 |
+
font-family: Arial, sans-serif;
|
| 11 |
+
line-height: 1.6;
|
| 12 |
+
color: #333;
|
| 13 |
+
max-width: 800px;
|
| 14 |
+
margin: 0 auto;
|
| 15 |
+
padding: 20px;
|
| 16 |
+
background-color: #f9f9f9;
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
header {
|
| 20 |
+
background-color: #90aacc;
|
| 21 |
+
color: #fff;
|
| 22 |
+
padding: 10px 0;
|
| 23 |
+
text-align: center;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
header img {
|
| 27 |
+
width: 200px;
|
| 28 |
+
height: 200px;
|
| 29 |
+
vertical-align: middle;
|
| 30 |
+
border-radius: 15px;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
header h1 {
|
| 34 |
+
display: inline;
|
| 35 |
+
margin-left: 20px;
|
| 36 |
+
font-size: 36px;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
.feature {
|
| 40 |
+
background-color: #f4f4f4;
|
| 41 |
+
border-left: 4px solid #3498db;
|
| 42 |
+
padding: 10px;
|
| 43 |
+
margin-bottom: 10px;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
a {
|
| 47 |
+
color: #3498db;
|
| 48 |
+
}
|
| 49 |
+
</style>
|
| 50 |
+
</head>
|
| 51 |
+
<body>
|
| 52 |
+
<header>
|
| 53 |
+
<img src="/logo.png" alt="MediaFlow Proxy Logo">
|
| 54 |
+
<h1>MediaFlow Proxy</h1>
|
| 55 |
+
</header>
|
| 56 |
+
<p>A high-performance proxy server for streaming media, supporting HTTP(S), HLS, and MPEG-DASH with real-time DRM decryption.</p>
|
| 57 |
+
|
| 58 |
+
<h2>Key Features</h2>
|
| 59 |
+
<div class="feature">Convert MPEG-DASH streams (DRM-protected and non-protected) to HLS</div>
|
| 60 |
+
<div class="feature">Support for Clear Key DRM-protected MPD DASH streams</div>
|
| 61 |
+
<div class="feature">Handle both live and video-on-demand (VOD) DASH streams</div>
|
| 62 |
+
<div class="feature">Proxy HTTP/HTTPS links with custom headers</div>
|
| 63 |
+
<div class="feature">Proxy and modify HLS (M3U8) streams in real-time with custom headers and key URL modifications for bypassing some sneaky restrictions.</div>
|
| 64 |
+
<div class="feature">Protect against unauthorized access and network bandwidth abuses</div>
|
| 65 |
+
|
| 66 |
+
<h2>Getting Started</h2>
|
| 67 |
+
<p>Visit the <a href="https://github.com/mhdzumair/mediaflow-proxy">GitHub repository</a> for installation instructions and documentation.</p>
|
| 68 |
+
|
| 69 |
+
<h2>Premium Hosted Service</h2>
|
| 70 |
+
<p>For a hassle-free experience, check out <a href="https://store.elfhosted.com/product/mediaflow-proxy">premium hosted service on ElfHosted</a>.</p>
|
| 71 |
+
|
| 72 |
+
<h2>API Documentation</h2>
|
| 73 |
+
<p>Explore the <a href="/docs">Swagger UI</a> for comprehensive details about the API endpoints and their usage.</p>
|
| 74 |
+
|
| 75 |
+
</body>
|
| 76 |
+
</html>
|
mediaflow_proxy/static/logo.png
ADDED
|
mediaflow_proxy/utils/__init__.py
ADDED
|
File without changes
|
mediaflow_proxy/utils/cache_utils.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from cachetools import TTLCache
|
| 5 |
+
|
| 6 |
+
from .http_utils import download_file_with_retry
|
| 7 |
+
from .mpd_utils import parse_mpd, parse_mpd_dict
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
# cache dictionary
|
| 12 |
+
mpd_cache = TTLCache(maxsize=100, ttl=300) # 5 minutes default TTL
|
| 13 |
+
init_segment_cache = TTLCache(maxsize=100, ttl=3600) # 1 hour default TTL
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
async def get_cached_mpd(
|
| 17 |
+
mpd_url: str, headers: dict, parse_drm: bool, parse_segment_profile_id: str | None = None, verify_ssl: bool = True
|
| 18 |
+
) -> dict:
|
| 19 |
+
"""
|
| 20 |
+
Retrieves and caches the MPD manifest, parsing it if not already cached.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
mpd_url (str): The URL of the MPD manifest.
|
| 24 |
+
headers (dict): The headers to include in the request.
|
| 25 |
+
parse_drm (bool): Whether to parse DRM information.
|
| 26 |
+
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
| 27 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
dict: The parsed MPD manifest data.
|
| 31 |
+
"""
|
| 32 |
+
current_time = datetime.datetime.now(datetime.UTC)
|
| 33 |
+
if mpd_url in mpd_cache and mpd_cache[mpd_url]["expires"] > current_time:
|
| 34 |
+
logger.info(f"Using cached MPD for {mpd_url}")
|
| 35 |
+
return parse_mpd_dict(mpd_cache[mpd_url]["mpd"], mpd_url, parse_drm, parse_segment_profile_id)
|
| 36 |
+
|
| 37 |
+
mpd_dict = parse_mpd(await download_file_with_retry(mpd_url, headers, verify_ssl=verify_ssl))
|
| 38 |
+
parsed_mpd_dict = parse_mpd_dict(mpd_dict, mpd_url, parse_drm, parse_segment_profile_id)
|
| 39 |
+
current_time = datetime.datetime.now(datetime.UTC)
|
| 40 |
+
expiration_time = current_time + datetime.timedelta(seconds=parsed_mpd_dict.get("minimumUpdatePeriod", 300))
|
| 41 |
+
mpd_cache[mpd_url] = {"mpd": mpd_dict, "expires": expiration_time}
|
| 42 |
+
return parsed_mpd_dict
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
async def get_cached_init_segment(init_url: str, headers: dict, verify_ssl: bool = True) -> bytes:
|
| 46 |
+
"""
|
| 47 |
+
Retrieves and caches the initialization segment.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
init_url (str): The URL of the initialization segment.
|
| 51 |
+
headers (dict): The headers to include in the request.
|
| 52 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
bytes: The initialization segment content.
|
| 56 |
+
"""
|
| 57 |
+
if init_url not in init_segment_cache:
|
| 58 |
+
init_content = await download_file_with_retry(init_url, headers, verify_ssl=verify_ssl)
|
| 59 |
+
init_segment_cache[init_url] = init_content
|
| 60 |
+
return init_segment_cache[init_url]
|
mediaflow_proxy/utils/http_utils.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import typing
|
| 3 |
+
from functools import partial
|
| 4 |
+
from urllib import parse
|
| 5 |
+
|
| 6 |
+
import anyio
|
| 7 |
+
import httpx
|
| 8 |
+
import tenacity
|
| 9 |
+
from fastapi import Response
|
| 10 |
+
from starlette.background import BackgroundTask
|
| 11 |
+
from starlette.concurrency import iterate_in_threadpool
|
| 12 |
+
from starlette.requests import Request
|
| 13 |
+
from starlette.types import Receive, Send, Scope
|
| 14 |
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
| 15 |
+
|
| 16 |
+
from mediaflow_proxy.configs import settings
|
| 17 |
+
from mediaflow_proxy.const import SUPPORTED_REQUEST_HEADERS
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DownloadError(Exception):
|
| 23 |
+
def __init__(self, status_code, message):
|
| 24 |
+
self.status_code = status_code
|
| 25 |
+
self.message = message
|
| 26 |
+
super().__init__(message)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@retry(
|
| 30 |
+
stop=stop_after_attempt(3),
|
| 31 |
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
| 32 |
+
retry=retry_if_exception_type(DownloadError),
|
| 33 |
+
)
|
| 34 |
+
async def fetch_with_retry(client, method, url, headers, follow_redirects=True, **kwargs):
|
| 35 |
+
"""
|
| 36 |
+
Fetches a URL with retry logic.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
client (httpx.AsyncClient): The HTTP client to use for the request.
|
| 40 |
+
method (str): The HTTP method to use (e.g., GET, POST).
|
| 41 |
+
url (str): The URL to fetch.
|
| 42 |
+
headers (dict): The headers to include in the request.
|
| 43 |
+
follow_redirects (bool, optional): Whether to follow redirects. Defaults to True.
|
| 44 |
+
**kwargs: Additional arguments to pass to the request.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
httpx.Response: The HTTP response.
|
| 48 |
+
|
| 49 |
+
Raises:
|
| 50 |
+
DownloadError: If the request fails after retries.
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
response = await client.request(method, url, headers=headers, follow_redirects=follow_redirects, **kwargs)
|
| 54 |
+
response.raise_for_status()
|
| 55 |
+
return response
|
| 56 |
+
except httpx.TimeoutException:
|
| 57 |
+
logger.warning(f"Timeout while downloading {url}")
|
| 58 |
+
raise DownloadError(409, f"Timeout while downloading {url}")
|
| 59 |
+
except httpx.HTTPStatusError as e:
|
| 60 |
+
logger.error(f"HTTP error {e.response.status_code} while downloading {url}")
|
| 61 |
+
# if e.response.status_code == 404:
|
| 62 |
+
# logger.error(f"Segment Resource not found: {url}")
|
| 63 |
+
# raise e
|
| 64 |
+
raise DownloadError(e.response.status_code, f"HTTP error {e.response.status_code} while downloading {url}")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
logger.error(f"Error downloading {url}: {e}")
|
| 67 |
+
raise
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Streamer:
|
| 71 |
+
def __init__(self, client):
|
| 72 |
+
"""
|
| 73 |
+
Initializes the Streamer with an HTTP client.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
client (httpx.AsyncClient): The HTTP client to use for streaming.
|
| 77 |
+
"""
|
| 78 |
+
self.client = client
|
| 79 |
+
self.response = None
|
| 80 |
+
|
| 81 |
+
async def stream_content(self, url: str, headers: dict):
|
| 82 |
+
"""
|
| 83 |
+
Streams content from a URL.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
url (str): The URL to stream content from.
|
| 87 |
+
headers (dict): The headers to include in the request.
|
| 88 |
+
|
| 89 |
+
Yields:
|
| 90 |
+
bytes: Chunks of the streamed content.
|
| 91 |
+
"""
|
| 92 |
+
async with self.client.stream("GET", url, headers=headers, follow_redirects=True) as self.response:
|
| 93 |
+
self.response.raise_for_status()
|
| 94 |
+
async for chunk in self.response.aiter_raw():
|
| 95 |
+
yield chunk
|
| 96 |
+
|
| 97 |
+
async def head(self, url: str, headers: dict):
|
| 98 |
+
"""
|
| 99 |
+
Sends a HEAD request to a URL.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
url (str): The URL to send the HEAD request to.
|
| 103 |
+
headers (dict): The headers to include in the request.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
httpx.Response: The HTTP response.
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
self.response = await fetch_with_retry(self.client, "HEAD", url, headers)
|
| 110 |
+
except tenacity.RetryError as e:
|
| 111 |
+
raise e.last_attempt.result()
|
| 112 |
+
return self.response
|
| 113 |
+
|
| 114 |
+
async def get_text(self, url: str, headers: dict):
|
| 115 |
+
"""
|
| 116 |
+
Sends a GET request to a URL and returns the response text.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
url (str): The URL to send the GET request to.
|
| 120 |
+
headers (dict): The headers to include in the request.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
str: The response text.
|
| 124 |
+
"""
|
| 125 |
+
try:
|
| 126 |
+
self.response = await fetch_with_retry(self.client, "GET", url, headers)
|
| 127 |
+
except tenacity.RetryError as e:
|
| 128 |
+
raise e.last_attempt.result()
|
| 129 |
+
return self.response.text
|
| 130 |
+
|
| 131 |
+
async def close(self):
|
| 132 |
+
"""
|
| 133 |
+
Closes the HTTP client and response.
|
| 134 |
+
"""
|
| 135 |
+
if self.response:
|
| 136 |
+
await self.response.aclose()
|
| 137 |
+
await self.client.aclose()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
async def download_file_with_retry(url: str, headers: dict, timeout: float = 10.0, verify_ssl: bool = True):
|
| 141 |
+
"""
|
| 142 |
+
Downloads a file with retry logic.
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
url (str): The URL of the file to download.
|
| 146 |
+
headers (dict): The headers to include in the request.
|
| 147 |
+
timeout (float, optional): The request timeout. Defaults to 10.0.
|
| 148 |
+
verify_ssl (bool, optional): Whether to verify the SSL certificate of the destination. Defaults to True.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
bytes: The downloaded file content.
|
| 152 |
+
|
| 153 |
+
Raises:
|
| 154 |
+
DownloadError: If the download fails after retries.
|
| 155 |
+
"""
|
| 156 |
+
async with httpx.AsyncClient(
|
| 157 |
+
follow_redirects=True, timeout=timeout, proxy=settings.proxy_url, verify=verify_ssl
|
| 158 |
+
) as client:
|
| 159 |
+
try:
|
| 160 |
+
response = await fetch_with_retry(client, "GET", url, headers)
|
| 161 |
+
return response.content
|
| 162 |
+
except DownloadError as e:
|
| 163 |
+
logger.error(f"Failed to download file: {e}")
|
| 164 |
+
raise e
|
| 165 |
+
except tenacity.RetryError as e:
|
| 166 |
+
raise DownloadError(502, f"Failed to download file: {e.last_attempt.result()}")
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
async def request_with_retry(method: str, url: str, headers: dict, timeout: float = 10.0, **kwargs):
|
| 170 |
+
"""
|
| 171 |
+
Sends an HTTP request with retry logic.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
method (str): The HTTP method to use (e.g., GET, POST).
|
| 175 |
+
url (str): The URL to send the request to.
|
| 176 |
+
headers (dict): The headers to include in the request.
|
| 177 |
+
timeout (float, optional): The request timeout. Defaults to 10.0.
|
| 178 |
+
**kwargs: Additional arguments to pass to the request.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
httpx.Response: The HTTP response.
|
| 182 |
+
|
| 183 |
+
Raises:
|
| 184 |
+
DownloadError: If the request fails after retries.
|
| 185 |
+
"""
|
| 186 |
+
async with httpx.AsyncClient(follow_redirects=True, timeout=timeout, proxy=settings.proxy_url) as client:
|
| 187 |
+
try:
|
| 188 |
+
response = await fetch_with_retry(client, method, url, headers, **kwargs)
|
| 189 |
+
return response
|
| 190 |
+
except DownloadError as e:
|
| 191 |
+
logger.error(f"Failed to download file: {e}")
|
| 192 |
+
raise
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def encode_mediaflow_proxy_url(
|
| 196 |
+
mediaflow_proxy_url: str,
|
| 197 |
+
endpoint: str | None = None,
|
| 198 |
+
destination_url: str | None = None,
|
| 199 |
+
query_params: dict | None = None,
|
| 200 |
+
request_headers: dict | None = None,
|
| 201 |
+
) -> str:
|
| 202 |
+
"""
|
| 203 |
+
Encodes a MediaFlow proxy URL with query parameters and headers.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
mediaflow_proxy_url (str): The base MediaFlow proxy URL.
|
| 207 |
+
endpoint (str, optional): The endpoint to append to the base URL. Defaults to None.
|
| 208 |
+
destination_url (str, optional): The destination URL to include in the query parameters. Defaults to None.
|
| 209 |
+
query_params (dict, optional): Additional query parameters to include. Defaults to None.
|
| 210 |
+
request_headers (dict, optional): Headers to include as query parameters. Defaults to None.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
str: The encoded MediaFlow proxy URL.
|
| 214 |
+
"""
|
| 215 |
+
query_params = query_params or {}
|
| 216 |
+
if destination_url is not None:
|
| 217 |
+
query_params["d"] = destination_url
|
| 218 |
+
|
| 219 |
+
# Add headers if provided
|
| 220 |
+
if request_headers:
|
| 221 |
+
query_params.update(
|
| 222 |
+
{key if key.startswith("h_") else f"h_{key}": value for key, value in request_headers.items()}
|
| 223 |
+
)
|
| 224 |
+
# Encode the query parameters
|
| 225 |
+
encoded_params = parse.urlencode(query_params, quote_via=parse.quote)
|
| 226 |
+
|
| 227 |
+
# Construct the full URL
|
| 228 |
+
if endpoint is None:
|
| 229 |
+
return f"{mediaflow_proxy_url}?{encoded_params}"
|
| 230 |
+
|
| 231 |
+
base_url = parse.urljoin(mediaflow_proxy_url, endpoint)
|
| 232 |
+
return f"{base_url}?{encoded_params}"
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def get_original_scheme(request: Request) -> str:
|
| 236 |
+
"""
|
| 237 |
+
Determines the original scheme (http or https) of the request.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
request (Request): The incoming HTTP request.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
str: The original scheme ('http' or 'https')
|
| 244 |
+
"""
|
| 245 |
+
# Check the X-Forwarded-Proto header first
|
| 246 |
+
forwarded_proto = request.headers.get("X-Forwarded-Proto")
|
| 247 |
+
if forwarded_proto:
|
| 248 |
+
return forwarded_proto
|
| 249 |
+
|
| 250 |
+
# Check if the request is secure
|
| 251 |
+
if request.url.scheme == "https" or request.headers.get("X-Forwarded-Ssl") == "on":
|
| 252 |
+
return "https"
|
| 253 |
+
|
| 254 |
+
# Check for other common headers that might indicate HTTPS
|
| 255 |
+
if (
|
| 256 |
+
request.headers.get("X-Forwarded-Ssl") == "on"
|
| 257 |
+
or request.headers.get("X-Forwarded-Protocol") == "https"
|
| 258 |
+
or request.headers.get("X-Url-Scheme") == "https"
|
| 259 |
+
):
|
| 260 |
+
return "https"
|
| 261 |
+
|
| 262 |
+
# Default to http if no indicators of https are found
|
| 263 |
+
return "http"
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_proxy_headers(request: Request) -> dict:
|
| 267 |
+
"""
|
| 268 |
+
Extracts proxy headers from the request query parameters.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
request (Request): The incoming HTTP request.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
dict: A dictionary of proxy headers.
|
| 275 |
+
"""
|
| 276 |
+
request_headers = {k: v for k, v in request.headers.items() if k in SUPPORTED_REQUEST_HEADERS}
|
| 277 |
+
request_headers.update({k[2:].lower(): v for k, v in request.query_params.items() if k.startswith("h_")})
|
| 278 |
+
return request_headers
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class EnhancedStreamingResponse(Response):
|
| 282 |
+
body_iterator: typing.AsyncIterable[typing.Any]
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
content: typing.Union[typing.AsyncIterable[typing.Any], typing.Iterable[typing.Any]],
|
| 287 |
+
status_code: int = 200,
|
| 288 |
+
headers: typing.Optional[typing.Mapping[str, str]] = None,
|
| 289 |
+
media_type: typing.Optional[str] = None,
|
| 290 |
+
background: typing.Optional[BackgroundTask] = None,
|
| 291 |
+
) -> None:
|
| 292 |
+
if isinstance(content, typing.AsyncIterable):
|
| 293 |
+
self.body_iterator = content
|
| 294 |
+
else:
|
| 295 |
+
self.body_iterator = iterate_in_threadpool(content)
|
| 296 |
+
self.status_code = status_code
|
| 297 |
+
self.media_type = self.media_type if media_type is None else media_type
|
| 298 |
+
self.background = background
|
| 299 |
+
self.init_headers(headers)
|
| 300 |
+
|
| 301 |
+
@staticmethod
|
| 302 |
+
async def listen_for_disconnect(receive: Receive) -> None:
|
| 303 |
+
try:
|
| 304 |
+
while True:
|
| 305 |
+
message = await receive()
|
| 306 |
+
if message["type"] == "http.disconnect":
|
| 307 |
+
logger.debug("Client disconnected")
|
| 308 |
+
break
|
| 309 |
+
except Exception as e:
|
| 310 |
+
logger.error(f"Error in listen_for_disconnect: {str(e)}")
|
| 311 |
+
|
| 312 |
+
async def stream_response(self, send: Send) -> None:
|
| 313 |
+
try:
|
| 314 |
+
await send(
|
| 315 |
+
{
|
| 316 |
+
"type": "http.response.start",
|
| 317 |
+
"status": self.status_code,
|
| 318 |
+
"headers": self.raw_headers,
|
| 319 |
+
}
|
| 320 |
+
)
|
| 321 |
+
async for chunk in self.body_iterator:
|
| 322 |
+
if not isinstance(chunk, (bytes, memoryview)):
|
| 323 |
+
chunk = chunk.encode(self.charset)
|
| 324 |
+
try:
|
| 325 |
+
await send({"type": "http.response.body", "body": chunk, "more_body": True})
|
| 326 |
+
except (ConnectionResetError, anyio.BrokenResourceError):
|
| 327 |
+
logger.info("Client disconnected during streaming")
|
| 328 |
+
return
|
| 329 |
+
|
| 330 |
+
await send({"type": "http.response.body", "body": b"", "more_body": False})
|
| 331 |
+
except Exception as e:
|
| 332 |
+
logger.error(f"Error in stream_response: {str(e)}")
|
| 333 |
+
|
| 334 |
+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
| 335 |
+
async with anyio.create_task_group() as task_group:
|
| 336 |
+
|
| 337 |
+
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
|
| 338 |
+
try:
|
| 339 |
+
await func()
|
| 340 |
+
except ExceptionGroup as e:
|
| 341 |
+
if not any(isinstance(exc, anyio.get_cancelled_exc_class()) for exc in e.exceptions):
|
| 342 |
+
logger.exception("Error in streaming task")
|
| 343 |
+
raise
|
| 344 |
+
except Exception as e:
|
| 345 |
+
if not isinstance(e, anyio.get_cancelled_exc_class()):
|
| 346 |
+
logger.exception("Error in streaming task")
|
| 347 |
+
raise
|
| 348 |
+
finally:
|
| 349 |
+
task_group.cancel_scope.cancel()
|
| 350 |
+
|
| 351 |
+
task_group.start_soon(wrap, partial(self.stream_response, send))
|
| 352 |
+
await wrap(partial(self.listen_for_disconnect, receive))
|
| 353 |
+
|
| 354 |
+
if self.background is not None:
|
| 355 |
+
await self.background()
|
mediaflow_proxy/utils/m3u8_processor.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from urllib import parse
|
| 3 |
+
|
| 4 |
+
from pydantic import HttpUrl
|
| 5 |
+
|
| 6 |
+
from mediaflow_proxy.utils.http_utils import encode_mediaflow_proxy_url, get_original_scheme
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class M3U8Processor:
|
| 10 |
+
def __init__(self, request, key_url: HttpUrl = None):
|
| 11 |
+
"""
|
| 12 |
+
Initializes the M3U8Processor with the request and URL prefix.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
request (Request): The incoming HTTP request.
|
| 16 |
+
key_url (HttpUrl, optional): The URL of the key server. Defaults to None.
|
| 17 |
+
"""
|
| 18 |
+
self.request = request
|
| 19 |
+
self.key_url = key_url
|
| 20 |
+
self.mediaflow_proxy_url = str(request.url_for("hls_stream_proxy").replace(scheme=get_original_scheme(request)))
|
| 21 |
+
|
| 22 |
+
async def process_m3u8(self, content: str, base_url: str) -> str:
|
| 23 |
+
"""
|
| 24 |
+
Processes the m3u8 content, proxying URLs and handling key lines.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
content (str): The m3u8 content to process.
|
| 28 |
+
base_url (str): The base URL to resolve relative URLs.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
str: The processed m3u8 content.
|
| 32 |
+
"""
|
| 33 |
+
lines = content.splitlines()
|
| 34 |
+
processed_lines = []
|
| 35 |
+
for line in lines:
|
| 36 |
+
if "URI=" in line:
|
| 37 |
+
processed_lines.append(await self.process_key_line(line, base_url))
|
| 38 |
+
elif not line.startswith("#") and line.strip():
|
| 39 |
+
processed_lines.append(await self.proxy_url(line, base_url))
|
| 40 |
+
else:
|
| 41 |
+
processed_lines.append(line)
|
| 42 |
+
return "\n".join(processed_lines)
|
| 43 |
+
|
| 44 |
+
async def process_key_line(self, line: str, base_url: str) -> str:
|
| 45 |
+
"""
|
| 46 |
+
Processes a key line in the m3u8 content, proxying the URI.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
line (str): The key line to process.
|
| 50 |
+
base_url (str): The base URL to resolve relative URLs.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
str: The processed key line.
|
| 54 |
+
"""
|
| 55 |
+
uri_match = re.search(r'URI="([^"]+)"', line)
|
| 56 |
+
if uri_match:
|
| 57 |
+
original_uri = uri_match.group(1)
|
| 58 |
+
uri = parse.urlparse(original_uri)
|
| 59 |
+
if self.key_url:
|
| 60 |
+
uri = uri._replace(scheme=self.key_url.scheme, netloc=self.key_url.host)
|
| 61 |
+
new_uri = await self.proxy_url(uri.geturl(), base_url)
|
| 62 |
+
line = line.replace(f'URI="{original_uri}"', f'URI="{new_uri}"')
|
| 63 |
+
return line
|
| 64 |
+
|
| 65 |
+
async def proxy_url(self, url: str, base_url: str) -> str:
|
| 66 |
+
"""
|
| 67 |
+
Proxies a URL, encoding it with the MediaFlow proxy URL.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
url (str): The URL to proxy.
|
| 71 |
+
base_url (str): The base URL to resolve relative URLs.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
str: The proxied URL.
|
| 75 |
+
"""
|
| 76 |
+
full_url = parse.urljoin(base_url, url)
|
| 77 |
+
|
| 78 |
+
return encode_mediaflow_proxy_url(
|
| 79 |
+
self.mediaflow_proxy_url,
|
| 80 |
+
"",
|
| 81 |
+
full_url,
|
| 82 |
+
query_params=dict(self.request.query_params),
|
| 83 |
+
)
|
mediaflow_proxy/utils/mpd_utils.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import re
|
| 4 |
+
from datetime import datetime, timedelta, timezone
|
| 5 |
+
from typing import List, Dict
|
| 6 |
+
from urllib.parse import urljoin
|
| 7 |
+
|
| 8 |
+
import xmltodict
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_mpd(mpd_content: str | bytes) -> dict:
|
| 14 |
+
"""
|
| 15 |
+
Parses the MPD content into a dictionary.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
mpd_content (str | bytes): The MPD content to parse.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
dict: The parsed MPD content as a dictionary.
|
| 22 |
+
"""
|
| 23 |
+
return xmltodict.parse(mpd_content)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def parse_mpd_dict(
|
| 27 |
+
mpd_dict: dict, mpd_url: str, parse_drm: bool = True, parse_segment_profile_id: str | None = None
|
| 28 |
+
) -> dict:
|
| 29 |
+
"""
|
| 30 |
+
Parses the MPD dictionary and extracts relevant information.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
mpd_dict (dict): The MPD content as a dictionary.
|
| 34 |
+
mpd_url (str): The URL of the MPD manifest.
|
| 35 |
+
parse_drm (bool, optional): Whether to parse DRM information. Defaults to True.
|
| 36 |
+
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
dict: The parsed MPD information including profiles and DRM info.
|
| 40 |
+
|
| 41 |
+
This function processes the MPD dictionary to extract profiles, DRM information, and other relevant data.
|
| 42 |
+
It handles both live and static MPD manifests.
|
| 43 |
+
"""
|
| 44 |
+
profiles = []
|
| 45 |
+
parsed_dict = {}
|
| 46 |
+
source = "/".join(mpd_url.split("/")[:-1])
|
| 47 |
+
|
| 48 |
+
is_live = mpd_dict["MPD"].get("@type", "static").lower() == "dynamic"
|
| 49 |
+
parsed_dict["isLive"] = is_live
|
| 50 |
+
|
| 51 |
+
media_presentation_duration = mpd_dict["MPD"].get("@mediaPresentationDuration")
|
| 52 |
+
|
| 53 |
+
# Parse additional MPD attributes for live streams
|
| 54 |
+
if is_live:
|
| 55 |
+
parsed_dict["minimumUpdatePeriod"] = parse_duration(mpd_dict["MPD"].get("@minimumUpdatePeriod", "PT0S"))
|
| 56 |
+
parsed_dict["timeShiftBufferDepth"] = parse_duration(mpd_dict["MPD"].get("@timeShiftBufferDepth", "PT2M"))
|
| 57 |
+
parsed_dict["availabilityStartTime"] = datetime.fromisoformat(
|
| 58 |
+
mpd_dict["MPD"]["@availabilityStartTime"].replace("Z", "+00:00")
|
| 59 |
+
)
|
| 60 |
+
parsed_dict["publishTime"] = datetime.fromisoformat(
|
| 61 |
+
mpd_dict["MPD"].get("@publishTime", "").replace("Z", "+00:00")
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
periods = mpd_dict["MPD"]["Period"]
|
| 65 |
+
periods = periods if isinstance(periods, list) else [periods]
|
| 66 |
+
|
| 67 |
+
for period in periods:
|
| 68 |
+
parsed_dict["PeriodStart"] = parse_duration(period.get("@start", "PT0S"))
|
| 69 |
+
for adaptation in period["AdaptationSet"]:
|
| 70 |
+
representations = adaptation["Representation"]
|
| 71 |
+
representations = representations if isinstance(representations, list) else [representations]
|
| 72 |
+
|
| 73 |
+
for representation in representations:
|
| 74 |
+
profile = parse_representation(
|
| 75 |
+
parsed_dict,
|
| 76 |
+
representation,
|
| 77 |
+
adaptation,
|
| 78 |
+
source,
|
| 79 |
+
media_presentation_duration,
|
| 80 |
+
parse_segment_profile_id,
|
| 81 |
+
)
|
| 82 |
+
if profile:
|
| 83 |
+
profiles.append(profile)
|
| 84 |
+
parsed_dict["profiles"] = profiles
|
| 85 |
+
|
| 86 |
+
if parse_drm:
|
| 87 |
+
drm_info = extract_drm_info(periods, mpd_url)
|
| 88 |
+
else:
|
| 89 |
+
drm_info = {}
|
| 90 |
+
parsed_dict["drmInfo"] = drm_info
|
| 91 |
+
|
| 92 |
+
return parsed_dict
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def pad_base64(encoded_key_id):
|
| 96 |
+
"""
|
| 97 |
+
Pads a base64 encoded key ID to make its length a multiple of 4.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
encoded_key_id (str): The base64 encoded key ID.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
str: The padded base64 encoded key ID.
|
| 104 |
+
"""
|
| 105 |
+
return encoded_key_id + "=" * (4 - len(encoded_key_id) % 4)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def extract_drm_info(periods: List[Dict], mpd_url: str) -> Dict:
|
| 109 |
+
"""
|
| 110 |
+
Extracts DRM information from the MPD periods.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
periods (List[Dict]): The list of periods in the MPD.
|
| 114 |
+
mpd_url (str): The URL of the MPD manifest.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Dict: The extracted DRM information.
|
| 118 |
+
|
| 119 |
+
This function processes the ContentProtection elements in the MPD to extract DRM system information,
|
| 120 |
+
such as ClearKey, Widevine, and PlayReady.
|
| 121 |
+
"""
|
| 122 |
+
drm_info = {"isDrmProtected": False}
|
| 123 |
+
|
| 124 |
+
for period in periods:
|
| 125 |
+
adaptation_sets: list[dict] | dict = period.get("AdaptationSet", [])
|
| 126 |
+
if not isinstance(adaptation_sets, list):
|
| 127 |
+
adaptation_sets = [adaptation_sets]
|
| 128 |
+
|
| 129 |
+
for adaptation_set in adaptation_sets:
|
| 130 |
+
# Check ContentProtection in AdaptationSet
|
| 131 |
+
process_content_protection(adaptation_set.get("ContentProtection", []), drm_info)
|
| 132 |
+
|
| 133 |
+
# Check ContentProtection inside each Representation
|
| 134 |
+
representations: list[dict] | dict = adaptation_set.get("Representation", [])
|
| 135 |
+
if not isinstance(representations, list):
|
| 136 |
+
representations = [representations]
|
| 137 |
+
|
| 138 |
+
for representation in representations:
|
| 139 |
+
process_content_protection(representation.get("ContentProtection", []), drm_info)
|
| 140 |
+
|
| 141 |
+
# If we have a license acquisition URL, make sure it's absolute
|
| 142 |
+
if "laUrl" in drm_info and not drm_info["laUrl"].startswith(("http://", "https://")):
|
| 143 |
+
drm_info["laUrl"] = urljoin(mpd_url, drm_info["laUrl"])
|
| 144 |
+
|
| 145 |
+
return drm_info
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def process_content_protection(content_protection: list[dict] | dict, drm_info: dict):
|
| 149 |
+
"""
|
| 150 |
+
Processes the ContentProtection elements to extract DRM information.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
content_protection (list[dict] | dict): The ContentProtection elements.
|
| 154 |
+
drm_info (dict): The dictionary to store DRM information.
|
| 155 |
+
|
| 156 |
+
This function updates the drm_info dictionary with DRM system information found in the ContentProtection elements.
|
| 157 |
+
"""
|
| 158 |
+
if not isinstance(content_protection, list):
|
| 159 |
+
content_protection = [content_protection]
|
| 160 |
+
|
| 161 |
+
for protection in content_protection:
|
| 162 |
+
drm_info["isDrmProtected"] = True
|
| 163 |
+
scheme_id_uri = protection.get("@schemeIdUri", "").lower()
|
| 164 |
+
|
| 165 |
+
if "clearkey" in scheme_id_uri:
|
| 166 |
+
drm_info["drmSystem"] = "clearkey"
|
| 167 |
+
if "clearkey:Laurl" in protection:
|
| 168 |
+
la_url = protection["clearkey:Laurl"].get("#text")
|
| 169 |
+
if la_url and "laUrl" not in drm_info:
|
| 170 |
+
drm_info["laUrl"] = la_url
|
| 171 |
+
|
| 172 |
+
elif "widevine" in scheme_id_uri or "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed" in scheme_id_uri:
|
| 173 |
+
drm_info["drmSystem"] = "widevine"
|
| 174 |
+
pssh = protection.get("cenc:pssh", {}).get("#text")
|
| 175 |
+
if pssh:
|
| 176 |
+
drm_info["pssh"] = pssh
|
| 177 |
+
|
| 178 |
+
elif "playready" in scheme_id_uri or "9a04f079-9840-4286-ab92-e65be0885f95" in scheme_id_uri:
|
| 179 |
+
drm_info["drmSystem"] = "playready"
|
| 180 |
+
|
| 181 |
+
if "@cenc:default_KID" in protection:
|
| 182 |
+
key_id = protection["@cenc:default_KID"].replace("-", "")
|
| 183 |
+
if "keyId" not in drm_info:
|
| 184 |
+
drm_info["keyId"] = key_id
|
| 185 |
+
|
| 186 |
+
if "ms:laurl" in protection:
|
| 187 |
+
la_url = protection["ms:laurl"].get("@licenseUrl")
|
| 188 |
+
if la_url and "laUrl" not in drm_info:
|
| 189 |
+
drm_info["laUrl"] = la_url
|
| 190 |
+
|
| 191 |
+
return drm_info
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def parse_representation(
|
| 195 |
+
parsed_dict: dict,
|
| 196 |
+
representation: dict,
|
| 197 |
+
adaptation: dict,
|
| 198 |
+
source: str,
|
| 199 |
+
media_presentation_duration: str,
|
| 200 |
+
parse_segment_profile_id: str | None,
|
| 201 |
+
) -> dict | None:
|
| 202 |
+
"""
|
| 203 |
+
Parses a representation and extracts profile information.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
parsed_dict (dict): The parsed MPD data.
|
| 207 |
+
representation (dict): The representation data.
|
| 208 |
+
adaptation (dict): The adaptation set data.
|
| 209 |
+
source (str): The source URL.
|
| 210 |
+
media_presentation_duration (str): The media presentation duration.
|
| 211 |
+
parse_segment_profile_id (str, optional): The profile ID to parse segments for. Defaults to None.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
dict | None: The parsed profile information or None if not applicable.
|
| 215 |
+
"""
|
| 216 |
+
mime_type = _get_key(adaptation, representation, "@mimeType") or (
|
| 217 |
+
"video/mp4" if "avc" in representation["@codecs"] else "audio/mp4"
|
| 218 |
+
)
|
| 219 |
+
if "video" not in mime_type and "audio" not in mime_type:
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
profile = {
|
| 223 |
+
"id": representation.get("@id") or adaptation.get("@id"),
|
| 224 |
+
"mimeType": mime_type,
|
| 225 |
+
"lang": representation.get("@lang") or adaptation.get("@lang"),
|
| 226 |
+
"codecs": representation.get("@codecs") or adaptation.get("@codecs"),
|
| 227 |
+
"bandwidth": int(representation.get("@bandwidth") or adaptation.get("@bandwidth")),
|
| 228 |
+
"startWithSAP": (_get_key(adaptation, representation, "@startWithSAP") or "1") == "1",
|
| 229 |
+
"mediaPresentationDuration": media_presentation_duration,
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
if "audio" in profile["mimeType"]:
|
| 233 |
+
profile["audioSamplingRate"] = representation.get("@audioSamplingRate") or adaptation.get("@audioSamplingRate")
|
| 234 |
+
profile["channels"] = representation.get("AudioChannelConfiguration", {}).get("@value", "2")
|
| 235 |
+
else:
|
| 236 |
+
profile["width"] = int(representation["@width"])
|
| 237 |
+
profile["height"] = int(representation["@height"])
|
| 238 |
+
frame_rate = representation.get("@frameRate") or adaptation.get("@maxFrameRate") or "30000/1001"
|
| 239 |
+
frame_rate = frame_rate if "/" in frame_rate else f"{frame_rate}/1"
|
| 240 |
+
profile["frameRate"] = round(int(frame_rate.split("/")[0]) / int(frame_rate.split("/")[1]), 3)
|
| 241 |
+
profile["sar"] = representation.get("@sar", "1:1")
|
| 242 |
+
|
| 243 |
+
if parse_segment_profile_id is None or profile["id"] != parse_segment_profile_id:
|
| 244 |
+
return profile
|
| 245 |
+
|
| 246 |
+
item = adaptation.get("SegmentTemplate") or representation.get("SegmentTemplate")
|
| 247 |
+
if item:
|
| 248 |
+
profile["segments"] = parse_segment_template(parsed_dict, item, profile, source)
|
| 249 |
+
else:
|
| 250 |
+
profile["segments"] = parse_segment_base(representation, source)
|
| 251 |
+
|
| 252 |
+
return profile
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def _get_key(adaptation: dict, representation: dict, key: str) -> str | None:
|
| 256 |
+
"""
|
| 257 |
+
Retrieves a key from the representation or adaptation set.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
adaptation (dict): The adaptation set data.
|
| 261 |
+
representation (dict): The representation data.
|
| 262 |
+
key (str): The key to retrieve.
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
str | None: The value of the key or None if not found.
|
| 266 |
+
"""
|
| 267 |
+
return representation.get(key, adaptation.get(key, None))
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def parse_segment_template(parsed_dict: dict, item: dict, profile: dict, source: str) -> List[Dict]:
|
| 271 |
+
"""
|
| 272 |
+
Parses a segment template and extracts segment information.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
parsed_dict (dict): The parsed MPD data.
|
| 276 |
+
item (dict): The segment template data.
|
| 277 |
+
profile (dict): The profile information.
|
| 278 |
+
source (str): The source URL.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
List[Dict]: The list of parsed segments.
|
| 282 |
+
"""
|
| 283 |
+
segments = []
|
| 284 |
+
timescale = int(item.get("@timescale", 1))
|
| 285 |
+
|
| 286 |
+
# Initialization
|
| 287 |
+
if "@initialization" in item:
|
| 288 |
+
media = item["@initialization"]
|
| 289 |
+
media = media.replace("$RepresentationID$", profile["id"])
|
| 290 |
+
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
| 291 |
+
if not media.startswith("http"):
|
| 292 |
+
media = f"{source}/{media}"
|
| 293 |
+
profile["initUrl"] = media
|
| 294 |
+
|
| 295 |
+
# Segments
|
| 296 |
+
if "SegmentTimeline" in item:
|
| 297 |
+
segments.extend(parse_segment_timeline(parsed_dict, item, profile, source, timescale))
|
| 298 |
+
elif "@duration" in item:
|
| 299 |
+
segments.extend(parse_segment_duration(parsed_dict, item, profile, source, timescale))
|
| 300 |
+
|
| 301 |
+
return segments
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def parse_segment_timeline(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
| 305 |
+
"""
|
| 306 |
+
Parses a segment timeline and extracts segment information.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
parsed_dict (dict): The parsed MPD data.
|
| 310 |
+
item (dict): The segment timeline data.
|
| 311 |
+
profile (dict): The profile information.
|
| 312 |
+
source (str): The source URL.
|
| 313 |
+
timescale (int): The timescale for the segments.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List[Dict]: The list of parsed segments.
|
| 317 |
+
"""
|
| 318 |
+
timelines = item["SegmentTimeline"]["S"]
|
| 319 |
+
timelines = timelines if isinstance(timelines, list) else [timelines]
|
| 320 |
+
period_start = parsed_dict["availabilityStartTime"] + timedelta(seconds=parsed_dict.get("PeriodStart", 0))
|
| 321 |
+
presentation_time_offset = int(item.get("@presentationTimeOffset", 0))
|
| 322 |
+
start_number = int(item.get("@startNumber", 1))
|
| 323 |
+
|
| 324 |
+
segments = [
|
| 325 |
+
create_segment_data(timeline, item, profile, source, timescale)
|
| 326 |
+
for timeline in preprocess_timeline(timelines, start_number, period_start, presentation_time_offset, timescale)
|
| 327 |
+
]
|
| 328 |
+
return segments
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def preprocess_timeline(
|
| 332 |
+
timelines: List[Dict], start_number: int, period_start: datetime, presentation_time_offset: int, timescale: int
|
| 333 |
+
) -> List[Dict]:
|
| 334 |
+
"""
|
| 335 |
+
Preprocesses the segment timeline data.
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
timelines (List[Dict]): The list of timeline segments.
|
| 339 |
+
start_number (int): The starting segment number.
|
| 340 |
+
period_start (datetime): The start time of the period.
|
| 341 |
+
presentation_time_offset (int): The presentation time offset.
|
| 342 |
+
timescale (int): The timescale for the segments.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
List[Dict]: The list of preprocessed timeline segments.
|
| 346 |
+
"""
|
| 347 |
+
processed_data = []
|
| 348 |
+
current_time = 0
|
| 349 |
+
for timeline in timelines:
|
| 350 |
+
repeat = int(timeline.get("@r", 0))
|
| 351 |
+
duration = int(timeline["@d"])
|
| 352 |
+
start_time = int(timeline.get("@t", current_time))
|
| 353 |
+
|
| 354 |
+
for _ in range(repeat + 1):
|
| 355 |
+
segment_start_time = period_start + timedelta(seconds=(start_time - presentation_time_offset) / timescale)
|
| 356 |
+
segment_end_time = segment_start_time + timedelta(seconds=duration / timescale)
|
| 357 |
+
processed_data.append(
|
| 358 |
+
{
|
| 359 |
+
"number": start_number,
|
| 360 |
+
"start_time": segment_start_time,
|
| 361 |
+
"end_time": segment_end_time,
|
| 362 |
+
"duration": duration,
|
| 363 |
+
"time": start_time,
|
| 364 |
+
}
|
| 365 |
+
)
|
| 366 |
+
start_time += duration
|
| 367 |
+
start_number += 1
|
| 368 |
+
|
| 369 |
+
current_time = start_time
|
| 370 |
+
|
| 371 |
+
return processed_data
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def parse_segment_duration(parsed_dict: dict, item: dict, profile: dict, source: str, timescale: int) -> List[Dict]:
|
| 375 |
+
"""
|
| 376 |
+
Parses segment duration and extracts segment information.
|
| 377 |
+
This is used for static or live MPD manifests.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
parsed_dict (dict): The parsed MPD data.
|
| 381 |
+
item (dict): The segment duration data.
|
| 382 |
+
profile (dict): The profile information.
|
| 383 |
+
source (str): The source URL.
|
| 384 |
+
timescale (int): The timescale for the segments.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
List[Dict]: The list of parsed segments.
|
| 388 |
+
"""
|
| 389 |
+
duration = int(item["@duration"])
|
| 390 |
+
start_number = int(item.get("@startNumber", 1))
|
| 391 |
+
segment_duration_sec = duration / timescale
|
| 392 |
+
|
| 393 |
+
if parsed_dict["isLive"]:
|
| 394 |
+
segments = generate_live_segments(parsed_dict, segment_duration_sec, start_number)
|
| 395 |
+
else:
|
| 396 |
+
segments = generate_vod_segments(profile, duration, timescale, start_number)
|
| 397 |
+
|
| 398 |
+
return [create_segment_data(seg, item, profile, source, timescale) for seg in segments]
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def generate_live_segments(parsed_dict: dict, segment_duration_sec: float, start_number: int) -> List[Dict]:
|
| 402 |
+
"""
|
| 403 |
+
Generates live segments based on the segment duration and start number.
|
| 404 |
+
This is used for live MPD manifests.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
parsed_dict (dict): The parsed MPD data.
|
| 408 |
+
segment_duration_sec (float): The segment duration in seconds.
|
| 409 |
+
start_number (int): The starting segment number.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
List[Dict]: The list of generated live segments.
|
| 413 |
+
"""
|
| 414 |
+
time_shift_buffer_depth = timedelta(seconds=parsed_dict.get("timeShiftBufferDepth", 60))
|
| 415 |
+
segment_count = math.ceil(time_shift_buffer_depth.total_seconds() / segment_duration_sec)
|
| 416 |
+
current_time = datetime.now(tz=timezone.utc)
|
| 417 |
+
earliest_segment_number = max(
|
| 418 |
+
start_number
|
| 419 |
+
+ math.floor((current_time - parsed_dict["availabilityStartTime"]).total_seconds() / segment_duration_sec)
|
| 420 |
+
- segment_count,
|
| 421 |
+
start_number,
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
return [
|
| 425 |
+
{
|
| 426 |
+
"number": number,
|
| 427 |
+
"start_time": parsed_dict["availabilityStartTime"]
|
| 428 |
+
+ timedelta(seconds=(number - start_number) * segment_duration_sec),
|
| 429 |
+
"duration": segment_duration_sec,
|
| 430 |
+
}
|
| 431 |
+
for number in range(earliest_segment_number, earliest_segment_number + segment_count)
|
| 432 |
+
]
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def generate_vod_segments(profile: dict, duration: int, timescale: int, start_number: int) -> List[Dict]:
|
| 436 |
+
"""
|
| 437 |
+
Generates VOD segments based on the segment duration and start number.
|
| 438 |
+
This is used for static MPD manifests.
|
| 439 |
+
|
| 440 |
+
Args:
|
| 441 |
+
profile (dict): The profile information.
|
| 442 |
+
duration (int): The segment duration.
|
| 443 |
+
timescale (int): The timescale for the segments.
|
| 444 |
+
start_number (int): The starting segment number.
|
| 445 |
+
|
| 446 |
+
Returns:
|
| 447 |
+
List[Dict]: The list of generated VOD segments.
|
| 448 |
+
"""
|
| 449 |
+
total_duration = profile.get("mediaPresentationDuration") or 0
|
| 450 |
+
if isinstance(total_duration, str):
|
| 451 |
+
total_duration = parse_duration(total_duration)
|
| 452 |
+
segment_count = math.ceil(total_duration * timescale / duration)
|
| 453 |
+
|
| 454 |
+
return [{"number": start_number + i, "duration": duration / timescale} for i in range(segment_count)]
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def create_segment_data(segment: Dict, item: dict, profile: dict, source: str, timescale: int | None = None) -> Dict:
|
| 458 |
+
"""
|
| 459 |
+
Creates segment data based on the segment information. This includes the segment URL and metadata.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
segment (Dict): The segment information.
|
| 463 |
+
item (dict): The segment template data.
|
| 464 |
+
profile (dict): The profile information.
|
| 465 |
+
source (str): The source URL.
|
| 466 |
+
timescale (int, optional): The timescale for the segments. Defaults to None.
|
| 467 |
+
|
| 468 |
+
Returns:
|
| 469 |
+
Dict: The created segment data.
|
| 470 |
+
"""
|
| 471 |
+
media_template = item["@media"]
|
| 472 |
+
media = media_template.replace("$RepresentationID$", profile["id"])
|
| 473 |
+
media = media.replace("$Number%04d$", f"{segment['number']:04d}")
|
| 474 |
+
media = media.replace("$Number$", str(segment["number"]))
|
| 475 |
+
media = media.replace("$Bandwidth$", str(profile["bandwidth"]))
|
| 476 |
+
|
| 477 |
+
if "time" in segment and timescale is not None:
|
| 478 |
+
media = media.replace("$Time$", str(int(segment["time"] * timescale)))
|
| 479 |
+
|
| 480 |
+
if not media.startswith("http"):
|
| 481 |
+
media = f"{source}/{media}"
|
| 482 |
+
|
| 483 |
+
segment_data = {
|
| 484 |
+
"type": "segment",
|
| 485 |
+
"media": media,
|
| 486 |
+
"number": segment["number"],
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
if "start_time" in segment and "end_time" in segment:
|
| 490 |
+
segment_data.update(
|
| 491 |
+
{
|
| 492 |
+
"start_time": segment["start_time"],
|
| 493 |
+
"end_time": segment["end_time"],
|
| 494 |
+
"extinf": (segment["end_time"] - segment["start_time"]).total_seconds(),
|
| 495 |
+
"program_date_time": segment["start_time"].isoformat() + "Z",
|
| 496 |
+
}
|
| 497 |
+
)
|
| 498 |
+
elif "start_time" in segment and "duration" in segment:
|
| 499 |
+
duration = segment["duration"]
|
| 500 |
+
segment_data.update(
|
| 501 |
+
{
|
| 502 |
+
"start_time": segment["start_time"],
|
| 503 |
+
"end_time": segment["start_time"] + timedelta(seconds=duration),
|
| 504 |
+
"extinf": duration,
|
| 505 |
+
"program_date_time": segment["start_time"].isoformat() + "Z",
|
| 506 |
+
}
|
| 507 |
+
)
|
| 508 |
+
elif "duration" in segment:
|
| 509 |
+
segment_data["extinf"] = segment["duration"]
|
| 510 |
+
|
| 511 |
+
return segment_data
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def parse_segment_base(representation: dict, source: str) -> List[Dict]:
|
| 515 |
+
"""
|
| 516 |
+
Parses segment base information and extracts segment data. This is used for single-segment representations.
|
| 517 |
+
|
| 518 |
+
Args:
|
| 519 |
+
representation (dict): The representation data.
|
| 520 |
+
source (str): The source URL.
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
List[Dict]: The list of parsed segments.
|
| 524 |
+
"""
|
| 525 |
+
segment = representation["SegmentBase"]
|
| 526 |
+
start, end = map(int, segment["@indexRange"].split("-"))
|
| 527 |
+
if "Initialization" in segment:
|
| 528 |
+
start, _ = map(int, segment["Initialization"]["@range"].split("-"))
|
| 529 |
+
|
| 530 |
+
return [
|
| 531 |
+
{
|
| 532 |
+
"type": "segment",
|
| 533 |
+
"range": f"{start}-{end}",
|
| 534 |
+
"media": f"{source}/{representation['BaseURL']}",
|
| 535 |
+
}
|
| 536 |
+
]
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def parse_duration(duration_str: str) -> float:
|
| 540 |
+
"""
|
| 541 |
+
Parses a duration ISO 8601 string into seconds.
|
| 542 |
+
|
| 543 |
+
Args:
|
| 544 |
+
duration_str (str): The duration string to parse.
|
| 545 |
+
|
| 546 |
+
Returns:
|
| 547 |
+
float: The parsed duration in seconds.
|
| 548 |
+
"""
|
| 549 |
+
pattern = re.compile(r"P(?:(\d+)Y)?(?:(\d+)M)?(?:(\d+)D)?T?(?:(\d+)H)?(?:(\d+)M)?(?:(\d+(?:\.\d+)?)S)?")
|
| 550 |
+
match = pattern.match(duration_str)
|
| 551 |
+
if not match:
|
| 552 |
+
raise ValueError(f"Invalid duration format: {duration_str}")
|
| 553 |
+
|
| 554 |
+
years, months, days, hours, minutes, seconds = [float(g) if g else 0 for g in match.groups()]
|
| 555 |
+
return years * 365 * 24 * 3600 + months * 30 * 24 * 3600 + days * 24 * 3600 + hours * 3600 + minutes * 60 + seconds
|