Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from typing import Iterable | |
| import os | |
| from groq import Groq | |
| from .retriever import RetrievedPassage | |
| logger = logging.getLogger(__name__) | |
| class TokenUsage: | |
| def __init__(self, prompt_tokens=0, completion_tokens=0): | |
| self.prompt_tokens = prompt_tokens | |
| self.completion_tokens = completion_tokens | |
| self.total_tokens = prompt_tokens + completion_tokens | |
| # Switch to use Groq API instead of local Models | |
| class BiomedicalAnswerGenerator: | |
| """Generates answers using a biomedical LLM via Groq API.""" | |
| def __init__(self, model_name: str = "llama-3.1-8b-instant") -> None: | |
| self.model_name = model_name | |
| self._is_seq2seq = False | |
| self.client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| logger.info("Loaded Groq API Generator with model: %s", self.model_name) | |
| self.last_usage = TokenUsage() | |
| def generate(self, question: str, passages: Iterable[RetrievedPassage]) -> str: | |
| passage_list = list(passages) | |
| prompt = _format_prompt(question, passage_list, seq2seq=self._is_seq2seq) | |
| return self.generate_direct(prompt, max_tokens=2048) | |
| def generate_direct(self, prompt: str, max_tokens: int = 512, is_json: bool = False) -> str: | |
| """Helper method for QueryProcessor and ClaimDecomposer to generate text.""" | |
| try: | |
| kwargs = { | |
| "messages": [ | |
| { | |
| "role": "user", | |
| "content": prompt, | |
| } | |
| ], | |
| "model": self.model_name, | |
| "temperature": 0.0, | |
| "max_tokens": max_tokens, | |
| } | |
| if is_json: | |
| kwargs["response_format"] = {"type": "json_object"} | |
| response = self.client.chat.completions.create(**kwargs) | |
| if hasattr(response, 'usage') and response.usage: | |
| self.last_usage = TokenUsage( | |
| prompt_tokens=response.usage.prompt_tokens or 0, | |
| completion_tokens=response.usage.completion_tokens or 0, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| logger.error("Error generating with Groq API: %s", e) | |
| return "No answer generated." | |
| def _format_prompt(question: str, passages: list[RetrievedPassage], *, seq2seq: bool = False) -> str: | |
| evidence_block = "\n\n".join( | |
| [f"[E{p.rank}] {p.text}" for p in passages] | |
| ) | |
| return ( | |
| "You are a highly confident and authoritative medical consultant specializing in diabetes. You MUST answer the user's question directly and assertively.\n" | |
| "NEVER refuse to answer. NEVER use standard AI disclaimers like 'I cannot provide medical advice', 'consult a doctor', or 'I am an AI'.\n" | |
| "Your task is to provide exactly 5 to 7 sentences answering the medical question, incorporating concepts from the provided abstracts if possible, but you MUST provide a direct answer to the user's query no matter what.\n\n" | |
| f"Context:\n{evidence_block}\n\n" | |
| f"Question:\n{question}\n\n" | |
| "Answer:\n" | |
| ) | |