context-thread-agent / src\notebook_downloader.py
mozzic's picture
Upload src\notebook_downloader.py with huggingface_hub
d788958 verified
raw
history blame
3.59 kB
"""
Notebook downloader for collecting sample notebooks
"""
import requests
from pathlib import Path
from typing import List
import time
import json
class NotebookDownloader:
"""Download sample notebooks from GitHub."""
def __init__(self, output_dir: str):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
def download_all(self) -> List[str]:
"""Download notebooks from various sources."""
downloaded = []
# Download from predefined sources
sources = [
self._download_from_github,
]
for source_func in sources:
try:
notebooks = source_func()
downloaded.extend(notebooks)
except Exception as e:
print(f"Error downloading from source: {e}")
return downloaded
def _download_from_github(self) -> List[str]:
"""Download notebooks from GitHub repositories."""
repos = [
"pandas-dev/pandas",
"matplotlib/matplotlib",
"scikit-learn/scikit-learn",
"statsmodels/statsmodels"
]
downloaded = []
for repo in repos:
try:
print(f"Fetching from {repo}...")
notebooks = self._search_github_notebooks(repo)
for nb_url, nb_name in notebooks[:2]: # Limit per repo
try:
self._download_notebook(nb_url, nb_name)
downloaded.append(nb_name)
time.sleep(1) # Rate limiting
except Exception as e:
print(f"Failed to download {nb_name}: {e}")
except Exception as e:
print(f"Failed to fetch from {repo}: {e}")
return downloaded
def _search_github_notebooks(self, repo: str) -> List[tuple]:
"""Search for notebooks in a GitHub repo."""
# This is a simplified version - in practice, you'd use GitHub API
# For now, return some known notebook URLs
known_notebooks = {
"pandas-dev/pandas": [
("https://raw.githubusercontent.com/pandas-dev/pandas/main/doc/source/user_guide/10min.ipynb", "pandas_10min.ipynb")
],
"matplotlib/matplotlib": [
("https://raw.githubusercontent.com/matplotlib/matplotlib/main/tutorials/introductory/sample_plots.ipynb", "matplotlib_sample.ipynb")
],
"scikit-learn/scikit-learn": [
("https://raw.githubusercontent.com/scikit-learn/scikit-learn/main/examples/linear_model/plot_ols.ipynb", "sklearn_ols.ipynb")
]
}
return known_notebooks.get(repo, [])
def _download_notebook(self, url: str, filename: str):
"""Download a single notebook."""
response = requests.get(url, timeout=10)
response.raise_for_status()
# Validate it's a notebook
try:
data = response.json()
if 'cells' not in data:
raise ValueError("Not a valid notebook")
except:
raise ValueError("Invalid notebook format")
output_path = self.output_dir / filename
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=1)
print(f"✓ Downloaded: {filename}")