|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks |
|
|
from fastapi.responses import JSONResponse |
|
|
import asyncio |
|
|
import os |
|
|
import time |
|
|
import json |
|
|
from typing import Optional, Dict, Any, List |
|
|
from enum import Enum |
|
|
from pydantic import BaseModel |
|
|
from rich.progress import ( |
|
|
Progress, |
|
|
SpinnerColumn, |
|
|
TimeElapsedColumn, |
|
|
DownloadColumn, |
|
|
TransferSpeedColumn, |
|
|
BarColumn, |
|
|
TextColumn, |
|
|
) |
|
|
from rich.console import Console |
|
|
from rich.live import Live |
|
|
from rich.table import Table |
|
|
import download_channel |
|
|
|
|
|
|
|
|
console = Console() |
|
|
|
|
|
app = FastAPI(title="Telegram Channel Downloader API") |
|
|
|
|
|
|
|
|
active_downloads: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
class FileStatus(str, Enum): |
|
|
PENDING = "pending" |
|
|
DOWNLOADING = "downloading" |
|
|
DOWNLOADED = "downloaded" |
|
|
FAILED = "failed" |
|
|
|
|
|
class ChannelFile(BaseModel): |
|
|
message_id: int |
|
|
filename: str |
|
|
status: FileStatus |
|
|
size: Optional[int] = None |
|
|
download_time: Optional[float] = None |
|
|
error: Optional[str] = None |
|
|
upload_path: Optional[str] = None |
|
|
|
|
|
class DownloadState(BaseModel): |
|
|
channel: str |
|
|
last_scanned_id: Optional[int] = None |
|
|
files: List[ChannelFile] = [] |
|
|
current_download: Optional[int] = None |
|
|
last_updated: float = time.time() |
|
|
|
|
|
class DownloadRequest(BaseModel): |
|
|
channel: Optional[str] = None |
|
|
message_limit: Optional[int] = None |
|
|
|
|
|
class DownloadStatus(BaseModel): |
|
|
channel: str |
|
|
status: str |
|
|
message_count: int = 0 |
|
|
downloaded: int = 0 |
|
|
downloading: Optional[str] = None |
|
|
error: Optional[str] = None |
|
|
|
|
|
def create_hf_dataset(token: str) -> bool: |
|
|
"""Create the Hugging Face dataset if it doesn't exist.""" |
|
|
try: |
|
|
from huggingface_hub import create_repo, RepoNotFoundError |
|
|
try: |
|
|
|
|
|
create_repo( |
|
|
repo_id=download_channel.HF_REPO_ID, |
|
|
token=token, |
|
|
repo_type="dataset", |
|
|
exist_ok=True |
|
|
) |
|
|
console.print(f"[green]Created or verified dataset:[/green] {download_channel.HF_REPO_ID}") |
|
|
|
|
|
|
|
|
initial_state = DownloadState(channel=download_channel.CHANNEL) |
|
|
with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f: |
|
|
json.dump(initial_state.dict(), f, indent=2, ensure_ascii=False) |
|
|
|
|
|
|
|
|
if download_channel.upload_file_to_hf( |
|
|
download_channel.STATE_FILE, |
|
|
download_channel.STATE_FILE, |
|
|
token |
|
|
): |
|
|
console.print("[green]Initialized dataset with empty state file[/green]") |
|
|
return True |
|
|
except Exception as e: |
|
|
console.print(f"[red]Failed to create dataset:[/red] {str(e)}") |
|
|
return False |
|
|
except ImportError: |
|
|
console.print("[red]huggingface_hub not properly installed[/red]") |
|
|
return False |
|
|
return True |
|
|
|
|
|
def download_state_from_hf(token: str) -> DownloadState: |
|
|
"""Try to download the state file from the HF dataset. Returns state dict or creates new.""" |
|
|
if not token: |
|
|
return DownloadState(channel=download_channel.CHANNEL) |
|
|
try: |
|
|
|
|
|
local_path = download_channel.hf_hub_download( |
|
|
repo_id=download_channel.HF_REPO_ID, |
|
|
filename=download_channel.STATE_FILE, |
|
|
repo_type="dataset", |
|
|
token=token |
|
|
) |
|
|
with open(local_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
return DownloadState(**data) |
|
|
except Exception as e: |
|
|
console.print(f"[yellow]No existing state found, creating new dataset:[/yellow] {str(e)}") |
|
|
if create_hf_dataset(token): |
|
|
console.print("[green]Dataset created successfully![/green]") |
|
|
return DownloadState(channel=download_channel.CHANNEL) |
|
|
else: |
|
|
console.print("[red]Failed to create dataset, using local state only[/red]") |
|
|
return DownloadState(channel=download_channel.CHANNEL) |
|
|
|
|
|
async def clean_downloaded_file(file_path: str): |
|
|
"""Remove local file after successful upload""" |
|
|
try: |
|
|
os.remove(file_path) |
|
|
console.print(f"[blue]Cleaned up:[/blue] {os.path.basename(file_path)}") |
|
|
except Exception as e: |
|
|
console.print(f"[yellow]Warning:[/yellow] Could not clean up {file_path}: {e}") |
|
|
|
|
|
async def update_and_upload_state(state: DownloadState, token: str) -> bool: |
|
|
"""Update state timestamp and upload to dataset""" |
|
|
state.last_updated = time.time() |
|
|
try: |
|
|
|
|
|
with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f: |
|
|
json.dump(state.dict(), f, indent=2, ensure_ascii=False) |
|
|
|
|
|
return download_channel.upload_file_to_hf( |
|
|
download_channel.STATE_FILE, |
|
|
download_channel.STATE_FILE, |
|
|
token |
|
|
) |
|
|
except Exception as e: |
|
|
console.print(f"[red]Failed to update state:[/red] {e}") |
|
|
return False |
|
|
|
|
|
async def process_message(message, state: DownloadState, client) -> Optional[str]: |
|
|
"""Process a single message, return output path if file downloaded or None""" |
|
|
if not message.media: |
|
|
return None |
|
|
|
|
|
|
|
|
is_rar = False |
|
|
filename = "" |
|
|
if message.file: |
|
|
filename = getattr(message.file, 'name', '') or '' |
|
|
if filename: |
|
|
is_rar = filename.lower().endswith('.rar') |
|
|
else: |
|
|
mime_type = getattr(message.file, 'mime_type', '') or '' |
|
|
is_rar = 'rar' in mime_type.lower() if mime_type else False |
|
|
|
|
|
if not is_rar: |
|
|
return None |
|
|
|
|
|
|
|
|
if filename: |
|
|
suggested = f"{message.id}_{filename}" |
|
|
else: |
|
|
suggested = f"{message.id}.rar" |
|
|
|
|
|
return os.path.join(download_channel.OUTPUT_DIR, suggested) |
|
|
|
|
|
async def run_download(channel: Optional[str], message_limit: Optional[int], task_id: str): |
|
|
"""Background task to run the download with state management""" |
|
|
try: |
|
|
|
|
|
if channel: |
|
|
download_channel.CHANNEL = channel |
|
|
if message_limit is not None: |
|
|
download_channel.MESSAGE_LIMIT = message_limit |
|
|
|
|
|
|
|
|
state = download_state_from_hf(download_channel.HF_TOKEN) |
|
|
|
|
|
|
|
|
status = { |
|
|
"channel": state.channel, |
|
|
"status": "running", |
|
|
"message_count": len(state.files), |
|
|
"downloaded": len([f for f in state.files if f.status == FileStatus.DOWNLOADED]), |
|
|
"downloading": None, |
|
|
"error": None |
|
|
} |
|
|
active_downloads[task_id] = status |
|
|
|
|
|
|
|
|
progress = Progress( |
|
|
SpinnerColumn(), |
|
|
TextColumn("[bold blue]{task.fields[filename]}", justify="right"), |
|
|
BarColumn(bar_width=40), |
|
|
"[progress.percentage]{task.percentage:>3.1f}%", |
|
|
"•", |
|
|
DownloadColumn(), |
|
|
"•", |
|
|
TransferSpeedColumn(), |
|
|
"•", |
|
|
TimeElapsedColumn(), |
|
|
) |
|
|
|
|
|
overall_progress = Progress( |
|
|
TextColumn("[bold yellow]{task.description}", justify="right"), |
|
|
BarColumn(bar_width=40), |
|
|
"[progress.percentage]{task.percentage:>3.1f}%", |
|
|
"•", |
|
|
TextColumn("[bold green]{task.fields[stats]}") |
|
|
) |
|
|
|
|
|
|
|
|
client = download_channel.TelegramClient( |
|
|
download_channel.SESSION_FILE, |
|
|
download_channel.API_ID, |
|
|
download_channel.API_HASH |
|
|
) |
|
|
|
|
|
async with client: |
|
|
try: |
|
|
entity = await client.get_entity(download_channel.CHANNEL) |
|
|
except Exception as e: |
|
|
console.print(f"[red]Failed to resolve channel:[/red] {e}") |
|
|
return 1 |
|
|
|
|
|
console.print(f"[green]Starting download from:[/green] {entity.title if hasattr(entity, 'title') else download_channel.CHANNEL}") |
|
|
|
|
|
|
|
|
scan_count = 0 |
|
|
last_message_id = state.last_scanned_id |
|
|
|
|
|
try: |
|
|
async for message in client.iter_messages(entity, limit=download_channel.MESSAGE_LIMIT or None): |
|
|
scan_count += 1 |
|
|
|
|
|
|
|
|
if last_message_id is None or message.id > last_message_id: |
|
|
last_message_id = message.id |
|
|
|
|
|
|
|
|
if any(f.message_id == message.id for f in state.files): |
|
|
continue |
|
|
|
|
|
|
|
|
out_path = await process_message(message, state, client) |
|
|
if out_path: |
|
|
|
|
|
file_info = ChannelFile( |
|
|
message_id=message.id, |
|
|
filename=os.path.basename(out_path), |
|
|
status=FileStatus.PENDING, |
|
|
size=getattr(message.media, 'size', 0) or 0 |
|
|
) |
|
|
state.files.append(file_info) |
|
|
|
|
|
|
|
|
state.last_scanned_id = last_message_id |
|
|
if download_channel.HF_TOKEN: |
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN) |
|
|
|
|
|
console.print(f"[green]Channel scan complete:[/green] Found {scan_count} messages") |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[red]Error during channel scan:[/red] {e}") |
|
|
|
|
|
|
|
|
pending_files = [f for f in state.files if f.status == FileStatus.PENDING] |
|
|
total_pending = len(pending_files) |
|
|
|
|
|
if total_pending == 0: |
|
|
console.print("[green]No new files to download![/green]") |
|
|
return 0 |
|
|
|
|
|
console.print(f"[green]Starting downloads:[/green] {total_pending} files pending") |
|
|
|
|
|
|
|
|
with Live(progress) as live_progress, Live(overall_progress) as live_overall: |
|
|
overall_task = overall_progress.add_task( |
|
|
f"Channel: {download_channel.CHANNEL}", |
|
|
total=total_pending, |
|
|
stats=f"Pending: {total_pending}" |
|
|
) |
|
|
|
|
|
for file_info in pending_files: |
|
|
try: |
|
|
|
|
|
file_info.status = FileStatus.DOWNLOADING |
|
|
state.current_download = file_info.message_id |
|
|
if download_channel.HF_TOKEN: |
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN) |
|
|
|
|
|
|
|
|
status["downloading"] = file_info.filename |
|
|
|
|
|
|
|
|
message = await client.get_messages(entity, ids=file_info.message_id) |
|
|
if not message or not message.media: |
|
|
file_info.status = FileStatus.FAILED |
|
|
file_info.error = "Message not found or no media" |
|
|
continue |
|
|
|
|
|
out_path = os.path.join(download_channel.OUTPUT_DIR, file_info.filename) |
|
|
file_task = progress.add_task( |
|
|
"download", |
|
|
total=file_info.size or 100, |
|
|
filename=file_info.filename |
|
|
) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
async def progress_callback(current, total): |
|
|
progress.update(file_task, completed=current) |
|
|
overall_stats = f"Downloaded: {len([f for f in state.files if f.status == FileStatus.DOWNLOADED])}" |
|
|
overall_progress.update(overall_task, completed=current/total*100, stats=overall_stats) |
|
|
|
|
|
await client.download_media( |
|
|
message, |
|
|
file=out_path, |
|
|
progress_callback=progress_callback |
|
|
) |
|
|
|
|
|
|
|
|
if download_channel.HF_TOKEN: |
|
|
console.print(f"[yellow]Uploading to HF:[/yellow] {file_info.filename}") |
|
|
path_in_repo = f"files/{file_info.filename}" |
|
|
ok = download_channel.upload_file_to_hf( |
|
|
out_path, |
|
|
path_in_repo, |
|
|
download_channel.HF_TOKEN |
|
|
) |
|
|
if ok: |
|
|
console.print(f"[green]Uploaded:[/green] {file_info.filename}") |
|
|
|
|
|
await clean_downloaded_file(out_path) |
|
|
file_info.upload_path = path_in_repo |
|
|
else: |
|
|
console.print(f"[red]Upload failed:[/red] {file_info.filename}") |
|
|
file_info.error = "Upload to dataset failed" |
|
|
file_info.status = FileStatus.FAILED |
|
|
continue |
|
|
|
|
|
|
|
|
file_info.status = FileStatus.DOWNLOADED |
|
|
file_info.download_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if download_channel.HF_TOKEN: |
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN) |
|
|
|
|
|
|
|
|
status["downloaded"] += 1 |
|
|
await asyncio.sleep(0.2) |
|
|
|
|
|
except download_channel.errors.FloodWaitError as fw: |
|
|
wait = int(fw.seconds) if fw.seconds else 60 |
|
|
console.print(f"[yellow]FloodWait:[/yellow] Sleeping {wait}s") |
|
|
await asyncio.sleep(wait + 1) |
|
|
|
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[red]Error:[/red] {str(e)}") |
|
|
file_info.status = FileStatus.FAILED |
|
|
file_info.error = str(e) |
|
|
if download_channel.HF_TOKEN: |
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN) |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[red]Fatal error processing {file_info.filename}:[/red] {str(e)}") |
|
|
continue |
|
|
|
|
|
|
|
|
state.current_download = None |
|
|
if download_channel.HF_TOKEN: |
|
|
await update_and_upload_state(state, download_channel.HF_TOKEN) |
|
|
|
|
|
console.print("[green]Download session completed![/green]") |
|
|
status["status"] = "completed" |
|
|
status["downloading"] = None |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[red]Fatal error:[/red] {str(e)}") |
|
|
if "status" in locals(): |
|
|
status["status"] = "failed" |
|
|
status["error"] = str(e) |
|
|
|
|
|
return 0 |
|
|
|
|
|
@app.on_event("startup") |
|
|
async def start_initial_download(): |
|
|
"""Start the download process automatically when the server starts""" |
|
|
task_id = "initial_download" |
|
|
|
|
|
|
|
|
if not download_channel.HF_TOKEN: |
|
|
console.print("[red]ERROR: HF_TOKEN not set. Please set your Hugging Face token.[/red]") |
|
|
return |
|
|
|
|
|
|
|
|
console.print("[yellow]Checking Hugging Face dataset...[/yellow]") |
|
|
try: |
|
|
state = download_state_from_hf(download_channel.HF_TOKEN) |
|
|
console.print(f"[green]Using channel:[/green] {state.channel}") |
|
|
|
|
|
|
|
|
os.makedirs(download_channel.OUTPUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
asyncio.create_task(run_download( |
|
|
channel=None, |
|
|
message_limit=None, |
|
|
task_id=task_id |
|
|
)) |
|
|
console.print(f"[green]Started initial download task:[/green] {task_id}") |
|
|
|
|
|
except Exception as e: |
|
|
console.print(f"[red]Failed to initialize:[/red] {str(e)}") |
|
|
|
|
|
@app.post("/download", response_model=Dict[str, str]) |
|
|
async def start_download(request: DownloadRequest, background_tasks: BackgroundTasks): |
|
|
"""Start a new download task""" |
|
|
task_id = f"download_{len(active_downloads) + 1}" |
|
|
|
|
|
background_tasks.add_task( |
|
|
run_download, |
|
|
channel=request.channel, |
|
|
message_limit=request.message_limit, |
|
|
task_id=task_id |
|
|
) |
|
|
|
|
|
return {"task_id": task_id} |
|
|
|
|
|
@app.get("/status/{task_id}", response_model=DownloadStatus) |
|
|
async def get_status(task_id: str): |
|
|
"""Get the status of a download task""" |
|
|
if task_id not in active_downloads: |
|
|
raise HTTPException(status_code=404, detail="Task not found") |
|
|
return active_downloads[task_id] |
|
|
|
|
|
@app.get("/active", response_model=Dict[str, DownloadStatus]) |
|
|
async def list_active(): |
|
|
"""List all active or completed downloads""" |
|
|
return active_downloads |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="127.0.0.1", port=8000) |