Spaces:
Sleeping
Sleeping
| import asyncio | |
| from pyrogram import Client, filters | |
| import pyromod | |
| from pyrogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton | |
| import config | |
| from utils.logger import Logger | |
| from pathlib import Path | |
| from utils.mongo_indexer import ( | |
| index_channel_messages, | |
| index_channel_messages_with_iter, | |
| MongoNotConfiguredError, | |
| ) | |
| from utils.clients import get_client | |
| logger = Logger(__name__) | |
| START_CMD = """π **Welcome To TG Drive's Bot Mode** | |
| You can use this bot to upload files to your TG Drive website directly instead of doing it from website. | |
| π **Commands:** | |
| /set_folder - Set folder for file uploads | |
| /current_folder - Check current folder | |
| π€ **How To Upload Files:** Send a file to this bot and it will be uploaded to your TG Drive website. You can also set a folder for file uploads using /set_folder command. | |
| Read more about [TG Drive's Bot Mode](https://github.com/TechShreyash/TGDrive#tg-drives-bot-mode) | |
| """ | |
| SET_FOLDER_PATH_CACHE = {} # Cache to store folder path for each folder id | |
| DRIVE_DATA = None | |
| BOT_MODE = None | |
| session_cache_path = Path(f"./cache") | |
| session_cache_path.parent.mkdir(parents=True, exist_ok=True) | |
| main_bot = Client( | |
| name="main_bot", | |
| api_id=config.API_ID, | |
| api_hash=config.API_HASH, | |
| bot_token=config.MAIN_BOT_TOKEN, | |
| sleep_threshold=config.SLEEP_THRESHOLD, | |
| workdir=session_cache_path, | |
| ) | |
| async def start_handler(client: Client, message: Message): | |
| await message.reply_text(START_CMD) | |
| # Lock to prevent concurrent indexing | |
| _index_lock = asyncio.Lock() | |
| async def index_handler(client: Client, message: Message): | |
| """ | |
| /index - Index channel messages to MongoDB | |
| Note: Requires STRING_SESSIONS (user accounts) - bots cannot index channels. | |
| Usage: | |
| - /index <channel_id> - Index from channel ID directly | |
| - /index - Then forward a message or send message link | |
| """ | |
| if _index_lock.locked(): | |
| await message.reply_text("β³ Wait until the previous indexing process completes.") | |
| return | |
| # Check if string sessions are available (required for indexing) | |
| from utils.clients import get_client | |
| try: | |
| indexing_client = get_client(premium_required=True) | |
| except RuntimeError: | |
| await message.reply_text( | |
| "β **Indexing requires STRING_SESSIONS (user accounts).**\n\n" | |
| "Bots cannot access full message history. Please configure `STRING_SESSIONS` " | |
| "in your environment variables to enable indexing." | |
| ) | |
| return | |
| try: | |
| # Check if channel ID was provided as command argument | |
| if len(message.command) > 1: | |
| raw_chat_id = message.command[1].strip() | |
| try: | |
| chat_id = int(raw_chat_id) | |
| except ValueError: | |
| chat_id = raw_chat_id | |
| # Verify it's a channel and get info (use string session client) | |
| try: | |
| chat = await indexing_client.get_chat(chat_id) | |
| # Check if it's a channel - handle both enum and string formats | |
| chat_type = getattr(chat, 'type', None) | |
| type_str = str(chat_type).lower() | |
| # Check for channel in various formats: "channel", "chattype.channel", enum name, etc. | |
| if "channel" not in type_str and not getattr(chat, 'is_channel', False): | |
| await message.reply_text(f"β I can only index channels. Got type: {chat_type}") | |
| return | |
| # Ask user to optionally send a message (forward or link) to specify last message ID | |
| # Or they can just send a number for skip count to use auto-detection | |
| i = await message.reply_text( | |
| "π€ **Option 1:** Forward the last message from the channel, or send the last message link.\n" | |
| "π **Option 2:** Just send a number (skip count, default: 0) to auto-detect last message.\n\n" | |
| "_(Send a number or forward/link within 60 seconds)_" | |
| ) | |
| msg_input = await client.listen(chat_id=message.chat.id, user_id=message.from_user.id, timeout=60) | |
| await i.delete() | |
| last_msg_id = None | |
| skip = 0 | |
| # Check if it's a message link | |
| if msg_input.text and msg_input.text.startswith("https://t.me"): | |
| try: | |
| msg_link = msg_input.text.split("/") | |
| last_msg_id = int(msg_link[-1]) | |
| # Verify it's from the same channel | |
| chat_id_from_link = msg_link[-2] | |
| if chat_id_from_link.isnumeric(): | |
| chat_id_from_link = int("-100" + chat_id_from_link) | |
| if chat_id_from_link != chat_id: | |
| await message.reply_text("β Message link is from a different channel!") | |
| return | |
| except Exception: | |
| await message.reply_text("β Invalid message link!") | |
| return | |
| # Check if it's a forwarded message (using new forward_origin API) | |
| elif msg_input.forward_origin: | |
| from pyrogram.types import MessageOriginChannel | |
| if isinstance(msg_input.forward_origin, MessageOriginChannel): | |
| forward_chat = msg_input.forward_origin.chat | |
| forward_chat_id = forward_chat.username if forward_chat else None | |
| if not forward_chat_id: | |
| forward_chat_id = msg_input.forward_origin.chat.id if msg_input.forward_origin.chat else None | |
| if forward_chat_id != chat_id: | |
| await message.reply_text("β Forwarded message is from a different channel!") | |
| return | |
| last_msg_id = msg_input.forward_origin.message_id | |
| else: | |
| await message.reply_text("β Please forward a message from a channel.") | |
| return | |
| # Check if it's just a number (skip count) | |
| elif msg_input.text and msg_input.text.strip().isdigit(): | |
| skip = int(msg_input.text.strip()) | |
| # Auto-detect last message ID using get_chat_history (works with user accounts) | |
| last_msg_id = None | |
| try: | |
| # Get the most recent message to find last_msg_id | |
| async for msg in indexing_client.get_chat_history(chat_id, limit=1): | |
| last_msg_id = msg.id | |
| logger.info(f"Found last message ID: {last_msg_id}") | |
| break | |
| if last_msg_id is None: | |
| await message.reply_text("β Could not find any messages in the channel.") | |
| return | |
| except Exception as e: | |
| logger.warning(f"Could not find last message ID: {e}") | |
| await message.reply_text( | |
| f"β οΈ Could not automatically find the last message ID.\n\n" | |
| f"Error: {e}\n\n" | |
| "Please send a message link or forward a message instead." | |
| ) | |
| return | |
| else: | |
| await message.reply_text("β Please send a number (skip count), message link, or forward a message.") | |
| return | |
| # Confirm indexing | |
| from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton | |
| buttons = [ | |
| [InlineKeyboardButton("β YES", callback_data=f"index#yes#{chat_id}#{last_msg_id}#{skip}")], | |
| [InlineKeyboardButton("β CANCEL", callback_data="index#cancel")] | |
| ] | |
| reply_markup = InlineKeyboardMarkup(buttons) | |
| await message.reply_text( | |
| f"π Do you want to index **{chat.title}** channel?\n\n" | |
| f"Last Message ID: `{last_msg_id}`\n" | |
| f"Skip Count: `{skip}`", | |
| reply_markup=reply_markup | |
| ) | |
| return | |
| except Exception as e: | |
| await message.reply_text(f"β Error accessing channel: {e}") | |
| return | |
| # No channel ID provided, ask for forward/link | |
| i = await message.reply_text("π€ Forward the last message from the channel, or send the last message link.") | |
| msg = await client.listen(chat_id=message.chat.id, user_id=message.from_user.id, timeout=60) | |
| await i.delete() | |
| chat_id = None | |
| last_msg_id = None | |
| # Check if it's a message link | |
| if msg.text and msg.text.startswith("https://t.me"): | |
| try: | |
| msg_link = msg.text.split("/") | |
| last_msg_id = int(msg_link[-1]) | |
| chat_id_str = msg_link[-2] | |
| if chat_id_str.isnumeric(): | |
| chat_id = int("-100" + chat_id_str) | |
| else: | |
| chat_id = chat_id_str | |
| except Exception: | |
| await message.reply_text("β Invalid message link!") | |
| return | |
| # Check if it's a forwarded message (using new forward_origin API) | |
| elif msg.forward_origin: | |
| from pyrogram.types import MessageOriginChannel | |
| if isinstance(msg.forward_origin, MessageOriginChannel): | |
| forward_chat = msg.forward_origin.chat | |
| chat_type = getattr(forward_chat, 'type', None) if forward_chat else None | |
| type_str = str(chat_type).lower() if chat_type else "" | |
| if "channel" not in type_str and not getattr(forward_chat, 'is_channel', False) if forward_chat else False: | |
| await message.reply_text("β Please forward a message from a channel.") | |
| return | |
| last_msg_id = msg.forward_origin.message_id | |
| chat_id = forward_chat.username if forward_chat and forward_chat.username else (forward_chat.id if forward_chat else None) | |
| if chat_id is None: | |
| await message.reply_text("β Could not determine channel from forwarded message.") | |
| return | |
| else: | |
| await message.reply_text("β Please forward a message from a channel.") | |
| return | |
| else: | |
| await message.reply_text("β Please forward a message from a channel or send a message link.") | |
| return | |
| # Verify it's a channel (use string session client) | |
| try: | |
| chat = await indexing_client.get_chat(chat_id) | |
| chat_type = getattr(chat, 'type', None) | |
| type_str = str(chat_type).lower() | |
| # Check for channel in various formats: "channel", "chattype.channel", enum name, etc. | |
| if "channel" not in type_str and not getattr(chat, 'is_channel', False): | |
| await message.reply_text(f"β I can only index channels. Got type: {chat_type}") | |
| return | |
| except Exception as e: | |
| await message.reply_text(f"β Error: {e}") | |
| return | |
| # Ask for skip count | |
| s = await message.reply_text("π Send the number of messages to skip from the start (default: 0)") | |
| skip_msg = await client.listen(chat_id=message.chat.id, user_id=message.from_user.id, timeout=60) | |
| await s.delete() | |
| try: | |
| skip = int(skip_msg.text) if skip_msg.text.strip() else 0 | |
| except ValueError: | |
| await message.reply_text("β Invalid number. Using 0 as default.") | |
| skip = 0 | |
| # Confirm indexing | |
| from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton | |
| buttons = [ | |
| [InlineKeyboardButton("β YES", callback_data=f"index#yes#{chat_id}#{last_msg_id}#{skip}")], | |
| [InlineKeyboardButton("β CANCEL", callback_data="index#cancel")] | |
| ] | |
| reply_markup = InlineKeyboardMarkup(buttons) | |
| await message.reply_text( | |
| f"π Do you want to index **{chat.title}** channel?\n\n" | |
| f"Total Messages: `{last_msg_id}`\n" | |
| f"Skip Count: `{skip}`", | |
| reply_markup=reply_markup | |
| ) | |
| except asyncio.TimeoutError: | |
| await message.reply_text("β±οΈ Timeout. Please try again.") | |
| except Exception as e: | |
| await message.reply_text(f"β Error: {e}") | |
| async def index_callback_handler(client: Client, callback_query): | |
| """Handle index confirmation callback""" | |
| data = callback_query.data.split("#") | |
| if len(data) < 2: | |
| await callback_query.answer("Invalid callback data") | |
| return | |
| ident = data[1] | |
| if ident == "cancel": | |
| global _index_cancel | |
| _index_cancel = True | |
| await callback_query.message.edit("π Trying to cancel indexing...") | |
| await callback_query.answer("Cancellation requested") | |
| return | |
| if ident == "yes" and len(data) >= 5: | |
| chat_id = data[2] | |
| last_msg_id = int(data[3]) | |
| skip = int(data[4]) | |
| try: | |
| chat_id_int = int(chat_id) | |
| except ValueError: | |
| chat_id_int = chat_id | |
| await callback_query.message.edit("π Starting indexing...") | |
| await callback_query.answer("Indexing started") | |
| # Get string session client for indexing | |
| try: | |
| indexing_client = get_client(premium_required=True) | |
| except RuntimeError: | |
| await callback_query.message.edit( | |
| "β **Indexing requires STRING_SESSIONS (user accounts).**\n\n" | |
| "Bots cannot access full message history. Please configure `STRING_SESSIONS` " | |
| "in your environment variables to enable indexing." | |
| ) | |
| await callback_query.answer("STRING_SESSIONS required") | |
| return | |
| # Run indexing in background | |
| asyncio.create_task( | |
| index_channel_messages_with_iter( | |
| indexing_client, chat_id_int, last_msg_id, skip, callback_query.message | |
| ) | |
| ) | |
| else: | |
| await callback_query.answer("Invalid callback data") | |
| async def set_folder_handler(client: Client, message: Message): | |
| global SET_FOLDER_PATH_CACHE, DRIVE_DATA | |
| while True: | |
| try: | |
| folder_name = await client.ask( | |
| message.chat.id, | |
| "Send the folder name where you want to upload files\n\n/cancel to cancel", | |
| timeout=60, | |
| filters=filters.text, | |
| ) | |
| except asyncio.TimeoutError: | |
| await message.reply_text("Timeout\n\nUse /set_folder to set folder again") | |
| return | |
| if folder_name.text.lower() == "/cancel": | |
| await message.reply_text("Cancelled") | |
| return | |
| folder_name = folder_name.text.strip() | |
| search_result = DRIVE_DATA.search_file_folder(folder_name) | |
| # Get folders from search result | |
| folders = {} | |
| for item in search_result.values(): | |
| if item.type == "folder": | |
| folders[item.id] = item | |
| if len(folders) == 0: | |
| await message.reply_text(f"No Folder found with name {folder_name}") | |
| else: | |
| break | |
| buttons = [] | |
| folder_cache = {} | |
| folder_cache_id = len(SET_FOLDER_PATH_CACHE) + 1 | |
| for folder in search_result.values(): | |
| path = folder.path.strip("/") | |
| folder_path = "/" + ("/" + path + "/" + folder.id).strip("/") | |
| folder_cache[folder.id] = (folder_path, folder.name) | |
| buttons.append( | |
| [ | |
| InlineKeyboardButton( | |
| folder.name, | |
| callback_data=f"set_folder_{folder_cache_id}_{folder.id}", | |
| ) | |
| ] | |
| ) | |
| SET_FOLDER_PATH_CACHE[folder_cache_id] = folder_cache | |
| await message.reply_text( | |
| "Select the folder where you want to upload files", | |
| reply_markup=InlineKeyboardMarkup(buttons), | |
| ) | |
| async def set_folder_callback(client: Client, callback_query: Message): | |
| global SET_FOLDER_PATH_CACHE, BOT_MODE | |
| folder_cache_id, folder_id = callback_query.data.split("_")[2:] | |
| folder_path_cache = SET_FOLDER_PATH_CACHE.get(int(folder_cache_id)) | |
| if folder_path_cache is None: | |
| await callback_query.answer("Request Expired, Send /set_folder again") | |
| await callback_query.message.delete() | |
| return | |
| folder_path, name = folder_path_cache.get(folder_id) | |
| del SET_FOLDER_PATH_CACHE[int(folder_cache_id)] | |
| BOT_MODE.set_folder(folder_path, name) | |
| await callback_query.answer(f"Folder Set Successfully To : {name}") | |
| await callback_query.message.edit( | |
| f"Folder Set Successfully To : {name}\n\nNow you can send / forward files to me and it will be uploaded to this folder." | |
| ) | |
| async def current_folder_handler(client: Client, message: Message): | |
| global BOT_MODE | |
| await message.reply_text(f"Current Folder: {BOT_MODE.current_folder_name}") | |
| # Handling when any file is sent to the bot | |
| async def file_handler(client: Client, message: Message): | |
| global BOT_MODE, DRIVE_DATA | |
| copied_message = await message.copy(config.STORAGE_CHANNEL) | |
| file = ( | |
| copied_message.document | |
| or copied_message.video | |
| or copied_message.audio | |
| or copied_message.photo | |
| or copied_message.sticker | |
| ) | |
| DRIVE_DATA.new_file( | |
| BOT_MODE.current_folder, | |
| file.file_name, | |
| copied_message.id, | |
| file.file_size, | |
| ) | |
| await message.reply_text( | |
| f"""β File Uploaded Successfully To Your TG Drive Website | |
| **File Name:** {file.file_name} | |
| **Folder:** {BOT_MODE.current_folder_name} | |
| """ | |
| ) | |
| async def start_bot_mode(d, b): | |
| global DRIVE_DATA, BOT_MODE | |
| DRIVE_DATA = d | |
| BOT_MODE = b | |
| logger.info("Starting Main Bot") | |
| try: | |
| await main_bot.start() | |
| except Exception as e: | |
| # If the main bot fails to start (e.g. network issues), log and continue without bot mode | |
| logger.error(f"Failed to start Main Bot (Bot Mode will be disabled): {e}") | |
| return | |
| try: | |
| await main_bot.send_message( | |
| config.STORAGE_CHANNEL, "Main Bot Started -> TG Drive's Bot Mode Enabled" | |
| ) | |
| logger.info("Main Bot Started") | |
| logger.info("TG Drive's Bot Mode Enabled") | |
| except Exception as e: | |
| logger.error(f"Main Bot started but failed to send startup message: {e}") |