Sourudra commited on
Commit
424a2e7
·
verified ·
1 Parent(s): 86bbd4a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import DistilBertTokenizer, DistilBertModel
5
+ import faiss
6
+
7
+ # Load the DistilBERT model and tokenizer
8
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
9
+ model = DistilBertModel.from_pretrained("distilbert-base-uncased")
10
+
11
+ # Example documents to simulate a knowledge base
12
+ documents = [
13
+ "Python is a programming language that is widely used in data science and machine learning.",
14
+ "The Eiffel Tower is a famous landmark located in Paris, France.",
15
+ "Generative Adversarial Networks (GANs) are a class of machine learning models used for image generation.",
16
+ "Hugging Face is a company specializing in natural language processing and machine learning."
17
+ ]
18
+
19
+ # Tokenize the documents and create embeddings
20
+ def create_embeddings(documents):
21
+ embeddings = []
22
+ for doc in documents:
23
+ inputs = tokenizer(doc, return_tensors="pt", padding=True, truncation=True, max_length=512)
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().numpy())
27
+ return np.array(embeddings)
28
+
29
+ # Create FAISS index for document retrieval
30
+ def create_faiss_index(embeddings):
31
+ index = faiss.IndexFlatL2(embeddings.shape[1]) # Use L2 distance for retrieval
32
+ index.add(embeddings)
33
+ return index
34
+
35
+ # Create embeddings for the documents and the FAISS index
36
+ document_embeddings = create_embeddings(documents)
37
+ faiss_index = create_faiss_index(document_embeddings)
38
+
39
+ # Function to retrieve the most relevant document based on the question
40
+ def retrieve_document(question):
41
+ # Encode the question into an embedding
42
+ inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True, max_length=512)
43
+ with torch.no_grad():
44
+ outputs = model(**inputs)
45
+ question_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
46
+
47
+ # Search for the most relevant document using FAISS
48
+ D, I = faiss_index.search(np.array([question_embedding]), k=1)
49
+ return documents[I[0][0]] # Return the most relevant document
50
+
51
+ # Function to answer the question using the retrieved document
52
+ def answer_question(question):
53
+ retrieved_doc = retrieve_document(question)
54
+ return f"Retrieved Document: {retrieved_doc}\nAnswer: {retrieved_doc}"
55
+
56
+ # Create a Gradio interface for the chatbot
57
+ interface = gr.Interface(
58
+ fn=answer_question,
59
+ inputs="text",
60
+ outputs="text",
61
+ live=True,
62
+ title="RAG-Based Question Answering with DistilBERT",
63
+ description="Ask a question, and I will retrieve the most relevant document to answer it."
64
+ )
65
+
66
+ # Launch the Gradio app
67
+ interface.launch()