davidtran999 commited on
Commit
db9ca2f
·
verified ·
1 Parent(s): 3f040cc

Upload backend/hue_portal/chatbot/download_progress.py with huggingface_hub

Browse files
backend/hue_portal/chatbot/download_progress.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Download progress tracker for Hugging Face models.
3
+ Tracks real-time download progress in bytes.
4
+ """
5
+ import threading
6
+ import time
7
+ from typing import Dict, Optional
8
+ from dataclasses import dataclass, field
9
+
10
+
11
+ @dataclass
12
+ class DownloadProgress:
13
+ """Track download progress for a single file."""
14
+ filename: str
15
+ total_bytes: int = 0
16
+ downloaded_bytes: int = 0
17
+ started_at: Optional[float] = None
18
+ completed_at: Optional[float] = None
19
+ speed_bytes_per_sec: float = 0.0
20
+
21
+ @property
22
+ def percentage(self) -> float:
23
+ """Calculate download percentage."""
24
+ if self.total_bytes == 0:
25
+ return 0.0
26
+ return min(100.0, (self.downloaded_bytes / self.total_bytes) * 100.0)
27
+
28
+ @property
29
+ def is_complete(self) -> bool:
30
+ """Check if download is complete."""
31
+ return self.total_bytes > 0 and self.downloaded_bytes >= self.total_bytes
32
+
33
+ @property
34
+ def elapsed_time(self) -> float:
35
+ """Get elapsed time in seconds."""
36
+ if self.started_at is None:
37
+ return 0.0
38
+ end_time = self.completed_at or time.time()
39
+ return end_time - self.started_at
40
+
41
+
42
+ @dataclass
43
+ class ModelDownloadProgress:
44
+ """Track overall download progress for a model."""
45
+ model_path: str
46
+ files: Dict[str, DownloadProgress] = field(default_factory=dict)
47
+ started_at: Optional[float] = None
48
+ completed_at: Optional[float] = None
49
+
50
+ def update_file(self, filename: str, downloaded: int, total: int):
51
+ """Update progress for a specific file."""
52
+ if filename not in self.files:
53
+ self.files[filename] = DownloadProgress(
54
+ filename=filename,
55
+ started_at=time.time()
56
+ )
57
+ if self.started_at is None:
58
+ self.started_at = time.time()
59
+
60
+ file_progress = self.files[filename]
61
+ file_progress.downloaded_bytes = downloaded
62
+ file_progress.total_bytes = total
63
+
64
+ # Calculate speed
65
+ if file_progress.started_at:
66
+ elapsed = time.time() - file_progress.started_at
67
+ if elapsed > 0:
68
+ file_progress.speed_bytes_per_sec = downloaded / elapsed
69
+
70
+ # Mark as complete
71
+ if total > 0 and downloaded >= total:
72
+ file_progress.completed_at = time.time()
73
+
74
+ def complete_file(self, filename: str):
75
+ """Mark a file as complete."""
76
+ if filename in self.files:
77
+ self.files[filename].completed_at = time.time()
78
+
79
+ @property
80
+ def total_bytes(self) -> int:
81
+ """Get total bytes across all files."""
82
+ return sum(f.total_bytes for f in self.files.values())
83
+
84
+ @property
85
+ def downloaded_bytes(self) -> int:
86
+ """Get downloaded bytes across all files."""
87
+ return sum(f.downloaded_bytes for f in self.files.values())
88
+
89
+ @property
90
+ def percentage(self) -> float:
91
+ """Calculate overall download percentage."""
92
+ total = self.total_bytes
93
+ if total == 0:
94
+ # If no total yet, count completed files
95
+ if len(self.files) == 0:
96
+ return 0.0
97
+ completed = sum(1 for f in self.files.values() if f.is_complete)
98
+ return (completed / len(self.files)) * 100.0
99
+ return min(100.0, (self.downloaded_bytes / total) * 100.0)
100
+
101
+ @property
102
+ def is_complete(self) -> bool:
103
+ """Check if all files are downloaded."""
104
+ if len(self.files) == 0:
105
+ return False
106
+ return all(f.is_complete for f in self.files.values())
107
+
108
+ @property
109
+ def speed_bytes_per_sec(self) -> float:
110
+ """Get overall download speed."""
111
+ total_speed = sum(f.speed_bytes_per_sec for f in self.files.values() if f.started_at)
112
+ return total_speed
113
+
114
+ @property
115
+ def elapsed_time(self) -> float:
116
+ """Get elapsed time in seconds."""
117
+ if self.started_at is None:
118
+ return 0.0
119
+ end_time = self.completed_at or time.time()
120
+ return end_time - self.started_at
121
+
122
+ def to_dict(self) -> Dict:
123
+ """Convert to dictionary for JSON serialization."""
124
+ return {
125
+ "model_path": self.model_path,
126
+ "total_bytes": self.total_bytes,
127
+ "downloaded_bytes": self.downloaded_bytes,
128
+ "percentage": round(self.percentage, 2),
129
+ "speed_bytes_per_sec": round(self.speed_bytes_per_sec, 2),
130
+ "speed_mb_per_sec": round(self.speed_bytes_per_sec / (1024 * 1024), 2),
131
+ "elapsed_time": round(self.elapsed_time, 2),
132
+ "is_complete": self.is_complete,
133
+ "files_count": len(self.files),
134
+ "files_completed": sum(1 for f in self.files.values() if f.is_complete),
135
+ "files": {
136
+ name: {
137
+ "filename": f.filename,
138
+ "total_bytes": f.total_bytes,
139
+ "downloaded_bytes": f.downloaded_bytes,
140
+ "percentage": round(f.percentage, 2),
141
+ "speed_mb_per_sec": round(f.speed_bytes_per_sec / (1024 * 1024), 2),
142
+ "is_complete": f.is_complete
143
+ }
144
+ for name, f in self.files.items()
145
+ }
146
+ }
147
+
148
+
149
+ class ProgressTracker:
150
+ """Thread-safe progress tracker for multiple models."""
151
+
152
+ def __init__(self):
153
+ self._progress: Dict[str, ModelDownloadProgress] = {}
154
+ self._lock = threading.Lock()
155
+
156
+ def get_or_create(self, model_path: str) -> ModelDownloadProgress:
157
+ """Get or create progress tracker for a model."""
158
+ with self._lock:
159
+ if model_path not in self._progress:
160
+ self._progress[model_path] = ModelDownloadProgress(model_path=model_path)
161
+ return self._progress[model_path]
162
+
163
+ def get(self, model_path: str) -> Optional[ModelDownloadProgress]:
164
+ """Get progress tracker for a model."""
165
+ with self._lock:
166
+ return self._progress.get(model_path)
167
+
168
+ def update(self, model_path: str, filename: str, downloaded: int, total: int):
169
+ """Update download progress for a file."""
170
+ progress = self.get_or_create(model_path)
171
+ progress.update_file(filename, downloaded, total)
172
+
173
+ def complete_file(self, model_path: str, filename: str):
174
+ """Mark a file as complete."""
175
+ progress = self.get(model_path)
176
+ if progress:
177
+ progress.complete_file(filename)
178
+
179
+ def complete_model(self, model_path: str):
180
+ """Mark entire model download as complete."""
181
+ progress = self.get(model_path)
182
+ if progress:
183
+ progress.completed_at = time.time()
184
+
185
+ def get_all(self) -> Dict[str, Dict]:
186
+ """Get all progress as dictionary."""
187
+ with self._lock:
188
+ return {
189
+ path: prog.to_dict()
190
+ for path, prog in self._progress.items()
191
+ }
192
+
193
+ def get_model_progress(self, model_path: str) -> Optional[Dict]:
194
+ """Get progress for a specific model."""
195
+ progress = self.get(model_path)
196
+ if progress:
197
+ return progress.to_dict()
198
+ return None
199
+
200
+
201
+ # Global progress tracker instance
202
+ _global_tracker = ProgressTracker()
203
+
204
+
205
+ def get_progress_tracker() -> ProgressTracker:
206
+ """Get global progress tracker instance."""
207
+ return _global_tracker
208
+
209
+
210
+ def create_progress_callback(model_path: str):
211
+ """
212
+ Create a progress callback for huggingface_hub downloads.
213
+
214
+ Usage:
215
+ from huggingface_hub import snapshot_download
216
+ callback = create_progress_callback("Qwen/Qwen2.5-32B-Instruct")
217
+ snapshot_download(repo_id=model_path, resume_download=True,
218
+ tqdm_class=callback)
219
+ """
220
+ tracker = get_progress_tracker()
221
+
222
+ class ProgressCallback:
223
+ """Progress callback for tqdm."""
224
+
225
+ def __init__(self, *args, **kwargs):
226
+ # Store tqdm arguments but don't initialize yet
227
+ self.tqdm_args = args
228
+ self.tqdm_kwargs = kwargs
229
+ self.current_file = None
230
+
231
+ def __call__(self, *args, **kwargs):
232
+ # This will be called by huggingface_hub
233
+ # We'll intercept the progress updates
234
+ pass
235
+
236
+ def update(self, n: int = 1):
237
+ """Update progress."""
238
+ if self.current_file:
239
+ # Get current progress from tqdm
240
+ if hasattr(self, 'n'):
241
+ downloaded = self.n
242
+ else:
243
+ downloaded = n
244
+ if hasattr(self, 'total'):
245
+ total = self.total
246
+ else:
247
+ total = 0
248
+ tracker.update(model_path, self.current_file, downloaded, total)
249
+
250
+ def set_description(self, desc: str):
251
+ """Set description (filename)."""
252
+ # Extract filename from description
253
+ if desc:
254
+ self.current_file = desc.split()[-1] if ' ' in desc else desc
255
+
256
+ def close(self):
257
+ """Close progress bar."""
258
+ if self.current_file:
259
+ tracker.complete_file(model_path, self.current_file)
260
+
261
+ return ProgressCallback
262
+
263
+
264
+ def create_hf_progress_callback(model_path: str):
265
+ """
266
+ Create a progress callback compatible with huggingface_hub.
267
+ Returns a function that can be used with tqdm.
268
+ """
269
+ tracker = get_progress_tracker()
270
+ current_file = [None] # Use list to allow modification in nested function
271
+
272
+ def progress_callback(tqdm_bar):
273
+ """Progress callback function."""
274
+ if tqdm_bar.desc:
275
+ # Extract filename from description
276
+ filename = tqdm_bar.desc.split()[-1] if ' ' in tqdm_bar.desc else tqdm_bar.desc
277
+ if filename != current_file[0]:
278
+ current_file[0] = filename
279
+ if current_file[0] not in tracker.get_or_create(model_path).files:
280
+ tracker.get_or_create(model_path).files[current_file[0]] = DownloadProgress(
281
+ filename=current_file[0],
282
+ started_at=time.time()
283
+ )
284
+
285
+ if current_file[0]:
286
+ downloaded = getattr(tqdm_bar, 'n', 0)
287
+ total = getattr(tqdm_bar, 'total', 0)
288
+ tracker.update(model_path, current_file[0], downloaded, total)
289
+
290
+ return progress_callback
291
+
292
+
293
+
294
+