evalstate HF Staff commited on
Commit
ec45ad9
·
verified ·
1 Parent(s): 938e81f

add sanitization

Browse files
Files changed (1) hide show
  1. hf_papers_tool.py +252 -139
hf_papers_tool.py CHANGED
@@ -10,167 +10,280 @@ from urllib.parse import urlencode
10
  from urllib.request import Request, urlopen
11
 
12
  DEFAULT_LIMIT = 20
13
- DEFAULT_TIMEOUT_SEC = 30
14
  MAX_API_LIMIT = 100
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  def _load_token() -> str | None:
18
- # Check for request-scoped token first (when running as MCP server)
19
- try:
20
- from fast_agent.mcp.auth.context import request_bearer_token
21
 
22
- ctx_token = request_bearer_token.get()
23
- if ctx_token:
24
- return ctx_token
25
- except ImportError:
26
- pass
27
 
28
- return None
 
 
 
29
 
 
 
 
 
 
30
 
31
- def _normalize_date_param(value: str | None) -> str | None:
32
- if not value:
33
  return None
34
- return value.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def _build_url(params: dict[str, Any]) -> str:
38
- base = os.getenv("HF_ENDPOINT", "https://huggingface.co").rstrip("/")
39
- query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True)
40
- return f"{base}/api/daily_papers?{query}" if query else f"{base}/api/daily_papers"
41
 
42
 
43
  def _request_json(url: str) -> list[dict[str, Any]]:
44
- headers = {"Accept": "application/json"}
45
- token = _load_token()
46
- if token:
47
- headers["Authorization"] = f"Bearer {token}"
48
-
49
- request = Request(url, headers=headers, method="GET")
50
- try:
51
- with urlopen(request, timeout=DEFAULT_TIMEOUT_SEC) as response:
52
- raw = response.read()
53
- except HTTPError as exc:
54
- error_body = exc.read().decode("utf-8", errors="replace")
55
- raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc
56
- except URLError as exc:
57
- raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc
58
-
59
- payload = json.loads(raw)
60
- if not isinstance(payload, list):
61
- raise RuntimeError("Unexpected response shape from /api/daily_papers")
62
- return payload
63
 
64
 
65
  def _extract_search_blob(item: dict[str, Any]) -> str:
66
- paper = item.get("paper") or {}
67
- authors = paper.get("authors") or []
68
- author_names = [a.get("name", "") for a in authors if isinstance(a, dict)]
69
-
70
- ai_keywords = paper.get("ai_keywords") or []
71
- if isinstance(ai_keywords, list):
72
- ai_keywords_text = " ".join(str(k) for k in ai_keywords)
73
- else:
74
- ai_keywords_text = str(ai_keywords)
75
-
76
- parts = [
77
- item.get("title"),
78
- item.get("summary"),
79
- paper.get("title"),
80
- paper.get("summary"),
81
- paper.get("ai_summary"),
82
- ai_keywords_text,
83
- " ".join(author_names),
84
- paper.get("id"),
85
- paper.get("projectPage"),
86
- paper.get("githubRepo"),
87
- ]
88
-
89
- text = " ".join(str(part) for part in parts if part)
90
- return text.lower()
91
 
92
 
93
  def _matches_query(item: dict[str, Any], query: str) -> bool:
94
- tokens = [t for t in re.split(r"\s+", query.strip().lower()) if t]
95
- if not tokens:
96
- return True
97
- haystack = _extract_search_blob(item)
98
- return all(token in haystack for token in tokens)
 
 
 
 
 
 
 
 
 
99
 
100
 
101
  def hf_papers_search(
102
- query: str | None = None,
103
- *,
104
- date: str | None = None,
105
- week: str | None = None,
106
- month: str | None = None,
107
- submitter: str | None = None,
108
- sort: str | None = None,
109
- limit: int | None = None,
110
- page: int | None = None,
111
- max_pages: int | None = None,
112
- api_limit: int | None = None,
113
  ) -> dict[str, Any]:
