File size: 2,320 Bytes
bebb279
 
 
 
 
e4544a7
bebb279
 
 
 
 
72a12a2
bebb279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4544a7
 
 
 
bebb279
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# app/tools/llm_answer.py
from __future__ import annotations
from typing import Optional, Dict, Any, List
import requests, json

HF_CHAT_URL = "https://router.huggingface.co/v1/chat/completions"

SYSTEM_PROMPT = """You are a BI copilot.
- NEVER invent numbers; only summarize from provided table rows.
- Use 3-letter region codes (NCR, BLR, MUM, HYD, CHN, PUN).
- Write one concise paragraph and up to 2 brief bullets with clear takeaways.
- If you donot get any response then just say that donot invent anything new.
"""

class AnswerLLM:
    def __init__(self, model_id: str, token: Optional[str], temperature: float = 0.2, max_tokens: int = 300, timeout: int = 60):
        self.model_id = model_id
        self.token = token
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.timeout = timeout
        self.enabled = bool(token and model_id)

    def set_token(self, token: Optional[str]) -> None:
        self.token = token
        self.enabled = bool(token and self.model_id)

    def generate(self, question: str, sql: str, columns: List[str], rows: List[list]) -> str:
        if not self.enabled:
            # deterministic fallback
            return f"Rows: {len(rows)} | Columns: {columns[:4]}..."
        # keep rows small in prompt; if big, sample top-N
        preview = rows if len(rows) <= 50 else rows[:50]
        table_json = json.dumps({"columns": columns, "rows": preview}, ensure_ascii=False)
        payload = {
            "model": self.model_id,
            "stream": False,
            "messages": [
                {"role":"system", "content":[{"type":"text","text":SYSTEM_PROMPT}]},
                {"role":"user", "content":[
                    {"type":"text","text": f"Question: {question}\nSQL used:\n{sql}\n\nHere are the rows (JSON):\n{table_json}\n\nAnswer:"}
                ]},
            ],
            "temperature": self.temperature,
            "max_tokens": self.max_tokens,
        }
        headers = {"Authorization": f"Bearer {self.token}",
                    "Accept": "application/json",
                    "Accept-Encoding": "identity",
        }
        r = requests.post(HF_CHAT_URL, headers=headers, json=payload, timeout=self.timeout)
        r.raise_for_status()
        return r.json()["choices"][0]["message"]["content"]