Spaces:
Running
Running
| import dataclasses | |
| import hashlib | |
| import io | |
| import tempfile | |
| from etils import epath | |
| import magic | |
| import pandas as pd | |
| import requests | |
| from .names import find_unique_name | |
| from .path import get_resource_path | |
| from .state import FileObject | |
| from .state import FileSet | |
| FILE_OBJECT = "FileObject" | |
| FILE_SET = "FileSet" | |
| RESOURCE_TYPES = [FILE_OBJECT, FILE_SET] | |
| class FileType: | |
| name: str | |
| encoding_format: str | |
| extensions: list[str] | |
| class FileTypes: | |
| CSV = FileType(name="CSV", encoding_format="text/csv", extensions=["csv"]) | |
| EXCEL = FileType( | |
| name="Excel", | |
| encoding_format="application/vnd.ms-excel", | |
| extensions=["xls", "xlsx", "xlsm"], | |
| ) | |
| GZIP = FileType(name="GZIP", encoding_format="application/gzip", extensions=["gz"]) | |
| JPEG = FileType(name="JPEG", encoding_format="image/jpeg", extensions=["json"]) | |
| JSON = FileType( | |
| name="JSON", encoding_format="application/json", extensions=["json"] | |
| ) | |
| JSONL = FileType( | |
| name="JSON-Lines", | |
| encoding_format="application/jsonl+json", | |
| extensions=["jsonl"], | |
| ) | |
| PARQUET = FileType( | |
| name="Parquet", | |
| encoding_format="application/vnd.apache.parquet", | |
| extensions=["parquet"], | |
| ) | |
| TAR = FileType( | |
| name="Archive (TAR)", | |
| encoding_format="application/x-tar", | |
| extensions=["tar"], | |
| ) | |
| TSV = FileType( | |
| name="TSV", encoding_format="text/tab-separated-values", extensions=["tsv"] | |
| ) | |
| TXT = FileType( | |
| name="Text", | |
| encoding_format="text/plain", | |
| extensions=["txt"], | |
| ) | |
| ZIP = FileType( | |
| name="ZIP", | |
| encoding_format="application/zip", | |
| extensions=["zip"], | |
| ) | |
| def _full_name(file_type: FileType): | |
| return f"{file_type.name} ({file_type.encoding_format})" | |
| FILE_TYPES: dict[str, FileType] = { | |
| _full_name(file_type): file_type | |
| for file_type in [ | |
| FileTypes.CSV, | |
| FileTypes.EXCEL, | |
| FileTypes.GZIP, | |
| FileTypes.JPEG, | |
| FileTypes.JSON, | |
| FileTypes.JSONL, | |
| FileTypes.PARQUET, | |
| FileTypes.TAR, | |
| FileTypes.TSV, | |
| FileTypes.TXT, | |
| FileTypes.ZIP, | |
| ] | |
| } | |
| ENCODING_FORMATS: dict[str, FileType] = { | |
| file_type.encoding_format: file_type for file_type in FILE_TYPES.values() | |
| } | |
| def name_to_code(file_type_name: str) -> str | None: | |
| """Maps names to the encoding format: Text => plain/text.""" | |
| for name, file_type in FILE_TYPES.items(): | |
| if file_type_name == name: | |
| return file_type.encoding_format | |
| return None | |
| def code_to_index(encoding_format: str) -> int | None: | |
| """Maps the encoding format to its index in the list of keys: plain/text => 12.""" | |
| for i, file_type in enumerate(FILE_TYPES.values()): | |
| if file_type.encoding_format == encoding_format: | |
| return i | |
| return None | |
| def _sha256(content: bytes): | |
| """Computes the sha256 digest of the byte string.""" | |
| return hashlib.sha256(content).hexdigest() | |
| def hash_file_path(url: str) -> epath.Path: | |
| """Reproducibly produces the file path.""" | |
| tempdir = epath.Path(tempfile.gettempdir()) | |
| hash = _sha256(url.encode()) | |
| return tempdir / f"croissant-editor-{hash}" | |
| def download_file(url: str, file_path: epath.Path): | |
| """Downloads the file locally to `file_path`.""" | |
| with requests.get(url, stream=True) as request: | |
| request.raise_for_status() | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| tmpdir = epath.Path(tmpdir) / "file" | |
| with tmpdir.open("wb") as file: | |
| for chunk in request.iter_content(chunk_size=8192): | |
| file.write(chunk) | |
| tmpdir.copy(file_path) | |
| def get_dataframe(file_type: FileType, file: io.BytesIO | epath.Path) -> pd.DataFrame: | |
| """Gets the df associated to the file.""" | |
| if file_type == FileTypes.CSV: | |
| df = pd.read_csv(file) | |
| elif file_type == FileTypes.EXCEL: | |
| df = pd.read_excel(file) | |
| elif file_type == FileTypes.JSON: | |
| df = pd.read_json(file) | |
| elif file_type == FileTypes.JSONL: | |
| df = pd.read_json(file, lines=True) | |
| elif file_type == FileTypes.PARQUET: | |
| df = pd.read_parquet(file) | |
| elif file_type == FileTypes.TSV: | |
| df = pd.read_csv(file, sep="\t") | |
| else: | |
| raise NotImplementedError( | |
| f"File type {file_type} is not supported. Please, open an issue on GitHub:" | |
| " https://github.com/mlcommons/croissant/issues/new" | |
| ) | |
| return df.infer_objects() | |
| def _guess_mime_type(path: epath.Path) -> str: | |
| """Guess most specific MIME type.""" | |
| mime = magic.from_file(path, mime=True) | |
| extension = path.suffix | |
| if mime == "text/plain": | |
| # In some cases, a CSV/TSV may be classified as text | |
| # For example, if the file is not terminated by a newline | |
| if extension == ".csv": | |
| mime = "text/csv" | |
| elif extension == ".tsv": | |
| mime = "text/tab-separated-values" | |
| return mime | |
| def guess_file_type(path: epath.Path) -> FileType | None: | |
| mime = _guess_mime_type(path) | |
| return ENCODING_FORMATS.get(mime) | |
| def file_from_url(url: str, names: set[str], folder: epath.Path) -> FileObject: | |
| """Downloads locally and extracts the file information.""" | |
| file_path = hash_file_path(url) | |
| if not file_path.exists(): | |
| download_file(url, file_path) | |
| with file_path.open("rb") as file: | |
| sha256 = _sha256(file.read()) | |
| file_type = guess_file_type(file_path) | |
| df = get_dataframe(file_type, file_path) | |
| name = find_unique_name(names, url.split("/")[-1]) | |
| return FileObject( | |
| id=name, | |
| name=name, | |
| description="", | |
| content_url=url, | |
| encoding_format=file_type.encoding_format, | |
| sha256=sha256, | |
| df=df, | |
| folder=folder, | |
| ) | |
| def file_from_upload( | |
| file: io.BytesIO, names: set[str], folder: epath.Path | |
| ) -> FileObject: | |
| """Uploads locally and extracts the file information.""" | |
| value = file.getvalue() | |
| content_url = f"data/{file.name}" | |
| sha256 = _sha256(value) | |
| file_path = get_resource_path(content_url) | |
| with file_path.open("wb") as f: | |
| f.write(value) | |
| file_type = guess_file_type(file_path) | |
| df = get_dataframe(file_type, file) | |
| name = find_unique_name(names, file.name) | |
| return FileObject( | |
| id=name, | |
| name=name, | |
| description="", | |
| content_url=content_url, | |
| encoding_format=file_type.encoding_format, | |
| sha256=sha256, | |
| df=df, | |
| folder=folder, | |
| ) | |
| def file_from_form( | |
| type: str, names: set[str], folder: epath.Path | |
| ) -> FileObject | FileSet: | |
| """Creates a file based on manually added fields.""" | |
| if type == FILE_OBJECT: | |
| name = find_unique_name(names, "file_object") | |
| return FileObject(id=name, name=name, folder=folder) | |
| elif type == FILE_SET: | |
| name = find_unique_name(names, "file_set") | |
| return FileSet(id=name, name=name) | |
| else: | |
| raise ValueError("type has to be one of FILE_OBJECT, FILE_SET") | |
| def is_url(file: FileObject) -> bool: | |
| return file.content_url and file.content_url.startswith("http") | |
| def trigger_download(file: FileObject): | |
| if is_url(file): | |
| file_path = hash_file_path(file.content_url) | |
| if not file_path.exists(): | |
| download_file(file.content_url, file_path) | |
| else: | |
| file_path = get_resource_path(file.content_url) | |
| file_type = guess_file_type(file_path) | |
| df = get_dataframe(file_type, file_path) | |
| file.df = df | |