Fred808 commited on
Commit
8acc639
·
verified ·
1 Parent(s): 01870a8

Upload 2 files

Browse files
Files changed (2) hide show
  1. api_server.py +398 -192
  2. requirements.txt +1 -0
api_server.py CHANGED
@@ -1,193 +1,399 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
- from fastapi.responses import JSONResponse
3
- import asyncio
4
- import os
5
- from typing import Optional, Dict, Any
6
- from pydantic import BaseModel
7
- import download_channel # Import the existing downloader
8
-
9
- app = FastAPI(title="Telegram Channel Downloader API")
10
-
11
- # Track active downloads and their status
12
- active_downloads: Dict[str, Dict[str, Any]] = {}
13
-
14
- @app.on_event("startup")
15
- async def start_initial_download():
16
- """Start the download process automatically when the server starts"""
17
- task_id = "initial_download"
18
- # Start the download process with default settings
19
- asyncio.create_task(run_download(
20
- channel=None, # Use default from download_channel.py
21
- message_limit=None, # Use default
22
- task_id=task_id
23
- ))
24
- print(f"Started initial download task with ID: {task_id}")
25
-
26
- class DownloadRequest(BaseModel):
27
- channel: Optional[str] = None # If None, uses default from download_channel.py
28
- message_limit: Optional[int] = None
29
-
30
- class DownloadStatus(BaseModel):
31
- channel: str
32
- status: str # "running", "completed", "failed"
33
- message_count: int = 0
34
- downloaded: int = 0
35
- skipped: int = 0
36
- not_rar: int = 0
37
- error: Optional[str] = None
38
-
39
- async def run_download(channel: Optional[str], message_limit: Optional[int], task_id: str):
40
- """Background task to run the download"""
41
- try:
42
- # Override channel and message limit if provided
43
- if channel:
44
- download_channel.CHANNEL = channel
45
- if message_limit is not None:
46
- download_channel.MESSAGE_LIMIT = message_limit
47
-
48
- # Create a status tracker
49
- status = {
50
- "channel": download_channel.CHANNEL,
51
- "status": "running",
52
- "message_count": 0,
53
- "downloaded": 0,
54
- "skipped": 0,
55
- "not_rar": 0,
56
- "error": None
57
- }
58
- active_downloads[task_id] = status
59
-
60
- # Patch the download function to update our status
61
- original_download = download_channel.download_channel
62
-
63
- async def wrapped_download():
64
- nonlocal status
65
- try:
66
- # Initialize client and get entity
67
- client = download_channel.TelegramClient(
68
- download_channel.SESSION_FILE,
69
- download_channel.API_ID,
70
- download_channel.API_HASH
71
- )
72
-
73
- async with client:
74
- try:
75
- entity = await client.get_entity(download_channel.CHANNEL)
76
- except Exception as e:
77
- status["error"] = f"Failed to resolve channel: {str(e)}"
78
- status["status"] = "failed"
79
- return 1
80
-
81
- try:
82
- async for message in client.iter_messages(entity, limit=download_channel.MESSAGE_LIMIT or None):
83
- status["message_count"] += 1
84
-
85
- if not message.media:
86
- continue
87
-
88
- # Check if it's a RAR file
89
- is_rar = False
90
- if message.file:
91
- filename = getattr(message.file, 'name', '') or ''
92
- if filename:
93
- is_rar = filename.lower().endswith('.rar')
94
- else:
95
- mime_type = getattr(message.file, 'mime_type', '') or ''
96
- is_rar = 'rar' in mime_type.lower() if mime_type else False
97
-
98
- if not is_rar:
99
- status["not_rar"] += 1
100
- continue
101
-
102
- # Use the same filename logic
103
- if filename:
104
- suggested = f"{message.id}_{filename}"
105
- else:
106
- suggested = f"{message.id}.rar"
107
-
108
- out_path = os.path.join(download_channel.OUTPUT_DIR, suggested)
109
-
110
- if os.path.exists(out_path):
111
- status["skipped"] += 1
112
- continue
113
-
114
- try:
115
- await client.download_media(message, file=out_path)
116
- status["downloaded"] += 1
117
-
118
- # Upload to HF if token available
119
- if download_channel.HF_TOKEN:
120
- path_in_repo = f"files/{os.path.basename(out_path)}"
121
- ok = download_channel.upload_file_to_hf(
122
- out_path, path_in_repo, download_channel.HF_TOKEN
123
- )
124
- if not ok:
125
- print(f"Warning: failed to upload {out_path}")
126
-
127
- await asyncio.sleep(0.2) # Be polite
128
-
129
- except download_channel.errors.FloodWaitError as fw:
130
- wait = int(fw.seconds) if fw.seconds else 60
131
- print(f"Hit FloodWait: sleeping {wait}s")
132
- await asyncio.sleep(wait + 1)
133
- except Exception as e:
134
- print(f"Error downloading {message.id}: {e}")
135
-
136
- except Exception as e:
137
- status["error"] = str(e)
138
- status["status"] = "failed"
139
- return 1
140
-
141
- status["status"] = "completed"
142
- return 0
143
-
144
- except Exception as e:
145
- status["error"] = str(e)
146
- status["status"] = "failed"
147
- return 1
148
-
149
- await wrapped_download()
150
-
151
- except Exception as e:
152
- active_downloads[task_id] = {
153
- "channel": download_channel.CHANNEL,
154
- "status": "failed",
155
- "message_count": 0,
156
- "downloaded": 0,
157
- "skipped": 0,
158
- "not_rar": 0,
159
- "error": str(e)
160
- }
161
-
162
- @app.post("/download", response_model=Dict[str, str])
163
- async def start_download(request: DownloadRequest, background_tasks: BackgroundTasks):
164
- """Start a new download task"""
165
- task_id = f"download_{len(active_downloads) + 1}"
166
-
167
- # Schedule the download
168
- background_tasks.add_task(
169
- run_download,
170
- channel=request.channel,
171
- message_limit=request.message_limit,
172
- task_id=task_id
173
- )
174
-
175
- return {"task_id": task_id}
176
-
177
- @app.get("/status/{task_id}", response_model=DownloadStatus)
178
- async def get_status(task_id: str):
179
- """Get the status of a download task"""
180
- if task_id not in active_downloads:
181
- raise HTTPException(status_code=404, detail="Task not found")
182
- return active_downloads[task_id]
183
-
184
- @app.get("/active", response_model=Dict[str, DownloadStatus])
185
- async def list_active():
186
- """List all active or completed downloads"""
187
- return active_downloads
188
-
189
- if __name__ == "__main__":
190
- import uvicorn
191
- # Note: When running directly, this runs on 8000
192
- # For production, use: uvicorn api_server:app --host 0.0.0.0 --port 8000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  uvicorn.run(app, host="127.0.0.1", port=8000)
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
+ from fastapi.responses import JSONResponse
3
+ import asyncio
4
+ import os
5
+ import time
6
+ from typing import Optional, Dict, Any, List
7
+ from enum import Enum
8
+ from pydantic import BaseModel
9
+ from rich.progress import (
10
+ Progress,
11
+ SpinnerColumn,
12
+ TimeElapsedColumn,
13
+ DownloadColumn,
14
+ TransferSpeedColumn,
15
+ BarColumn,
16
+ TextColumn,
17
+ )
18
+ from rich.console import Console
19
+ from rich.live import Live
20
+ from rich.table import Table
21
+ import download_channel
22
+
23
+ # Initialize rich console for pretty logging
24
+ console = Console()
25
+
26
+ app = FastAPI(title="Telegram Channel Downloader API")
27
+
28
+ # Track active downloads and their status
29
+ active_downloads: Dict[str, Dict[str, Any]] = {}
30
+
31
+ class FileStatus(str, Enum):
32
+ PENDING = "pending"
33
+ DOWNLOADING = "downloading"
34
+ DOWNLOADED = "downloaded"
35
+ FAILED = "failed"
36
+
37
+ class ChannelFile(BaseModel):
38
+ message_id: int
39
+ filename: str
40
+ status: FileStatus
41
+ size: Optional[int] = None
42
+ download_time: Optional[float] = None
43
+ error: Optional[str] = None
44
+ upload_path: Optional[str] = None
45
+
46
+ class DownloadState(BaseModel):
47
+ channel: str
48
+ last_scanned_id: Optional[int] = None
49
+ files: List[ChannelFile] = []
50
+ current_download: Optional[int] = None # message_id of current download
51
+ last_updated: float = time.time()
52
+
53
+ class DownloadRequest(BaseModel):
54
+ channel: Optional[str] = None
55
+ message_limit: Optional[int] = None
56
+
57
+ class DownloadStatus(BaseModel):
58
+ channel: str
59
+ status: str
60
+ message_count: int = 0
61
+ downloaded: int = 0
62
+ downloading: Optional[str] = None
63
+ error: Optional[str] = None
64
+
65
+ def download_state_from_hf(token: str) -> DownloadState:
66
+ """Try to download the state file from the HF dataset. Returns state dict or creates new."""
67
+ if not token:
68
+ return DownloadState(channel=download_channel.CHANNEL)
69
+ try:
70
+ # Try to download existing state
71
+ local_path = download_channel.hf_hub_download(
72
+ repo_id=download_channel.HF_REPO_ID,
73
+ filename=download_channel.STATE_FILE,
74
+ repo_type="dataset",
75
+ token=token
76
+ )
77
+ with open(local_path, "r", encoding="utf-8") as f:
78
+ data = json.load(f)
79
+ return DownloadState(**data)
80
+ except Exception as e:
81
+ console.print(f"[yellow]No existing state found, creating new:[/yellow] {str(e)}")
82
+ return DownloadState(channel=download_channel.CHANNEL)
83
+
84
+ async def clean_downloaded_file(file_path: str):
85
+ """Remove local file after successful upload"""
86
+ try:
87
+ os.remove(file_path)
88
+ console.print(f"[blue]Cleaned up:[/blue] {os.path.basename(file_path)}")
89
+ except Exception as e:
90
+ console.print(f"[yellow]Warning:[/yellow] Could not clean up {file_path}: {e}")
91
+
92
+ async def update_and_upload_state(state: DownloadState, token: str) -> bool:
93
+ """Update state timestamp and upload to dataset"""
94
+ state.last_updated = time.time()
95
+ try:
96
+ # Save state locally first
97
+ with open(download_channel.STATE_FILE, "w", encoding="utf-8") as f:
98
+ json.dump(state.dict(), f, indent=2, ensure_ascii=False)
99
+ # Upload to dataset
100
+ return download_channel.upload_file_to_hf(
101
+ download_channel.STATE_FILE,
102
+ download_channel.STATE_FILE,
103
+ token
104
+ )
105
+ except Exception as e:
106
+ console.print(f"[red]Failed to update state:[/red] {e}")
107
+ return False
108
+
109
+ async def process_message(message, state: DownloadState, client) -> Optional[str]:
110
+ """Process a single message, return output path if file downloaded or None"""
111
+ if not message.media:
112
+ return None
113
+
114
+ # Check if it's a RAR file
115
+ is_rar = False
116
+ filename = ""
117
+ if message.file:
118
+ filename = getattr(message.file, 'name', '') or ''
119
+ if filename:
120
+ is_rar = filename.lower().endswith('.rar')
121
+ else:
122
+ mime_type = getattr(message.file, 'mime_type', '') or ''
123
+ is_rar = 'rar' in mime_type.lower() if mime_type else False
124
+
125
+ if not is_rar:
126
+ return None
127
+
128
+ # Use message ID and original filename for saved file
129
+ if filename:
130
+ suggested = f"{message.id}_{filename}"
131
+ else:
132
+ suggested = f"{message.id}.rar"
133
+
134
+ return os.path.join(download_channel.OUTPUT_DIR, suggested)
135
+
136
+ async def run_download(channel: Optional[str], message_limit: Optional[int], task_id: str):
137
+ """Background task to run the download with state management"""
138
+ try:
139
+ # Override channel if provided
140
+ if channel:
141
+ download_channel.CHANNEL = channel
142
+ if message_limit is not None:
143
+ download_channel.MESSAGE_LIMIT = message_limit
144
+
145
+ # Get or create download state
146
+ state = download_state_from_hf(download_channel.HF_TOKEN)
147
+
148
+ # Initialize status for API
149
+ status = {
150
+ "channel": state.channel,
151
+ "status": "running",
152
+ "message_count": len(state.files),
153
+ "downloaded": len([f for f in state.files if f.status == FileStatus.DOWNLOADED]),
154
+ "downloading": None,
155
+ "error": None
156
+ }
157
+ active_downloads[task_id] = status
158
+
159
+ # Create progress displays
160
+ progress = Progress(
161
+ SpinnerColumn(),
162
+ TextColumn("[bold blue]{task.fields[filename]}", justify="right"),
163
+ BarColumn(bar_width=40),
164
+ "[progress.percentage]{task.percentage:>3.1f}%",
165
+ "",
166
+ DownloadColumn(),
167
+ "•",
168
+ TransferSpeedColumn(),
169
+ "•",
170
+ TimeElapsedColumn(),
171
+ )
172
+
173
+ overall_progress = Progress(
174
+ TextColumn("[bold yellow]{task.description}", justify="right"),
175
+ BarColumn(bar_width=40),
176
+ "[progress.percentage]{task.percentage:>3.1f}%",
177
+ "",
178
+ TextColumn("[bold green]{task.fields[stats]}")
179
+ )
180
+
181
+ # Initialize client
182
+ client = download_channel.TelegramClient(
183
+ download_channel.SESSION_FILE,
184
+ download_channel.API_ID,
185
+ download_channel.API_HASH
186
+ )
187
+
188
+ async with client:
189
+ try:
190
+ entity = await client.get_entity(download_channel.CHANNEL)
191
+ except Exception as e:
192
+ console.print(f"[red]Failed to resolve channel:[/red] {e}")
193
+ return 1
194
+
195
+ console.print(f"[green]Starting download from:[/green] {entity.title if hasattr(entity, 'title') else download_channel.CHANNEL}")
196
+
197
+ # First, scan for new messages and update state
198
+ scan_count = 0
199
+ last_message_id = state.last_scanned_id
200
+
201
+ try:
202
+ async for message in client.iter_messages(entity, limit=download_channel.MESSAGE_LIMIT or None):
203
+ scan_count += 1
204
+
205
+ # Update last scanned ID
206
+ if last_message_id is None or message.id > last_message_id:
207
+ last_message_id = message.id
208
+
209
+ # Skip if we already know about this message
210
+ if any(f.message_id == message.id for f in state.files):
211
+ continue
212
+
213
+ # Check if it's a downloadable file
214
+ out_path = await process_message(message, state, client)
215
+ if out_path:
216
+ # Add to state as pending
217
+ file_info = ChannelFile(
218
+ message_id=message.id,
219
+ filename=os.path.basename(out_path),
220
+ status=FileStatus.PENDING,
221
+ size=getattr(message.media, 'size', 0) or 0
222
+ )
223
+ state.files.append(file_info)
224
+
225
+ # Update state with scan results
226
+ state.last_scanned_id = last_message_id
227
+ if download_channel.HF_TOKEN:
228
+ await update_and_upload_state(state, download_channel.HF_TOKEN)
229
+
230
+ console.print(f"[green]Channel scan complete:[/green] Found {scan_count} messages")
231
+
232
+ except Exception as e:
233
+ console.print(f"[red]Error during channel scan:[/red] {e}")
234
+
235
+ # Now process pending downloads
236
+ pending_files = [f for f in state.files if f.status == FileStatus.PENDING]
237
+ total_pending = len(pending_files)
238
+
239
+ if total_pending == 0:
240
+ console.print("[green]No new files to download![/green]")
241
+ return 0
242
+
243
+ console.print(f"[green]Starting downloads:[/green] {total_pending} files pending")
244
+
245
+ # Process pending files
246
+ with Live(progress) as live_progress, Live(overall_progress) as live_overall:
247
+ overall_task = overall_progress.add_task(
248
+ f"Channel: {download_channel.CHANNEL}",
249
+ total=total_pending,
250
+ stats=f"Pending: {total_pending}"
251
+ )
252
+
253
+ for file_info in pending_files:
254
+ try:
255
+ # Mark as downloading in state
256
+ file_info.status = FileStatus.DOWNLOADING
257
+ state.current_download = file_info.message_id
258
+ if download_channel.HF_TOKEN:
259
+ await update_and_upload_state(state, download_channel.HF_TOKEN)
260
+
261
+ # Update status
262
+ status["downloading"] = file_info.filename
263
+
264
+ # Get message and prepare download
265
+ message = await client.get_messages(entity, ids=file_info.message_id)
266
+ if not message or not message.media:
267
+ file_info.status = FileStatus.FAILED
268
+ file_info.error = "Message not found or no media"
269
+ continue
270
+
271
+ out_path = os.path.join(download_channel.OUTPUT_DIR, file_info.filename)
272
+ file_task = progress.add_task(
273
+ "download",
274
+ total=file_info.size or 100,
275
+ filename=file_info.filename
276
+ )
277
+
278
+ # Download with progress
279
+ start_time = time.time()
280
+ try:
281
+ async def progress_callback(current, total):
282
+ progress.update(file_task, completed=current)
283
+ overall_stats = f"Downloaded: {len([f for f in state.files if f.status == FileStatus.DOWNLOADED])}"
284
+ overall_progress.update(overall_task, completed=current/total*100, stats=overall_stats)
285
+
286
+ await client.download_media(
287
+ message,
288
+ file=out_path,
289
+ progress_callback=progress_callback
290
+ )
291
+
292
+ # Upload to HF
293
+ if download_channel.HF_TOKEN:
294
+ console.print(f"[yellow]Uploading to HF:[/yellow] {file_info.filename}")
295
+ path_in_repo = f"files/{file_info.filename}"
296
+ ok = download_channel.upload_file_to_hf(
297
+ out_path,
298
+ path_in_repo,
299
+ download_channel.HF_TOKEN
300
+ )
301
+ if ok:
302
+ console.print(f"[green]Uploaded:[/green] {file_info.filename}")
303
+ # Clean up local file
304
+ await clean_downloaded_file(out_path)
305
+ file_info.upload_path = path_in_repo
306
+ else:
307
+ console.print(f"[red]Upload failed:[/red] {file_info.filename}")
308
+ file_info.error = "Upload to dataset failed"
309
+ file_info.status = FileStatus.FAILED
310
+ continue
311
+
312
+ # Mark as completed in state
313
+ file_info.status = FileStatus.DOWNLOADED
314
+ file_info.download_time = time.time() - start_time
315
+
316
+ # Update state
317
+ if download_channel.HF_TOKEN:
318
+ await update_and_upload_state(state, download_channel.HF_TOKEN)
319
+
320
+ # Update status
321
+ status["downloaded"] += 1
322
+ await asyncio.sleep(0.2) # Be polite
323
+
324
+ except download_channel.errors.FloodWaitError as fw:
325
+ wait = int(fw.seconds) if fw.seconds else 60
326
+ console.print(f"[yellow]FloodWait:[/yellow] Sleeping {wait}s")
327
+ await asyncio.sleep(wait + 1)
328
+ # Retry this file
329
+ continue
330
+
331
+ except Exception as e:
332
+ console.print(f"[red]Error:[/red] {str(e)}")
333
+ file_info.status = FileStatus.FAILED
334
+ file_info.error = str(e)
335
+ if download_channel.HF_TOKEN:
336
+ await update_and_upload_state(state, download_channel.HF_TOKEN)
337
+
338
+ except Exception as e:
339
+ console.print(f"[red]Fatal error processing {file_info.filename}:[/red] {str(e)}")
340
+ continue
341
+
342
+ # Clear current download
343
+ state.current_download = None
344
+ if download_channel.HF_TOKEN:
345
+ await update_and_upload_state(state, download_channel.HF_TOKEN)
346
+
347
+ console.print("[green]Download session completed![/green]")
348
+ status["status"] = "completed"
349
+ status["downloading"] = None
350
+
351
+ except Exception as e:
352
+ console.print(f"[red]Fatal error:[/red] {str(e)}")
353
+ if "status" in locals():
354
+ status["status"] = "failed"
355
+ status["error"] = str(e)
356
+
357
+ return 0
358
+
359
+ @app.on_event("startup")
360
+ async def start_initial_download():
361
+ """Start the download process automatically when the server starts"""
362
+ task_id = "initial_download"
363
+ # Start the download process with default settings
364
+ asyncio.create_task(run_download(
365
+ channel=None, # Use default from download_channel.py
366
+ message_limit=None, # Use default
367
+ task_id=task_id
368
+ ))
369
+ console.print(f"[green]Started initial download task:[/green] {task_id}")
370
+
371
+ @app.post("/download", response_model=Dict[str, str])
372
+ async def start_download(request: DownloadRequest, background_tasks: BackgroundTasks):
373
+ """Start a new download task"""
374
+ task_id = f"download_{len(active_downloads) + 1}"
375
+
376
+ background_tasks.add_task(
377
+ run_download,
378
+ channel=request.channel,
379
+ message_limit=request.message_limit,
380
+ task_id=task_id
381
+ )
382
+
383
+ return {"task_id": task_id}
384
+
385
+ @app.get("/status/{task_id}", response_model=DownloadStatus)
386
+ async def get_status(task_id: str):
387
+ """Get the status of a download task"""
388
+ if task_id not in active_downloads:
389
+ raise HTTPException(status_code=404, detail="Task not found")
390
+ return active_downloads[task_id]
391
+
392
+ @app.get("/active", response_model=Dict[str, DownloadStatus])
393
+ async def list_active():
394
+ """List all active or completed downloads"""
395
+ return active_downloads
396
+
397
+ if __name__ == "__main__":
398
+ import uvicorn
399
  uvicorn.run(app, host="127.0.0.1", port=8000)
requirements.txt CHANGED
@@ -3,3 +3,4 @@ huggingface_hub>=0.17.0
3
  fastapi>=0.104.0
4
  uvicorn>=0.24.0
5
  pydantic>=2.4.2
 
 
3
  fastapi>=0.104.0
4
  uvicorn>=0.24.0
5
  pydantic>=2.4.2
6
+ rich>=13.6.0