114
- """
115
- Search Hugging Face Daily Papers with optional local filtering.
116
-
117
- Args:
118
- query: Case-insensitive keyword search across title, summary, authors,
119
- AI summary/keywords, project page, repo link, and paper id.
120
- date: ISO date (YYYY-MM-DD).
121
- week: ISO week (YYYY-Www).
122
- month: ISO month (YYYY-MM).
123
- submitter: HF username of the submitter.
124
- sort: "publishedAt" or "trending".
125
- limit: Max results to return after filtering (default 20).
126
- page: Page index for the API (default 0).
127
- max_pages: Number of pages to fetch for local filtering (default 1).
128
- api_limit: Page size for the API (default 50, max 100).
129
-
130
- Returns:
131
- dict with query metadata and list of daily paper entries.
132
- """
133
- resolved_limit = DEFAULT_LIMIT if limit is None else max(int(limit), 1)
134
- start_page = max(int(page or 0), 0)
135
- pages_to_fetch = max(int(max_pages or 1), 1)
136
-
137
- per_page = 50 if api_limit is None else max(int(api_limit), 1)
138
- per_page = min(per_page, MAX_API_LIMIT)
139
-
140
- params_base: dict[str, Any] = {
141
- "date": _normalize_date_param(date),
142
- "week": _normalize_date_param(week),
143
- "month": _normalize_date_param(month),
144
- "submitter": submitter.strip() if submitter else None,
145
- "sort": sort.strip() if sort else None,
146
- "limit": per_page,
147
- }
148
-
149
- results: list[dict[str, Any]] = []
150
- pages_fetched = 0
151
- for page_index in range(start_page, start_page + pages_to_fetch):
152
- params = {**params_base, "p": page_index}
153
- url = _build_url(params)
154
- payload = _request_json(url)
155
- pages_fetched += 1
156
-
157
- if query:
158
- filtered = [item for item in payload if _matches_query(item, query)]
159
- else:
160
- filtered = payload
161
-
162
- results.extend(filtered)
163
- if len(results) >= resolved_limit:
164
- break
165
-
166
- return {
167
- "query": query,
168
- "params": {
169
- **{k: v for k, v in params_base.items() if v is not None},
170
- "page": start_page,
171
- "max_pages": pages_fetched,
172
- "api_limit": per_page,
173
- },
174
- "returned": min(len(results), resolved_limit),
175
- "data": results[:resolved_limit],
176
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from urllib.request import Request, urlopen
11
 
12
  DEFAULT_LIMIT = 20
13
+ DEFAULT_TIMEOUT_SEC = 10
14
  MAX_API_LIMIT = 100
15
+ MAX_PAGES = 10
16
+ MAX_TOTAL_FETCH = 500
17
+ MAX_QUERY_LENGTH = 300
18
+ BASE_API_URL = "https://huggingface.co/api"
19
+
20
+ DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
21
+ WEEK_RE = re.compile(r"^\d{4}-W\d{2}$")
22
+ MONTH_RE = re.compile(r"^\d{4}-\d{2}$")
23
+ SUBMITTER_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,38}$")
24
+ ALLOWED_SORTS = {"publishedAt", "trending"}
25
 
26
 
27
  def _load_token() -> str | None:
28
+ # Check for request-scoped token first (when running as MCP server)
29
+ try:
30
+ from fast_agent.mcp.auth.context import request_bearer_token
31
 
32
+ ctx_token = request_bearer_token.get()
33
+ if ctx_token:
34
+ return ctx_token
35
+ except ImportError:
36
+ pass
37
 
38
+ # Fall back to HF_TOKEN environment variable
39
+ token = os.getenv("HF_TOKEN")
40
+ if token:
41
+ return token
42
 
43
+ # Fall back to cached huggingface token file
44
+ token_path = Path.home() / ".cache" / "huggingface" / "token"
45
+ if token_path.exists():
46
+ token_value = token_path.read_text(encoding="utf-8").strip()
47
+ return token_value or None
48
 
 
 
49
  return None
