Jayanthk2004 commited on
Commit
78e09e0
·
0 Parent(s):

Rag-Implementation

Browse files
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ .venv/
7
+ venv/
8
+ .env
9
+
10
+ # VSCode
11
+ .vscode/
12
+
13
+ # Streamlit
14
+ frontend/.streamlit/
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # RAG Mini Project
2
+
3
+ Project structure and setup instructions.
backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Package marker
backend/chunker.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nltk.tokenize import sent_tokenize
2
+ import nltk
3
+ import re
4
+
5
+ # Download the new punkt_tab tokenizer
6
+ try:
7
+ nltk.download('punkt_tab')
8
+ except:
9
+ # Fallback if punkt_tab is not available
10
+ try:
11
+ nltk.download('punkt')
12
+ except:
13
+ print("Warning: NLTK punkt tokenizer not available")
14
+
15
+ def chunk_text(text, chunk_size=200, overlap=50):
16
+ """
17
+ Chunk text into smaller segments with word-based overlap.
18
+
19
+ Args:
20
+ text (str): Input text to chunk
21
+ chunk_size (int): Maximum words per chunk
22
+ overlap (int): Number of overlapping words between chunks
23
+
24
+ Returns:
25
+ list: List of text chunks
26
+ """
27
+ try:
28
+ # Try using NLTK sentence tokenizer
29
+ sentences = sent_tokenize(text)
30
+ except LookupError:
31
+ # Fallback to regex-based sentence splitting if NLTK fails
32
+ sentences = re.split(r'[.!?]+', text)
33
+ sentences = [s.strip() for s in sentences if s.strip()]
34
+
35
+ chunks = []
36
+ current_chunk = []
37
+ current_length = 0
38
+
39
+ for sentence in sentences:
40
+ words = sentence.split()
41
+ word_count = len(words)
42
+
43
+ # If adding this sentence doesn't exceed chunk size
44
+ if current_length + word_count <= chunk_size:
45
+ current_chunk.append(sentence)
46
+ current_length += word_count
47
+ else:
48
+ # Finalize current chunk
49
+ if current_chunk: # Only add non-empty chunks
50
+ chunks.append(" ".join(current_chunk))
51
+
52
+ # Handle overlap using words, not sentences
53
+ if overlap > 0 and current_chunk:
54
+ # Get last 'overlap' words from the end of current chunk
55
+ overlap_words = " ".join(current_chunk).split()[-overlap:]
56
+ current_chunk = [" ".join(overlap_words)]
57
+ current_length = len(overlap_words)
58
+ else:
59
+ current_chunk = []
60
+ current_length = 0
61
+
62
+ # Start new chunk with the current sentence
63
+ current_chunk.append(sentence)
64
+ current_length += word_count
65
+
66
+ # Add last remaining chunk
67
+ if current_chunk:
68
+ chunks.append(" ".join(current_chunk))
69
+
70
+ return chunks
71
+
72
+ # Alternative chunking function without NLTK dependency
73
+ def chunk_text_simple(text, chunk_size=200, overlap=50):
74
+ """
75
+ Simple text chunking without NLTK dependency.
76
+
77
+ Args:
78
+ text (str): Input text to chunk
79
+ chunk_size (int): Maximum words per chunk
80
+ overlap (int): Number of overlapping words between chunks
81
+
82
+ Returns:
83
+ list: List of text chunks
84
+ """
85
+ words = text.split()
86
+ chunks = []
87
+
88
+ i = 0
89
+ while i < len(words):
90
+ # Get chunk_size words starting from position i
91
+ chunk_words = words[i:i + chunk_size]
92
+ chunks.append(" ".join(chunk_words))
93
+
94
+ # Move forward by (chunk_size - overlap) words
95
+ i += max(1, chunk_size - overlap)
96
+
97
+ return chunks
backend/embed_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # backend/embed_utils.py
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+ from backend.qdrant_client import qdrant_client, COLLECTION_NAME
5
+ from qdrant_client.http.models import PointStruct
6
+ from backend.chunker import chunk_text
7
+ from backend.llm_client import query_llm
8
+ import uuid
9
+
10
+ model = SentenceTransformer("all-MiniLM-L6-v2")
11
+
12
+ def process_document(text):
13
+ chunks = chunk_text(text)
14
+ vectors = model.encode(chunks)
15
+ points = [
16
+ PointStruct(id=str(uuid.uuid4()), vector=vec.tolist(), payload={"text": chunk})
17
+ for chunk, vec in zip(chunks, vectors)
18
+ ]
19
+ qdrant_client.upsert(collection_name=COLLECTION_NAME, points=points)
20
+
21
+ def get_answer(question):
22
+ q_vector = model.encode([question])[0].tolist()
23
+ hits = qdrant_client.search(collection_name=COLLECTION_NAME, query_vector=q_vector, limit=3)
24
+ context = "\n".join(hit.payload["text"] for hit in hits)
25
+ return query_llm(prompt=question, context=context)
backend/llm_client.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from dotenv import load_dotenv
3
+ import os
4
+
5
+ load_dotenv()
6
+
7
+ api_key = os.getenv("OPENROUTER_API_KEY")
8
+ model_name = os.getenv("OPENROUTER_MODEL", "sarvamai/sarvam-m:free")
9
+
10
+ client = OpenAI(base_url="https://openrouter.ai/api/v1", api_key=api_key)
11
+
12
+ def query_llm(prompt: str, context: str = "") -> str:
13
+ response = client.chat.completions.create(
14
+ model=model_name,
15
+ messages=[
16
+ {"role": "system", "content": "You are a helpful assistant."},
17
+ {"role": "user", "content": f"Context:\n{context}\n\nQuestion:\n{prompt}"}
18
+ ]
19
+ )
20
+ return response.choices[0].message.content
backend/main.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from backend.embed_utils import process_document, get_answer
4
+ import os
5
+
6
+ app = FastAPI()
7
+
8
+ @app.post("/upload")
9
+ def upload_doc(file: UploadFile = File(...)):
10
+ try:
11
+ content = file.file.read().decode("utf-8")
12
+ process_document(content)
13
+ return {"status": "✅ Document processed and stored in vector DB."}
14
+ except UnicodeDecodeError:
15
+ raise HTTPException(status_code=400, detail="❌ File must be UTF-8 encoded plain text.")
16
+ except Exception as e:
17
+ return JSONResponse(status_code=500, content={"error": f"⚠️ Internal Server Error: {str(e)}"})
18
+
19
+ @app.get("/ask")
20
+ def ask_question(question: str):
21
+ try:
22
+ answer = get_answer(question)
23
+ return {"answer": answer}
24
+ except Exception as e:
25
+ return JSONResponse(status_code=500, content={"error": f"⚠️ Failed to retrieve answer: {str(e)}"})
backend/qdrant_client.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from qdrant_client.http.models import VectorParams, Distance
3
+ import os
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ QDRANT_HOST = os.getenv(
9
+ "QDRANT_HOST",
10
+ "https://af4d46cf-7554-4390-a899-d26487a92023.eu-central-1-0.aws.cloud.qdrant.io"
11
+ )
12
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
13
+ COLLECTION_NAME = "rag_collection"
14
+
15
+ qdrant_client = QdrantClient(
16
+ url=QDRANT_HOST,
17
+ api_key=QDRANT_API_KEY,
18
+ prefer_grpc=False,
19
+ check_compatibility=False,
20
+ )
21
+
22
+ # Optional: Print connection test
23
+ print("✅ Connected to Qdrant via REST")
24
+
25
+ # Create collection if not exists
26
+ existing_collections = qdrant_client.get_collections().collections
27
+ if COLLECTION_NAME not in [col.name for col in existing_collections]:
28
+ qdrant_client.recreate_collection(
29
+ collection_name=COLLECTION_NAME,
30
+ vectors_config=VectorParams(size=384, distance=Distance.COSINE),
31
+ )
frontend/app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit UI
2
+ import streamlit as st
3
+ import requests
4
+ import nltk
5
+ nltk.download('punkt')
6
+
7
+ st.title("📄 RAG Mini App")
8
+
9
+ st.subheader("Upload Document")
10
+ uploaded_file = st.file_uploader("Upload a .txt file", type=["txt"])
11
+
12
+ if uploaded_file:
13
+ content = uploaded_file.read().decode("utf-8")
14
+ response = requests.post("http://localhost:8000/upload", files={"file": ("doc.txt", content)})
15
+ st.success("Uploaded and processed.")
16
+
17
+ st.subheader("Ask a Question")
18
+ question = st.text_input("Enter your question")
19
+
20
+ if st.button("Get Answer") and question:
21
+ response = requests.get("http://localhost:8000/ask", params={"question": question})
22
+ st.markdown("**Answer:**")
23
+ st.write(response.json()["answer"])
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add your Python dependencies here
2
+ fastapi
3
+ uvicorn
4
+ qdrant-client
5
+ sentence-transformers
6
+ streamlit
7
+ openai
8
+ python-dotenv
9
+ requests
10
+ nltk
11
+ python-multipart
12
+