openalex_mapper / openalex_client.py
MaxNoichl's picture
Add user OpenAlex key support and clearer limit errors
716a752
import os
import random
import re
import time
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Iterable
from urllib.parse import parse_qs, urlparse
import requests
from config_loader import load_local_config
DEFAULT_BASE_URL = "https://api.openalex.org"
DEFAULT_PER_PAGE = 100
DEFAULT_TIMEOUT = 30
DEFAULT_RETRIES = 5
DEFAULT_SELECT_FIELDS = (
"id",
"title",
"display_name",
"doi",
"publication_year",
"abstract_inverted_index",
"primary_location",
"primary_topic",
"referenced_works",
)
IGNORED_INPUT_PARAMS = {
"api_key",
"cursor",
"page",
"per-page",
"per_page",
"sample",
"seed",
"select",
}
OPENALEX_ID_RE = re.compile(r"^[A-Za-z]\d+$")
class OpenAlexRequestError(RuntimeError):
def __init__(self, message, *, status_code=None, payload=None, headers=None):
super().__init__(message)
self.status_code = status_code
self.payload = payload
self.headers = dict(headers or {})
@dataclass(frozen=True)
class FilterToken:
key: str
value: str
@dataclass
class OpenAlexQuery:
entity: str = "works"
params: dict[str, str] = field(default_factory=dict)
filter_tokens: list[FilterToken] = field(default_factory=list)
legacy_filters: list[str] = field(default_factory=list)
def as_params(self, select_fields: Iterable[str] | None = None, extra_params: dict[str, str] | None = None):
params = dict(self.params)
if self.filter_tokens:
params["filter"] = ",".join(f"{token.key}:{token.value}" for token in self.filter_tokens)
if select_fields:
params["select"] = ",".join(select_fields)
if extra_params:
params.update({key: str(value) for key, value in extra_params.items()})
return params
def without_params(self, *keys: str):
keys_to_remove = set(keys)
return OpenAlexQuery(
entity=self.entity,
params={key: value for key, value in self.params.items() if key not in keys_to_remove},
filter_tokens=list(self.filter_tokens),
legacy_filters=list(self.legacy_filters),
)
def _normalize_url_input(url: str):
url = url.strip()
if "://" not in url and url.startswith(("openalex.org/", "api.openalex.org/")):
return f"https://{url}"
return url
def _split_filter_string(filter_value: str):
tokens = []
current = []
quote = None
for char in filter_value:
if char in {"'", '"'}:
if quote == char:
quote = None
elif quote is None:
quote = char
current.append(char)
continue
if char == "," and quote is None:
token = "".join(current).strip()
if token:
tokens.append(token)
current = []
continue
current.append(char)
token = "".join(current).strip()
if token:
tokens.append(token)
return tokens
def _normalize_filter_token(key: str, value: str):
legacy_key = None
if key == "default.search":
return "__search__", value, key
if key == "host_venue.id":
return "primary_location.source.id", value, key
if key.startswith("host_venue."):
return key.replace("host_venue.", "primary_location.source.", 1), value, key
if key == "alternate_host_venues.id":
return "locations.source.id", value, key
if key.startswith("alternate_host_venues."):
return key.replace("alternate_host_venues.", "locations.source.", 1), value, key
if key.startswith("x_concepts"):
return key.replace("x_concepts", "concepts", 1), value, key
return key, value, legacy_key
def normalize_openalex_url(url: str):
url = _normalize_url_input(url)
parsed_url = urlparse(url)
path_parts = [part for part in parsed_url.path.split("/") if part]
entity = path_parts[0] if path_parts else "works"
query_params = parse_qs(parsed_url.query, keep_blank_values=True)
params = {}
legacy_filters = []
filter_tokens = []
search_value = query_params.get("search", [None])[0]
for raw_filter in query_params.get("filter", []):
for token in _split_filter_string(raw_filter):
key, sep, value = token.partition(":")
if not sep:
continue
key = key.strip()
value = value.strip()
key, value, legacy_key = _normalize_filter_token(key, value)
if legacy_key:
legacy_filters.append(legacy_key)
if key == "__search__":
if search_value is None:
search_value = value
else:
filter_tokens.append(FilterToken("default.search", value))
continue
filter_tokens.append(FilterToken(key, value))
if search_value:
params["search"] = search_value
for key, values in query_params.items():
if not values or key in {"filter", "search"} or key in IGNORED_INPUT_PARAMS:
continue
normalized_key = "per_page" if key == "per-page" else key
params[normalized_key] = values[0]
return OpenAlexQuery(
entity=entity,
params=params,
filter_tokens=filter_tokens,
legacy_filters=legacy_filters,
)
def _normalize_openalex_id(entity_id: str):
entity_id = entity_id.strip()
if entity_id.startswith("https://openalex.org/"):
entity_id = entity_id.rstrip("/").split("/")[-1]
if OPENALEX_ID_RE.match(entity_id):
return entity_id[0].upper() + entity_id[1:]
return entity_id
def _clean_api_key(api_key):
if api_key is None:
return None
api_key = str(api_key).strip()
return api_key or None
class OpenAlexClient:
def __init__(
self,
api_key=None,
base_url=DEFAULT_BASE_URL,
timeout=DEFAULT_TIMEOUT,
max_retries=DEFAULT_RETRIES,
):
self.api_key = _clean_api_key(api_key)
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.max_retries = max_retries
self.session = requests.Session()
self.session.headers.update(
{
"User-Agent": "OpenAlexMapper/1.0 (+https://huggingface.co/spaces/m7n/openalex_mapper)",
"Accept": "application/json",
}
)
self._entity_cache = {}
@classmethod
def from_env(cls, require_api_key=False):
load_local_config()
api_key = _clean_api_key(os.environ.get("OPENALEX_API_KEY"))
if require_api_key and not api_key:
raise RuntimeError(
"OPENALEX_API_KEY is required. Set it as a Hugging Face Space secret or in openalex_config.local.json."
)
return cls(api_key=api_key or None)
def with_api_key(self, api_key):
return OpenAlexClient(
api_key=api_key,
base_url=self.base_url,
timeout=self.timeout,
max_retries=self.max_retries,
)
def _request_json(self, path, params=None):
params = dict(params or {})
if self.api_key:
params["api_key"] = self.api_key
url = f"{self.base_url}/{path.lstrip('/')}"
last_error = None
for attempt in range(self.max_retries):
response = None
try:
response = self.session.get(url, params=params, timeout=self.timeout)
if response.status_code in {500, 502, 503, 504}:
retry_after = response.headers.get("Retry-After")
wait_time = float(retry_after) if retry_after else (2 ** attempt)
time.sleep(wait_time)
continue
response.raise_for_status()
return response.json()
except requests.RequestException as exc:
last_error = exc
if response is not None and response.status_code in {401, 403, 429}:
break
if attempt == self.max_retries - 1:
break
time.sleep(2 ** attempt)
if response is not None:
try:
payload = response.json()
except ValueError:
payload = response.text
raise OpenAlexRequestError(
f"OpenAlex request failed for {url}: {payload}",
status_code=response.status_code,
payload=payload,
headers=response.headers,
) from last_error
raise OpenAlexRequestError(f"OpenAlex request failed for {url}: {last_error}") from last_error
def get_entity(self, entity, entity_id, select_fields=None):
normalized_id = _normalize_openalex_id(entity_id)
cache_key = (entity, normalized_id, tuple(select_fields or ()))
if cache_key in self._entity_cache:
return self._entity_cache[cache_key]
payload = self._request_json(
f"{entity}/{normalized_id}",
params={"select": ",".join(select_fields)} if select_fields else None,
)
self._entity_cache[cache_key] = payload
return payload
def count(self, query):
payload = self._request_json(
query.entity,
params=query.as_params(select_fields=("id",), extra_params={"per_page": 1}),
)
return int(payload.get("meta", {}).get("count") or 0)
def _normalize_work_record(self, record):
normalized = dict(record)
normalized["title"] = normalized.get("title") or normalized.get("display_name") or " "
normalized.setdefault("abstract_inverted_index", None)
normalized.setdefault("primary_location", None)
normalized.setdefault("primary_topic", None)
normalized.setdefault("referenced_works", [])
return normalized
def iter_works(self, query, limit=None, extra_params=None, per_page=DEFAULT_PER_PAGE):
params = query.as_params(select_fields=DEFAULT_SELECT_FIELDS, extra_params=extra_params)
params["cursor"] = params.get("cursor", "*")
fetched = 0
while True:
current_per_page = per_page
if limit is not None:
remaining = limit - fetched
if remaining <= 0:
break
current_per_page = min(current_per_page, remaining)
params["per_page"] = current_per_page
payload = self._request_json(query.entity, params=params)
results = payload.get("results", [])
if not results:
break
for record in results:
yield self._normalize_work_record(record)
fetched += 1
if limit is not None and fetched >= limit:
return
next_cursor = payload.get("meta", {}).get("next_cursor")
if next_cursor is None:
break
params["cursor"] = next_cursor
def fetch_works(self, query, limit=None):
return list(self.iter_works(query, limit=limit))
def fetch_sampled_works(self, query, sample_size, seed):
sampling_query = query.without_params("sort")
if sample_size <= 10000:
return list(
self.iter_works(
sampling_query,
limit=sample_size,
extra_params={"sample": sample_size, "seed": seed},
)
)
return self.reservoir_sample_works(sampling_query, sample_size, seed)
def reservoir_sample_works(self, query, sample_size, seed):
rng = random.Random(seed)
reservoir = []
for index, record in enumerate(self.iter_works(query)):
if index < sample_size:
reservoir.append(record)
continue
sample_index = rng.randint(0, index)
if sample_index < sample_size:
reservoir[sample_index] = record
return reservoir
def fetch_records_from_dois(self, doi_list, block_size=50):
all_records = []
clean_dois = [doi.strip() for doi in doi_list if isinstance(doi, str) and doi.strip()]
for start in range(0, len(clean_dois), block_size):
sublist = clean_dois[start : start + block_size]
doi_filter = "|".join(sublist)
query = OpenAlexQuery(
entity="works",
filter_tokens=[FilterToken("doi", doi_filter)],
)
all_records.extend(self.fetch_works(query, limit=len(sublist)))
return all_records
@lru_cache(maxsize=1)
def get_openalex_client(require_api_key=False):
return OpenAlexClient.from_env(require_api_key=require_api_key)