chayanbhansali commited on
Commit
8f28ceb
·
verified ·
1 Parent(s): d5877da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from sentence_transformers import SentenceTransformer
4
+ from PyPDF2 import PdfReader
5
+ import numpy as np
6
+ import torch
7
+
8
+ class RAGChatbot:
9
+ def __init__(self, model_name="TheBloke/Mistral-7B-Instruct-v0.1-GPTQ", embedding_model="all-MiniLM-L6-v2"):
10
+ # Initialize tokenizer and model
11
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
13
+
14
+ # Initialize embedding model
15
+ self.embedding_model = SentenceTransformer(embedding_model)
16
+
17
+ # Initialize document storage
18
+ self.documents = []
19
+ self.embeddings = []
20
+
21
+ def extract_text_from_pdf(self, pdf_path):
22
+ reader = PdfReader(pdf_path)
23
+ text = ""
24
+ for page in reader.pages:
25
+ text += page.extract_text() + "\n"
26
+ return text
27
+
28
+ def load_documents(self, file_paths):
29
+ self.documents = []
30
+ self.embeddings = []
31
+
32
+ for file_path in file_paths:
33
+ if file_path.endswith('.pdf'):
34
+ text = self.extract_text_from_pdf(file_path)
35
+ else:
36
+ with open(file_path, 'r', encoding='utf-8') as f:
37
+ text = f.read()
38
+
39
+ # Split text into chunks
40
+ chunks = [text[i:i+500] for i in range(0, len(text), 500)]
41
+ self.documents.extend(chunks)
42
+
43
+ # Generate embeddings
44
+ self.embeddings = self.embedding_model.encode(self.documents)
45
+ return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files"
46
+
47
+ def retrieve_relevant_context(self, query, top_k=3):
48
+ if not self.documents:
49
+ return "No documents loaded"
50
+
51
+ # Generate query embedding
52
+ query_embedding = self.embedding_model.encode([query])[0]
53
+
54
+ # Calculate cosine similarities
55
+ similarities = np.dot(self.embeddings, query_embedding) / (
56
+ np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
57
+ )
58
+
59
+ # Get top k most similar documents
60
+ top_indices = similarities.argsort()[-top_k:][::-1]
61
+ return " ".join([self.documents[i] for i in top_indices])
62
+
63
+ def generate_response(self, query, context):
64
+ # Construct prompt with context
65
+ full_prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
66
+
67
+ # Generate response
68
+ inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device)
69
+ outputs = self.model.generate(**inputs, max_new_tokens=150)
70
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+
72
+ return response.split("Answer:")[-1].strip()
73
+
74
+ def chat(self, query, history):
75
+ try:
76
+ # Retrieve relevant context
77
+ context = self.retrieve_relevant_context(query)
78
+
79
+ # Generate response
80
+ response = self.generate_response(query, context)
81
+
82
+ return response
83
+ except Exception as e:
84
+ return f"An error occurred: {str(e)}"
85
+
86
+ # Create Gradio interface
87
+ def create_interface():
88
+ rag_chatbot = RAGChatbot()
89
+
90
+ with gr.Blocks() as demo:
91
+ gr.Markdown("# RAG Chatbot with Hugging Face Models")
92
+
93
+ with gr.Row():
94
+ file_input = gr.File(label="Upload Documents", file_count="multiple", type="filepath")
95
+ load_btn = gr.Button("Load Documents")
96
+
97
+ status_output = gr.Textbox(label="Load Status")
98
+
99
+ chatbot = gr.Chatbot()
100
+ msg = gr.Textbox(label="Enter your query")
101
+ submit_btn = gr.Button("Send")
102
+ clear_btn = gr.Button("Clear Chat")
103
+
104
+ # Event handlers
105
+ load_btn.click(
106
+ rag_chatbot.load_documents,
107
+ inputs=[file_input],
108
+ outputs=[status_output]
109
+ )
110
+
111
+ submit_btn.click(
112
+ rag_chatbot.chat,
113
+ inputs=[msg, chatbot],
114
+ outputs=[chatbot]
115
+ ).then(
116
+ lambda: gr.Textbox(interactive=True),
117
+ None,
118
+ [msg]
119
+ )
120
+
121
+ msg.submit(
122
+ rag_chatbot.chat,
123
+ inputs=[msg, chatbot],
124
+ outputs=[chatbot]
125
+ ).then(
126
+ lambda: gr.Textbox(interactive=True),
127
+ None,
128
+ [msg]
129
+ )
130
+
131
+ clear_btn.click(lambda: None, None, [chatbot, msg])
132
+
133
+ return demo
134
+
135
+ # Launch the app
136
+ if __name__ == "__main__":
137
+ demo = create_interface()
138
+ demo.launch()