| from crewai.tools import BaseTool |
| from pydantic import BaseModel, Field |
| from typing import List, Dict, Optional, Any |
| import os |
| import json |
| import torch |
| import faiss |
| import numpy as np |
| from pathlib import Path |
| from .specter2_embedder import embed_texts_specter2 |
|
|
| |
| class PubmedQueryInput(BaseModel): |
| caption: str |
|
|
| |
| class PubmedRetrievalTool(BaseTool): |
| |
| name: str = "pubmed_retrieval_tool" |
| description: str = ( |
| "Retrieves the most relevant PubMed articles for a given radiology caption." |
| ) |
| args_schema: type = PubmedQueryInput |
| metadata: dict = {} |
| |
| def __init__(self, **data): |
| |
| super().__init__(**data) |
| |
| def _run(self, caption: str = None, **kwargs) -> str: |
| """ |
| Retrieves relevant PubMed articles based on a radiology caption. |
| """ |
| |
| if not caption and 'caption' in kwargs: |
| caption = kwargs['caption'] |
| |
| |
| if not caption or not str(caption).strip(): |
| return "Error: No caption provided. Unable to search PubMed." |
| |
| caption = str(caption).strip() |
| |
| |
| BASE_DIR = Path(__file__).parent.parent.parent |
| default_data_dir = str(BASE_DIR / "data") |
| |
| data_dir = self.metadata.get("DATA_DIR", default_data_dir) |
| top_k = self.metadata.get("TOP_K", 3) |
| |
| try: |
| |
| index_path = os.path.join(data_dir, "text_faiss.bin") |
| metadata_path = os.path.join(data_dir, "raw_abstracts.jsonl") |
| |
| |
| if not os.path.exists(index_path): |
| return f"Error: FAISS index not found at {index_path}" |
| if not os.path.exists(metadata_path): |
| return f"Error: Metadata file not found at {metadata_path}" |
| |
| |
| index = faiss.read_index(index_path) |
| |
| with open(metadata_path, "r", encoding="utf-8") as f: |
| metadata = [json.loads(line) for line in f] |
| |
| |
| query_vec = embed_texts_specter2([caption]).astype("float32") |
| |
| scores, indices = index.search(query_vec, top_k) |
| |
| |
| formatted = [] |
| for i, (score, idx) in enumerate(zip(scores[0], indices[0]), 1): |
| entry = metadata[idx] |
| formatted.append( |
| f"Citation {i}:\n" |
| f"PMID: {entry.get('pmid', 'Unknown')}\n" |
| f"Similarity Score: {score:.3f}\n" |
| f"Title: {entry.get('title', 'Untitled').strip()}\n" |
| f"Abstract: {entry.get('abstract', 'No abstract available.').strip()}\n" |
| ) |
| |
| |
| return "\n---\n".join(formatted) |
| |
| except Exception as e: |
| |
| return f"Error during PubMed search: {str(e)}" |