50
+
51
+
52
+ def _max_results_from_env() -> int:
53
+ raw = os.getenv("HF_MAX_RESULTS")
54
+ if not raw:
55
+ return DEFAULT_LIMIT
56
+ try:
57
+ value = int(raw)
58
+ except ValueError:
59
+ return DEFAULT_LIMIT
60
+ return value if value > 0 else DEFAULT_LIMIT
61
+
62
+
63
+ def _timeout_from_env() -> int:
64
+ raw = os.getenv("HF_TIMEOUT_SEC")
65
+ if not raw:
66
+ return DEFAULT_TIMEOUT_SEC
67
+ try:
68
+ value = int(raw)
69
+ except ValueError:
70
+ return DEFAULT_TIMEOUT_SEC
71
+ if value <= 0:
72
+ return DEFAULT_TIMEOUT_SEC
73
+ return min(value, DEFAULT_TIMEOUT_SEC)
74
+
75
+
76
+ def _coerce_int(name: str, value: int | None, *, default: int) -> int:
77
+ if value is None:
78
+ return default
79
+ try:
80
+ resolved = int(value)
81
+ except (TypeError, ValueError) as exc:
82
+ raise ValueError(f"{name} must be an integer.") from exc
83
+ return resolved
84
+
85
+
86
+ def _normalize_date_param(name: str, value: str | None, pattern: re.Pattern[str]) -> str | None:
87
+ if not value:
88
+ return None
89
+ cleaned = value.strip()
90
+ if not cleaned:
91
+ return None
92
+ if not pattern.match(cleaned):
93
+ raise ValueError(f"{name} must match {pattern.pattern}.")
94
+ return cleaned
95
+
96
+
97
+ def _normalize_submitter(value: str | None) -> str | None:
98
+ if not value:
99
+ return None
100
+ cleaned = value.strip()
101
+ if not cleaned:
102
+ return None
103
+ if not SUBMITTER_RE.match(cleaned):
104
+ raise ValueError("submitter must be a valid HF username.")
105
+ return cleaned
106
+
107
+
108
+ def _normalize_sort(value: str | None) -> str | None:
109
+ if not value:
110
+ return None
111
+ cleaned = value.strip()
112
+ if cleaned not in ALLOWED_SORTS:
113
+ allowed = ", ".join(sorted(ALLOWED_SORTS))
114
+ raise ValueError(f"sort must be one of: {allowed}.")
115
+ return cleaned
116
+
117
+
118
+ def _normalize_query(value: str | None) -> str | None:
119
+ if value is None:
120
+ return None
121
+ cleaned = value.strip()
122
+ if not cleaned:
123
+ return None
124
+ return cleaned[:MAX_QUERY_LENGTH]
125
 
126
 
127
  def _build_url(params: dict[str, Any]) -> str:
128
+ query = urlencode({k: v for k, v in params.items() if v is not None}, doseq=True)
129
+ return f"{BASE_API_URL}/daily_papers?{query}" if query else f"{BASE_API_URL}/daily_papers"
 
130
 
131
 
132
  def _request_json(url: str) -> list[dict[str, Any]]:
133
+ headers = {"Accept": "application/json"}
134
+ token = _load_token()
135
+ if token:
136
+ headers["Authorization"] = f"Bearer {token}"
137
+
138
+ request = Request(url, headers=headers, method="GET")
139
+ try:
140
+ with urlopen(request, timeout=_timeout_from_env()) as response:
141
+ raw = response.read()
142
+ except HTTPError as exc:
143
+ error_body = exc.read().decode("utf-8", errors="replace")
144
+ raise RuntimeError(f"HF API error {exc.code} for {url}: {error_body}") from exc
145
+ except URLError as exc:
146
+ raise RuntimeError(f"HF API request failed for {url}: {exc}") from exc
147
+
148
+ payload = json.loads(raw)
149
+ if not isinstance(payload, list):
150
+ raise RuntimeError("Unexpected response shape from /api/daily_papers")
151
+ return payload
152
 
153
 
154
  def _extract_search_blob(item: dict[str, Any]) -> str:
155
+ paper = item.get("paper") or {}
156
+ authors = paper.get("authors") or []
157
+ author_names = [a.get("name", "") for a in authors if isinstance(a, dict)]
158
+
159
+ ai_keywords = paper.get("ai_keywords") or []
160
+ if isinstance(ai_keywords, list):
161
+ ai_keywords_text = " ".join(str(k) for k in ai_keywords)
162
+ else:
163
+ ai_keywords_text = str(ai_keywords)
164
+
165
+ parts = [
166
+ item.get("title"),
167
+ item.get("summary"),
168
+ paper.get("title"),
169
+ paper.get("summary"),
170
+ paper.get("ai_summary"),
171
+ ai_keywords_text,
172
+ " ".join(author_names),
173
+ paper.get("id"),
174
+ paper.get("projectPage"),
175
+ paper.get("githubRepo"),
176
+ ]
177
+
178
+ text = " ".join(str(part) for part in parts if part)
179
+ return text.lower()
180
 
181
 
182
  def _matches_query(item: dict[str, Any], query: str) -> bool:
