Spaces:
Running on Zero
Running on Zero
| import os | |
| import random | |
| import re | |
| import time | |
| from dataclasses import dataclass, field | |
| from functools import lru_cache | |
| from typing import Iterable | |
| from urllib.parse import parse_qs, urlparse | |
| import requests | |
| from config_loader import load_local_config | |
| DEFAULT_BASE_URL = "https://api.openalex.org" | |
| DEFAULT_PER_PAGE = 100 | |
| DEFAULT_TIMEOUT = 30 | |
| DEFAULT_RETRIES = 5 | |
| DEFAULT_SELECT_FIELDS = ( | |
| "id", | |
| "title", | |
| "display_name", | |
| "doi", | |
| "publication_year", | |
| "abstract_inverted_index", | |
| "primary_location", | |
| "primary_topic", | |
| "referenced_works", | |
| ) | |
| IGNORED_INPUT_PARAMS = { | |
| "api_key", | |
| "cursor", | |
| "page", | |
| "per-page", | |
| "per_page", | |
| "sample", | |
| "seed", | |
| "select", | |
| } | |
| OPENALEX_ID_RE = re.compile(r"^[A-Za-z]\d+$") | |
| class OpenAlexRequestError(RuntimeError): | |
| def __init__(self, message, *, status_code=None, payload=None, headers=None): | |
| super().__init__(message) | |
| self.status_code = status_code | |
| self.payload = payload | |
| self.headers = dict(headers or {}) | |
| class FilterToken: | |
| key: str | |
| value: str | |
| class OpenAlexQuery: | |
| entity: str = "works" | |
| params: dict[str, str] = field(default_factory=dict) | |
| filter_tokens: list[FilterToken] = field(default_factory=list) | |
| legacy_filters: list[str] = field(default_factory=list) | |
| def as_params(self, select_fields: Iterable[str] | None = None, extra_params: dict[str, str] | None = None): | |
| params = dict(self.params) | |
| if self.filter_tokens: | |
| params["filter"] = ",".join(f"{token.key}:{token.value}" for token in self.filter_tokens) | |
| if select_fields: | |
| params["select"] = ",".join(select_fields) | |
| if extra_params: | |
| params.update({key: str(value) for key, value in extra_params.items()}) | |
| return params | |
| def without_params(self, *keys: str): | |
| keys_to_remove = set(keys) | |
| return OpenAlexQuery( | |
| entity=self.entity, | |
| params={key: value for key, value in self.params.items() if key not in keys_to_remove}, | |
| filter_tokens=list(self.filter_tokens), | |
| legacy_filters=list(self.legacy_filters), | |
| ) | |
| def _normalize_url_input(url: str): | |
| url = url.strip() | |
| if "://" not in url and url.startswith(("openalex.org/", "api.openalex.org/")): | |
| return f"https://{url}" | |
| return url | |
| def _split_filter_string(filter_value: str): | |
| tokens = [] | |
| current = [] | |
| quote = None | |
| for char in filter_value: | |
| if char in {"'", '"'}: | |
| if quote == char: | |
| quote = None | |
| elif quote is None: | |
| quote = char | |
| current.append(char) | |
| continue | |
| if char == "," and quote is None: | |
| token = "".join(current).strip() | |
| if token: | |
| tokens.append(token) | |
| current = [] | |
| continue | |
| current.append(char) | |
| token = "".join(current).strip() | |
| if token: | |
| tokens.append(token) | |
| return tokens | |
| def _normalize_filter_token(key: str, value: str): | |
| legacy_key = None | |
| if key == "default.search": | |
| return "__search__", value, key | |
| if key == "host_venue.id": | |
| return "primary_location.source.id", value, key | |
| if key.startswith("host_venue."): | |
| return key.replace("host_venue.", "primary_location.source.", 1), value, key | |
| if key == "alternate_host_venues.id": | |
| return "locations.source.id", value, key | |
| if key.startswith("alternate_host_venues."): | |
| return key.replace("alternate_host_venues.", "locations.source.", 1), value, key | |
| if key.startswith("x_concepts"): | |
| return key.replace("x_concepts", "concepts", 1), value, key | |
| return key, value, legacy_key | |
| def normalize_openalex_url(url: str): | |
| url = _normalize_url_input(url) | |
| parsed_url = urlparse(url) | |
| path_parts = [part for part in parsed_url.path.split("/") if part] | |
| entity = path_parts[0] if path_parts else "works" | |
| query_params = parse_qs(parsed_url.query, keep_blank_values=True) | |
| params = {} | |
| legacy_filters = [] | |
| filter_tokens = [] | |
| search_value = query_params.get("search", [None])[0] | |
| for raw_filter in query_params.get("filter", []): | |
| for token in _split_filter_string(raw_filter): | |
| key, sep, value = token.partition(":") | |
| if not sep: | |
| continue | |
| key = key.strip() | |
| value = value.strip() | |
| key, value, legacy_key = _normalize_filter_token(key, value) | |
| if legacy_key: | |
| legacy_filters.append(legacy_key) | |
| if key == "__search__": | |
| if search_value is None: | |
| search_value = value | |
| else: | |
| filter_tokens.append(FilterToken("default.search", value)) | |
| continue | |
| filter_tokens.append(FilterToken(key, value)) | |
| if search_value: | |
| params["search"] = search_value | |
| for key, values in query_params.items(): | |
| if not values or key in {"filter", "search"} or key in IGNORED_INPUT_PARAMS: | |
| continue | |
| normalized_key = "per_page" if key == "per-page" else key | |
| params[normalized_key] = values[0] | |
| return OpenAlexQuery( | |
| entity=entity, | |
| params=params, | |
| filter_tokens=filter_tokens, | |
| legacy_filters=legacy_filters, | |
| ) | |
| def _normalize_openalex_id(entity_id: str): | |
| entity_id = entity_id.strip() | |
| if entity_id.startswith("https://openalex.org/"): | |
| entity_id = entity_id.rstrip("/").split("/")[-1] | |
| if OPENALEX_ID_RE.match(entity_id): | |
| return entity_id[0].upper() + entity_id[1:] | |
| return entity_id | |
| def _clean_api_key(api_key): | |
| if api_key is None: | |
| return None | |
| api_key = str(api_key).strip() | |
| return api_key or None | |
| class OpenAlexClient: | |
| def __init__( | |
| self, | |
| api_key=None, | |
| base_url=DEFAULT_BASE_URL, | |
| timeout=DEFAULT_TIMEOUT, | |
| max_retries=DEFAULT_RETRIES, | |
| ): | |
| self.api_key = _clean_api_key(api_key) | |
| self.base_url = base_url.rstrip("/") | |
| self.timeout = timeout | |
| self.max_retries = max_retries | |
| self.session = requests.Session() | |
| self.session.headers.update( | |
| { | |
| "User-Agent": "OpenAlexMapper/1.0 (+https://huggingface.co/spaces/m7n/openalex_mapper)", | |
| "Accept": "application/json", | |
| } | |
| ) | |
| self._entity_cache = {} | |
| def from_env(cls, require_api_key=False): | |
| load_local_config() | |
| api_key = _clean_api_key(os.environ.get("OPENALEX_API_KEY")) | |
| if require_api_key and not api_key: | |
| raise RuntimeError( | |
| "OPENALEX_API_KEY is required. Set it as a Hugging Face Space secret or in openalex_config.local.json." | |
| ) | |
| return cls(api_key=api_key or None) | |
| def with_api_key(self, api_key): | |
| return OpenAlexClient( | |
| api_key=api_key, | |
| base_url=self.base_url, | |
| timeout=self.timeout, | |
| max_retries=self.max_retries, | |
| ) | |
| def _request_json(self, path, params=None): | |
| params = dict(params or {}) | |
| if self.api_key: | |
| params["api_key"] = self.api_key | |
| url = f"{self.base_url}/{path.lstrip('/')}" | |
| last_error = None | |
| for attempt in range(self.max_retries): | |
| response = None | |
| try: | |
| response = self.session.get(url, params=params, timeout=self.timeout) | |
| if response.status_code in {500, 502, 503, 504}: | |
| retry_after = response.headers.get("Retry-After") | |
| wait_time = float(retry_after) if retry_after else (2 ** attempt) | |
| time.sleep(wait_time) | |
| continue | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.RequestException as exc: | |
| last_error = exc | |
| if response is not None and response.status_code in {401, 403, 429}: | |
| break | |
| if attempt == self.max_retries - 1: | |
| break | |
| time.sleep(2 ** attempt) | |
| if response is not None: | |
| try: | |
| payload = response.json() | |
| except ValueError: | |
| payload = response.text | |
| raise OpenAlexRequestError( | |
| f"OpenAlex request failed for {url}: {payload}", | |
| status_code=response.status_code, | |
| payload=payload, | |
| headers=response.headers, | |
| ) from last_error | |
| raise OpenAlexRequestError(f"OpenAlex request failed for {url}: {last_error}") from last_error | |
| def get_entity(self, entity, entity_id, select_fields=None): | |
| normalized_id = _normalize_openalex_id(entity_id) | |
| cache_key = (entity, normalized_id, tuple(select_fields or ())) | |
| if cache_key in self._entity_cache: | |
| return self._entity_cache[cache_key] | |
| payload = self._request_json( | |
| f"{entity}/{normalized_id}", | |
| params={"select": ",".join(select_fields)} if select_fields else None, | |
| ) | |
| self._entity_cache[cache_key] = payload | |
| return payload | |
| def count(self, query): | |
| payload = self._request_json( | |
| query.entity, | |
| params=query.as_params(select_fields=("id",), extra_params={"per_page": 1}), | |
| ) | |
| return int(payload.get("meta", {}).get("count") or 0) | |
| def _normalize_work_record(self, record): | |
| normalized = dict(record) | |
| normalized["title"] = normalized.get("title") or normalized.get("display_name") or " " | |
| normalized.setdefault("abstract_inverted_index", None) | |
| normalized.setdefault("primary_location", None) | |
| normalized.setdefault("primary_topic", None) | |
| normalized.setdefault("referenced_works", []) | |
| return normalized | |
| def iter_works(self, query, limit=None, extra_params=None, per_page=DEFAULT_PER_PAGE): | |
| params = query.as_params(select_fields=DEFAULT_SELECT_FIELDS, extra_params=extra_params) | |
| params["cursor"] = params.get("cursor", "*") | |
| fetched = 0 | |
| while True: | |
| current_per_page = per_page | |
| if limit is not None: | |
| remaining = limit - fetched | |
| if remaining <= 0: | |
| break | |
| current_per_page = min(current_per_page, remaining) | |
| params["per_page"] = current_per_page | |
| payload = self._request_json(query.entity, params=params) | |
| results = payload.get("results", []) | |
| if not results: | |
| break | |
| for record in results: | |
| yield self._normalize_work_record(record) | |
| fetched += 1 | |
| if limit is not None and fetched >= limit: | |
| return | |
| next_cursor = payload.get("meta", {}).get("next_cursor") | |
| if next_cursor is None: | |
| break | |
| params["cursor"] = next_cursor | |
| def fetch_works(self, query, limit=None): | |
| return list(self.iter_works(query, limit=limit)) | |
| def fetch_sampled_works(self, query, sample_size, seed): | |
| sampling_query = query.without_params("sort") | |
| if sample_size <= 10000: | |
| return list( | |
| self.iter_works( | |
| sampling_query, | |
| limit=sample_size, | |
| extra_params={"sample": sample_size, "seed": seed}, | |
| ) | |
| ) | |
| return self.reservoir_sample_works(sampling_query, sample_size, seed) | |
| def reservoir_sample_works(self, query, sample_size, seed): | |
| rng = random.Random(seed) | |
| reservoir = [] | |
| for index, record in enumerate(self.iter_works(query)): | |
| if index < sample_size: | |
| reservoir.append(record) | |
| continue | |
| sample_index = rng.randint(0, index) | |
| if sample_index < sample_size: | |
| reservoir[sample_index] = record | |
| return reservoir | |
| def fetch_records_from_dois(self, doi_list, block_size=50): | |
| all_records = [] | |
| clean_dois = [doi.strip() for doi in doi_list if isinstance(doi, str) and doi.strip()] | |
| for start in range(0, len(clean_dois), block_size): | |
| sublist = clean_dois[start : start + block_size] | |
| doi_filter = "|".join(sublist) | |
| query = OpenAlexQuery( | |
| entity="works", | |
| filter_tokens=[FilterToken("doi", doi_filter)], | |
| ) | |
| all_records.extend(self.fetch_works(query, limit=len(sublist))) | |
| return all_records | |
| def get_openalex_client(require_api_key=False): | |
| return OpenAlexClient.from_env(require_api_key=require_api_key) | |