Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import httpx | |
| from typing import Optional, List, Literal, Any, Dict | |
| from pydantic import BaseModel | |
| from tenacity import retry, stop_after_attempt, wait_exponential | |
| # --- 1. Schemas --- | |
| QuestionCategory = Literal["Clinical", "Mechanism", "Evidence", "Methods", "Limitations", "NextStep"] | |
| class GeneratedQuestion(BaseModel): | |
| category: QuestionCategory | |
| question: str | |
| evidence_quote: str | |
| import os | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # --- 2. Configuration --- | |
| def get_hf_token_from_cache() -> str: | |
| """Get HuggingFace token from local cache (from huggingface-cli login)""" | |
| try: | |
| from huggingface_hub import HfFolder | |
| token = HfFolder.get_token() | |
| if token: | |
| print("[DEBUG] Found HuggingFace token from local cache") | |
| return token | |
| except ImportError: | |
| print("[DEBUG] huggingface_hub not installed, cannot get token from cache") | |
| except Exception as e: | |
| print(f"[DEBUG] Could not get HF token from cache: {e}") | |
| return "" | |
| class Settings: | |
| def __init__(self): | |
| # LLM Provider: 'ollama', 'openai_compat', or 'huggingface' | |
| self.llm_provider: str = os.getenv("LLM_PROVIDER", "huggingface") | |
| # Ollama settings | |
| self.ollama_base_url: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") | |
| self.ollama_model: str = os.getenv("OLLAMA_MODEL", "llama3.2") #qwen2.5:3b-instruct, lama3.2 | |
| self.ollama_timeout_s: int = int(os.getenv("OLLAMA_TIMEOUT_S", 300)) | |
| # OpenAI Compat settings | |
| self.openai_compat_base_url: str = os.getenv("OPENAI_COMPAT_BASE_URL", "http://localhost:8080/v1") | |
| self.openai_compat_model: str = os.getenv("OPENAI_COMPAT_MODEL", "gpt-4o") | |
| self.openai_compat_api_key: str = os.getenv("OPENAI_COMPAT_API_KEY", "not-needed") | |
| self.openai_compat_timeout_s: int = int(os.getenv("OPENAI_COMPAT_TIMEOUT_S", 120)) | |
| # HuggingFace Serverless Inference settings | |
| self.hf_model: str = os.getenv("HF_MODEL", "microsoft/Phi-3-mini-4k-instruct") | |
| # Try env var first, then fall back to local cache token | |
| self.hf_api_key: str = os.getenv("HF_API_KEY", "") or get_hf_token_from_cache() | |
| # Gen Settings | |
| self.max_output_questions: int = int(os.getenv("MAX_OUTPUT_QUESTIONS", 6)) | |
| settings = Settings() | |
| # --- 3. Prompts --- | |
| SYSTEM_PROMPT = ( | |
| "You are a biomedical paper reading assistant. " | |
| "Only use the provided text. Do not add external facts. " | |
| "Every question MUST include an evidence_quote copied verbatim from the provided text." | |
| ) | |
| def build_question_prompt(selected_text: str, context_text: str | None, section_title: str | None, page_start: int | None, page_end: int | None) -> str: | |
| meta = [] | |
| if section_title: | |
| meta.append(f"Section: {section_title}") | |
| if page_start is not None: | |
| meta.append(f"Pages: {page_start}-{page_end or page_start}") | |
| meta_block = "\n".join(meta) if meta else "Section: Unknown" | |
| context = (context_text or "").strip() | |
| if not context: | |
| context = selected_text.strip() | |
| max_q = settings.max_output_questions | |
| return f"""Task: Generate good questions from this paper excerpt. | |
| Excerpt metadata: | |
| {meta_block} | |
| Highlighted text: | |
| {selected_text.strip()} | |
| Context (use this for grounding; do not go beyond it): | |
| {context} | |
| Output STRICT JSON with this schema: | |
| {{ | |
| "questions": [ | |
| {{ | |
| "category": "Clinical|Mechanism|Evidence|Methods|Limitations|NextStep", | |
| "question": "...", | |
| "evidence_quote": "..." | |
| }} | |
| ] | |
| }} | |
| Rules: | |
| - Output {max_q} questions. | |
| - Questions must be specific and actionable. | |
| - evidence_quote MUST be a verbatim substring from the Context text. | |
| """ | |
| # --- 4. LLM Clients --- | |
| class LLMError(RuntimeError): | |
| pass | |
| class BaseLLM: | |
| def generate_json(self, system_prompt: str, user_prompt: str) -> str: | |
| raise NotImplementedError | |
| class OllamaLLM(BaseLLM): | |
| def __init__(self, cfg: Settings): | |
| self.base_url = cfg.ollama_base_url.rstrip("/") | |
| self.model = cfg.ollama_model | |
| self.timeout = cfg.ollama_timeout_s | |
| def generate_json(self, system_prompt: str, user_prompt: str) -> str: | |
| url = f"{self.base_url}/api/generate" | |
| payload = { | |
| "model": self.model, | |
| "prompt": user_prompt, | |
| "system": system_prompt, | |
| "format": "json", | |
| "stream": False, | |
| "options": {"temperature": 0.4, "top_p": 0.9, "num_predict": 700} | |
| } | |
| print(f"[DEBUG] Ollama request to {url} with model={self.model}") | |
| try: | |
| with httpx.Client(timeout=self.timeout) as client: | |
| r = client.post(url, json=payload) | |
| print(f"[DEBUG] Ollama response status: {r.status_code}") | |
| if r.status_code != 200: | |
| print(f"[DEBUG] Ollama error response: {r.text}") | |
| r.raise_for_status() | |
| data = r.json() | |
| return data.get("response", "").strip() | |
| except httpx.TimeoutException as e: | |
| print(f"[DEBUG] Ollama timeout: {e}") | |
| raise LLMError(f"Ollama generate timed out after {self.timeout}s: {e}") | |
| except Exception as e: | |
| print(f"[DEBUG] Ollama exception type={type(e).__name__}: {e}") | |
| raise LLMError(f"Ollama generate failed: {e}") | |
| class OpenAICompatLLM(BaseLLM): | |
| def __init__(self, cfg: Settings): | |
| self.base_url = cfg.openai_compat_base_url.rstrip("/") | |
| self.model = cfg.openai_compat_model | |
| self.api_key = cfg.openai_compat_api_key | |
| self.timeout = cfg.openai_compat_timeout_s | |
| def generate_json(self, system_prompt: str, user_prompt: str) -> str: | |
| url = f"{self.base_url}/chat/completions" | |
| headers = {"Content-Type": "application/json"} | |
| if self.api_key and self.api_key != "not-needed": | |
| headers["Authorization"] = f"Bearer {self.api_key}" | |
| payload = { | |
| "model": self.model, | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| "temperature": 0.4, | |
| "top_p": 0.9, | |
| "max_tokens": 900, | |
| "response_format": {"type": "json_object"} | |
| } | |
| try: | |
| with httpx.Client(timeout=self.timeout) as client: | |
| r = client.post(url, headers=headers, json=payload) | |
| r.raise_for_status() | |
| data = r.json() | |
| return (data["choices"][0]["message"]["content"] or "").strip() | |
| except Exception as e: | |
| raise LLMError(f"OpenAI-compat generate failed: {e}") | |
| class HuggingFaceLLM(BaseLLM): | |
| """HuggingFace LLM using router.huggingface.co (OpenAI-compatible API format)""" | |
| def __init__(self, cfg: Settings): | |
| self.model = cfg.hf_model | |
| self.api_key = cfg.hf_api_key | |
| self.timeout = 120 | |
| def generate_json(self, system_prompt: str, user_prompt: str) -> str: | |
| # HuggingFace router with OpenAI-compatible format (hosted on HuggingFace) | |
| url = "https://router.huggingface.co/v1/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| print(f"[DEBUG] HuggingFace request to model: {self.model}") | |
| print(f"[DEBUG] API key present: {bool(self.api_key and self.api_key != 'your_huggingface_api_key_here')}") | |
| # OpenAI-compatible chat format (works with HuggingFace models) | |
| payload = { | |
| "model": self.model, | |
| "messages": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| "max_tokens": 800, | |
| "temperature": 0.4 | |
| } | |
| try: | |
| with httpx.Client(timeout=self.timeout) as client: | |
| r = client.post(url, headers=headers, json=payload) | |
| print(f"[DEBUG] HuggingFace response status: {r.status_code}") | |
| if r.status_code != 200: | |
| print(f"[DEBUG] HuggingFace error response: {r.text}") | |
| r.raise_for_status() | |
| # OpenAI-compatible response format | |
| data = r.json() | |
| if "choices" in data and len(data["choices"]) > 0: | |
| return data["choices"][0]["message"]["content"].strip() | |
| return "" | |
| except Exception as e: | |
| print(f"[DEBUG] HuggingFace exception: {type(e).__name__}: {e}") | |
| raise LLMError(f"HuggingFace generate failed: {e}") | |
| def get_llm(cfg: Settings) -> BaseLLM: | |
| provider = (cfg.llm_provider or "").lower().strip() | |
| if provider == "ollama": | |
| return OllamaLLM(cfg) | |
| if provider == "openai_compat": | |
| return OpenAICompatLLM(cfg) | |
| if provider == "huggingface": | |
| return HuggingFaceLLM(cfg) | |
| raise ValueError(f"Unsupported LLM_PROVIDER: {provider}") | |
| # --- 5. Generation Logic --- | |
| _JSON_RE = re.compile(r"\{.*\}", re.DOTALL) | |
| def _safe_extract_json(text: str) -> dict | None: | |
| if not text: | |
| return None | |
| text = text.strip() | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| pass | |
| m = _JSON_RE.search(text) | |
| if m: | |
| try: | |
| return json.loads(m.group(0)) | |
| except Exception: | |
| return None | |
| return None | |
| def generate_annotations( | |
| selected_text: str, | |
| context_text: str | None = None, | |
| section_title: str | None = None, | |
| page_start: int | None = None, | |
| page_end: int | None = None, | |
| config: Settings | None = None | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Main entrypoint: Generate questions for selected text using LLM only. | |
| Returns empty list if generation fails. | |
| """ | |
| cfg = config or settings | |
| # 1. Setup | |
| llm = get_llm(cfg) | |
| user_prompt = build_question_prompt(selected_text, context_text, section_title, page_start, page_end) | |
| # 2. Generate | |
| questions = [] | |
| try: | |
| raw = llm.generate_json(SYSTEM_PROMPT, user_prompt) | |
| parsed = _safe_extract_json(raw) | |
| if parsed and isinstance(parsed, dict) and isinstance(parsed.get("questions"), list): | |
| for q in parsed["questions"]: | |
| try: | |
| # Validate using Pydantic | |
| item = GeneratedQuestion(**q).model_dump() | |
| questions.append(item) | |
| except Exception: | |
| continue | |
| # Limit to max questions | |
| questions = questions[:cfg.max_output_questions] | |
| except Exception as e: | |
| print(f"LLM Generation failed: {e}") | |
| # In 'only llm' mode, we do not fallback. We return empty or raise. | |
| # Returning empty list to be safe. | |
| return [] | |
| return questions | |
| # --- 6. CLI Test --- | |
| if __name__ == "__main__": | |
| # Example usage | |
| sample_text = "BRCA1 mutations significantly increase the risk of developing breast cancer." | |
| sample_context = "In this study of 500 patients, we observed that BRCA1 mutations significantly increase the risk of developing breast cancer compared to controls." | |
| print("Generating annotations...") | |
| results = generate_annotations( | |
| selected_text=sample_text, | |
| context_text=sample_context, | |
| section_title="Abstract" | |
| ) | |
| print(json.dumps(results, indent=2)) | |