beta-tgdrive / utils /bot_mode.py
dragxd's picture
Fix deprecated forward_from_chat and add missing get_client import
07fad73
import asyncio
from pyrogram import Client, filters
import pyromod
from pyrogram.types import Message, InlineKeyboardMarkup, InlineKeyboardButton
import config
from utils.logger import Logger
from pathlib import Path
from utils.mongo_indexer import (
index_channel_messages,
index_channel_messages_with_iter,
MongoNotConfiguredError,
)
from utils.clients import get_client
logger = Logger(__name__)
START_CMD = """πŸš€ **Welcome To TG Drive's Bot Mode**
You can use this bot to upload files to your TG Drive website directly instead of doing it from website.
πŸ—„ **Commands:**
/set_folder - Set folder for file uploads
/current_folder - Check current folder
πŸ“€ **How To Upload Files:** Send a file to this bot and it will be uploaded to your TG Drive website. You can also set a folder for file uploads using /set_folder command.
Read more about [TG Drive's Bot Mode](https://github.com/TechShreyash/TGDrive#tg-drives-bot-mode)
"""
SET_FOLDER_PATH_CACHE = {} # Cache to store folder path for each folder id
DRIVE_DATA = None
BOT_MODE = None
session_cache_path = Path(f"./cache")
session_cache_path.parent.mkdir(parents=True, exist_ok=True)
main_bot = Client(
name="main_bot",
api_id=config.API_ID,
api_hash=config.API_HASH,
bot_token=config.MAIN_BOT_TOKEN,
sleep_threshold=config.SLEEP_THRESHOLD,
workdir=session_cache_path,
)
@main_bot.on_message(
filters.command(["start", "help"])
& filters.private
& filters.user(config.TELEGRAM_ADMIN_IDS),
)
async def start_handler(client: Client, message: Message):
await message.reply_text(START_CMD)
# Lock to prevent concurrent indexing
_index_lock = asyncio.Lock()
@main_bot.on_message(
filters.command("index")
& filters.private
& filters.user(config.TELEGRAM_ADMIN_IDS),
)
async def index_handler(client: Client, message: Message):
"""
/index - Index channel messages to MongoDB
Note: Requires STRING_SESSIONS (user accounts) - bots cannot index channels.
Usage:
- /index <channel_id> - Index from channel ID directly
- /index - Then forward a message or send message link
"""
if _index_lock.locked():
await message.reply_text("⏳ Wait until the previous indexing process completes.")
return
# Check if string sessions are available (required for indexing)
from utils.clients import get_client
try:
indexing_client = get_client(premium_required=True)
except RuntimeError:
await message.reply_text(
"❌ **Indexing requires STRING_SESSIONS (user accounts).**\n\n"
"Bots cannot access full message history. Please configure `STRING_SESSIONS` "
"in your environment variables to enable indexing."
)
return
try:
# Check if channel ID was provided as command argument
if len(message.command) > 1:
raw_chat_id = message.command[1].strip()
try:
chat_id = int(raw_chat_id)
except ValueError:
chat_id = raw_chat_id
# Verify it's a channel and get info (use string session client)
try:
chat = await indexing_client.get_chat(chat_id)
# Check if it's a channel - handle both enum and string formats
chat_type = getattr(chat, 'type', None)
type_str = str(chat_type).lower()
# Check for channel in various formats: "channel", "chattype.channel", enum name, etc.
if "channel" not in type_str and not getattr(chat, 'is_channel', False):
await message.reply_text(f"❌ I can only index channels. Got type: {chat_type}")
return
# Ask user to optionally send a message (forward or link) to specify last message ID
# Or they can just send a number for skip count to use auto-detection
i = await message.reply_text(
"πŸ“€ **Option 1:** Forward the last message from the channel, or send the last message link.\n"
"πŸ“Š **Option 2:** Just send a number (skip count, default: 0) to auto-detect last message.\n\n"
"_(Send a number or forward/link within 60 seconds)_"
)
msg_input = await client.listen(chat_id=message.chat.id, user_id=message.from_user.id, timeout=60)
await i.delete()
last_msg_id = None
skip = 0
# Check if it's a message link
if msg_input.text and msg_input.text.startswith("https://t.me"):
try:
msg_link = msg_input.text.split("/")
last_msg_id = int(msg_link[-1])
# Verify it's from the same channel
chat_id_from_link = msg_link[-2]
if chat_id_from_link.isnumeric():
chat_id_from_link = int("-100" + chat_id_from_link)
if chat_id_from_link != chat_id:
await message.reply_text("❌ Message link is from a different channel!")
return
except Exception:
await message.reply_text("❌ Invalid message link!")
return
# Check if it's a forwarded message (using new forward_origin API)
elif msg_input.forward_origin:
from pyrogram.types import MessageOriginChannel
if isinstance(msg_input.forward_origin, MessageOriginChannel):
forward_chat = msg_input.forward_origin.chat
forward_chat_id = forward_chat.username if forward_chat else None
if not forward_chat_id:
forward_chat_id = msg_input.forward_origin.chat.id if msg_input.forward_origin.chat else None
if forward_chat_id != chat_id:
await message.reply_text("❌ Forwarded message is from a different channel!")
return
last_msg_id = msg_input.forward_origin.message_id
else:
await message.reply_text("❌ Please forward a message from a channel.")
return
# Check if it's just a number (skip count)
elif msg_input.text and msg_input.text.strip().isdigit():
skip = int(msg_input.text.strip())
# Auto-detect last message ID using get_chat_history (works with user accounts)
last_msg_id = None
try:
# Get the most recent message to find last_msg_id
async for msg in indexing_client.get_chat_history(chat_id, limit=1):
last_msg_id = msg.id
logger.info(f"Found last message ID: {last_msg_id}")
break
if last_msg_id is None:
await message.reply_text("❌ Could not find any messages in the channel.")
return
except Exception as e:
logger.warning(f"Could not find last message ID: {e}")
await message.reply_text(
f"⚠️ Could not automatically find the last message ID.\n\n"
f"Error: {e}\n\n"
"Please send a message link or forward a message instead."
)
return
else:
await message.reply_text("❌ Please send a number (skip count), message link, or forward a message.")
return
# Confirm indexing
from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton
buttons = [
[InlineKeyboardButton("βœ… YES", callback_data=f"index#yes#{chat_id}#{last_msg_id}#{skip}")],
[InlineKeyboardButton("❌ CANCEL", callback_data="index#cancel")]
]
reply_markup = InlineKeyboardMarkup(buttons)
await message.reply_text(
f"πŸ“‹ Do you want to index **{chat.title}** channel?\n\n"
f"Last Message ID: `{last_msg_id}`\n"
f"Skip Count: `{skip}`",
reply_markup=reply_markup
)
return
except Exception as e:
await message.reply_text(f"❌ Error accessing channel: {e}")
return
# No channel ID provided, ask for forward/link
i = await message.reply_text("πŸ“€ Forward the last message from the channel, or send the last message link.")
msg = await client.listen(chat_id=message.chat.id, user_id=message.from_user.id, timeout=60)
await i.delete()
chat_id = None
last_msg_id = None
# Check if it's a message link
if msg.text and msg.text.startswith("https://t.me"):
try:
msg_link = msg.text.split("/")
last_msg_id = int(msg_link[-1])
chat_id_str = msg_link[-2]
if chat_id_str.isnumeric():
chat_id = int("-100" + chat_id_str)
else:
chat_id = chat_id_str
except Exception:
await message.reply_text("❌ Invalid message link!")
return
# Check if it's a forwarded message (using new forward_origin API)
elif msg.forward_origin:
from pyrogram.types import MessageOriginChannel
if isinstance(msg.forward_origin, MessageOriginChannel):
forward_chat = msg.forward_origin.chat
chat_type = getattr(forward_chat, 'type', None) if forward_chat else None
type_str = str(chat_type).lower() if chat_type else ""
if "channel" not in type_str and not getattr(forward_chat, 'is_channel', False) if forward_chat else False:
await message.reply_text("❌ Please forward a message from a channel.")
return
last_msg_id = msg.forward_origin.message_id
chat_id = forward_chat.username if forward_chat and forward_chat.username else (forward_chat.id if forward_chat else None)
if chat_id is None:
await message.reply_text("❌ Could not determine channel from forwarded message.")
return
else:
await message.reply_text("❌ Please forward a message from a channel.")
return
else:
await message.reply_text("❌ Please forward a message from a channel or send a message link.")
return
# Verify it's a channel (use string session client)
try:
chat = await indexing_client.get_chat(chat_id)
chat_type = getattr(chat, 'type', None)
type_str = str(chat_type).lower()
# Check for channel in various formats: "channel", "chattype.channel", enum name, etc.
if "channel" not in type_str and not getattr(chat, 'is_channel', False):
await message.reply_text(f"❌ I can only index channels. Got type: {chat_type}")
return
except Exception as e:
await message.reply_text(f"❌ Error: {e}")
return
# Ask for skip count
s = await message.reply_text("πŸ“Š Send the number of messages to skip from the start (default: 0)")
skip_msg = await client.listen(chat_id=message.chat.id, user_id=message.from_user.id, timeout=60)
await s.delete()
try:
skip = int(skip_msg.text) if skip_msg.text.strip() else 0
except ValueError:
await message.reply_text("❌ Invalid number. Using 0 as default.")
skip = 0
# Confirm indexing
from pyrogram.types import InlineKeyboardMarkup, InlineKeyboardButton
buttons = [
[InlineKeyboardButton("βœ… YES", callback_data=f"index#yes#{chat_id}#{last_msg_id}#{skip}")],
[InlineKeyboardButton("❌ CANCEL", callback_data="index#cancel")]
]
reply_markup = InlineKeyboardMarkup(buttons)
await message.reply_text(
f"πŸ“‹ Do you want to index **{chat.title}** channel?\n\n"
f"Total Messages: `{last_msg_id}`\n"
f"Skip Count: `{skip}`",
reply_markup=reply_markup
)
except asyncio.TimeoutError:
await message.reply_text("⏱️ Timeout. Please try again.")
except Exception as e:
await message.reply_text(f"❌ Error: {e}")
@main_bot.on_callback_query(filters.regex(r"^index#"))
async def index_callback_handler(client: Client, callback_query):
"""Handle index confirmation callback"""
data = callback_query.data.split("#")
if len(data) < 2:
await callback_query.answer("Invalid callback data")
return
ident = data[1]
if ident == "cancel":
global _index_cancel
_index_cancel = True
await callback_query.message.edit("πŸ›‘ Trying to cancel indexing...")
await callback_query.answer("Cancellation requested")
return
if ident == "yes" and len(data) >= 5:
chat_id = data[2]
last_msg_id = int(data[3])
skip = int(data[4])
try:
chat_id_int = int(chat_id)
except ValueError:
chat_id_int = chat_id
await callback_query.message.edit("πŸ”„ Starting indexing...")
await callback_query.answer("Indexing started")
# Get string session client for indexing
try:
indexing_client = get_client(premium_required=True)
except RuntimeError:
await callback_query.message.edit(
"❌ **Indexing requires STRING_SESSIONS (user accounts).**\n\n"
"Bots cannot access full message history. Please configure `STRING_SESSIONS` "
"in your environment variables to enable indexing."
)
await callback_query.answer("STRING_SESSIONS required")
return
# Run indexing in background
asyncio.create_task(
index_channel_messages_with_iter(
indexing_client, chat_id_int, last_msg_id, skip, callback_query.message
)
)
else:
await callback_query.answer("Invalid callback data")
@main_bot.on_message(
filters.command("set_folder")
& filters.private
& filters.user(config.TELEGRAM_ADMIN_IDS),
)
async def set_folder_handler(client: Client, message: Message):
global SET_FOLDER_PATH_CACHE, DRIVE_DATA
while True:
try:
folder_name = await client.ask(
message.chat.id,
"Send the folder name where you want to upload files\n\n/cancel to cancel",
timeout=60,
filters=filters.text,
)
except asyncio.TimeoutError:
await message.reply_text("Timeout\n\nUse /set_folder to set folder again")
return
if folder_name.text.lower() == "/cancel":
await message.reply_text("Cancelled")
return
folder_name = folder_name.text.strip()
search_result = DRIVE_DATA.search_file_folder(folder_name)
# Get folders from search result
folders = {}
for item in search_result.values():
if item.type == "folder":
folders[item.id] = item
if len(folders) == 0:
await message.reply_text(f"No Folder found with name {folder_name}")
else:
break
buttons = []
folder_cache = {}
folder_cache_id = len(SET_FOLDER_PATH_CACHE) + 1
for folder in search_result.values():
path = folder.path.strip("/")
folder_path = "/" + ("/" + path + "/" + folder.id).strip("/")
folder_cache[folder.id] = (folder_path, folder.name)
buttons.append(
[
InlineKeyboardButton(
folder.name,
callback_data=f"set_folder_{folder_cache_id}_{folder.id}",
)
]
)
SET_FOLDER_PATH_CACHE[folder_cache_id] = folder_cache
await message.reply_text(
"Select the folder where you want to upload files",
reply_markup=InlineKeyboardMarkup(buttons),
)
@main_bot.on_callback_query(
filters.user(config.TELEGRAM_ADMIN_IDS) & filters.regex(r"set_folder_")
)
async def set_folder_callback(client: Client, callback_query: Message):
global SET_FOLDER_PATH_CACHE, BOT_MODE
folder_cache_id, folder_id = callback_query.data.split("_")[2:]
folder_path_cache = SET_FOLDER_PATH_CACHE.get(int(folder_cache_id))
if folder_path_cache is None:
await callback_query.answer("Request Expired, Send /set_folder again")
await callback_query.message.delete()
return
folder_path, name = folder_path_cache.get(folder_id)
del SET_FOLDER_PATH_CACHE[int(folder_cache_id)]
BOT_MODE.set_folder(folder_path, name)
await callback_query.answer(f"Folder Set Successfully To : {name}")
await callback_query.message.edit(
f"Folder Set Successfully To : {name}\n\nNow you can send / forward files to me and it will be uploaded to this folder."
)
@main_bot.on_message(
filters.command("current_folder")
& filters.private
& filters.user(config.TELEGRAM_ADMIN_IDS),
)
async def current_folder_handler(client: Client, message: Message):
global BOT_MODE
await message.reply_text(f"Current Folder: {BOT_MODE.current_folder_name}")
# Handling when any file is sent to the bot
@main_bot.on_message(
filters.private
& filters.user(config.TELEGRAM_ADMIN_IDS)
& (
filters.document
| filters.video
| filters.audio
| filters.photo
| filters.sticker
)
)
async def file_handler(client: Client, message: Message):
global BOT_MODE, DRIVE_DATA
copied_message = await message.copy(config.STORAGE_CHANNEL)
file = (
copied_message.document
or copied_message.video
or copied_message.audio
or copied_message.photo
or copied_message.sticker
)
DRIVE_DATA.new_file(
BOT_MODE.current_folder,
file.file_name,
copied_message.id,
file.file_size,
)
await message.reply_text(
f"""βœ… File Uploaded Successfully To Your TG Drive Website
**File Name:** {file.file_name}
**Folder:** {BOT_MODE.current_folder_name}
"""
)
async def start_bot_mode(d, b):
global DRIVE_DATA, BOT_MODE
DRIVE_DATA = d
BOT_MODE = b
logger.info("Starting Main Bot")
try:
await main_bot.start()
except Exception as e:
# If the main bot fails to start (e.g. network issues), log and continue without bot mode
logger.error(f"Failed to start Main Bot (Bot Mode will be disabled): {e}")
return
try:
await main_bot.send_message(
config.STORAGE_CHANNEL, "Main Bot Started -> TG Drive's Bot Mode Enabled"
)
logger.info("Main Bot Started")
logger.info("TG Drive's Bot Mode Enabled")
except Exception as e:
logger.error(f"Main Bot started but failed to send startup message: {e}")