Anurag Shirke commited on
Commit ·
eefb354
1
Parent(s): 2688161
Adding Main functionality for backend with endpoints(query,upload)
Browse files- src/core/llm.py +39 -0
- src/core/models.py +8 -0
- src/core/processing.py +25 -0
- src/core/vector_store.py +46 -0
- src/main.py +101 -1
src/core/llm.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ollama
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# --- Ollama Client Initialization ---
|
| 5 |
+
|
| 6 |
+
def get_ollama_client():
|
| 7 |
+
"""Initializes and returns the Ollama client."""
|
| 8 |
+
host = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
| 9 |
+
return ollama.Client(host=host)
|
| 10 |
+
|
| 11 |
+
# --- Prompt Generation ---
|
| 12 |
+
|
| 13 |
+
def format_prompt(query: str, context: list[dict]) -> str:
|
| 14 |
+
"""Formats the prompt for the LLM with the retrieved context."""
|
| 15 |
+
context_str = "\n".join([item['payload']['text'] for item in context])
|
| 16 |
+
prompt = f"""**Instruction**:
|
| 17 |
+
Answer the user's query based *only* on the provided context.
|
| 18 |
+
If the context does not contain the answer, state that you cannot answer the question with the given information.
|
| 19 |
+
Do not use any prior knowledge.
|
| 20 |
+
|
| 21 |
+
**Context**:
|
| 22 |
+
{context_str}
|
| 23 |
+
|
| 24 |
+
**Query**:
|
| 25 |
+
{query}
|
| 26 |
+
|
| 27 |
+
**Answer**:
|
| 28 |
+
"""
|
| 29 |
+
return prompt
|
| 30 |
+
|
| 31 |
+
# --- LLM Interaction ---
|
| 32 |
+
|
| 33 |
+
def generate_response(client: ollama.Client, model: str, prompt: str):
|
| 34 |
+
"""Generates a response from the LLM."""
|
| 35 |
+
response = client.chat(
|
| 36 |
+
model=model,
|
| 37 |
+
messages=[{"role": "user", "content": prompt}]
|
| 38 |
+
)
|
| 39 |
+
return response['message']['content']
|
src/core/models.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
class QueryRequest(BaseModel):
|
| 4 |
+
query: str
|
| 5 |
+
|
| 6 |
+
class QueryResponse(BaseModel):
|
| 7 |
+
answer: str
|
| 8 |
+
source_documents: list[dict]
|
src/core/processing.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fitz # PyMuPDF
|
| 2 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
|
| 5 |
+
def parse_pdf(file_path: str) -> str:
|
| 6 |
+
"""Extracts text from a PDF file."""
|
| 7 |
+
doc = fitz.open(file_path)
|
| 8 |
+
text = ""
|
| 9 |
+
for page in doc:
|
| 10 |
+
text += page.get_text()
|
| 11 |
+
doc.close()
|
| 12 |
+
return text
|
| 13 |
+
|
| 14 |
+
def chunk_text(text: str) -> list[str]:
|
| 15 |
+
"""Splits text into smaller chunks."""
|
| 16 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 17 |
+
chunk_size=1000,
|
| 18 |
+
chunk_overlap=200,
|
| 19 |
+
length_function=len
|
| 20 |
+
)
|
| 21 |
+
return text_splitter.split_text(text)
|
| 22 |
+
|
| 23 |
+
def get_embedding_model(model_name: str = 'all-MiniLM-L6-v2'):
|
| 24 |
+
"""Loads the sentence-transformer model."""
|
| 25 |
+
return SentenceTransformer(model_name)
|
src/core/vector_store.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from qdrant_client import QdrantClient, models
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# --- Qdrant Client Initialization ---
|
| 5 |
+
|
| 6 |
+
def get_qdrant_client():
|
| 7 |
+
"""Initializes and returns the Qdrant client."""
|
| 8 |
+
# Get Qdrant host from environment variable, default to localhost if not set
|
| 9 |
+
host = os.environ.get("QDRANT_HOST", "localhost")
|
| 10 |
+
client = QdrantClient(host=host, port=6333)
|
| 11 |
+
return client
|
| 12 |
+
|
| 13 |
+
# --- Collection Management ---
|
| 14 |
+
|
| 15 |
+
def create_collection_if_not_exists(client: QdrantClient, collection_name: str, vector_size: int):
|
| 16 |
+
"""Creates a Qdrant collection if it doesn't already exist."""
|
| 17 |
+
try:
|
| 18 |
+
client.get_collection(collection_name=collection_name)
|
| 19 |
+
except Exception: # If the collection does not exist, this will raise an exception
|
| 20 |
+
client.create_collection(
|
| 21 |
+
collection_name=collection_name,
|
| 22 |
+
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# --- Vector Operations ---
|
| 26 |
+
|
| 27 |
+
def upsert_vectors(client: QdrantClient, collection_name: str, vectors, payloads):
|
| 28 |
+
"""Upserts vectors and their payloads into the specified collection."""
|
| 29 |
+
client.upsert(
|
| 30 |
+
collection_name=collection_name,
|
| 31 |
+
points=models.Batch(
|
| 32 |
+
ids=None, # Let Qdrant assign IDs
|
| 33 |
+
vectors=vectors,
|
| 34 |
+
payloads=payloads
|
| 35 |
+
),
|
| 36 |
+
wait=True
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def search_vectors(client: QdrantClient, collection_name: str, query_vector, limit: int = 5):
|
| 40 |
+
"""Searches for similar vectors in the collection."""
|
| 41 |
+
return client.search(
|
| 42 |
+
collection_name=collection_name,
|
| 43 |
+
query_vector=query_vector,
|
| 44 |
+
limit=limit,
|
| 45 |
+
with_payload=True
|
| 46 |
+
)
|
src/main.py
CHANGED
|
@@ -1,7 +1,107 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
app = FastAPI()
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
@app.get("/health")
|
| 6 |
def health_check():
|
| 7 |
return {"status": "ok"}
|
|
|
|
| 1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 2 |
+
import shutil
|
| 3 |
+
import os
|
| 4 |
+
from .core.processing import parse_pdf, chunk_text, get_embedding_model
|
| 5 |
+
from .core.vector_store import get_qdrant_client, create_collection_if_not_exists, upsert_vectors, search_vectors
|
| 6 |
+
from .core.llm import get_ollama_client, format_prompt, generate_response
|
| 7 |
+
from .core.models import QueryRequest, QueryResponse
|
| 8 |
|
| 9 |
app = FastAPI()
|
| 10 |
|
| 11 |
+
# --- Constants ---
|
| 12 |
+
UPLOADS_DIR = "uploads"
|
| 13 |
+
QDRANT_COLLECTION_NAME = "knowledge_base"
|
| 14 |
+
OLLAMA_MODEL = "llama3"
|
| 15 |
+
|
| 16 |
+
# --- Application Startup ---
|
| 17 |
+
# Create uploads directory if it doesn't exist
|
| 18 |
+
if not os.path.exists(UPLOADS_DIR):
|
| 19 |
+
os.makedirs(UPLOADS_DIR)
|
| 20 |
+
|
| 21 |
+
# Load models and clients on startup
|
| 22 |
+
embedding_model = get_embedding_model()
|
| 23 |
+
qdrant_client = get_qdrant_client()
|
| 24 |
+
ollama_client = get_ollama_client()
|
| 25 |
+
|
| 26 |
+
# Get the size of the embeddings from the model
|
| 27 |
+
embedding_size = embedding_model.get_sentence_embedding_dimension()
|
| 28 |
+
|
| 29 |
+
# Create the Qdrant collection if it doesn't exist
|
| 30 |
+
create_collection_if_not_exists(qdrant_client, QDRANT_COLLECTION_NAME, embedding_size)
|
| 31 |
+
|
| 32 |
+
# --- API Endpoints ---
|
| 33 |
+
@app.post("/upload")
|
| 34 |
+
def upload_file(file: UploadFile = File(...)):
|
| 35 |
+
if not file.filename.lower().endswith(".pdf"):
|
| 36 |
+
raise HTTPException(status_code=400, detail="Invalid file type. Only PDFs are supported.")
|
| 37 |
+
|
| 38 |
+
file_path = os.path.join(UPLOADS_DIR, file.filename)
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
with open(file_path, "wb") as buffer:
|
| 42 |
+
shutil.copyfileobj(file.file, buffer)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
raise HTTPException(status_code=500, detail=f"Error saving file: {e}")
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
text = parse_pdf(file_path)
|
| 48 |
+
if not text.strip():
|
| 49 |
+
raise HTTPException(status_code=400, detail="Could not extract text from the PDF.")
|
| 50 |
+
|
| 51 |
+
chunks = chunk_text(text)
|
| 52 |
+
embeddings = embedding_model.encode(chunks)
|
| 53 |
+
payloads = [{"text": chunk, "source": file.filename} for chunk in chunks]
|
| 54 |
+
|
| 55 |
+
upsert_vectors(qdrant_client, QDRANT_COLLECTION_NAME, embeddings, payloads)
|
| 56 |
+
|
| 57 |
+
except Exception as e:
|
| 58 |
+
os.remove(file_path)
|
| 59 |
+
raise HTTPException(status_code=500, detail=f"Error processing and storing file: {e}")
|
| 60 |
+
finally:
|
| 61 |
+
if os.path.exists(file_path):
|
| 62 |
+
os.remove(file_path)
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
"filename": file.filename,
|
| 66 |
+
"message": f"Successfully uploaded, processed, and stored.",
|
| 67 |
+
"num_chunks_stored": len(chunks)
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
@app.post("/query", response_model=QueryResponse)
|
| 71 |
+
def query_knowledge_base(request: QueryRequest):
|
| 72 |
+
try:
|
| 73 |
+
# 1. Embed the user's query
|
| 74 |
+
query_embedding = embedding_model.encode(request.query)
|
| 75 |
+
|
| 76 |
+
# 2. Search for relevant documents in Qdrant
|
| 77 |
+
search_results = search_vectors(
|
| 78 |
+
client=qdrant_client,
|
| 79 |
+
collection_name=QDRANT_COLLECTION_NAME,
|
| 80 |
+
query_vector=query_embedding,
|
| 81 |
+
limit=3 # Retrieve top 3 most relevant chunks
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 3. Format the prompt for the LLM
|
| 85 |
+
prompt = format_prompt(request.query, search_results)
|
| 86 |
+
|
| 87 |
+
# 4. Generate a response from the LLM
|
| 88 |
+
answer = generate_response(ollama_client, OLLAMA_MODEL, prompt)
|
| 89 |
+
|
| 90 |
+
# 5. Extract source documents for citation
|
| 91 |
+
source_documents = [
|
| 92 |
+
{
|
| 93 |
+
"source": result.payload["source"],
|
| 94 |
+
"text": result.payload["text"],
|
| 95 |
+
"score": result.score
|
| 96 |
+
}
|
| 97 |
+
for result in search_results
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
return QueryResponse(answer=answer, source_documents=source_documents)
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
raise HTTPException(status_code=500, detail=f"Error during query: {e}")
|
| 104 |
+
|
| 105 |
@app.get("/health")
|
| 106 |
def health_check():
|
| 107 |
return {"status": "ok"}
|