File size: 3,154 Bytes
e6a70ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# ─────────────────────────────────────────────────────────────
# utils/generator.py
# Calls HF Inference API to generate answers from context
# ─────────────────────────────────────────────────────────────

import requests
import time
import os


class HFGenerator:
    """
    Generates answers using a free HF LLM
    given a question and retrieved context chunks.
    """

    def __init__(self, model: str = None, token: str = None):
        self.model   = model or os.getenv("LLM_MODEL", "mistralai/Mistral-7B-Instruct-v0.1")
        self.token   = token or os.getenv("HF_TOKEN", "")
        self.api_url = f"https://api-inference.huggingface.co/models/{self.model}"
        self.headers = {"Authorization": f"Bearer {self.token}"}

    def _build_prompt(self, question: str, chunks: list) -> str:
        context = "\n".join([f"- {c['text']}" for c in chunks])
        return f"""Answer the question using only the context below.
If the answer is not in the context, say "I don't have enough information."

Context:
{context}

Question: {question}
Answer:"""

    def generate(self, question: str, chunks: list, retries: int = 3) -> str:
        """Generate an answer from question + retrieved chunks."""
        prompt = self._build_prompt(question, chunks)

        for attempt in range(retries):
            try:
                response = requests.post(
                    self.api_url,
                    headers=self.headers,
                    json={
                        "inputs"    : prompt,
                        "parameters": {
                            "max_new_tokens"  : 200,
                            "temperature"     : 0.3,
                            "return_full_text": False
                        }
                    },
                    timeout=60
                )

                if response.status_code == 503:
                    print(f"Model loading... retry {attempt + 1}/{retries}")
                    time.sleep(20)
                    continue

                if response.status_code == 200:
                    result = response.json()
                    if isinstance(result, list):
                        return result[0].get("generated_text", "").strip()

                print(f"Error {response.status_code}: {response.text[:100]}")
                return "Error generating answer."

            except Exception as e:
                print(f"Request failed: {e}")
                time.sleep(10)

        return "Failed to generate answer after retries."


# ── Quick test ────────────────────────────────────────────────
if __name__ == "__main__":
    gen    = HFGenerator()
    chunks = [{"text": "Refunds are processed within 5 business days."}]
    answer = gen.generate("What is the refund policy?", chunks)
    print(f"Answer: {answer}")