jahnaviKolli's picture
Upload app.py
4a06e3a verified
from sentence_transformers import SentenceTransformer
from transformers import pipeline
from sentence_transformers import CrossEncoder
import json
import gradio as gr
import pickle
import faiss
import numpy as np
# Step 1: Load saved chunks
with open("chunks.pkl", "rb") as f:
chunks = pickle.load(f)
# Step 2: Load FAISS index
index = faiss.read_index("gitlab_index.faiss")
#Loading the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Load a generative model
generator = pipeline("text2text-generation", model="google/flan-t5-base")
def generate_answer(context, question):
prompt = f"""
You are a helpful chatbot that answers questions for GitLab employees and applicants.
Use only the provided context. Be concise and do not repeat sentences.
If the answer is not in the context, respond with "I don't know."
Context:
{context}
Question: {question}
Answer:"""
response = generator(prompt, max_new_tokens=300, truncation=True)[0]["generated_text"]
return response.strip().split("Answer:")[-1].strip()
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank_chunks(query, candidate_chunks, top_k=3):
pairs = [[query, chunk] for chunk in candidate_chunks]
scores = cross_encoder.predict(pairs)
scored_chunks = sorted(zip(candidate_chunks, scores), key=lambda x: x[1], reverse=True)
return [chunk for chunk, _ in scored_chunks[:top_k]]
def query_knowledge_base(query, top_k=10):
query_embedding = embedding_model.encode([query])
query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) # Normalize
distances, indices = index.search(query_embedding, top_k)
initial_results = [chunks[i] for i in indices[0]]
return rerank_chunks(query, initial_results, top_k=2)
def rag_chatbot(question):
try:
# Step 1: Embed the query and search the index
top_chunks = query_knowledge_base(question)
# Step 2: Combine top chunks into a single context string
context = " ".join(top_chunks)
result = generate_answer(context, question)
return result
except Exception as e:
return f"An error occurred: {str(e)}"
# Gradio Interface
def chat_interface_fn(message, history):
response = rag_chatbot(message)
return response # Returning a response string
gr.ChatInterface(
fn=chat_interface_fn,
chatbot=gr.Chatbot(type='messages'),
title="GitLab All-Remote Hiring Assistant",
description="Ask me about GitLab All-Remote Hiring!"
).launch(share=True, debug=True)