rag-embedder / utils /generator.py
jackenmail's picture
Upload 4 files
e6a70ac verified
# ─────────────────────────────────────────────────────────────
# utils/generator.py
# Calls HF Inference API to generate answers from context
# ─────────────────────────────────────────────────────────────
import requests
import time
import os
class HFGenerator:
"""
Generates answers using a free HF LLM
given a question and retrieved context chunks.
"""
def __init__(self, model: str = None, token: str = None):
self.model = model or os.getenv("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.1")
self.token = token or os.getenv("HF_TOKEN", "")
self.api_url = f"https://api-inference.huggingface.co/models/{self.model}"
self.headers = {"Authorization": f"Bearer {self.token}"}
def _build_prompt(self, question: str, chunks: list) -> str:
context = "\n".join([f"- {c['text']}" for c in chunks])
return f"""Answer the question using only the context below.
If the answer is not in the context, say "I don't have enough information."
Context:
{context}
Question: {question}
Answer:"""
def generate(self, question: str, chunks: list, retries: int = 3) -> str:
"""Generate an answer from question + retrieved chunks."""
prompt = self._build_prompt(question, chunks)
for attempt in range(retries):
try:
response = requests.post(
self.api_url,
headers=self.headers,
json={
"inputs" : prompt,
"parameters": {
"max_new_tokens" : 200,
"temperature" : 0.3,
"return_full_text": False
}
},
timeout=60
)
if response.status_code == 503:
print(f"Model loading... retry {attempt + 1}/{retries}")
time.sleep(20)
continue
if response.status_code == 200:
result = response.json()
if isinstance(result, list):
return result[0].get("generated_text", "").strip()
print(f"Error {response.status_code}: {response.text[:100]}")
return "Error generating answer."
except Exception as e:
print(f"Request failed: {e}")
time.sleep(10)
return "Failed to generate answer after retries."
# ── Quick test ────────────────────────────────────────────────
if __name__ == "__main__":
gen = HFGenerator()
chunks = [{"text": "Refunds are processed within 5 business days."}]
answer = gen.generate("What is the refund policy?", chunks)
print(f"Answer: {answer}")