Extraction / biomed_annotator.py
glitz-dev
get_annotation with parameters of text and context added
4466fc1
import json
import re
import httpx
from typing import Optional, List, Literal, Any, Dict
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
# --- 1. Schemas ---
QuestionCategory = Literal["Clinical", "Mechanism", "Evidence", "Methods", "Limitations", "NextStep"]
class GeneratedQuestion(BaseModel):
category: QuestionCategory
question: str
evidence_quote: str
import os
from dotenv import load_dotenv
load_dotenv()
# --- 2. Configuration ---
def get_hf_token_from_cache() -> str:
"""Get HuggingFace token from local cache (from huggingface-cli login)"""
try:
from huggingface_hub import HfFolder
token = HfFolder.get_token()
if token:
print("[DEBUG] Found HuggingFace token from local cache")
return token
except ImportError:
print("[DEBUG] huggingface_hub not installed, cannot get token from cache")
except Exception as e:
print(f"[DEBUG] Could not get HF token from cache: {e}")
return ""
class Settings:
def __init__(self):
# LLM Provider: 'ollama', 'openai_compat', or 'huggingface'
self.llm_provider: str = os.getenv("LLM_PROVIDER", "huggingface")
# Ollama settings
self.ollama_base_url: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
self.ollama_model: str = os.getenv("OLLAMA_MODEL", "llama3.2") #qwen2.5:3b-instruct, lama3.2
self.ollama_timeout_s: int = int(os.getenv("OLLAMA_TIMEOUT_S", 300))
# OpenAI Compat settings
self.openai_compat_base_url: str = os.getenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080/v1")
self.openai_compat_model: str = os.getenv("OPENAI_COMPAT_MODEL", "gpt-4o")
self.openai_compat_api_key: str = os.getenv("OPENAI_COMPAT_API_KEY", "not-needed")
self.openai_compat_timeout_s: int = int(os.getenv("OPENAI_COMPAT_TIMEOUT_S", 120))
# HuggingFace Serverless Inference settings
self.hf_model: str = os.getenv("HF_MODEL", "microsoft/Phi-3-mini-4k-instruct")
# Try env var first, then fall back to local cache token
self.hf_api_key: str = os.getenv("HF_API_KEY", "") or get_hf_token_from_cache()
# Gen Settings
self.max_output_questions: int = int(os.getenv("MAX_OUTPUT_QUESTIONS", 6))
settings = Settings()
# --- 3. Prompts ---
SYSTEM_PROMPT = (
"You are a biomedical paper reading assistant. "
"Only use the provided text. Do not add external facts. "
"Every question MUST include an evidence_quote copied verbatim from the provided text."
)
def build_question_prompt(selected_text: str, context_text: str | None, section_title: str | None, page_start: int | None, page_end: int | None) -> str:
meta = []
if section_title:
meta.append(f"Section: {section_title}")
if page_start is not None:
meta.append(f"Pages: {page_start}-{page_end or page_start}")
meta_block = "\n".join(meta) if meta else "Section: Unknown"
context = (context_text or "").strip()
if not context:
context = selected_text.strip()
max_q = settings.max_output_questions
return f"""Task: Generate good questions from this paper excerpt.
Excerpt metadata:
{meta_block}
Highlighted text:
{selected_text.strip()}
Context (use this for grounding; do not go beyond it):
{context}
Output STRICT JSON with this schema:
{{
"questions": [
{{
"category": "Clinical|Mechanism|Evidence|Methods|Limitations|NextStep",
"question": "...",
"evidence_quote": "..."
}}
]
}}
Rules:
- Output {max_q} questions.
- Questions must be specific and actionable.
- evidence_quote MUST be a verbatim substring from the Context text.
"""
# --- 4. LLM Clients ---
class LLMError(RuntimeError):
pass
class BaseLLM:
def generate_json(self, system_prompt: str, user_prompt: str) -> str:
raise NotImplementedError
class OllamaLLM(BaseLLM):
def __init__(self, cfg: Settings):
self.base_url = cfg.ollama_base_url.rstrip("/")
self.model = cfg.ollama_model
self.timeout = cfg.ollama_timeout_s
@retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=0.5, max=2))
def generate_json(self, system_prompt: str, user_prompt: str) -> str:
url = f"{self.base_url}/api/generate"
payload = {
"model": self.model,
"prompt": user_prompt,
"system": system_prompt,
"format": "json",
"stream": False,
"options": {"temperature": 0.4, "top_p": 0.9, "num_predict": 700}
}
print(f"[DEBUG] Ollama request to {url} with model={self.model}")
try:
with httpx.Client(timeout=self.timeout) as client:
r = client.post(url, json=payload)
print(f"[DEBUG] Ollama response status: {r.status_code}")
if r.status_code != 200:
print(f"[DEBUG] Ollama error response: {r.text}")
r.raise_for_status()
data = r.json()
return data.get("response", "").strip()
except httpx.TimeoutException as e:
print(f"[DEBUG] Ollama timeout: {e}")
raise LLMError(f"Ollama generate timed out after {self.timeout}s: {e}")
except Exception as e:
print(f"[DEBUG] Ollama exception type={type(e).__name__}: {e}")
raise LLMError(f"Ollama generate failed: {e}")
class OpenAICompatLLM(BaseLLM):
def __init__(self, cfg: Settings):
self.base_url = cfg.openai_compat_base_url.rstrip("/")
self.model = cfg.openai_compat_model
self.api_key = cfg.openai_compat_api_key
self.timeout = cfg.openai_compat_timeout_s
@retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=0.5, max=2))
def generate_json(self, system_prompt: str, user_prompt: str) -> str:
url = f"{self.base_url}/chat/completions"
headers = {"Content-Type": "application/json"}
if self.api_key and self.api_key != "not-needed":
headers["Authorization"] = f"Bearer {self.api_key}"
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"temperature": 0.4,
"top_p": 0.9,
"max_tokens": 900,
"response_format": {"type": "json_object"}
}
try:
with httpx.Client(timeout=self.timeout) as client:
r = client.post(url, headers=headers, json=payload)
r.raise_for_status()
data = r.json()
return (data["choices"][0]["message"]["content"] or "").strip()
except Exception as e:
raise LLMError(f"OpenAI-compat generate failed: {e}")
class HuggingFaceLLM(BaseLLM):
"""HuggingFace LLM using router.huggingface.co (OpenAI-compatible API format)"""
def __init__(self, cfg: Settings):
self.model = cfg.hf_model
self.api_key = cfg.hf_api_key
self.timeout = 120
@retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=0.5, min=0.5, max=2))
def generate_json(self, system_prompt: str, user_prompt: str) -> str:
# HuggingFace router with OpenAI-compatible format (hosted on HuggingFace)
url = "https://router.huggingface.co/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
print(f"[DEBUG] HuggingFace request to model: {self.model}")
print(f"[DEBUG] API key present: {bool(self.api_key and self.api_key != 'your_huggingface_api_key_here')}")
# OpenAI-compatible chat format (works with HuggingFace models)
payload = {
"model": self.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
"max_tokens": 800,
"temperature": 0.4
}
try:
with httpx.Client(timeout=self.timeout) as client:
r = client.post(url, headers=headers, json=payload)
print(f"[DEBUG] HuggingFace response status: {r.status_code}")
if r.status_code != 200:
print(f"[DEBUG] HuggingFace error response: {r.text}")
r.raise_for_status()
# OpenAI-compatible response format
data = r.json()
if "choices" in data and len(data["choices"]) > 0:
return data["choices"][0]["message"]["content"].strip()
return ""
except Exception as e:
print(f"[DEBUG] HuggingFace exception: {type(e).__name__}: {e}")
raise LLMError(f"HuggingFace generate failed: {e}")
def get_llm(cfg: Settings) -> BaseLLM:
provider = (cfg.llm_provider or "").lower().strip()
if provider == "ollama":
return OllamaLLM(cfg)
if provider == "openai_compat":
return OpenAICompatLLM(cfg)
if provider == "huggingface":
return HuggingFaceLLM(cfg)
raise ValueError(f"Unsupported LLM_PROVIDER: {provider}")
# --- 5. Generation Logic ---
_JSON_RE = re.compile(r"\{.*\}", re.DOTALL)
def _safe_extract_json(text: str) -> dict | None:
if not text:
return None
text = text.strip()
try:
return json.loads(text)
except Exception:
pass
m = _JSON_RE.search(text)
if m:
try:
return json.loads(m.group(0))
except Exception:
return None
return None
def generate_annotations(
selected_text: str,
context_text: str | None = None,
section_title: str | None = None,
page_start: int | None = None,
page_end: int | None = None,
config: Settings | None = None
) -> List[Dict[str, Any]]:
"""
Main entrypoint: Generate questions for selected text using LLM only.
Returns empty list if generation fails.
"""
cfg = config or settings
# 1. Setup
llm = get_llm(cfg)
user_prompt = build_question_prompt(selected_text, context_text, section_title, page_start, page_end)
# 2. Generate
questions = []
try:
raw = llm.generate_json(SYSTEM_PROMPT, user_prompt)
parsed = _safe_extract_json(raw)
if parsed and isinstance(parsed, dict) and isinstance(parsed.get("questions"), list):
for q in parsed["questions"]:
try:
# Validate using Pydantic
item = GeneratedQuestion(**q).model_dump()
questions.append(item)
except Exception:
continue
# Limit to max questions
questions = questions[:cfg.max_output_questions]
except Exception as e:
print(f"LLM Generation failed: {e}")
# In 'only llm' mode, we do not fallback. We return empty or raise.
# Returning empty list to be safe.
return []
return questions
# --- 6. CLI Test ---
if __name__ == "__main__":
# Example usage
sample_text = "BRCA1 mutations significantly increase the risk of developing breast cancer."
sample_context = "In this study of 500 patients, we observed that BRCA1 mutations significantly increase the risk of developing breast cancer compared to controls."
print("Generating annotations...")
results = generate_annotations(
selected_text=sample_text,
context_text=sample_context,
section_title="Abstract"
)
print(json.dumps(results, indent=2))