| |
|
|
| import asyncio |
| import logging |
| from typing import AsyncGenerator, Dict, Union |
| from pyrogram import Client, utils, raw |
| from pyrogram.session import Session, Auth |
| from pyrogram.errors import AuthBytesInvalid |
| from pyrogram.file_id import FileId, FileType, ThumbnailSource |
| from WebStreamer.bot import work_loads |
| from .file_properties import get_file_ids |
|
|
|
|
| class ByteStreamer: |
| def __init__(self, client: Client): |
| """A custom class that holds the cache of a specific client and class functions. |
| attributes: |
| client: the client that the cache is for. |
| cached_file_ids: a dict of cached file IDs. |
| cached_file_properties: a dict of cached file properties. |
| |
| functions: |
| generate_file_properties: returns the properties for a media of a specific message contained in Tuple. |
| generate_media_session: returns the media session for the DC that contains the media file. |
| yield_file: yield a file from telegram servers for streaming. |
| |
| This is a modified version of the <https://github.com/eyaadh/megadlbot_oss/blob/master/mega/telegram/utils/custom_download.py> |
| Thanks to Eyaadh <https://github.com/eyaadh> |
| """ |
| self.clean_timer = 30 * 60 |
| self.client: Client = client |
| self.cached_file_ids: Dict[str, FileId] = {} |
| asyncio.create_task(self.clean_cache()) |
|
|
| async def get_file_properties(self, db_id: str, multi_clients) -> FileId: |
| """ |
| Returns the properties of a media of a specific message in a FIleId class. |
| if the properties are cached, then it'll return the cached results. |
| or it'll generate the properties from the Message ID and cache them. |
| """ |
| if not db_id in self.cached_file_ids: |
| logging.debug("Before Calling generate_file_properties") |
| await self.generate_file_properties(db_id, multi_clients) |
| logging.debug("Cached file properties for file with ID %s", db_id) |
| return self.cached_file_ids[db_id] |
|
|
| async def generate_file_properties(self, db_id: str, multi_clients) -> FileId: |
| """ |
| Generates the properties of a media file on a specific message. |
| returns ths properties in a FIleId class. |
| """ |
| logging.debug("Before calling get_file_ids") |
| file_id = await get_file_ids(self.client, db_id, multi_clients) |
| logging.debug("Generated file ID and Unique ID for file with ID %s", db_id) |
| self.cached_file_ids[db_id] = file_id |
| logging.debug("Cached media file with ID %s", db_id) |
| return self.cached_file_ids[db_id] |
|
|
| async def generate_media_session(self, client: Client, file_id: FileId) -> Session: |
| """ |
| Generates the media session for the DC that contains the media file. |
| This is required for getting the bytes from Telegram servers. |
| """ |
|
|
| return await client.get_session(file_id.dc_id, is_media=True) |
|
|
|
|
| @staticmethod |
| async def get_location(file_id: FileId) -> Union[raw.types.InputPhotoFileLocation, |
| raw.types.InputDocumentFileLocation, |
| raw.types.InputPeerPhotoFileLocation,]: |
| """ |
| Returns the file location for the media file. |
| """ |
| file_type = file_id.file_type |
|
|
| if file_type == FileType.CHAT_PHOTO: |
| if file_id.chat_id > 0: |
| peer = raw.types.InputPeerUser( |
| user_id=file_id.chat_id, access_hash=file_id.chat_access_hash |
| ) |
| else: |
| if file_id.chat_access_hash == 0: |
| peer = raw.types.InputPeerChat(chat_id=-file_id.chat_id) |
| else: |
| peer = raw.types.InputPeerChannel( |
| channel_id=utils.get_channel_id(file_id.chat_id), |
| access_hash=file_id.chat_access_hash, |
| ) |
|
|
| location = raw.types.InputPeerPhotoFileLocation( |
| peer=peer, |
| photo_id=file_id.media_id, |
| big=file_id.thumbnail_source == ThumbnailSource.CHAT_PHOTO_BIG |
| ) |
| elif file_type == FileType.PHOTO: |
| location = raw.types.InputPhotoFileLocation( |
| id=file_id.media_id, |
| access_hash=file_id.access_hash, |
| file_reference=file_id.file_reference, |
| thumb_size=file_id.thumbnail_size, |
| ) |
| else: |
| location = raw.types.InputDocumentFileLocation( |
| id=file_id.media_id, |
| access_hash=file_id.access_hash, |
| file_reference=file_id.file_reference, |
| thumb_size=file_id.thumbnail_size, |
| ) |
| return location |
|
|
| async def yield_file( |
| self, |
| file_id: FileId, |
| index: int, |
| offset: int, |
| first_part_cut: int, |
| last_part_cut: int, |
| part_count: int, |
| chunk_size: int, |
| ) -> AsyncGenerator[bytes, None]: |
| """ |
| Custom generator that yields the bytes of the media file concurrently using multiple clients. |
| """ |
| client = self.client |
| work_loads[index] += 1 |
| logging.debug("Starting to yielding file with client %s.", index) |
|
|
| from WebStreamer.bot import multi_clients |
| multi_clients_list = list(multi_clients.values()) |
| |
| sessions = [] |
| for c in multi_clients_list: |
| sessions.append(await self.generate_media_session(c, file_id)) |
|
|
| location = await self.get_location(file_id) |
|
|
| try: |
| concurrency = max(len(sessions), 4) |
| tasks = {} |
| next_part = 1 |
| fetch_part = 1 |
| current_offset = offset |
|
|
| while next_part <= part_count: |
| while fetch_part <= part_count and len(tasks) < concurrency: |
| session = sessions[(fetch_part - 1) % len(sessions)] |
| |
| async def fetch(s, off): |
| try: |
| r = await s.invoke(raw.functions.upload.GetFile(location=location, offset=off, limit=chunk_size)) |
| if isinstance(r, raw.types.upload.File): |
| return r.bytes |
| except Exception as e: |
| logging.error("Chunk fetch error: %s", e) |
| return b"" |
| |
| tasks[fetch_part] = asyncio.create_task(fetch(session, current_offset)) |
| fetch_part += 1 |
| current_offset += chunk_size |
| |
| chunk = await tasks.pop(next_part) |
| if not chunk: |
| break |
| |
| if part_count == 1: |
| yield chunk[first_part_cut:last_part_cut] |
| elif next_part == 1: |
| yield chunk[first_part_cut:] |
| elif next_part == part_count: |
| yield chunk[:last_part_cut] |
| else: |
| yield chunk |
| |
| next_part += 1 |
|
|
| except (TimeoutError, AttributeError): |
| pass |
| finally: |
| logging.debug("Finished yielding file with %s parts.", next_part - 1) |
| work_loads[index] -= 1 |
|
|
|
|
| async def clean_cache(self) -> None: |
| """ |
| function to clean the cache to reduce memory usage |
| """ |
| while True: |
| await asyncio.sleep(self.clean_timer) |
| self.cached_file_ids.clear() |
| logging.debug("Cleaned the cache") |
|
|