Kakarot21 commited on
Commit
0472254
Β·
1 Parent(s): ac4addf

feat: Implement a basic RAG chatbot application using local ChromaDB, HuggingFace embeddings and model, and a Gradio interface.

Browse files
Files changed (4) hide show
  1. app.py +91 -0
  2. data/book.txt +0 -0
  3. data_cutter.py +63 -0
  4. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_huggingface import HuggingFaceEmbeddings
3
+ from langchain_chroma import Chroma
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ import torch
6
+ import os
7
+ from data_cutter import create_db
8
+
9
+ # Constants
10
+ CHROMA_PATH = "chroma_db"
11
+ MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
12
+
13
+ print("πŸš€ Starting app...")
14
+
15
+ # 1. Initialize/Load Database
16
+ print("πŸ”„ Initializing database from data folder...")
17
+ # We rebuild the DB on startup to ensure it matches the current data
18
+ try:
19
+ vectorstore = create_db()
20
+ print("βœ… Database created successfully!")
21
+ except Exception as e:
22
+ print(f"❌ Error creating database: {e}")
23
+ # Fallback: try to load if exists, though create_db should have handled it
24
+ if os.path.exists(CHROMA_PATH):
25
+ print("⚠️ Attempting to load existing database...")
26
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
27
+ vectorstore = Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
28
+ else:
29
+ raise e
30
+
31
+ # 2. Load AI Model
32
+ print(f"πŸ€– Loading AI Model ({MODEL_ID})...")
33
+ try:
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
35
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
36
+
37
+ pipe = pipeline(
38
+ "text-generation",
39
+ model=model,
40
+ tokenizer=tokenizer,
41
+ max_new_tokens=512,
42
+ device=-1, # Run on CPU
43
+ do_sample=True,
44
+ temperature=0.7,
45
+ top_p=0.9,
46
+ )
47
+ print("βœ… AI Model loaded successfully!")
48
+ except Exception as e:
49
+ print(f"❌ Error loading model: {e}")
50
+ raise e
51
+
52
+ def chat_function(message, history):
53
+ print(f"πŸ“¨ Received query: {message}")
54
+
55
+ # Search documents
56
+ results = vectorstore.similarity_search(message, k=3)
57
+ context = "\n\n".join([doc.page_content for doc in results])
58
+
59
+ # Prepare prompt
60
+ messages = [
61
+ {"role": "system", "content": "You are a helpful assistant. Answer the user's question based ONLY on the provided context. If the answer is not in the context, say you don't know."},
62
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {message}"}
63
+ ]
64
+
65
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
66
+
67
+ # Generate response
68
+ outputs = pipe(prompt)
69
+ generated_text = outputs[0]['generated_text']
70
+
71
+ # Extract response
72
+ if "<|im_start|>assistant" in generated_text:
73
+ response_text = generated_text.split("<|im_start|>assistant")[-1].strip()
74
+ elif prompt in generated_text:
75
+ response_text = generated_text.replace(prompt, "").strip()
76
+ else:
77
+ response_text = generated_text
78
+
79
+ return response_text
80
+
81
+ # Create Gradio Interface
82
+ demo = gr.ChatInterface(
83
+ fn=chat_function,
84
+ title="RAG Chat with Your Data",
85
+ description=f"Ask questions about your documents. Powered by {MODEL_ID}.",
86
+ examples=["What is the main topic?", "Summarize the content."],
87
+ type="messages"
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch()
data/book.txt ADDED
The diff for this file is too large to render. See raw diff
 
data_cutter.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.document_loaders import DirectoryLoader, TextLoader
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_huggingface import HuggingFaceEmbeddings
4
+ from langchain_chroma import Chroma
5
+ from dotenv import load_dotenv
6
+ import os
7
+
8
+ # Load environment variables from .env file
9
+ load_dotenv()
10
+
11
+ DATA_PATH = "data"
12
+ CHROMA_PATH = "chroma_db"
13
+
14
+ def load_documents():
15
+ loader = DirectoryLoader(DATA_PATH, glob="*.txt", loader_cls=TextLoader)
16
+ documents = loader.load()
17
+ return documents
18
+
19
+ def create_db():
20
+ documents = load_documents()
21
+
22
+ text_splitter = RecursiveCharacterTextSplitter(
23
+ chunk_size = 1000,
24
+ chunk_overlap = 500,
25
+ length_function= len,
26
+ add_start_index= True,
27
+ )
28
+
29
+ chunks = text_splitter.split_documents(documents)
30
+
31
+ print(f"Loaded {len(documents)} document(s)")
32
+ print(f"Split into {len(chunks)} chunks")
33
+
34
+ # Clear existing database if it exists
35
+ if os.path.exists(CHROMA_PATH):
36
+ print(f"\nClearing existing database at {CHROMA_PATH}...")
37
+ import shutil
38
+ shutil.rmtree(CHROMA_PATH)
39
+
40
+ # Create embeddings and vector store
41
+ print(f"\nCreating ChromaDB vector store with HuggingFace embeddings (all-MiniLM-L6-v2)...")
42
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
43
+
44
+ # Create the vector store from documents
45
+ vectorstore = Chroma.from_documents(
46
+ documents=chunks,
47
+ embedding=embeddings,
48
+ persist_directory=CHROMA_PATH
49
+ )
50
+
51
+ print(f"βœ… Successfully created ChromaDB with {len(chunks)} chunks!")
52
+ print(f"πŸ“ Database saved to: {CHROMA_PATH}")
53
+ return vectorstore
54
+
55
+ if __name__ == "__main__":
56
+ vectorstore = create_db()
57
+
58
+ # Test the vector store with a simple query
59
+ print(f"\nπŸ” Testing vector store with a sample query...")
60
+ test_results = vectorstore.similarity_search("Alice", k=3)
61
+ print(f"Found {len(test_results)} relevant chunks for query 'Alice'")
62
+ print(f"\nFirst result preview:")
63
+ print(f"{test_results[0].page_content[:200]}..." if test_results else "No results")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ python-dotenv
3
+ langchain-huggingface
4
+ langchain-chroma
5
+ langchain-community
6
+ langchain-text-splitters
7
+ transformers
8
+ torch
9
+