jahnaviKolli commited on
Commit
4a06e3a
·
verified ·
1 Parent(s): 53d39bd

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from transformers import pipeline
3
+ from sentence_transformers import CrossEncoder
4
+ import json
5
+ import gradio as gr
6
+ import pickle
7
+ import faiss
8
+ import numpy as np
9
+
10
+
11
+ # Step 1: Load saved chunks
12
+ with open("chunks.pkl", "rb") as f:
13
+ chunks = pickle.load(f)
14
+
15
+ # Step 2: Load FAISS index
16
+ index = faiss.read_index("gitlab_index.faiss")
17
+
18
+
19
+ #Loading the embedding model
20
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
21
+
22
+ # Load a generative model
23
+ generator = pipeline("text2text-generation", model="google/flan-t5-base")
24
+
25
+ def generate_answer(context, question):
26
+ prompt = f"""
27
+ You are a helpful chatbot that answers questions for GitLab employees and applicants.
28
+ Use only the provided context. Be concise and do not repeat sentences.
29
+ If the answer is not in the context, respond with "I don't know."
30
+ Context:
31
+ {context}
32
+
33
+ Question: {question}
34
+ Answer:"""
35
+
36
+ response = generator(prompt, max_new_tokens=300, truncation=True)[0]["generated_text"]
37
+
38
+ return response.strip().split("Answer:")[-1].strip()
39
+
40
+
41
+
42
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
43
+
44
+ def rerank_chunks(query, candidate_chunks, top_k=3):
45
+ pairs = [[query, chunk] for chunk in candidate_chunks]
46
+ scores = cross_encoder.predict(pairs)
47
+ scored_chunks = sorted(zip(candidate_chunks, scores), key=lambda x: x[1], reverse=True)
48
+ return [chunk for chunk, _ in scored_chunks[:top_k]]
49
+
50
+
51
+
52
+ def query_knowledge_base(query, top_k=10):
53
+ query_embedding = embedding_model.encode([query])
54
+ query_embedding = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) # Normalize
55
+
56
+ distances, indices = index.search(query_embedding, top_k)
57
+ initial_results = [chunks[i] for i in indices[0]]
58
+
59
+ return rerank_chunks(query, initial_results, top_k=2)
60
+
61
+
62
+ def rag_chatbot(question):
63
+ try:
64
+ # Step 1: Embed the query and search the index
65
+ top_chunks = query_knowledge_base(question)
66
+
67
+ # Step 2: Combine top chunks into a single context string
68
+ context = " ".join(top_chunks)
69
+
70
+ result = generate_answer(context, question)
71
+
72
+ return result
73
+
74
+ except Exception as e:
75
+ return f"An error occurred: {str(e)}"
76
+
77
+
78
+
79
+ # Gradio Interface
80
+
81
+ def chat_interface_fn(message, history):
82
+ response = rag_chatbot(message)
83
+ return response # Returning a response string
84
+
85
+
86
+ gr.ChatInterface(
87
+ fn=chat_interface_fn,
88
+ chatbot=gr.Chatbot(type='messages'),
89
+ title="GitLab All-Remote Hiring Assistant",
90
+ description="Ask me about GitLab All-Remote Hiring!"
91
+ ).launch(share=True, debug=True)