Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, BackgroundTasks | |
| from fastapi.responses import JSONResponse | |
| import asyncio | |
| import os | |
| import time | |
| import json | |
| from typing import Optional, Dict, Any, List | |
| from enum import Enum | |
| from pydantic import BaseModel | |
| from rich.progress import ( | |
| Progress, | |
| SpinnerColumn, | |
| TimeElapsedColumn, | |
| DownloadColumn, | |
| TransferSpeedColumn, | |
| BarColumn, | |
| TextColumn, | |
| ) | |
| from rich.console import Console | |
| from rich.live import Live | |
| from rich.table import Table | |
| import download_channel | |
| # Initialize rich console for pretty logging | |
| console = Console() | |
| app = FastAPI(title="Telegram Channel Downloader API") | |
| # Track active downloads and their status | |
| active_downloads: Dict[str, Dict[str, Any]] = {} | |
| class FileStatus(str, Enum): | |
| PENDING = "pending" | |
| DOWNLOADING = "downloading" | |
| DOWNLOADED = "downloaded" | |
| FAILED = "failed" | |
| class ChannelFile(BaseModel): | |
| message_id: int | |
| filename: str | |
| status: FileStatus | |
| size: Optional[int] = None | |
| download_time: Optional[float] = None | |
| error: Optional[str] = None | |
| upload_path: Optional[str] = None | |
| class DownloadState(BaseModel): | |
| channel: str | |
| last_scanned_id: Optional[int] = None | |
| files: List[ChannelFile] = [] | |
| current_download: Optional[int] = None # message_id of current download | |
| last_updated: float = time.time() | |
| class DownloadRequest(BaseModel): | |
| channel: Optional[str] = None | |
| message_limit: Optional[int] = None | |
| class DownloadStatus(BaseModel): | |
| channel: str | |
| status: str | |
| message_count: int = 0 | |
| downloaded: int = 0 | |
| downloading: Optional[str] = None | |
| error: Optional[str] = None | |
| def create_hf_dataset(token: str) -> bool: | |
| """Create the Hugging Face dataset if it doesn't exist.""" | |
| try: | |
| from huggingface_hub import create_repo, RepoNotFoundError | |
| try: | |
| # Try to create the dataset repository | |
| create_repo( | |
| repo_id=download_channel.HF_REPO_ID, | |
| token=token, | |
| repo_type="dataset", | |
| exist_ok=True | |
| ) | |
| console.print(f"[green]Created or verified dataset:[/green] {download_channel.HF_REPO_ID}") | |
| # Create initial state file | |
| initial_state = DownloadState(channel=download_channel.CHANNEL) | |
| with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f: | |
| json.dump(initial_state.dict(), f, indent=2, ensure_ascii=False) | |
| # Upload initial state | |
| if download_channel.upload_file_to_hf( | |
| download_channel.STATE_FILE, | |
| download_channel.STATE_FILE, | |
| token | |
| ): | |
| console.print("[green]Initialized dataset with empty state file[/green]") | |
| return True | |
| except Exception as e: | |
| console.print(f"[red]Failed to create dataset:[/red] {str(e)}") | |
| return False | |
| except ImportError: | |
| console.print("[red]huggingface_hub not properly installed[/red]") | |
| return False | |
| return True | |
| def download_state_from_hf(token: str) -> DownloadState: | |
| """Try to download the state file from the HF dataset. Returns state dict or creates new.""" | |
| if not token: | |
| return DownloadState(channel=download_channel.CHANNEL) | |
| try: | |
| # Try to download existing state | |
| local_path = download_channel.hf_hub_download( | |
| repo_id=download_channel.HF_REPO_ID, | |
| filename=download_channel.STATE_FILE, | |
| repo_type="dataset", | |
| token=token | |
| ) | |
| with open(local_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return DownloadState(**data) | |
| except Exception as e: | |
| console.print(f"[yellow]No existing state found, creating new dataset:[/yellow] {str(e)}") | |
| if create_hf_dataset(token): | |
| console.print("[green]Dataset created successfully![/green]") | |
| return DownloadState(channel=download_channel.CHANNEL) | |
| else: | |
| console.print("[red]Failed to create dataset, using local state only[/red]") | |
| return DownloadState(channel=download_channel.CHANNEL) | |
| async def clean_downloaded_file(file_path: str): | |
| """Remove local file after successful upload""" | |
| try: | |
| os.remove(file_path) | |
| console.print(f"[blue]Cleaned up:[/blue] {os.path.basename(file_path)}") | |
| except Exception as e: | |
| console.print(f"[yellow]Warning:[/yellow] Could not clean up {file_path}: {e}") | |
| async def update_and_upload_state(state: DownloadState, token: str) -> bool: | |
| """Update state timestamp and upload to dataset""" | |
| state.last_updated = time.time() | |
| try: | |
| # Save state locally first | |
| with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f: | |
| json.dump(state.dict(), f, indent=2, ensure_ascii=False) | |
| # Upload to dataset | |
| return download_channel.upload_file_to_hf( | |
| download_channel.STATE_FILE, | |
| download_channel.STATE_FILE, | |
| token | |
| ) | |
| except Exception as e: | |
| console.print(f"[red]Failed to update state:[/red] {e}") | |
| return False | |
| async def process_message(message, state: DownloadState, client) -> Optional[str]: | |
| """Process a single message, return output path if file downloaded or None""" | |
| if not message.media: | |
| return None | |
| # Check if it's a RAR file | |
| is_rar = False | |
| filename = "" | |
| if message.file: | |
| filename = getattr(message.file, 'name', '') or '' | |
| if filename: | |
| is_rar = filename.lower().endswith('.rar') | |
| else: | |
| mime_type = getattr(message.file, 'mime_type', '') or '' | |
| is_rar = 'rar' in mime_type.lower() if mime_type else False | |
| if not is_rar: | |
| return None | |
| # Use message ID and original filename for saved file | |
| if filename: | |
| suggested = f"{message.id}_{filename}" | |
| else: | |
| suggested = f"{message.id}.rar" | |
| return os.path.join(download_channel.OUTPUT_DIR, suggested) | |
| async def run_download(channel: Optional[str], message_limit: Optional[int], task_id: str): | |
| """Background task to run the download with state management""" | |
| try: | |
| # Override channel if provided | |
| if channel: | |
| download_channel.CHANNEL = channel | |
| if message_limit is not None: | |
| download_channel.MESSAGE_LIMIT = message_limit | |
| # Get or create download state | |
| state = download_state_from_hf(download_channel.HF_TOKEN) | |
| # Initialize status for API | |
| status = { | |
| "channel": state.channel, | |
| "status": "running", | |
| "message_count": len(state.files), | |
| "downloaded": len([f for f in state.files if f.status == FileStatus.DOWNLOADED]), | |
| "downloading": None, | |
| "error": None | |
| } | |
| active_downloads[task_id] = status | |
| # Create progress displays | |
| progress = Progress( | |
| SpinnerColumn(), | |
| TextColumn("[bold blue]{task.fields[filename]}", justify="right"), | |
| BarColumn(bar_width=40), | |
| "[progress.percentage]{task.percentage:>3.1f}%", | |
| "•", | |
| DownloadColumn(), | |
| "•", | |
| TransferSpeedColumn(), | |
| "•", | |
| TimeElapsedColumn(), | |
| ) | |
| overall_progress = Progress( | |
| TextColumn("[bold yellow]{task.description}", justify="right"), | |
| BarColumn(bar_width=40), | |
| "[progress.percentage]{task.percentage:>3.1f}%", | |
| "•", | |
| TextColumn("[bold green]{task.fields[stats]}") | |
| ) | |
| # Initialize client | |
| client = download_channel.TelegramClient( | |
| download_channel.SESSION_FILE, | |
| download_channel.API_ID, | |
| download_channel.API_HASH | |
| ) | |
| async with client: | |
| try: | |
| entity = await client.get_entity(download_channel.CHANNEL) | |
| except Exception as e: | |
| console.print(f"[red]Failed to resolve channel:[/red] {e}") | |
| return 1 | |
| console.print(f"[green]Starting download from:[/green] {entity.title if hasattr(entity, 'title') else download_channel.CHANNEL}") | |
| # First, scan for new messages and update state | |
| scan_count = 0 | |
| last_message_id = state.last_scanned_id | |
| try: | |
| async for message in client.iter_messages(entity, limit=download_channel.MESSAGE_LIMIT or None): | |
| scan_count += 1 | |
| # Update last scanned ID | |
| if last_message_id is None or message.id > last_message_id: | |
| last_message_id = message.id | |
| # Skip if we already know about this message | |
| if any(f.message_id == message.id for f in state.files): | |
| continue | |
| # Check if it's a downloadable file | |
| out_path = await process_message(message, state, client) | |
| if out_path: | |
| # Add to state as pending | |
| file_info = ChannelFile( | |
| message_id=message.id, | |
| filename=os.path.basename(out_path), | |
| status=FileStatus.PENDING, | |
| size=getattr(message.media, 'size', 0) or 0 | |
| ) | |
| state.files.append(file_info) | |
| # Update state with scan results | |
| state.last_scanned_id = last_message_id | |
| if download_channel.HF_TOKEN: | |
| await update_and_upload_state(state, download_channel.HF_TOKEN) | |
| console.print(f"[green]Channel scan complete:[/green] Found {scan_count} messages") | |
| except Exception as e: | |
| console.print(f"[red]Error during channel scan:[/red] {e}") | |
| # Now process pending downloads | |
| pending_files = [f for f in state.files if f.status == FileStatus.PENDING] | |
| total_pending = len(pending_files) | |
| if total_pending == 0: | |
| console.print("[green]No new files to download![/green]") | |
| return 0 | |
| console.print(f"[green]Starting downloads:[/green] {total_pending} files pending") | |
| # Process pending files | |
| with Live(progress) as live_progress, Live(overall_progress) as live_overall: | |
| overall_task = overall_progress.add_task( | |
| f"Channel: {download_channel.CHANNEL}", | |
| total=total_pending, | |
| stats=f"Pending: {total_pending}" | |
| ) | |
| for file_info in pending_files: | |
| try: | |
| # Mark as downloading in state | |
| file_info.status = FileStatus.DOWNLOADING | |
| state.current_download = file_info.message_id | |
| if download_channel.HF_TOKEN: | |
| await update_and_upload_state(state, download_channel.HF_TOKEN) | |
| # Update status | |
| status["downloading"] = file_info.filename | |
| # Get message and prepare download | |
| message = await client.get_messages(entity, ids=file_info.message_id) | |
| if not message or not message.media: | |
| file_info.status = FileStatus.FAILED | |
| file_info.error = "Message not found or no media" | |
| continue | |
| out_path = os.path.join(download_channel.OUTPUT_DIR, file_info.filename) | |
| file_task = progress.add_task( | |
| "download", | |
| total=file_info.size or 100, | |
| filename=file_info.filename | |
| ) | |
| # Download with progress | |
| start_time = time.time() | |
| try: | |
| async def progress_callback(current, total): | |
| progress.update(file_task, completed=current) | |
| overall_stats = f"Downloaded: {len([f for f in state.files if f.status == FileStatus.DOWNLOADED])}" | |
| overall_progress.update(overall_task, completed=current/total*100, stats=overall_stats) | |
| await client.download_media( | |
| message, | |
| file=out_path, | |
| progress_callback=progress_callback | |
| ) | |
| # Upload to HF | |
| if download_channel.HF_TOKEN: | |
| console.print(f"[yellow]Uploading to HF:[/yellow] {file_info.filename}") | |
| path_in_repo = f"files/{file_info.filename}" | |
| ok = download_channel.upload_file_to_hf( | |
| out_path, | |
| path_in_repo, | |
| download_channel.HF_TOKEN | |
| ) | |
| if ok: | |
| console.print(f"[green]Uploaded:[/green] {file_info.filename}") | |
| # Clean up local file | |
| await clean_downloaded_file(out_path) | |
| file_info.upload_path = path_in_repo | |
| else: | |
| console.print(f"[red]Upload failed:[/red] {file_info.filename}") | |
| file_info.error = "Upload to dataset failed" | |
| file_info.status = FileStatus.FAILED | |
| continue | |
| # Mark as completed in state | |
| file_info.status = FileStatus.DOWNLOADED | |
| file_info.download_time = time.time() - start_time | |
| # Update state | |
| if download_channel.HF_TOKEN: | |
| await update_and_upload_state(state, download_channel.HF_TOKEN) | |
| # Update status | |
| status["downloaded"] += 1 | |
| await asyncio.sleep(0.2) # Be polite | |
| except download_channel.errors.FloodWaitError as fw: | |
| wait = int(fw.seconds) if fw.seconds else 60 | |
| console.print(f"[yellow]FloodWait:[/yellow] Sleeping {wait}s") | |
| await asyncio.sleep(wait + 1) | |
| # Retry this file | |
| continue | |
| except Exception as e: | |
| console.print(f"[red]Error:[/red] {str(e)}") | |
| file_info.status = FileStatus.FAILED | |
| file_info.error = str(e) | |
| if download_channel.HF_TOKEN: | |
| await update_and_upload_state(state, download_channel.HF_TOKEN) | |
| except Exception as e: | |
| console.print(f"[red]Fatal error processing {file_info.filename}:[/red] {str(e)}") | |
| continue | |
| # Clear current download | |
| state.current_download = None | |
| if download_channel.HF_TOKEN: | |
| await update_and_upload_state(state, download_channel.HF_TOKEN) | |
| console.print("[green]Download session completed![/green]") | |
| status["status"] = "completed" | |
| status["downloading"] = None | |
| except Exception as e: | |
| console.print(f"[red]Fatal error:[/red] {str(e)}") | |
| if "status" in locals(): | |
| status["status"] = "failed" | |
| status["error"] = str(e) | |
| return 0 | |
| async def start_initial_download(): | |
| """Start the download process automatically when the server starts""" | |
| task_id = "initial_download" | |
| # Verify HF token is set | |
| if not download_channel.HF_TOKEN: | |
| console.print("[red]ERROR: HF_TOKEN not set. Please set your Hugging Face token.[/red]") | |
| return | |
| # Create dataset structure if needed | |
| console.print("[yellow]Checking Hugging Face dataset...[/yellow]") | |
| try: | |
| state = download_state_from_hf(download_channel.HF_TOKEN) | |
| console.print(f"[green]Using channel:[/green] {state.channel}") | |
| # Create files directory in dataset if it doesn't exist | |
| os.makedirs(download_channel.OUTPUT_DIR, exist_ok=True) | |
| # Start the download process with default settings | |
| asyncio.create_task(run_download( | |
| channel=None, # Use default from download_channel.py | |
| message_limit=None, # Use default | |
| task_id=task_id | |
| )) | |
| console.print(f"[green]Started initial download task:[/green] {task_id}") | |
| except Exception as e: | |
| console.print(f"[red]Failed to initialize:[/red] {str(e)}") | |
| async def start_download(request: DownloadRequest, background_tasks: BackgroundTasks): | |
| """Start a new download task""" | |
| task_id = f"download_{len(active_downloads) + 1}" | |
| background_tasks.add_task( | |
| run_download, | |
| channel=request.channel, | |
| message_limit=request.message_limit, | |
| task_id=task_id | |
| ) | |
| return {"task_id": task_id} | |
| async def get_status(task_id: str): | |
| """Get the status of a download task""" | |
| if task_id not in active_downloads: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| return active_downloads[task_id] | |
| async def list_active(): | |
| """List all active or completed downloads""" | |
| return active_downloads | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="127.0.0.1", port=8000) |