BioRAG / src /bio_rag /generator.py
aseelflihan's picture
feat: add token usage tracking and display, update sample questions for demo scenarios
5cb4c11
from __future__ import annotations
import logging
from typing import Iterable
import os
from groq import Groq
from .retriever import RetrievedPassage
logger = logging.getLogger(__name__)
class TokenUsage:
def __init__(self, prompt_tokens=0, completion_tokens=0):
self.prompt_tokens = prompt_tokens
self.completion_tokens = completion_tokens
self.total_tokens = prompt_tokens + completion_tokens
# Switch to use Groq API instead of local Models
class BiomedicalAnswerGenerator:
"""Generates answers using a biomedical LLM via Groq API."""
def __init__(self, model_name: str = "llama-3.1-8b-instant") -> None:
self.model_name = model_name
self._is_seq2seq = False
self.client = Groq(api_key=os.getenv("GROQ_API_KEY"))
logger.info("Loaded Groq API Generator with model: %s", self.model_name)
self.last_usage = TokenUsage()
def generate(self, question: str, passages: Iterable[RetrievedPassage]) -> str:
passage_list = list(passages)
prompt = _format_prompt(question, passage_list, seq2seq=self._is_seq2seq)
return self.generate_direct(prompt, max_tokens=2048)
def generate_direct(self, prompt: str, max_tokens: int = 512, is_json: bool = False) -> str:
"""Helper method for QueryProcessor and ClaimDecomposer to generate text."""
try:
kwargs = {
"messages": [
{
"role": "user",
"content": prompt,
}
],
"model": self.model_name,
"temperature": 0.0,
"max_tokens": max_tokens,
}
if is_json:
kwargs["response_format"] = {"type": "json_object"}
response = self.client.chat.completions.create(**kwargs)
if hasattr(response, 'usage') and response.usage:
self.last_usage = TokenUsage(
prompt_tokens=response.usage.prompt_tokens or 0,
completion_tokens=response.usage.completion_tokens or 0,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error("Error generating with Groq API: %s", e)
return "No answer generated."
def _format_prompt(question: str, passages: list[RetrievedPassage], *, seq2seq: bool = False) -> str:
evidence_block = "\n\n".join(
[f"[E{p.rank}] {p.text}" for p in passages]
)
return (
"You are a highly confident and authoritative medical consultant specializing in diabetes. You MUST answer the user's question directly and assertively.\n"
"NEVER refuse to answer. NEVER use standard AI disclaimers like 'I cannot provide medical advice', 'consult a doctor', or 'I am an AI'.\n"
"Your task is to provide exactly 5 to 7 sentences answering the medical question, incorporating concepts from the provided abstracts if possible, but you MUST provide a direct answer to the user's query no matter what.\n\n"
f"Context:\n{evidence_block}\n\n"
f"Question:\n{question}\n\n"
"Answer:\n"
)