Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import json | |
| from huggingface_hub import InferenceClient, HfApi | |
| from rag_query import retrieve | |
| from embed_index import main as build_index | |
| # ---------- LOAD SETTINGS ---------- | |
| def load_settings(): | |
| with open("config/settings.json", "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| SETTINGS = load_settings() | |
| LLM_MODEL = SETTINGS["llm_model"] | |
| FAISS_INDEX_PATH = SETTINGS["faiss_index_path"] | |
| METADATA_PATH = SETTINGS["metadata_path"] | |
| INDEX_FILES = [ | |
| FAISS_INDEX_PATH, | |
| METADATA_PATH | |
| ] | |
| # ---------- LOAD PROMPT ---------- | |
| def load_prompt(): | |
| with open("prompts/rag_prompt.txt", "r", encoding="utf-8") as f: | |
| return f.read() | |
| # ---------- RAG CHAT ---------- | |
| def respond( | |
| message, | |
| history, | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| client = InferenceClient( | |
| model=LLM_MODEL # uses HF Space token implicitly | |
| ) | |
| try: | |
| retrieved = retrieve(message) | |
| except Exception: | |
| retrieved = [] | |
| context_blocks = [] | |
| for item in retrieved: | |
| context_blocks.append( | |
| f"[{item['condition']} β {item['section']}]\n{item}" | |
| ) | |
| context = "\n\n".join(context_blocks) if context_blocks else "No context available." | |
| prompt = load_prompt().format( | |
| context=context, | |
| question=message | |
| ) | |
| messages = [ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| response = "" | |
| for chunk in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| stream=True, | |
| ): | |
| if chunk.choices and chunk.choices[0].delta.content: | |
| response += chunk.choices[0].delta.content | |
| yield response | |
| # ---------- BUILD INDEX ---------- | |
| def rebuild_index_ui(): | |
| build_index() | |
| return "β Index rebuilt successfully." | |
| # ---------- COMMIT TO HF ---------- | |
| def commit_index_ui(): | |
| token = os.environ.get("HF_TOKEN") | |
| repo_id = os.environ.get("SPACE_ID") | |
| if not token: | |
| return "β HF_TOKEN not found in environment." | |
| if not repo_id: | |
| return "β SPACE_ID not found in environment." | |
| api = HfApi(token=token) | |
| for file_path in INDEX_FILES: | |
| if not os.path.exists(file_path): | |
| return f"β Missing file: {file_path}" | |
| api.upload_file( | |
| path_or_fileobj=file_path, | |
| path_in_repo=file_path, | |
| repo_id=repo_id, | |
| repo_type="space", | |
| commit_message="Update FAISS index" | |
| ) | |
| return "β¬ Index committed to Hugging Face successfully." | |
| # ---------- UI ---------- | |
| chatbot = gr.ChatInterface( | |
| respond, | |
| type="messages", | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are a medical education assistant.", | |
| label="System message" | |
| ), | |
| gr.Slider(1, 2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"), | |
| ], | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Sidebar(): | |
| build_btn = gr.Button("π¨ Build Index") | |
| commit_btn = gr.Button("β¬ Commit to HF") | |
| status_box = gr.Markdown() | |
| build_btn.click( | |
| rebuild_index_ui, | |
| outputs=status_box | |
| ) | |
| commit_btn.click( | |
| commit_index_ui, | |
| outputs=status_box | |
| ) | |
| chatbot.render() | |
| if __name__ == "__main__": | |
| demo.launch() | |