Mihirsingh1101 commited on
Commit
6aef832
·
verified ·
1 Parent(s): 5598fc7

Upload 3 files

Browse files
Files changed (3) hide show
  1. RAG_Solution_FIXED.py +92 -0
  2. app.py +55 -0
  3. requirements.txt +17 -0
RAG_Solution_FIXED.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ from typing import List
4
+
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.embeddings import SentenceTransformerEmbeddings
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
9
+
10
+ DEFAULT_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
11
+ DEFAULT_GEN_MODEL = "google/flan-t5-base"
12
+
13
+ def load_text_files(data_dir: Path, files: List[Path] = None) -> List[str]:
14
+ texts = []
15
+ if files:
16
+ paths = files
17
+ else:
18
+ paths = sorted(Path(data_dir).glob("*.txt"))
19
+ for p in paths:
20
+ if p.exists() and p.is_file():
21
+ texts.append(p.read_text(encoding="utf-8", errors="ignore"))
22
+ if not texts:
23
+ raise FileNotFoundError(f"No text files found in {Path(data_dir).resolve()} or provided via --files")
24
+ return texts
25
+
26
+ def build_chunks(texts: List[str], chunk_size: int = 800, chunk_overlap: int = 120):
27
+ splitter = RecursiveCharacterTextSplitter(
28
+ chunk_size=chunk_size, chunk_overlap=chunk_overlap,
29
+ separators=["\n\n", "\n", " ", ""]
30
+ )
31
+ docs = splitter.create_documents(texts)
32
+ return docs
33
+
34
+ def build_vectorstore(docs):
35
+ embeddings = SentenceTransformerEmbeddings(model_name=DEFAULT_EMBED_MODEL)
36
+ vs = FAISS.from_documents(docs, embeddings)
37
+ return vs
38
+
39
+ def make_generator(model_name: str = DEFAULT_GEN_MODEL, device: int = -1):
40
+ tok = AutoTokenizer.from_pretrained(model_name)
41
+ mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name)
42
+ gen = pipeline("text2text-generation", model=mdl, tokenizer=tok, device=device)
43
+ return gen
44
+
45
+ def format_prompt(question: str, contexts):
46
+ context_block = "\n\n".join([d.page_content for d in contexts])
47
+ return (
48
+ "You are an expert assistant. Use ONLY the context to answer.\n"
49
+ "If the answer can't be found in the context, say you don't know.\n\n"
50
+ f"Context:\n{context_block}\n\n"
51
+ f"Question: {question}\n"
52
+ "Answer:"
53
+ )
54
+
55
+ def answer_question(vs, generator, question: str, k: int = 4, max_new_tokens: int = 256):
56
+ contexts = vs.similarity_search(question, k=k)
57
+ prompt = format_prompt(question, contexts)
58
+ out = generator(prompt, max_new_tokens=max_new_tokens, do_sample=False)
59
+ text = out[0]["generated_text"]
60
+ sources = [d.metadata.get("source", "") for d in contexts]
61
+ return text, contexts, sources
62
+
63
+ def main():
64
+ ap = argparse.ArgumentParser()
65
+ ap.add_argument("--question", required=True, help="Question to ask the RAG system.")
66
+ ap.add_argument("--data_dir", default="./data", help="Folder with .txt files.")
67
+ ap.add_argument("--files", nargs="*", help="Specific files to use (overrides data_dir)")
68
+ ap.add_argument("--k", type=int, default=4)
69
+ ap.add_argument("--chunk_size", type=int, default=800)
70
+ ap.add_argument("--chunk_overlap", type=int, default=120)
71
+ ap.add_argument("--max_new_tokens", type=int, default=256)
72
+ args = ap.parse_args()
73
+
74
+ data_dir = Path(args.data_dir)
75
+ file_paths = [Path(f) for f in args.files] if args.files else None
76
+
77
+ texts = load_text_files(data_dir, file_paths)
78
+ docs = build_chunks(texts, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
79
+ vs = build_vectorstore(docs)
80
+ generator = make_generator()
81
+
82
+ answer, contexts, sources = answer_question(
83
+ vs, generator, args.question, k=args.k, max_new_tokens=args.max_new_tokens
84
+ )
85
+ print("\n=== Answer ===\n", answer.strip())
86
+ print("\n=== Top Sources (chunk previews) ===")
87
+ for i, d in enumerate(contexts, 1):
88
+ preview = d.page_content[:200].replace("\n", " ")
89
+ print(f"[{i}] {preview}...")
90
+
91
+ if __name__ == "__main__":
92
+ main()
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from pathlib import Path
4
+ from waitress import serve
5
+ import RAG_Solution_FIXED as rag # Import your existing RAG code
6
+
7
+ app = Flask(__name__)
8
+ CORS(app)
9
+
10
+ # --- Load your RAG model components ---
11
+ print("--- RAG Model Initializing ---")
12
+
13
+ # Define the data directory
14
+ data_dir = Path("./data")
15
+
16
+ # 1. Load text files
17
+ print("[1/4] Loading text files...")
18
+ texts = rag.load_text_files(data_dir)
19
+ print(" ...Done.")
20
+
21
+ # 2. Build chunks
22
+ print("[2/4] Building text chunks...")
23
+ docs = rag.build_chunks(texts)
24
+ print(" ...Done.")
25
+
26
+ # 3. Build vector store (This can be slow)
27
+ print("[3/4] Building vector store with embeddings. This may take a moment...")
28
+ vs = rag.build_vectorstore(docs)
29
+ print(" ...Done.")
30
+
31
+ # 4. Load the generative model (This is the slowest part)
32
+ print("[4/4] Loading generative model (e.g., Flan-T5). This can be very slow...")
33
+ generator = rag.make_generator()
34
+ print(" ...Done.")
35
+
36
+ print("--- Model loading complete. Server is now starting. ---")
37
+ # ------------------------------------
38
+
39
+ @app.route('/ask', methods=['POST'])
40
+ def ask_question():
41
+ """ This is the API endpoint that will receive questions from the frontend. """
42
+ # ... (rest of your function is the same)
43
+ if not request.json or 'question' not in request.json:
44
+ return jsonify({'error': 'Missing question in request body'}), 400
45
+ question = request.json['question']
46
+ try:
47
+ answer, _, _ = rag.answer_question(vs, generator, question)
48
+ return jsonify({'answer': answer.strip()})
49
+ except Exception as e:
50
+ return jsonify({'error': str(e)}), 500
51
+
52
+ if __name__ == '__main__':
53
+ # This line will only run AFTER all the models above are loaded
54
+ serve(app, host="127.0.0.1", port=8765)
55
+ print("Server has started on http://127.0.0.1:3000 and is ready for requests.")
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core AI and Language Model Libraries
2
+ langchain
3
+ langchain-community
4
+ transformers
5
+ sentence-transformers
6
+ torch
7
+
8
+ # Vector Store Library (CPU version is best for cloud deployment)
9
+ faiss-cpu
10
+
11
+ # Web Server Libraries
12
+ flask
13
+ flask-cors
14
+ waitress
15
+
16
+ # Additional Utilities
17
+ langchain-text-splitters