File size: 4,051 Bytes
33866cf
 
 
 
 
 
 
 
b55d380
372064c
b5b64fb
2a6e98a
c97a441
5230b46
b4b1755
5230b46
b55d380
4a65690
33866cf
 
 
 
 
 
 
 
4a65690
33866cf
 
 
 
 
 
 
4a65690
33866cf
 
 
 
 
 
 
 
5230b46
 
 
 
 
 
 
 
 
b55d380
4a65690
 
b55d380
5230b46
33866cf
 
 
 
5230b46
33866cf
 
 
 
 
 
 
 
 
 
 
 
 
4a65690
b5b64fb
4a65690
b5b64fb
 
 
 
5230b46
b5b64fb
 
 
 
5230b46
b5b64fb
 
 
 
4a65690
b5b64fb
 
4a65690
 
 
 
 
33866cf
4a65690
 
33866cf
4a65690
5230b46
 
 
 
 
 
 
 
 
 
4a65690
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import streamlit as st
import requests
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
import faiss
from groq import Groq

# Fetch API key from environment variable
API_KEY = os.environ.get('GroqApi')

global CHUNKS, INDEX, MODEL 
# Initialize global variables
CHUNKS = None
INDEX = None
MODEL = None

# Function to scrape tariff data
def scrape_tariff_data(url):
    response = requests.get(url)
    soup = BeautifulSoup(response.text, 'html.parser')
    tariff_data = []
    for paragraph in soup.find_all('p'):
        tariff_data.append(paragraph.text.strip())
    return "\n".join(tariff_data)

# Function to chunk text into manageable sizes
def chunk_text(text, max_length=512):
    words = text.split()
    chunks = []
    for i in range(0, len(words), max_length):
        chunks.append(" ".join(words[i:i+max_length]))
    return chunks

# Function to create embeddings and FAISS index
def create_faiss_index(chunks, model_name='all-MiniLM-L6-v2'):
    model = SentenceTransformer(model_name)
    embeddings = model.encode(chunks)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)
    return index, embeddings, model

# Function to search FAISS for relevant chunks
def search_faiss(query, index, chunks, model, top_k=5):
    query_embedding = model.encode([query])
    distances, indices = index.search(query_embedding, top_k)
    relevant_chunks = [chunks[i] for i in indices[0] if i < len(chunks)]
    return relevant_chunks

# Function to query the Groq API with augmented query
def query_llm(prompt, context):
    if not API_KEY:
        return "Error: GROQ_API_KEY is not set in environment variables."
    
    client = Groq(api_key=API_KEY)
    augmented_prompt = f"Based on the following data:\n\n{context}\n\nAnswer the question: {prompt}"
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "user",
                "content": augmented_prompt,
            }
        ],
        model="llama3-8b-8192",
    )
    return chat_completion.choices[0].message.content

# Streamlit UI
st.title("RAG-Based Tariff Data Application")

url = st.text_input("Enter Tariff Data URL", "https://iesco.com.pk/index.php/customer-services/tariff-guide")

if st.button("Process Tariff Data"):
    with st.spinner("Extracting and processing data..."):
        try:
            #global CHUNKS, INDEX, MODEL  # Declare globals before modifying
            text = scrape_tariff_data(url)
            if not text:
                st.error("Failed to scrape data from the provided URL.")
                st.stop()

            CHUNKS = chunk_text(text)
            if not CHUNKS:
                st.error("No data available for processing.")
                st.stop()

            INDEX, embeddings, MODEL = create_faiss_index(CHUNKS)
            if not INDEX:
                st.error("Failed to create FAISS index.")
                st.stop()

            st.success("Data processed and indexed!")
            st.write("Number of chunks processed:", len(CHUNKS))

        except Exception as e:
            st.error(f"Error processing data: {e}")

st.header("Query the Tariff Data")
prompt = st.text_input("Enter your query")

if st.button("Get Answer"):
    if prompt:
        with st.spinner("Fetching response..."):
            try:
                if not (INDEX and CHUNKS and MODEL):
                    st.error("Data has not been processed yet. Please process the data first.")
                else:
                    # Retrieve relevant chunks
                    relevant_chunks = search_faiss(prompt, INDEX, CHUNKS, MODEL)
                    context = "\n".join(relevant_chunks)

                    # Query the LLM with context
                    response = query_llm(prompt, context)
                    st.write(response)
            except Exception as e:
                st.error(f"Error querying the model: {e}")
    else:
        st.warning("Please enter a query to continue.")