MatchingTool / app.py
wahab5763's picture
Update app.py
fc3a2eb verified
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import streamlit as st
# Load models
@st.cache_resource
def load_embedding_model():
return SentenceTransformer('all-MiniLM-L6-v2') # Using Sentence-BERT
@st.cache_resource
def load_text_generator():
return pipeline("text2text-generation", model="t5-small")
@st.cache_resource
def load_rephrasing_model():
return pipeline("text2text-generation", model="t5-base")
embedding_model = load_embedding_model()
text_generator = load_text_generator()
rephrasing_model = load_rephrasing_model()
# Preprocess and embed data
def preprocess_and_embed(data):
text_data = data.astype(str).apply(" ".join, axis=1) # Concatenate rows
embeddings = embedding_model.encode(text_data.tolist(), convert_to_numpy=True)
return text_data, embeddings
# Create FAISS index
def create_faiss_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
# Retrieve recommendations
def retrieve_recommendations(query, text_data, index, embeddings, k=5):
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
distances, indices = index.search(query_embedding, k)
results = [{"text": text_data[i], "distance": d} for i, d in zip(indices[0], distances[0])]
return results
# Generate contextual response based on query and context
def generate_response(query, recommendations):
context = " ".join([rec["text"] for rec in recommendations])
limited_context = f"Context: {context}\nQuery: {query}\nAnswer:"
response = text_generator(limited_context, max_length=100, num_beams=3)
return response[0]['generated_text']
# Rephrase the generated response
def rephrase_response(response):
rephrased = rephrasing_model(f"Rephrase this: {response}", max_length=100, num_beams=3)
return rephrased[0]['generated_text']
# Streamlit UI
st.title("RAG-Based AI Recommendation System (Using SBERT)")
uploaded_file = st.file_uploader("Upload CSV/Excel file", type=["csv", "xlsx"])
if uploaded_file:
file_type = uploaded_file.name.split(".")[-1]
if file_type == "csv":
data = pd.read_csv(uploaded_file)
else:
data = pd.read_excel(uploaded_file)
st.write("Data Preview:")
st.write(data.head())
# Preprocess and embed
with st.spinner("Processing data and creating embeddings..."):
text_data, embeddings = preprocess_and_embed(data)
index = create_faiss_index(embeddings)
st.success("Data processed and embeddings created!")
# Query input
query = st.text_input("Enter your query:")
if query:
# Retrieve recommendations
with st.spinner("Retrieving recommendations..."):
recommendations = retrieve_recommendations(query, text_data, index, embeddings)
st.write("Top Recommendations:")
for rec in recommendations:
st.write(f"- {rec['text']} (Distance: {rec['distance']:.2f})")
# Generate contextual response
with st.spinner("Generating contextual recommendation..."):
refined_response = generate_response(query, recommendations)
# Rephrase the response
with st.spinner("Rephrasing the recommendation for better grammar and clarity..."):
final_response = rephrase_response(refined_response)
st.write("Contextual Recommendation:")
st.write(final_response)