File size: 3,661 Bytes
aa8691d
f7ef156
aa8691d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7ef156
296d0b2
 
8bd5b29
296d0b2
aa8691d
f7ef156
 
 
 
 
 
 
 
 
6abcc26
aa8691d
 
6abcc26
aa8691d
f7ef156
 
aa8691d
 
 
 
 
 
 
f7ef156
aa8691d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import os
from pathlib import Path
from sambanova import SambaNova
from langchain_huggingface import HuggingFaceEmbeddings

from chatbot import (
    load_config,
    build_rag_corpus,
    retrieve_relevant_chunks,
    build_prompt,
    ask_model,
    format_answer,
)

CONFIG_PATH = Path(__file__).parent / "config.yaml"
RESOURCE_STATE = {}


def init_resources():
    if RESOURCE_STATE:
        return RESOURCE_STATE

    # Try to load from environment variables first (for Spaces)
    llm_api_key = os.getenv("SAMBANOVA_API_KEY")
    website = os.getenv("WEBSITE")
    embedding_model_name = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
    system_prompt = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant.")

    # Fallback to config.yaml if env vars not set
    if not llm_api_key or not website:
        if CONFIG_PATH.exists():
            config = load_config(CONFIG_PATH)
            llm_api_key = llm_api_key or config.get("sambanova_api_key")
            website = website or config.get("website")
            embedding_model_name = embedding_model_name or config.get("embedding_model", "sentence-transformers/all-MiniLM-L6-v2")
            system_prompt = system_prompt or config.get("system_prompt", "You are a helpful assistant.")
        else:
            raise ValueError("Please set SAMBANOVA_API_KEY and WEBSITE as secrets in your Hugging Face Space settings, or provide config.yaml for local development")

    if not llm_api_key or not website:
        raise ValueError("SAMBANOVA_API_KEY and WEBSITE are required. Set them as secrets in Hugging Face Space settings.")

    embed_model = HuggingFaceEmbeddings(model_name=embedding_model_name)
    corpus = build_rag_corpus({"embedding_model": embedding_model_name}, embed_model, website)
    client = SambaNova(
        api_key=llm_api_key,
        base_url="https://api.sambanova.ai/v1",
        timeout=30,
    )

    RESOURCE_STATE.update(
        config={"embedding_model": embedding_model_name},
        website=website,
        system_prompt=system_prompt,
        embed_model=embed_model,
        corpus=corpus,
        client=client,
    )
    return RESOURCE_STATE


def answer_question(question: str):
    resources = init_resources()
    selected = retrieve_relevant_chunks(
        resources["corpus"],
        question,
        resources["embed_model"],
        top_k=4,
    )
    prompt = build_prompt(resources["system_prompt"], question, selected)
    raw_answer = ask_model(prompt, resources["client"])
    response = format_answer(raw_answer, selected)
    citations = "\n\n".join(
        [f"Chunk {i+1}: {chunk.text[:300]}..." for i, chunk in enumerate(selected)]
    )
    return response, citations


def main():
    resources = init_resources()

    with gr.Blocks(title="RAG Chatbot") as demo:
        gr.Markdown("# 🤖 RAG-Powered Chatbot")
        gr.Markdown(f"**Website:** {resources['website']}  \n**Chunks:** {len(resources['corpus'])}")

        with gr.Row():
            with gr.Column(scale=3):
                question_input = gr.Textbox(label="Ask a question", placeholder="What services do you provide?", lines=2)
                submit_button = gr.Button("Ask")
                answer_output = gr.Textbox(label="Answer", lines=12, interactive=False)

            with gr.Column(scale=1):
                citations_output = gr.Textbox(label="Citations", lines=20, interactive=False)

        submit_button.click(
            answer_question,
            inputs=[question_input],
            outputs=[answer_output, citations_output],
        )

    demo.launch()


if __name__ == "__main__":
    main()