File size: 3,828 Bytes
941240b
 
31ce18a
772864e
559a9a0
71b6f6e
 
0c9a8f6
c8b76d0
772864e
c8b76d0
772864e
 
0a6a1d9
 
 
 
 
 
c8b76d0
71b6f6e
c426283
71b6f6e
a9e1267
31ce18a
 
a9e1267
71b6f6e
0a510f7
71b6f6e
abd8f5a
a9e1267
 
f277648
9087b24
 
 
f277648
a9e1267
c426283
772864e
3f3e3b8
772864e
86d9f08
be9b0cf
86d9f08
be9b0cf
772864e
 
 
86d9f08
a2d4cd4
772864e
86d9f08
 
 
 
772864e
 
 
 
be9b0cf
 
86d9f08
772864e
abd8f5a
52fa7cc
71b6f6e
 
 
772864e
967d8a3
 
772864e
 
fb273e0
ec49bb3
fb273e0
 
b3d2978
fb273e0
 
967d8a3
fb273e0
abd8f5a
c426283
772864e
71b6f6e
772864e
be17c77
e230469
71b6f6e
 
0c9a8f6
 
 
772864e
 
 
 
 
71b6f6e
 
 
 
 
 
 
9b41328
71b6f6e
 
559a9a0
a9e1267
71b6f6e
 
f48a0cf
81f51c9
f48a0cf
 
3bcc980
f48a0cf
 
3bcc980
 
81f51c9
b0c0a48
 
 
772864e
b0c0a48
 
 
772864e
b0c0a48
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
import re
import os
import traceback
from huggingface_hub import login


token = os.getenv("HF_TOKEN")
print("πŸ”‘ HF_TOKEN available?", token is not None)
if token:
    login(token=token)
else:
    print("❌ No HF_TOKEN found in environment")


def build_qa():
    """Builds and returns the RAG QA pipeline (rag_chain style)."""
    print("πŸš€ Starting QA pipeline...")

    # 1. Embeddings
    print("πŸ”Ή Loading embeddings...")
    embeddings = HuggingFaceEmbeddings(
        model_name="sentence-transformers/all-MiniLM-L6-v2"
    )

    # 2. Load vector DB
    print("πŸ”Ή Loading Chroma DB...")
    vectorstore = Chroma(
        persist_directory="db",
        collection_name="rag-docs",
        embedding_function=embeddings,
    )
    print("πŸ“‚ Docs in DB:", vectorstore._collection.count())

    # 3. Load LLM (Phi-3.5-mini-instruct)
    print("πŸ”Ή Loading LLM...") 
    model_id = "microsoft/Phi-3.5-mini-instruct"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",       
        torch_dtype="auto",      
        trust_remote_code=True
    )
    model.config.use_cache = False

    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=80,       # shorter answers
        temperature=0.2,         # deterministic
        do_sample=False,
        repetition_penalty=1.2,
        eos_token_id=tokenizer.eos_token_id,
        return_full_text=False
    )

    llm = HuggingFacePipeline(pipeline=pipe)

    # 4. Retriever
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

    # 5. Prompt
    prompt = PromptTemplate( 
        input_variables=["context", "question"], 
        template="""Answer the question using the context below.
        Respond in ONE short factual sentence only.
        If you don't know, say "I don't know."
        
        Context:
        {context}
        
        Question:
        {question}
        
        Answer:""",
    )

    # 6. Helper
    def format_docs(docs):
        texts = [doc.page_content.strip() for doc in docs if doc.page_content]
        return "\n".join(texts)

    def hf_to_str(x):
        if isinstance(x, list) and "generated_text" in x[0]:
            text = x[0]["generated_text"]
        else:
            text = str(x)
        text = re.sub(r"\s+", " ", text).strip()
        # βœ… Only keep first sentence
        return re.split(r"(?<=[.!?])\s+", text)[0]

    # 7. Chain
    rag_chain = (
        {
            "context": retriever | format_docs,
            "question": RunnablePassthrough(),
        }
        | prompt
        | llm
        | (lambda x: hf_to_str(x))
        | StrOutputParser()
    )

    print("βœ… QA pipeline ready.")
    return rag_chain


# Build once
try:
    qa_pipeline = build_qa()
    print("βœ… qa_pipeline built successfully:", type(qa_pipeline))
except Exception as e:
    qa_pipeline = None
    print("❌ Failed to build QA pipeline")
    print("Error message:", str(e))
    traceback.print_exc()


def get_answer(query: str) -> str:
    """Run a query against the QA pipeline and return the answer text."""
    if qa_pipeline is None:
        return "⚠️ QA pipeline not initialized."
    try:
        result = qa_pipeline.invoke(query)
        return result
    except Exception as e:
        return f"❌ QA run failed: {e}"