Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- RAG_Solution_FIXED.py +92 -0
- app.py +55 -0
- requirements.txt +17 -0
RAG_Solution_FIXED.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 6 |
+
from langchain_community.vectorstores import FAISS
|
| 7 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 9 |
+
|
| 10 |
+
DEFAULT_EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
| 11 |
+
DEFAULT_GEN_MODEL = "google/flan-t5-base"
|
| 12 |
+
|
| 13 |
+
def load_text_files(data_dir: Path, files: List[Path] = None) -> List[str]:
|
| 14 |
+
texts = []
|
| 15 |
+
if files:
|
| 16 |
+
paths = files
|
| 17 |
+
else:
|
| 18 |
+
paths = sorted(Path(data_dir).glob("*.txt"))
|
| 19 |
+
for p in paths:
|
| 20 |
+
if p.exists() and p.is_file():
|
| 21 |
+
texts.append(p.read_text(encoding="utf-8", errors="ignore"))
|
| 22 |
+
if not texts:
|
| 23 |
+
raise FileNotFoundError(f"No text files found in {Path(data_dir).resolve()} or provided via --files")
|
| 24 |
+
return texts
|
| 25 |
+
|
| 26 |
+
def build_chunks(texts: List[str], chunk_size: int = 800, chunk_overlap: int = 120):
|
| 27 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 28 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
|
| 29 |
+
separators=["\n\n", "\n", " ", ""]
|
| 30 |
+
)
|
| 31 |
+
docs = splitter.create_documents(texts)
|
| 32 |
+
return docs
|
| 33 |
+
|
| 34 |
+
def build_vectorstore(docs):
|
| 35 |
+
embeddings = SentenceTransformerEmbeddings(model_name=DEFAULT_EMBED_MODEL)
|
| 36 |
+
vs = FAISS.from_documents(docs, embeddings)
|
| 37 |
+
return vs
|
| 38 |
+
|
| 39 |
+
def make_generator(model_name: str = DEFAULT_GEN_MODEL, device: int = -1):
|
| 40 |
+
tok = AutoTokenizer.from_pretrained(model_name)
|
| 41 |
+
mdl = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
| 42 |
+
gen = pipeline("text2text-generation", model=mdl, tokenizer=tok, device=device)
|
| 43 |
+
return gen
|
| 44 |
+
|
| 45 |
+
def format_prompt(question: str, contexts):
|
| 46 |
+
context_block = "\n\n".join([d.page_content for d in contexts])
|
| 47 |
+
return (
|
| 48 |
+
"You are an expert assistant. Use ONLY the context to answer.\n"
|
| 49 |
+
"If the answer can't be found in the context, say you don't know.\n\n"
|
| 50 |
+
f"Context:\n{context_block}\n\n"
|
| 51 |
+
f"Question: {question}\n"
|
| 52 |
+
"Answer:"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def answer_question(vs, generator, question: str, k: int = 4, max_new_tokens: int = 256):
|
| 56 |
+
contexts = vs.similarity_search(question, k=k)
|
| 57 |
+
prompt = format_prompt(question, contexts)
|
| 58 |
+
out = generator(prompt, max_new_tokens=max_new_tokens, do_sample=False)
|
| 59 |
+
text = out[0]["generated_text"]
|
| 60 |
+
sources = [d.metadata.get("source", "") for d in contexts]
|
| 61 |
+
return text, contexts, sources
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
ap = argparse.ArgumentParser()
|
| 65 |
+
ap.add_argument("--question", required=True, help="Question to ask the RAG system.")
|
| 66 |
+
ap.add_argument("--data_dir", default="./data", help="Folder with .txt files.")
|
| 67 |
+
ap.add_argument("--files", nargs="*", help="Specific files to use (overrides data_dir)")
|
| 68 |
+
ap.add_argument("--k", type=int, default=4)
|
| 69 |
+
ap.add_argument("--chunk_size", type=int, default=800)
|
| 70 |
+
ap.add_argument("--chunk_overlap", type=int, default=120)
|
| 71 |
+
ap.add_argument("--max_new_tokens", type=int, default=256)
|
| 72 |
+
args = ap.parse_args()
|
| 73 |
+
|
| 74 |
+
data_dir = Path(args.data_dir)
|
| 75 |
+
file_paths = [Path(f) for f in args.files] if args.files else None
|
| 76 |
+
|
| 77 |
+
texts = load_text_files(data_dir, file_paths)
|
| 78 |
+
docs = build_chunks(texts, chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap)
|
| 79 |
+
vs = build_vectorstore(docs)
|
| 80 |
+
generator = make_generator()
|
| 81 |
+
|
| 82 |
+
answer, contexts, sources = answer_question(
|
| 83 |
+
vs, generator, args.question, k=args.k, max_new_tokens=args.max_new_tokens
|
| 84 |
+
)
|
| 85 |
+
print("\n=== Answer ===\n", answer.strip())
|
| 86 |
+
print("\n=== Top Sources (chunk previews) ===")
|
| 87 |
+
for i, d in enumerate(contexts, 1):
|
| 88 |
+
preview = d.page_content[:200].replace("\n", " ")
|
| 89 |
+
print(f"[{i}] {preview}...")
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
main()
|
app.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, request, jsonify
|
| 2 |
+
from flask_cors import CORS
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from waitress import serve
|
| 5 |
+
import RAG_Solution_FIXED as rag # Import your existing RAG code
|
| 6 |
+
|
| 7 |
+
app = Flask(__name__)
|
| 8 |
+
CORS(app)
|
| 9 |
+
|
| 10 |
+
# --- Load your RAG model components ---
|
| 11 |
+
print("--- RAG Model Initializing ---")
|
| 12 |
+
|
| 13 |
+
# Define the data directory
|
| 14 |
+
data_dir = Path("./data")
|
| 15 |
+
|
| 16 |
+
# 1. Load text files
|
| 17 |
+
print("[1/4] Loading text files...")
|
| 18 |
+
texts = rag.load_text_files(data_dir)
|
| 19 |
+
print(" ...Done.")
|
| 20 |
+
|
| 21 |
+
# 2. Build chunks
|
| 22 |
+
print("[2/4] Building text chunks...")
|
| 23 |
+
docs = rag.build_chunks(texts)
|
| 24 |
+
print(" ...Done.")
|
| 25 |
+
|
| 26 |
+
# 3. Build vector store (This can be slow)
|
| 27 |
+
print("[3/4] Building vector store with embeddings. This may take a moment...")
|
| 28 |
+
vs = rag.build_vectorstore(docs)
|
| 29 |
+
print(" ...Done.")
|
| 30 |
+
|
| 31 |
+
# 4. Load the generative model (This is the slowest part)
|
| 32 |
+
print("[4/4] Loading generative model (e.g., Flan-T5). This can be very slow...")
|
| 33 |
+
generator = rag.make_generator()
|
| 34 |
+
print(" ...Done.")
|
| 35 |
+
|
| 36 |
+
print("--- Model loading complete. Server is now starting. ---")
|
| 37 |
+
# ------------------------------------
|
| 38 |
+
|
| 39 |
+
@app.route('/ask', methods=['POST'])
|
| 40 |
+
def ask_question():
|
| 41 |
+
""" This is the API endpoint that will receive questions from the frontend. """
|
| 42 |
+
# ... (rest of your function is the same)
|
| 43 |
+
if not request.json or 'question' not in request.json:
|
| 44 |
+
return jsonify({'error': 'Missing question in request body'}), 400
|
| 45 |
+
question = request.json['question']
|
| 46 |
+
try:
|
| 47 |
+
answer, _, _ = rag.answer_question(vs, generator, question)
|
| 48 |
+
return jsonify({'answer': answer.strip()})
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return jsonify({'error': str(e)}), 500
|
| 51 |
+
|
| 52 |
+
if __name__ == '__main__':
|
| 53 |
+
# This line will only run AFTER all the models above are loaded
|
| 54 |
+
serve(app, host="127.0.0.1", port=8765)
|
| 55 |
+
print("Server has started on http://127.0.0.1:3000 and is ready for requests.")
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core AI and Language Model Libraries
|
| 2 |
+
langchain
|
| 3 |
+
langchain-community
|
| 4 |
+
transformers
|
| 5 |
+
sentence-transformers
|
| 6 |
+
torch
|
| 7 |
+
|
| 8 |
+
# Vector Store Library (CPU version is best for cloud deployment)
|
| 9 |
+
faiss-cpu
|
| 10 |
+
|
| 11 |
+
# Web Server Libraries
|
| 12 |
+
flask
|
| 13 |
+
flask-cors
|
| 14 |
+
waitress
|
| 15 |
+
|
| 16 |
+
# Additional Utilities
|
| 17 |
+
langchain-text-splitters
|