navjotk commited on
Commit
290b82a
·
verified ·
1 Parent(s): 09f7129

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -32
app.py CHANGED
@@ -2,8 +2,6 @@ import os
2
  from pathlib import Path
3
  import gradio as gr
4
 
5
- from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain_community.vectorstores import FAISS
8
  from langchain_huggingface import HuggingFaceEmbeddings
9
  from langchain_core.prompts import PromptTemplate
@@ -12,49 +10,31 @@ from langchain.llms import HuggingFacePipeline
12
  from transformers import pipeline
13
 
14
  # Constants
15
- DATA_PATH = "data/"
16
- DB_FAISS_PATH = "vectorstore/db_faiss"
17
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
18
- MODEL_NAME = "MBZUAI/LaMini-Flan-T5-783M" # Light model for CPU
19
- CHUNK_SIZE = 1500
20
- CHUNK_OVERLAP = 150
21
 
22
- # Step 1: Load PDF documents and split into chunks
23
- def load_documents():
24
- loader = DirectoryLoader(DATA_PATH, glob="*.pdf", loader_cls=PyPDFLoader)
25
- documents = loader.load()
26
- splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
27
- return splitter.split_documents(documents)
28
-
29
- # Step 2: Create vectorstore if not exists
30
- def ensure_vector_store():
31
- if not Path(DB_FAISS_PATH).exists():
32
- print("Creating new vector store...")
33
- documents = load_documents()
34
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
35
- db = FAISS.from_documents(documents, embeddings)
36
- db.save_local(DB_FAISS_PATH)
37
- else:
38
- print("Loading existing vector store...")
39
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
40
  return FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
41
 
42
- # Step 3: Load lightweight LLM using HuggingFace pipeline
43
  def load_llm():
44
  pipe = pipeline("text2text-generation", model=MODEL_NAME)
45
  return HuggingFacePipeline(pipeline=pipe)
46
 
47
- # Step 4: Setup QA chain
48
  def setup_chain():
49
  prompt_template = """
50
  Use the following context to answer the question.
51
  If the answer is not in the context, just say you don't know.
52
-
53
  Context: {context}
54
  Question: {question}
55
  """
56
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
57
- retriever = ensure_vector_store().as_retriever(search_kwargs={"k": 3})
58
  llm = load_llm()
59
  return RetrievalQA.from_chain_type(
60
  llm=llm,
@@ -66,12 +46,17 @@ def setup_chain():
66
 
67
  qa_chain = setup_chain()
68
 
69
- # Step 5: Gradio Interface
70
  def rag_bot(query):
71
  result = qa_chain.invoke({"query": query})
72
  return result["result"]
73
 
74
- demo = gr.Interface(fn=rag_bot, inputs="text", outputs="text",
75
- title="TextileVision",
76
- description="Ask queries related to TextileVision")
 
 
 
 
 
77
  demo.launch()
 
2
  from pathlib import Path
3
  import gradio as gr
4
 
 
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from langchain_core.prompts import PromptTemplate
 
10
  from transformers import pipeline
11
 
12
  # Constants
13
+ DB_FAISS_PATH = "vectorstore/db_faiss" # Pre-generated FAISS directory
 
14
  EMBEDDING_MODEL_NAME = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
15
+ MODEL_NAME = "MBZUAI/LaMini-Flan-T5-783M" # Lightweight CPU-friendly model
 
 
16
 
17
+ # Step 1: Load FAISS vectorstore (already created offline)
18
+ def load_vector_store():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME)
20
  return FAISS.load_local(DB_FAISS_PATH, embeddings, allow_dangerous_deserialization=True)
21
 
22
+ # Step 2: Load lightweight HuggingFace model (no token needed)
23
  def load_llm():
24
  pipe = pipeline("text2text-generation", model=MODEL_NAME)
25
  return HuggingFacePipeline(pipeline=pipe)
26
 
27
+ # Step 3: Setup QA chain
28
  def setup_chain():
29
  prompt_template = """
30
  Use the following context to answer the question.
31
  If the answer is not in the context, just say you don't know.
32
+
33
  Context: {context}
34
  Question: {question}
35
  """
36
  prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
37
+ retriever = load_vector_store().as_retriever(search_kwargs={"k": 3})
38
  llm = load_llm()
39
  return RetrievalQA.from_chain_type(
40
  llm=llm,
 
46
 
47
  qa_chain = setup_chain()
48
 
49
+ # Step 4: Gradio Interface
50
  def rag_bot(query):
51
  result = qa_chain.invoke({"query": query})
52
  return result["result"]
53
 
54
+ # Step 5: Launch Interface
55
+ demo = gr.Interface(
56
+ fn=rag_bot,
57
+ inputs="text",
58
+ outputs="text",
59
+ title="TextileVision: AI Chatbot",
60
+ description="Ask queries about loom speed, yarn mixing, knitting prediction, and textile operations."
61
+ )
62
  demo.launch()