File size: 3,581 Bytes
17adc75
a8b0c6a
694a672
a8b0c6a
694a672
a0f7d2a
a8b0c6a
a0f7d2a
694a672
 
 
 
 
 
 
 
 
 
 
a0f7d2a
a8b0c6a
 
 
 
 
 
694a672
a0f7d2a
 
 
17adc75
 
a8b0c6a
17adc75
 
a0f7d2a
17adc75
 
 
 
 
a0f7d2a
3193e82
a0f7d2a
 
694a672
 
 
 
a0f7d2a
 
 
 
 
 
17adc75
694a672
17adc75
a0f7d2a
 
 
 
17adc75
a0f7d2a
 
 
 
17adc75
 
a0f7d2a
17adc75
 
 
 
a0f7d2a
17adc75
a0f7d2a
 
 
17adc75
 
a8b0c6a
 
 
 
 
 
 
3193e82
 
694a672
3193e82
 
 
a8b0c6a
3193e82
 
 
a8b0c6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17adc75
 
 
 
a0f7d2a
 
 
17adc75
a0f7d2a
 
 
17adc75
 
 
 
 
a8b0c6a
 
 
 
 
 
 
 
 
 
 
 
 
 
17adc75
 
a8b0c6a
17adc75
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import os
import json
from huggingface_hub import InferenceClient, HfApi

from rag_query import retrieve
from embed_index import main as build_index


# ---------- LOAD SETTINGS ----------
def load_settings():
    with open("config/settings.json", "r", encoding="utf-8") as f:
        return json.load(f)

SETTINGS = load_settings()

LLM_MODEL = SETTINGS["llm_model"]
FAISS_INDEX_PATH = SETTINGS["faiss_index_path"]
METADATA_PATH = SETTINGS["metadata_path"]

INDEX_FILES = [
    FAISS_INDEX_PATH,
    METADATA_PATH
]


# ---------- LOAD PROMPT ----------
def load_prompt():
    with open("prompts/rag_prompt.txt", "r", encoding="utf-8") as f:
        return f.read()


# ---------- RAG CHAT ----------
def respond(
    message,
    history,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    client = InferenceClient(
        model=LLM_MODEL  # uses HF Space token implicitly
    )

    try:
        retrieved = retrieve(message)
    except Exception:
        retrieved = []

    context_blocks = []
    for item in retrieved:
        context_blocks.append(
            f"[{item['condition']} – {item['section']}]\n{item}"
        )

    context = "\n\n".join(context_blocks) if context_blocks else "No context available."

    prompt = load_prompt().format(
        context=context,
        question=message
    )

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": prompt}
    ]

    response = ""
    for chunk in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        stream=True,
    ):
        if chunk.choices and chunk.choices[0].delta.content:
            response += chunk.choices[0].delta.content
            yield response


# ---------- BUILD INDEX ----------
def rebuild_index_ui():
    build_index()
    return "βœ… Index rebuilt successfully."


# ---------- COMMIT TO HF ----------
def commit_index_ui():
    token = os.environ.get("HF_TOKEN")
    repo_id = os.environ.get("SPACE_ID")

    if not token:
        return "❌ HF_TOKEN not found in environment."
    if not repo_id:
        return "❌ SPACE_ID not found in environment."

    api = HfApi(token=token)

    for file_path in INDEX_FILES:
        if not os.path.exists(file_path):
            return f"❌ Missing file: {file_path}"

        api.upload_file(
            path_or_fileobj=file_path,
            path_in_repo=file_path,
            repo_id=repo_id,
            repo_type="space",
            commit_message="Update FAISS index"
        )

    return "⬆ Index committed to Hugging Face successfully."


# ---------- UI ----------
chatbot = gr.ChatInterface(
    respond,
    type="messages",
    additional_inputs=[
        gr.Textbox(
            value="You are a medical education assistant.",
            label="System message"
        ),
        gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
    ],
)

with gr.Blocks() as demo:
    with gr.Sidebar():
        build_btn = gr.Button("πŸ”¨ Build Index")
        commit_btn = gr.Button("⬆ Commit to HF")
        status_box = gr.Markdown()

        build_btn.click(
            rebuild_index_ui,
            outputs=status_box
        )

        commit_btn.click(
            commit_index_ui,
            outputs=status_box
        )

    chatbot.render()


if __name__ == "__main__":
    demo.launch()