File size: 2,427 Bytes
97bec02
 
 
 
76ce598
97bec02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
from typing import List, Dict, Tuple
import numpy as np
from openai import OpenAI
from faq_store import FAQ_ENTRIES, FAQ_VECS

RAG_CONFIDENCE_THRESHOLD = 0.6
MAX_FAQ_MATCHES = 3
_EMBED_MODEL = "text-embedding-3-small"
_CHAT_MODEL = "gpt-4o-mini"

SYSTEM_PROMPT = (
    "You are a helpful assistant for ScanAssured, a medical document OCR and NER app. "
    "Answer only based on the provided FAQ context. "
    "You do NOT have access to any user scan results or personal medical data. "
    "For personal medical advice, always direct users to a qualified healthcare professional. "
    "Keep answers concise and clear."
)
FALLBACK_MESSAGE = (
    "I'm not certain about that. Please consult a qualified healthcare professional "
    "for personal medical advice, or refer to the app documentation for usage questions."
)

client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

_query_cache: dict[str, np.ndarray] = {}


def cosine(a: np.ndarray, b: np.ndarray) -> float:
    return float(a.dot(b) / (np.linalg.norm(a) * np.linalg.norm(b)))


async def get_answer(question: str, history: List[Dict]) -> Tuple[str, List[Dict]]:
    if question in _query_cache:
        vec = _query_cache[question]
    else:
        resp = client.embeddings.create(model=_EMBED_MODEL, input=question)
        vec = np.array(resp.data[0].embedding, dtype=np.float32)
        _query_cache[question] = vec

    scores = [(fid, cosine(vec, fvec)) for fid, fvec in FAQ_VECS]
    scores.sort(key=lambda x: x[1], reverse=True)

    if not scores or scores[0][1] < RAG_CONFIDENCE_THRESHOLD:
        return FALLBACK_MESSAGE, []

    matches = []
    for fid, score in scores[:MAX_FAQ_MATCHES]:
        faq = FAQ_ENTRIES[fid]
        matches.append({"id": fid, "answer": faq["answer"], "source": faq["source"], "score": score})

    messages: List[Dict] = [{"role": "system", "content": SYSTEM_PROMPT}]
    for msg in history:
        messages.append({"role": msg["role"], "content": msg["content"]})
    for faq in matches:
        messages.append({"role": "system", "content": faq["answer"]})
    messages.append({"role": "user", "content": question})

    chat_resp = client.chat.completions.create(
        model=_CHAT_MODEL,
        messages=messages,
        stream=False,
    )
    answer = chat_resp.choices[0].message.content

    citations = [{"id": faq["id"], "source": faq["source"]} for faq in matches]
    return answer, citations