File size: 4,426 Bytes
9f1266b
b80958f
2a7248b
b80958f
cb2cd52
 
0e448fc
cb2cd52
 
 
af27c5f
bacd419
4a886f9
33b48d5
d37cdaf
af27c5f
b80958f
33b48d5
14d0614
b80958f
304ecf4
 
6ffac07
 
 
304ecf4
6ffac07
304ecf4
6ffac07
 
 
 
 
 
 
 
 
18af55b
 
7bf59df
 
 
 
 
0fea124
7bf59df
 
18af55b
0fea124
304ecf4
3a37689
 
 
 
 
7ff4c08
af27c5f
b80958f
 
 
b3928e7
b80958f
 
 
 
 
 
660ad64
bacd419
2434dc7
0f68754
2434dc7
 
 
 
 
 
0f68754
55d6354
 
bacd419
2434dc7
660ad64
 
b80958f
 
 
 
 
7ff4c08
b80958f
03f48fa
7ff4c08
660ad64
7ff4c08
c5193e0
660ad64
b80958f
 
 
03f48fa
3a37689
 
 
 
 
2a7248b
b80958f
 
 
 
0fea124
18af55b
5135d87
d37cdaf
b932d3d
 
 
 
d37cdaf
b932d3d
d37cdaf
 
 
 
522dea6
987cbb7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import streamlit as st
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.llms import Ollama
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
from huggingface_hub import InferenceClient
import re


HF_TOKEN = os.environ.get("HF_TOKEN")

# ----------------------


# qa_template = """Use the given context to answer the question.
# If you don't know the answer, just say that you don't know, don't try to make up an answer.
# Keep the answer as concise as possible.

# Context: {context}

# Question: {question}
# Answer:
# """


# prompt = PromptTemplate.from_template(qa_template)

prompt = PromptTemplate(
    input_variables=["context", "question"],
    template=(
        "You are a knowledgeable agricultural research assistant.\n"
        "Use the context below to answer the question concisely.\n"
        "Respond ONLY with the final answer inside <answer> and </answer> tags.\n\n"
        "Example:\n"
        "Question: What is photosynthesis?\n"
        "Answer: <answer>Photosynthesis is the process by which plants convert sunlight into energy using chlorophyll, water, and carbon dioxide.</answer>\n\n"
        "Context:\n{context}\n\n"
        "Question: {question}\n"
        "Answer:"
    )
    )

EXAMPLE_QUESTIONS = [
    "What is agriculture?",
    "Why is crop rotation important?",
    "How does composting help farming?",
]

# Initialize embeddings & documents
@st.cache_resource
def load_retriever():
    embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
    db = FAISS.load_local("./vectorstore", embeddings, allow_dangerous_deserialization=True)
    retriever = db.as_retriever()
    return retriever

# Load a lightweight model via HuggingFace pipeline
@st.cache_resource
def load_llm():
    # pipe = pipeline("text-generation", model="google/flan-t5-small", max_new_tokens=256)
    # load the tokenizer and model on cpu/gpu
    quantization_config = BitsAndBytesConfig(load_in_8bit=True,llm_int8_enable_fp32_cpu_offload=True)

    # quantization_config = BitsAndBytesConfig(
    #     load_in_4bit=True,
    #     bnb_4bit_compute_dtype=torch.float16,
    #     bnb_4bit_quant_type="nf4",
    #     bnb_4bit_use_double_quant=True
    # )

    model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    # model_name = "meta-llama/Llama-2-7b-chat-hf"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto",quantization_config=quantization_config, low_cpu_mem_usage=True)
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=256)
    
    return HuggingFacePipeline(pipeline=pipe)

# Setup RAG Chain
@st.cache_resource
def setup_qa():

    retriever = load_retriever()
    llm = load_llm().bind(stop=["</answer>"])
    question_answer_chain = create_stuff_documents_chain(llm,prompt)
    # chain = create_retrieval_chain(retriever, question_answer_chain)

    qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff", return_source_documents=True,chain_type_kwargs={'prompt':prompt})
    return qa_chain


# Streamlit App UI
st.title("🌾 AgriQuery: RAG-Based Research Assistant")
# Show example questions
with st.expander("💡 Try example questions"):
    for q in EXAMPLE_QUESTIONS:
        st.markdown(f"- {q}")
        
query = st.text_input("Ask a question related to agriculture:")

if query:
    qa = setup_qa()
    with st.spinner("Thinking..."):
        result = qa.invoke({"query":query})
        raw = result["result"]
        raw_answer = result["result"]

    matches = re.findall(r"<answer>(.*?)</answer>", raw_answer, re.DOTALL)

    if matches:
        clean_answer = matches[-1].strip()   # last <answer>...</answer> block
    else:
        clean_answer = raw_answer.strip()    # fallback

    st.success(clean_answer)
    # st.success(answer[-1])
    # st.success(answer)
    st.success(f"Source Document(s): {result['source_documents']}")