manpreet88 commited on
Commit
3ac8e96
·
1 Parent(s): 8227c1a

Delete rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +0 -780
rag_pipeline.py DELETED
@@ -1,780 +0,0 @@
1
- # src/rag_pipeline.py
2
- # -*- coding: utf-8 -*-
3
- """
4
- Polymer RAG pipeline (robust edition)
5
-
6
- Features:
7
- - Fetch OA PDFs from OpenAlex + arXiv + Europe PMC (no API keys required).
8
- - Parallel downloads with retries/backoff; de-dup via SHA256; manifest.jsonl to resume.
9
- - Rich metadata attached to saved PDFs.
10
- - BM25 + Vector ensemble via local RRF fusion.
11
- - Embeddings: "sentence-transformers/all-mpnet-base-v2" (default) or "intfloat/e5-large-v2"
12
- with correct query/passage prefixing handled for you.
13
- - Vector store: Chroma (default) or FAISS (optional).
14
- """
15
- from __future__ import annotations
16
- import os
17
- import re
18
- import time
19
- import json
20
- import hashlib
21
- import pathlib
22
- import tempfile
23
- from typing import List, Optional, Dict, Any, Union
24
- from concurrent.futures import ThreadPoolExecutor, as_completed
25
-
26
- import requests
27
- from tqdm import tqdm
28
-
29
- # LangChain / community (expect these installed)
30
- from langchain_community.vectorstores import Chroma
31
- from langchain_community.embeddings import HuggingFaceEmbeddings
32
- from langchain_text_splitters import RecursiveCharacterTextSplitter
33
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
34
- from langchain_community.retrievers import BM25Retriever
35
-
36
- # --------------------------------------------------------------------------------------
37
- # Config
38
- # --------------------------------------------------------------------------------------
39
-
40
- ARXIV_SEARCH_URL = "http://export.arxiv.org/api/query"
41
- OPENALEX_WORKS_URL = "https://api.openalex.org/works"
42
- EPMC_SEARCH_URL = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
43
-
44
- DEFAULT_PERSIST_DIR = "chroma_polymer_db"
45
- DEFAULT_TMP_DOWNLOAD_DIR = os.path.join(tempfile.gettempdir(), "polymer_rag_pdfs")
46
- MANIFEST_NAME = "manifest.jsonl"
47
-
48
- # default set of polymer-related keywords (expandable)
49
- POLYMER_KEYWORDS = [
50
- "polymer", "macromolecule", "macromolecular", "polymeric",
51
- "polymer informatics", "polymer chemistry", "polymer physics",
52
- "PSMILES", "pSMILES", "BigSMILES", "polymer SMILES", "polymer sequence",
53
- "foundation model", "self-supervised", "masked language model", "transformer",
54
- "polymer electrolyte", "polymer morphology", "generative model polymer",
55
- ]
56
-
57
- # polite defaults
58
- DEFAULT_MAILTO = "your_email@example.com" # replace if you like
59
-
60
- # --------------------------------------------------------------------------------------
61
- # Utility helpers (filenames, hashing, manifest)
62
- # --------------------------------------------------------------------------------------
63
-
64
-
65
- def _sha256_bytes(data: bytes) -> str:
66
- return hashlib.sha256(data).hexdigest()
67
-
68
-
69
- def _safe_filename(name: str) -> str:
70
- name = str(name or "").strip().replace("/", "_").replace("\\", "_")
71
- name = re.sub(r"[^a-zA-Z0-9._ -]+", "_", name)
72
- return name[:200]
73
-
74
-
75
- def _is_probably_pdf(raw: bytes, content_type: str = "") -> bool:
76
- if not raw:
77
- return False
78
- if raw[:4] == b"%PDF":
79
- return True
80
- return "pdf" in (content_type or "").lower()
81
-
82
-
83
- def _ensure_dir(path: str) -> None:
84
- os.makedirs(path, exist_ok=True)
85
-
86
-
87
- def _append_manifest(out_dir: str, record: Dict[str, Any]) -> None:
88
- try:
89
- _ensure_dir(out_dir)
90
- with open(os.path.join(out_dir, MANIFEST_NAME), "a", encoding="utf-8") as f:
91
- f.write(json.dumps(record, ensure_ascii=False) + "\n")
92
- except Exception:
93
- pass
94
-
95
-
96
- def _load_manifest(out_dir: str) -> Dict[str, Dict[str, Any]]:
97
- data: Dict[str, Dict[str, Any]] = {}
98
- try:
99
- mpath = os.path.join(out_dir, MANIFEST_NAME)
100
- if not os.path.exists(mpath):
101
- return data
102
- with open(mpath, "r", encoding="utf-8") as f:
103
- for line in f:
104
- try:
105
- rec = json.loads(line)
106
- p = rec.get("path")
107
- if p:
108
- data[p] = rec
109
- except Exception:
110
- continue
111
- except Exception:
112
- pass
113
- return data
114
-
115
-
116
- # --------------------------------------------------------------------------------------
117
- # Downloading PDFs (single + parallel with retry)
118
- # --------------------------------------------------------------------------------------
119
-
120
-
121
- def download_pdf(url: str, out_dir: str, suggested_name: Optional[str] = None,
122
- timeout: int = 60, meta: Optional[Dict[str, Any]] = None) -> Optional[str]:
123
- """
124
- Download a PDF and return local file path, or None on failure.
125
- Deduplicates by SHA256 content hash. Writes manifest record if meta provided.
126
- """
127
- try:
128
- headers = {"User-Agent": f"polymer-rag/1.0 (+{DEFAULT_MAILTO})"}
129
- with requests.get(url, headers=headers, timeout=timeout, stream=True, allow_redirects=True) as r:
130
- r.raise_for_status()
131
- content_type = r.headers.get("Content-Type", "")
132
- raw = r.content
133
- if not raw or not _is_probably_pdf(raw, content_type):
134
- return None
135
-
136
- sha = _sha256_bytes(raw)
137
- _ensure_dir(out_dir)
138
-
139
- # dedup by saved files having hash prefix
140
- existing = list(pathlib.Path(out_dir).glob(f"*{sha[:16]}*.pdf"))
141
- if existing:
142
- path = str(existing[0])
143
- if meta:
144
- rec = dict(meta)
145
- rec.update({"sha256": sha, "path": path})
146
- _append_manifest(out_dir, rec)
147
- return path
148
-
149
- base = suggested_name or pathlib.Path(url).name or "paper.pdf"
150
- base = _safe_filename(base)
151
- if not base.lower().endswith(".pdf"):
152
- base += ".pdf"
153
- fname = f"{sha[:16]}_{base}"
154
- fpath = os.path.join(out_dir, fname)
155
- with open(fpath, "wb") as f:
156
- f.write(raw)
157
-
158
- if meta:
159
- rec = dict(meta)
160
- rec.update({"sha256": sha, "path": fpath})
161
- _append_manifest(out_dir, rec)
162
- return fpath
163
- except Exception:
164
- return None
165
-
166
-
167
- def _retry(fn, *args, _retries=3, _sleep=0.6, **kwargs):
168
- for i in range(_retries):
169
- out = fn(*args, **kwargs)
170
- if out:
171
- return out
172
- time.sleep(_sleep * (2 ** i))
173
- return None
174
-
175
-
176
- def _download_one(entry: Union[str, Dict[str, Any]], out_dir: str):
177
- if isinstance(entry, dict):
178
- return download_pdf(entry["url"], out_dir, suggested_name=entry.get("name"), meta=entry.get("meta"))
179
- return download_pdf(entry, out_dir)
180
-
181
-
182
- def parallel_download_pdfs(entries: List[Union[str, Dict[str, Any]]], out_dir: str, max_workers: int = 12) -> List[str]:
183
- _ensure_dir(out_dir)
184
- results = []
185
- with ThreadPoolExecutor(max_workers=max_workers) as ex:
186
- futs = [ex.submit(_retry, _download_one, e, out_dir) for e in entries]
187
- for f in tqdm(as_completed(futs), total=len(futs), desc="Downloading PDFs (parallel)"):
188
- p = f.result()
189
- if p:
190
- results.append(p)
191
- return results
192
-
193
-
194
- # --------------------------------------------------------------------------------------
195
- # arXiv helper (robust)
196
- # --------------------------------------------------------------------------------------
197
-
198
-
199
- def _arxiv_query_from_keywords(keywords: List[str]) -> str:
200
- kw = [k.replace('"', '') for k in keywords]
201
- terms = " OR ".join([f'ti:"{k}"' for k in kw] + [f'abs:"{k}"' for k in kw])
202
- cats = "(cat:cond-mat.mtrl-sci OR cat:cond-mat.soft OR cat:physics.chem-ph OR cat:cs.LG OR cat:stat.ML)"
203
- return f"({terms}) AND {cats}"
204
-
205
-
206
- def fetch_arxiv_pdf_urls(keywords: List[str], max_results: int = 200) -> List[str]:
207
- """
208
- Extract explicit /pdf/ links and fallback to building from <id> entries.
209
- """
210
- query = _arxiv_query_from_keywords(keywords)
211
- params = {
212
- "search_query": query,
213
- "start": 0,
214
- "max_results": max_results,
215
- "sortBy": "submittedDate",
216
- "sortOrder": "descending",
217
- }
218
- headers = {"User-Agent": f"polymer-rag/1.0 (+{DEFAULT_MAILTO})"}
219
- try:
220
- resp = requests.get(ARXIV_SEARCH_URL, params=params, headers=headers, timeout=60)
221
- resp.raise_for_status()
222
- xml = resp.text
223
- except Exception:
224
- return []
225
-
226
- pdfs = []
227
- seen = set()
228
- # explicit /pdf/ hrefs
229
- for p in re.findall(r'href="(https?://arxiv\.org/pdf/[^"]+)"', xml):
230
- if p not in seen:
231
- pdfs.append(p); seen.add(p)
232
- # fallback: build from <id> entries
233
- for aid in re.findall(r'<id>(https?://arxiv\.org/abs/[^<]+)</id>', xml):
234
- m = re.search(r'arxiv\.org\/abs\/([^/]+)(?:/v\d+)?$', aid)
235
- if m:
236
- identifier = m.group(1)
237
- pdf = f"https://arxiv.org/pdf/{identifier}.pdf"
238
- if pdf not in seen:
239
- pdfs.append(pdf); seen.add(pdf)
240
- return pdfs
241
-
242
-
243
- def fetch_arxiv_pdfs(keywords: List[str], out_dir: str, max_results: int = 200, polite_delay: float = 0.25) -> List[str]:
244
- urls = fetch_arxiv_pdf_urls(keywords, max_results=max_results)
245
- entries = [{"url": u, "name": u.rstrip("/").split("/")[-1], "meta": {"source": "arxiv", "url": u}} for u in urls]
246
- paths = parallel_download_pdfs(entries, out_dir, max_workers=8)
247
- # small pause
248
- time.sleep(polite_delay)
249
- return paths
250
-
251
-
252
- # --------------------------------------------------------------------------------------
253
- # OpenAlex (robust, fallback strategies)
254
- # --------------------------------------------------------------------------------------
255
-
256
-
257
- def _openalex_build_search_query(keywords: List[str]) -> str:
258
- return " ".join(sorted(set(keywords), key=str.lower))
259
-
260
-
261
- def _openalex_fetch_works_try(search: str, filter_str: str, per_page: int, page: int, mailto: Optional[str]) -> Dict[str, Any]:
262
- headers = {"User-Agent": f"polymer-rag/1.0 (+{mailto or DEFAULT_MAILTO})"}
263
- params = {
264
- "search": search,
265
- "per-page": per_page,
266
- "per_page": per_page,
267
- "page": page,
268
- "sort": "publication_date:desc",
269
- }
270
- if filter_str:
271
- params["filter"] = filter_str
272
- if mailto:
273
- params["mailto"] = mailto
274
- resp = requests.get(OPENALEX_WORKS_URL, params=params, headers=headers, timeout=60)
275
- resp.raise_for_status()
276
- return resp.json()
277
-
278
-
279
- def _openalex_fetch_works(keywords: List[str], max_results: int = 2000, per_page: int = 200, mailto: Optional[str] = None) -> List[Dict[str, Any]]:
280
- """
281
- Try multiple query forms:
282
- - combined-space query
283
- - OR-joined query
284
- - single-keyword fallback
285
- Also retries with relaxed filters if needed.
286
- """
287
- kws = sorted(set(keywords or []), key=str.lower)
288
- # prepare query forms
289
- combined = " ".join(kws)
290
- or_query = " OR ".join(kws)
291
- singles = kws or ["polymer"]
292
-
293
- attempts = [
294
- {"q": combined, "filter": "is_oa:true,language:en"},
295
- {"q": or_query, "filter": "is_oa:true,language:en"},
296
- {"q": or_query, "filter": "is_oa:true"},
297
- {"q": or_query, "filter": ""}, # no filters
298
- ]
299
- # append single-key fallback attempts
300
- for s in singles[:3]:
301
- attempts.append({"q": s, "filter": ""})
302
-
303
- works: List[Dict[str, Any]] = []
304
- for attempt in attempts:
305
- search = attempt["q"]
306
- filter_str = attempt["filter"]
307
- page = 1
308
- # iterate pages
309
- while len(works) < max_results:
310
- try:
311
- data = _openalex_fetch_works_try(search, filter_str, per_page, page, mailto or DEFAULT_MAILTO)
312
- except Exception as e:
313
- print(f"[WARN] OpenAlex request failed for search='{search}' filter='{filter_str}': {e}")
314
- break
315
- results = data.get("results", [])
316
- print(f"[DEBUG] OpenAlex (search='{search[:120]}...' filter='{filter_str}') page={page} got {len(results)} results (total so far {len(works)})")
317
- if page == 1 and results:
318
- print("[DEBUG] sample result keys:", list(results[0].keys()))
319
- if not results:
320
- break
321
- works.extend(results)
322
- if len(results) < per_page:
323
- break
324
- page += 1
325
- time.sleep(0.12)
326
- if len(works) >= max_results:
327
- break
328
- if works:
329
- break
330
- # cap to max_results
331
- return works[:max_results]
332
-
333
-
334
- def _openalex_extract_pdf_entries(works: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
335
- """
336
- Extract candidate PDF URLs and name hints from OpenAlex works.
337
- Returns entries like {"url": pdf_url, "name": name, "meta": {...}}
338
- """
339
- out = []
340
- seen_urls = set()
341
- for w in works:
342
- pdf = ""
343
- # best_oa_location
344
- best = w.get("best_oa_location") or {}
345
- if isinstance(best, dict):
346
- pdf = best.get("pdf_url") or best.get("url_for_pdf") or best.get("url") or ""
347
- # primary_location
348
- if not pdf:
349
- pl = w.get("primary_location") or {}
350
- if isinstance(pl, dict):
351
- pdf = pl.get("pdf_url") or pl.get("url_for_pdf") or pl.get("landing_page_url") or ""
352
- # open_access fallback
353
- if not pdf:
354
- oa = w.get("open_access") or {}
355
- if isinstance(oa, dict):
356
- pdf = oa.get("oa_url") or oa.get("oa_url_for_pdf") or ""
357
- if not pdf:
358
- continue
359
- if pdf in seen_urls:
360
- continue
361
- seen_urls.add(pdf)
362
- title = (w.get("title") or w.get("display_name") or "").strip()
363
- year = w.get("publication_year") or w.get("publication_date") or ""
364
- venue = ""
365
- pl = w.get("primary_location") or {}
366
- if isinstance(pl, dict):
367
- venue = (pl.get("source") or {}).get("display_name") or ""
368
- if not venue:
369
- venue = ((w.get("host_venue") or {}).get("display_name") or "").strip()
370
- name = " - ".join([s for s in [title, venue, str(year or "")] if s])
371
- meta = {"title": title, "year": year, "venue": venue, "source": "openalex"}
372
- out.append({"url": pdf, "name": name, "meta": meta})
373
- return out
374
-
375
-
376
- def fetch_openalex_pdfs(keywords: List[str], out_dir: str, max_results: int = 2000, per_page: int = 200, mailto: Optional[str] = None) -> List[str]:
377
- works = _openalex_fetch_works(keywords, max_results=max_results, per_page=per_page, mailto=mailto)
378
- if not works:
379
- print("[INFO] OpenAlex returned no works for given queries/filters.")
380
- return []
381
- entries = _openalex_extract_pdf_entries(works)
382
- if not entries:
383
- print("[INFO] OpenAlex works found, but no PDF links extracted.")
384
- return []
385
- print(f"[INFO] OpenAlex: {len(entries)} candidate PDF URLs extracted (will attempt download).")
386
- paths = parallel_download_pdfs(entries, out_dir, max_workers=16)
387
- return paths
388
-
389
-
390
- # --------------------------------------------------------------------------------------
391
- # Europe PMC fetching (additional OA source)
392
- # --------------------------------------------------------------------------------------
393
-
394
-
395
- def _epmc_query_from_keywords(keywords: List[str]) -> str:
396
- # build a simple AND/OR query that Europe PMC understands; keep compact
397
- q = " OR ".join([f'"{k}"' for k in keywords])
398
- return q
399
-
400
-
401
- def _epmc_extract_pdf_entries_from_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
402
- out = []
403
- seen = set()
404
- for r in results:
405
- # Europe PMC 'fullTextUrlList' or 'fullTextUrl'
406
- ftl = r.get("fullTextUrlList") or {}
407
- urls = []
408
- # fullTextUrlList -> fullTextUrl is list of dicts with 'url' and 'documentStyle'
409
- if isinstance(ftl, dict):
410
- for ful in (ftl.get("fullTextUrl") or []):
411
- if isinstance(ful, dict):
412
- u = ful.get("url") or ""
413
- if u:
414
- urls.append(u)
415
- # direct 'fullTextUrl' string
416
- if not urls:
417
- fu = r.get("fullTextUrl")
418
- if isinstance(fu, str) and fu:
419
- urls.append(fu)
420
- # also check 'doi' -> build DOI resolver landing page (not direct PDF) - skip for now
421
- for u in urls:
422
- if not u:
423
- continue
424
- if u in seen:
425
- continue
426
- seen.add(u)
427
- title = (r.get("title") or "").strip()
428
- year = r.get("firstPublicationDate") or r.get("pubYear") or ""
429
- name = " - ".join([s for s in [title, str(year or "")] if s])
430
- out.append({"url": u, "name": name, "meta": {"title": title, "year": year, "source": "epmc"}})
431
- return out
432
-
433
-
434
- def fetch_epmc_pdfs(keywords: List[str], out_dir: str, max_results: int = 1000, page_size: int = 25) -> List[str]:
435
- """
436
- Query Europe PMC and extract fullTextUrlList entries. Europe PMC often contains links to
437
- PMC fulltext pages, publisher pages, or direct PDFs. We attempt all and let download_pdf filter for PDFs.
438
- """
439
- q = _epmc_query_from_keywords(keywords)
440
- params = {
441
- "query": q,
442
- "format": "json",
443
- "pageSize": page_size,
444
- "sort": "FIRST_PDATE_D desc",
445
- }
446
- headers = {"User-Agent": f"polymer-rag/1.0 (+{DEFAULT_MAILTO})"}
447
- saved = []
448
- cursor = 1
449
- total_fetched = 0
450
- while total_fetched < max_results:
451
- params["page"] = cursor
452
- try:
453
- resp = requests.get(EPMC_SEARCH_URL, params=params, headers=headers, timeout=30)
454
- resp.raise_for_status()
455
- data = resp.json()
456
- except Exception as e:
457
- print(f"[WARN] Europe PMC request failed: {e}")
458
- break
459
- results = data.get("resultList", {}).get("result", [])
460
- if not results:
461
- break
462
- entries = _epmc_extract_pdf_entries_from_results(results)
463
- if not entries:
464
- cursor += 1
465
- total_fetched += len(results)
466
- time.sleep(0.2)
467
- continue
468
- paths = parallel_download_pdfs(entries, out_dir, max_workers=8)
469
- saved.extend(paths)
470
- total_fetched += len(results)
471
- cursor += 1
472
- time.sleep(0.2)
473
- return saved
474
-
475
-
476
- # --------------------------------------------------------------------------------------
477
- # Embeddings: Smart wrapper for E5 prefixing
478
- # --------------------------------------------------------------------------------------
479
-
480
-
481
- class SmartHFEmbeddings(HuggingFaceEmbeddings):
482
- def __init__(self, model_name: str = "sentence-transformers/all-mpnet-base-v2", **kwargs):
483
- super().__init__(model_name=model_name, **kwargs)
484
- self._use_e5 = "e5" in (model_name or "").lower()
485
-
486
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
487
- if self._use_e5:
488
- texts = [f"passage: {t}" for t in texts]
489
- return super().embed_documents(texts)
490
-
491
- def embed_query(self, text: str) -> List[float]:
492
- if self._use_e5:
493
- text = f"query: {text}"
494
- return super().embed_query(text)
495
-
496
-
497
- # --------------------------------------------------------------------------------------
498
- # Local ensemble (RRF)
499
- # --------------------------------------------------------------------------------------
500
-
501
-
502
- class SimpleEnsembleRetriever:
503
- def __init__(self, retrievers, weights=None, k: int = 6, rrf_k: int = 60):
504
- assert retrievers, "At least one retriever required"
505
- self.retrievers = retrievers
506
- self.weights = weights or [1.0] * len(retrievers)
507
- assert len(self.weights) == len(self.retrievers)
508
- self.k = k
509
- self.rrf_k = rrf_k
510
-
511
- def _run_retriever(self, retriever, query: str):
512
- if hasattr(retriever, "get_relevant_documents"):
513
- return retriever.get_relevant_documents(query)
514
- if hasattr(retriever, "invoke"):
515
- return retriever.invoke(query)
516
- if callable(retriever):
517
- return retriever(query)
518
- if hasattr(retriever, "_get_relevant_documents"):
519
- try:
520
- return retriever._get_relevant_documents(query, run_manager=None)
521
- except TypeError:
522
- try:
523
- return retriever._get_relevant_documents(query)
524
- except TypeError:
525
- pass
526
- raise TypeError(f"Unsupported retriever interface: {type(retriever)}")
527
-
528
- def get_relevant_documents(self, query: str):
529
- all_lists = []
530
- for r in self.retrievers:
531
- docs = self._run_retriever(r, query)
532
- all_lists.append(docs or [])
533
- scores: Dict[int, float] = {}
534
- index_map: Dict[int, Any] = {}
535
-
536
- def doc_key(doc):
537
- meta = getattr(doc, "metadata", {}) or {}
538
- src = meta.get("source", "")
539
- page = str(meta.get("page", ""))
540
- text = (getattr(doc, "page_content", "") or "")[:500]
541
- return f"{src}|{page}|{hash(text)}"
542
-
543
- key_to_idx: Dict[str, int] = {}
544
- next_idx = 0
545
-
546
- for w, docs in zip(self.weights, all_lists):
547
- for rank, doc in enumerate(docs):
548
- key = doc_key(doc)
549
- if key not in key_to_idx:
550
- key_to_idx[key] = next_idx
551
- index_map[next_idx] = doc
552
- next_idx += 1
553
- idx = key_to_idx[key]
554
- scores[idx] = scores.get(idx, 0.0) + w * (1.0 / (self.rrf_k + rank + 1))
555
-
556
- ranked = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
557
- return [index_map[i] for i, _ in ranked[: self.k]]
558
-
559
-
560
- # --------------------------------------------------------------------------------------
561
- # Builder: load PDFs, chunk, index (Chroma / FAISS)
562
- # --------------------------------------------------------------------------------------
563
-
564
-
565
- def _attach_extra_metadata_from_manifest(docs: List[Any], manifest: Dict[str, Dict[str, Any]]) -> None:
566
- for d in docs:
567
- src_path = d.metadata.get("source", "") # some loaders store source path in metadata
568
- if not src_path:
569
- continue
570
- rec = manifest.get(src_path)
571
- if not rec:
572
- # try basename match
573
- for k, v in manifest.items():
574
- if os.path.basename(k) == os.path.basename(src_path):
575
- rec = v
576
- break
577
- if rec:
578
- for k in ("title", "year", "venue", "url", "source"):
579
- if k in rec:
580
- d.metadata[k] = rec[k]
581
-
582
-
583
- def _split_and_build_retriever(
584
- documents_dir: str,
585
- persist_dir: Optional[str] = None,
586
- k: int = 6,
587
- embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
588
- vector_backend: str = "chroma",
589
- min_chunk_chars: int = 200,
590
- ):
591
- print(f"🗂️ Loading PDFs from: {documents_dir}")
592
- loader = DirectoryLoader(documents_dir, glob="**/*.pdf", loader_cls=PyPDFLoader, show_progress=True, use_multithreading=True)
593
- docs = loader.load()
594
- if not docs:
595
- raise RuntimeError("No PDF documents found to index.")
596
- manifest = _load_manifest(documents_dir)
597
- _attach_extra_metadata_from_manifest(docs, manifest)
598
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1600, chunk_overlap=250, length_function=len, separators=["\n\n", "\n", " ", ""])
599
- documents = text_splitter.split_documents(docs)
600
- documents = [d for d in documents if len(d.page_content or "") >= min_chunk_chars]
601
- bm25_retriever = BM25Retriever.from_documents(documents)
602
- bm25_retriever.k = max(k, 8)
603
- print(f"🔤 Using embeddings model: {embedding_model}")
604
- embeddings = SmartHFEmbeddings(model_name=embedding_model)
605
- if vector_backend.lower() == "chroma":
606
- if persist_dir:
607
- print(f"💾 Using Chroma persist_dir={persist_dir}")
608
- vector_store = Chroma.from_documents(documents, embeddings, persist_directory=persist_dir)
609
- try:
610
- vector_store.persist()
611
- except Exception:
612
- pass
613
- else:
614
- vector_store = Chroma.from_documents(documents, embeddings)
615
- elif vector_backend.lower() == "faiss":
616
- try:
617
- from langchain_community.vectorstores import FAISS
618
- except Exception as e:
619
- raise RuntimeError("FAISS requested but not available; pip install faiss-cpu") from e
620
- vector_store = FAISS.from_documents(documents, embeddings)
621
- else:
622
- raise ValueError("vector_backend must be 'chroma' or 'faiss'")
623
- vector_retriever = vector_store.as_retriever(search_kwargs={"k": k})
624
- ensemble = SimpleEnsembleRetriever(retrievers=[bm25_retriever, vector_retriever], weights=[0.45, 0.55], k=k)
625
- print("✅ RAG KB ready (BM25 + Vector ensemble).")
626
- return ensemble
627
-
628
-
629
- # --------------------------------------------------------------------------------------
630
- # High-level fetch builder that uses multiple sources and targets a large total
631
- # --------------------------------------------------------------------------------------
632
-
633
-
634
- def build_retriever_from_web(
635
- polymer_keywords: Optional[List[str]] = None,
636
- max_openalex: int = 3000,
637
- max_arxiv: int = 1000,
638
- max_epmc: int = 1000,
639
- max_total_pdfs: int = 5000,
640
- from_year: int = 2010,
641
- extra_pdf_urls: Optional[List[str]] = None,
642
- persist_dir: str = DEFAULT_PERSIST_DIR,
643
- tmp_download_dir: str = DEFAULT_TMP_DOWNLOAD_DIR,
644
- k: int = 6,
645
- embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
646
- vector_backend: str = "chroma",
647
- mailto: Optional[str] = None,
648
- ):
649
- polymer_keywords = sorted(set(polymer_keywords or POLYMER_KEYWORDS), key=str.lower)
650
- print("📡 Fetching polymer PDFs from OpenAlex, arXiv, Europe PMC and extras...")
651
- _ensure_dir(tmp_download_dir)
652
- all_paths: List[str] = []
653
- seen_urls = set()
654
-
655
- # 1) OpenAlex (largest coverage) - fetch works then extract PDF links
656
- try:
657
- openalex_paths = fetch_openalex_pdfs(polymer_keywords, out_dir=tmp_download_dir, max_results=max_openalex, per_page=200, mailto=mailto)
658
- for p in openalex_paths:
659
- if p not in all_paths:
660
- all_paths.append(p)
661
- except Exception as e:
662
- print(f"[WARN] OpenAlex fetch error: {e}")
663
-
664
- # 2) arXiv (good specialized coverage)
665
- try:
666
- arxiv_paths = fetch_arxiv_pdfs(polymer_keywords, out_dir=tmp_download_dir, max_results=max_arxiv)
667
- for p in arxiv_paths:
668
- if p not in all_paths:
669
- all_paths.append(p)
670
- except Exception as e:
671
- print(f"[WARN] arXiv fetch error: {e}")
672
-
673
- # 3) Europe PMC
674
- try:
675
- epmc_paths = fetch_epmc_pdfs(polymer_keywords, out_dir=tmp_download_dir, max_results=max_epmc)
676
- for p in epmc_paths:
677
- if p not in all_paths:
678
- all_paths.append(p)
679
- except Exception as e:
680
- print(f"[WARN] Europe PMC fetch error: {e}")
681
-
682
- # 4) Extra URLs
683
- if extra_pdf_urls:
684
- extra_entries = [{"url": u, "name": None, "meta": {"url": u, "source": "extra"}} for u in extra_pdf_urls]
685
- extra_paths = parallel_download_pdfs(extra_entries, tmp_download_dir, max_workers=8)
686
- for p in extra_paths:
687
- if p not in all_paths:
688
- all_paths.append(p)
689
-
690
- # If not enough, attempt incremental fallback: try single-key searches and looser search forms
691
- total_found = len(all_paths)
692
- print(f"🔎 Initial fetched PDFs: {total_found}")
693
- if total_found < max_total_pdfs:
694
- print("[INFO] Not enough PDFs yet — attempting additional looser searches (OR-joined single-key fallbacks).")
695
- # Use single keywords to expand
696
- for kw in polymer_keywords:
697
- if len(all_paths) >= max_total_pdfs:
698
- break
699
- try:
700
- aa = fetch_openalex_pdfs([kw], out_dir=tmp_download_dir, max_results=200, per_page=200, mailto=mailto)
701
- for p in aa:
702
- if p not in all_paths:
703
- all_paths.append(p)
704
- time.sleep(0.12)
705
- except Exception:
706
- continue
707
-
708
- total = len(all_paths)
709
- print(f"✅ Downloaded {total} PDFs (OpenAlex/arXiv/EuropePMC/extra).")
710
- if total == 0:
711
- raise RuntimeError("No PDFs fetched. Adjust keywords or add extra_pdf_urls.")
712
-
713
- print("🧠 Building knowledge base from downloaded PDFs...")
714
- retriever = _split_and_build_retriever(documents_dir=tmp_download_dir, persist_dir=persist_dir, k=k, embedding_model=embedding_model, vector_backend=vector_backend)
715
- return retriever
716
-
717
-
718
- # --------------------------------------------------------------------------------------
719
- # Local builder from existing folder
720
- # --------------------------------------------------------------------------------------
721
-
722
-
723
- def build_retriever(
724
- papers_path: str,
725
- persist_dir: Optional[str] = DEFAULT_PERSIST_DIR,
726
- k: int = 6,
727
- embedding_model: str = "sentence-transformers/all-mpnet-base-v2",
728
- vector_backend: str = "chroma",
729
- ):
730
- print("📚 Building RAG knowledge base from local PDFs...")
731
- return _split_and_build_retriever(documents_dir=papers_path, persist_dir=persist_dir, k=k, embedding_model=embedding_model, vector_backend=vector_backend)
732
-
733
-
734
- # --------------------------------------------------------------------------------------
735
- # Convenience wrapper
736
- # --------------------------------------------------------------------------------------
737
-
738
-
739
- def build_retriever_polymer_foundation_models(
740
- persist_dir: str = DEFAULT_PERSIST_DIR,
741
- k: int = 6,
742
- from_year: int = 2015,
743
- vector_backend: str = "chroma",
744
- ):
745
- fm_kw = list(set(POLYMER_KEYWORDS + [
746
- "BigSMILES", "PSMILES", "polymer SMILES", "polymer language model",
747
- "foundation model polymer", "masked language model polymer",
748
- "self-supervised polymer", "generative polymer",
749
- "Perceiver polymer", "Performer polymer",
750
- "polymer sequence modeling", "representation learning polymer",
751
- ]))
752
- return build_retriever_from_web(polymer_keywords=fm_kw, max_openalex=4000, max_arxiv=800, max_epmc=800, max_total_pdfs=5000, from_year=from_year, persist_dir=persist_dir, k=k, vector_backend=vector_backend)
753
-
754
-
755
- # --------------------------------------------------------------------------------------
756
- # CLI smoke (example)
757
- # --------------------------------------------------------------------------------------
758
-
759
- if __name__ == "__main__":
760
- retriever = build_retriever_from_web(
761
- polymer_keywords=POLYMER_KEYWORDS,
762
- max_openalex=2000,
763
- max_arxiv=500,
764
- max_epmc=500,
765
- max_total_pdfs=1200,
766
- persist_dir="chroma_polymer_db_big",
767
- tmp_download_dir=DEFAULT_TMP_DOWNLOAD_DIR,
768
- k=6,
769
- embedding_model="intfloat/e5-large-v2",
770
- vector_backend="chroma",
771
- mailto=DEFAULT_MAILTO,
772
- )
773
- print("🔎 Sample query:")
774
- docs = retriever.get_relevant_documents("PSMILES polymer electrolyte design")
775
- for i, d in enumerate(docs, 1):
776
- meta = d.metadata or {}
777
- title = meta.get("title") or os.path.basename(meta.get("source", "")) or "document"
778
- year = meta.get("year", "")
779
- src = meta.get("source", "unknown")
780
- print(f"[{i}] {title} ({year}) [{src}] :: {(d.page_content or '')[:200]}...")