navjotk commited on
Commit
ade6dde
·
verified ·
1 Parent(s): 321d947

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ from langchain.vectorstores import FAISS
3
+ from langchain.embeddings import HuggingFaceEmbeddings
4
+ from langchain.document_loaders import TextLoader
5
+ from langchain.text_splitter import CharacterTextSplitter
6
+ from langchain.chains import RetrievalQA
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain.llms import HuggingFacePipeline
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
+ import torch
11
+ import os
12
+
13
+ # === Step 1: Build or load Vectorstore ===
14
+
15
+ VECTORSTORE_DIR = "vectorstore"
16
+ DATA_PATH = "data/textile_notes.txt" # Your textile documents path
17
+
18
+ def build_vectorstore():
19
+ loader = TextLoader(DATA_PATH)
20
+ documents = loader.load()
21
+ splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
22
+ chunks = splitter.split_documents(documents)
23
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
24
+
25
+ db = FAISS.from_documents(chunks, embeddings)
26
+ db.save_local(VECTORSTORE_DIR)
27
+ return db
28
+
29
+ def load_vectorstore():
30
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
31
+ return FAISS.load_local(VECTORSTORE_DIR, embeddings)
32
+
33
+ # === Step 2: Load LLM and create QA chain ===
34
+
35
+ def load_qa_chain():
36
+ # Load local vectorstore if exists else build it
37
+ if os.path.exists(VECTORSTORE_DIR):
38
+ vectorstore = load_vectorstore()
39
+ else:
40
+ vectorstore = build_vectorstore()
41
+
42
+ # Load open-source Mistral 7B Instruct model (small & free-ish)
43
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ "mistralai/Mistral-7B-Instruct-v0.1",
46
+ torch_dtype=torch.float16,
47
+ device_map="auto"
48
+ )
49
+
50
+ pipe = pipeline(
51
+ "text-generation",
52
+ model=model,
53
+ tokenizer=tokenizer,
54
+ max_new_tokens=512,
55
+ do_sample=True,
56
+ temperature=0.7,
57
+ )
58
+ llm = HuggingFacePipeline(pipeline=pipe)
59
+
60
+ # Prompt template for friendly, user-focused answers
61
+ prompt_template = """
62
+ Answer the question using ONLY the context below.
63
+ Be clear, helpful, and friendly.
64
+
65
+ Context:
66
+ {context}
67
+
68
+ Question:
69
+ {question}
70
+ """
71
+
72
+ prompt = PromptTemplate(
73
+ template=prompt_template,
74
+ input_variables=["context", "question"]
75
+ )
76
+
77
+ qa_chain = RetrievalQA.from_chain_type(
78
+ llm=llm,
79
+ chain_type="stuff",
80
+ retriever=vectorstore.as_retriever(search_kwargs={"k": 4}),
81
+ chain_type_kwargs={"prompt": prompt},
82
+ return_source_documents=True,
83
+ )
84
+
85
+ return qa_chain
86
+
87
+ # === Chainlit event handlers ===
88
+
89
+ @cl.on_chat_start
90
+ async def on_chat_start():
91
+ qa = load_qa_chain()
92
+ cl.user_session.set("qa_chain", qa)
93
+ await cl.Message("👋 Hi! Ask me anything about textile — I'll answer using our custom documents.").send()
94
+
95
+ @cl.on_message
96
+ async def on_message(message: cl.Message):
97
+ qa = cl.user_session.get("qa_chain")
98
+ answer = qa.run(message.content)
99
+ await cl.Message(answer).send()