183
+ tokens = [t for t in re.split(r"\s+", query.strip().lower()) if t]
184
+ if not tokens:
185
+ return True
186
+ haystack = _extract_search_blob(item)
187
+ return all(token in haystack for token in tokens)
188
+
189
+
190
+ def _clamp_total_fetch(pages: int, per_page: int) -> tuple[int, int]:
191
+ if per_page * pages <= MAX_TOTAL_FETCH:
192
+ return pages, per_page
193
+ if per_page > MAX_TOTAL_FETCH:
194
+ return 1, MAX_TOTAL_FETCH
195
+ max_pages = max(MAX_TOTAL_FETCH // per_page, 1)
196
+ return min(pages, max_pages), per_page
197
 
198
 
199
  def hf_papers_search(
200
+ query: str | None = None,
201
+ *,
202
+ date: str | None = None,
203
+ week: str | None = None,
204
+ month: str | None = None,
205
+ submitter: str | None = None,
206
+ sort: str | None = None,
207
+ limit: int | None = None,
208
+ page: int | None = None,
209
+ max_pages: int | None = None,
210
+ api_limit: int | None = None,
211
  ) -> dict[str, Any]:
212
+ """
213
+ Search Hugging Face Daily Papers with optional local filtering.
214
+
215
+ Args:
216
+ query: Case-insensitive keyword search across title, summary, authors,
217
+ AI summary/keywords, project page, repo link, and paper id.
218
+ date: ISO date (YYYY-MM-DD).
219
+ week: ISO week (YYYY-Www).
220
+ month: ISO month (YYYY-MM).
221
+ submitter: HF username of the submitter.
222
+ sort: "publishedAt" or "trending".
223
+ limit: Max results to return after filtering (default 20).
224
+ page: Page index for the API (default 0).
225
+ max_pages: Number of pages to fetch for local filtering (default 1).
226
+ api_limit: Page size for the API (default 50, max 100).
227
+
228
+ Returns:
229
+ dict with query metadata and list of daily paper entries.
230
+ """
231
+ resolved_limit = _coerce_int("limit", limit, default=_max_results_from_env())
232
+ if resolved_limit < 1:
233
+ raise ValueError("limit must be >= 1.")
234
+
235
+ start_page = _coerce_int("page", page, default=0)
236
+ if start_page < 0:
237
+ raise ValueError("page must be >= 0.")
238
+
239
+ pages_to_fetch = _coerce_int("max_pages", max_pages, default=1)
240
+ if pages_to_fetch < 1:
241
+ raise ValueError("max_pages must be >= 1.")
242
+ pages_to_fetch = min(pages_to_fetch, MAX_PAGES)
243
+
244
+ per_page = _coerce_int("api_limit", api_limit, default=50)
245
+ if per_page < 1:
246
+ raise ValueError("api_limit must be >= 1.")
247
+ per_page = min(per_page, MAX_API_LIMIT)
248
+
249
+ pages_to_fetch, per_page = _clamp_total_fetch(pages_to_fetch, per_page)
250
+
251
+ normalized_query = _normalize_query(query)
252
+
253
+ params_base: dict[str, Any] = {
254
+ "date": _normalize_date_param("date", date, DATE_RE),
255
+ "week": _normalize_date_param("week", week, WEEK_RE),
256
+ "month": _normalize_date_param("month", month, MONTH_RE),
257
+ "submitter": _normalize_submitter(submitter),
258
+ "sort": _normalize_sort(sort),
259
+ "limit": per_page,
260
+ }
261
+
262
+ results: list[dict[str, Any]] = []
263
+ pages_fetched = 0
264
+ for page_index in range(start_page, start_page + pages_to_fetch):
265
+ params = {**params_base, "p": page_index}
266
+ url = _build_url(params)
267
+ payload = _request_json(url)
268
+ pages_fetched += 1
269
+
270
+ if normalized_query:
271
+ filtered = [item for item in payload if _matches_query(item, normalized_query)]
272
+ else:
273
+ filtered = payload
274
+
275
+ results.extend(filtered)
276
+ if len(results) >= resolved_limit:
277
+ break
278
+
279
+ return {
280
+ "query": normalized_query,
281
+ "params": {
282
+ **{k: v for k, v in params_base.items() if v is not None},
283
+ "page": start_page,
284
+ "max_pages": pages_fetched,
285
+ "api_limit": per_page,
286
+ },
287
+ "returned": min(len(results), resolved_limit),
288
+ "data": results[:resolved_limit],
289
+ }