Spaces:
Sleeping
Sleeping
| import contextlib | |
| from typing import Literal, Tuple, List | |
| import httpx | |
| import nbformat | |
| from nbformat import NotebookNode, ValidationError | |
| from nbconvert import HTMLExporter | |
| from starlette.applications import Starlette | |
| from starlette.exceptions import HTTPException | |
| from starlette.responses import FileResponse, JSONResponse, HTMLResponse | |
| from starlette.requests import Request | |
| from starlette.routing import Route | |
| from nbconvert.preprocessors import Preprocessor | |
| import re | |
| from traitlets.config import Config | |
| from huggingface_hub import model_info, dataset_info | |
| from huggingface_hub.utils import RepositoryNotFoundError | |
| from functools import lru_cache | |
| hub_id_regex = re.compile(r"[^\w]([a-zA-Z\d-]{3,32}\/[\w\-._]{3,64})[^\w/]") | |
| # TODO possibly make async but might be tricky to call inside PreProcessor | |
| def check_hub_item(hub_id_match): | |
| with contextlib.suppress(RepositoryNotFoundError): | |
| model_info(hub_id_match) | |
| return hub_id_match, "model" | |
| with contextlib.suppress(RepositoryNotFoundError): | |
| dataset_info(hub_id_match) | |
| return hub_id_match, "dataset" | |
| # async def check_repo_exists(regex_hub_id_match: str) -> Optional[Tuple[str, str]]: | |
| # r = await client.get(f"https://huggingface.co/api/models/{regex_hub_id_match}") | |
| # if r.status_code == 200: | |
| # return regex_hub_id_match, 'model' | |
| # r = await client.get(f"https://huggingface.co/api/datasets/{regex_hub_id_match}") | |
| # if r.status_code == 200: | |
| # return regex_hub_id_match, 'dataset' | |
| class HubIDCell(Preprocessor): | |
| def preprocess_cell(self, cell, resources, index): | |
| if cell["cell_type"] == "code": | |
| resources.setdefault("dataset_matches", set()) | |
| resources.setdefault("model_matches", set()) | |
| if match := re.search(hub_id_regex, cell["source"]): | |
| hub_id_match = match.groups(0)[0] | |
| if ( | |
| hub_id_match not in resources["model_matches"] | |
| or resources["dataset_matches"] | |
| ): | |
| if hub_check := check_hub_item(hub_id_match): | |
| hub_id_match, repo_item_type = hub_check | |
| if repo_item_type == "model": | |
| resources["model_matches"].add(hub_id_match) | |
| if repo_item_type == "dataset": | |
| resources["dataset_matches"].add(hub_id_match) | |
| return cell, resources | |
| c = Config() | |
| c.HTMLExporter.preprocessors = [HubIDCell] | |
| client = httpx.AsyncClient() | |
| html_exporter = HTMLExporter(config=c) | |
| async def homepage(_): | |
| return FileResponse("static/index.html") | |
| async def healthz(_): | |
| return JSONResponse({"success": True}) | |
| def convert( | |
| s: str, theme: Literal["light", "dark"], debug_info: str | |
| ) -> Tuple[str, List[str], List[str]]: | |
| # Capture potential validation error: | |
| try: | |
| notebook_node: NotebookNode = nbformat.reads( | |
| s, | |
| as_version=nbformat.current_nbformat, | |
| ) | |
| except nbformat.reader.NotJSONError: | |
| print(400, f"Notebook is not JSON. {debug_info}") | |
| raise HTTPException(400, "Notebook is not JSON.") | |
| except ValidationError as e: | |
| print( | |
| 400, | |
| f"Notebook is invalid according to nbformat: {e}. {debug_info}", | |
| ) | |
| raise HTTPException( | |
| 400, | |
| f"Notebook is invalid according to nbformat: {e}.", | |
| ) | |
| print(f"Input: nbformat v{notebook_node.nbformat}.{notebook_node.nbformat_minor}") | |
| html_exporter.theme = theme | |
| body, metadata = html_exporter.from_notebook_node(notebook_node) | |
| metadata = dict(metadata) | |
| model_matches = metadata["model_matches"] | |
| dataset_matches = metadata["dataset_matches"] | |
| # TODO(customize or simplify template?) | |
| # TODO(also check source code for jupyter/nbviewer) | |
| for model_match in model_matches: | |
| print(f"updating {model_match}") | |
| body = body.replace( | |
| model_match, | |
| f"""<a href="https://huggingface.co/{model_match}">{model_match} </a>""", | |
| ) | |
| for dataset_match in dataset_matches: | |
| body = body.replace( | |
| dataset_match, | |
| f"""<a href="https://huggingface.co/dataset/{dataset_match}">{dataset_match} </a>""", | |
| ) | |
| return body, metadata["model_matches"], metadata["dataset_matches"] | |
| async def convert_from_url(req: Request): | |
| url = req.query_params.get("url") | |
| theme = "dark" if req.query_params.get("theme") == "dark" else "light" | |
| if not url: | |
| raise HTTPException(400, "Param url is missing") | |
| print("\n===", url) | |
| r = await client.get( | |
| url, | |
| follow_redirects=True, | |
| # httpx no follow redirect by default | |
| ) | |
| if r.status_code != 200: | |
| raise HTTPException( | |
| 400, f"Got an error {r.status_code} when fetching remote file" | |
| ) | |
| # return HTMLResponse(content=convert(r.text, theme=theme, debug_info=f"url={url}")) | |
| html_text, model_matches, dataset_matches = convert( | |
| r.text, theme=theme, debug_info=f"url={url}" | |
| ) | |
| # return HTMLResponse(content=html_text) | |
| return JSONResponse( | |
| content={ | |
| "html": html_text, | |
| "model_matches": list(model_matches), | |
| "dataset_matches": list(dataset_matches), | |
| } | |
| ) | |
| async def convert_from_upload(req: Request): | |
| theme = "dark" if req.query_params.get("theme") == "dark" else "light" | |
| s = (await req.body()).decode("utf-8") | |
| return HTMLResponse( | |
| content=convert( | |
| s, theme=theme, debug_info=f"upload_from={req.headers.get('user-agent')}" | |
| ) | |
| ) | |
| app = Starlette( | |
| debug=False, | |
| routes=[ | |
| Route("/", homepage), | |
| Route("/healthz", healthz), | |
| Route("/convert", convert_from_url), | |
| Route("/upload", convert_from_upload, methods=["POST"]), | |
| ], | |
| ) | |