Anurag Shirke commited on
Commit
eefb354
·
1 Parent(s): 2688161

Adding Main functionality for backend with endpoints(query,upload)

Browse files
src/core/llm.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ollama
2
+ import os
3
+
4
+ # --- Ollama Client Initialization ---
5
+
6
+ def get_ollama_client():
7
+ """Initializes and returns the Ollama client."""
8
+ host = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
9
+ return ollama.Client(host=host)
10
+
11
+ # --- Prompt Generation ---
12
+
13
+ def format_prompt(query: str, context: list[dict]) -> str:
14
+ """Formats the prompt for the LLM with the retrieved context."""
15
+ context_str = "\n".join([item['payload']['text'] for item in context])
16
+ prompt = f"""**Instruction**:
17
+ Answer the user's query based *only* on the provided context.
18
+ If the context does not contain the answer, state that you cannot answer the question with the given information.
19
+ Do not use any prior knowledge.
20
+
21
+ **Context**:
22
+ {context_str}
23
+
24
+ **Query**:
25
+ {query}
26
+
27
+ **Answer**:
28
+ """
29
+ return prompt
30
+
31
+ # --- LLM Interaction ---
32
+
33
+ def generate_response(client: ollama.Client, model: str, prompt: str):
34
+ """Generates a response from the LLM."""
35
+ response = client.chat(
36
+ model=model,
37
+ messages=[{"role": "user", "content": prompt}]
38
+ )
39
+ return response['message']['content']
src/core/models.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ class QueryRequest(BaseModel):
4
+ query: str
5
+
6
+ class QueryResponse(BaseModel):
7
+ answer: str
8
+ source_documents: list[dict]
src/core/processing.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fitz # PyMuPDF
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ def parse_pdf(file_path: str) -> str:
6
+ """Extracts text from a PDF file."""
7
+ doc = fitz.open(file_path)
8
+ text = ""
9
+ for page in doc:
10
+ text += page.get_text()
11
+ doc.close()
12
+ return text
13
+
14
+ def chunk_text(text: str) -> list[str]:
15
+ """Splits text into smaller chunks."""
16
+ text_splitter = RecursiveCharacterTextSplitter(
17
+ chunk_size=1000,
18
+ chunk_overlap=200,
19
+ length_function=len
20
+ )
21
+ return text_splitter.split_text(text)
22
+
23
+ def get_embedding_model(model_name: str = 'all-MiniLM-L6-v2'):
24
+ """Loads the sentence-transformer model."""
25
+ return SentenceTransformer(model_name)
src/core/vector_store.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient, models
2
+ import os
3
+
4
+ # --- Qdrant Client Initialization ---
5
+
6
+ def get_qdrant_client():
7
+ """Initializes and returns the Qdrant client."""
8
+ # Get Qdrant host from environment variable, default to localhost if not set
9
+ host = os.environ.get("QDRANT_HOST", "localhost")
10
+ client = QdrantClient(host=host, port=6333)
11
+ return client
12
+
13
+ # --- Collection Management ---
14
+
15
+ def create_collection_if_not_exists(client: QdrantClient, collection_name: str, vector_size: int):
16
+ """Creates a Qdrant collection if it doesn't already exist."""
17
+ try:
18
+ client.get_collection(collection_name=collection_name)
19
+ except Exception: # If the collection does not exist, this will raise an exception
20
+ client.create_collection(
21
+ collection_name=collection_name,
22
+ vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE),
23
+ )
24
+
25
+ # --- Vector Operations ---
26
+
27
+ def upsert_vectors(client: QdrantClient, collection_name: str, vectors, payloads):
28
+ """Upserts vectors and their payloads into the specified collection."""
29
+ client.upsert(
30
+ collection_name=collection_name,
31
+ points=models.Batch(
32
+ ids=None, # Let Qdrant assign IDs
33
+ vectors=vectors,
34
+ payloads=payloads
35
+ ),
36
+ wait=True
37
+ )
38
+
39
+ def search_vectors(client: QdrantClient, collection_name: str, query_vector, limit: int = 5):
40
+ """Searches for similar vectors in the collection."""
41
+ return client.search(
42
+ collection_name=collection_name,
43
+ query_vector=query_vector,
44
+ limit=limit,
45
+ with_payload=True
46
+ )
src/main.py CHANGED
@@ -1,7 +1,107 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  @app.get("/health")
6
  def health_check():
7
  return {"status": "ok"}
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ import shutil
3
+ import os
4
+ from .core.processing import parse_pdf, chunk_text, get_embedding_model
5
+ from .core.vector_store import get_qdrant_client, create_collection_if_not_exists, upsert_vectors, search_vectors
6
+ from .core.llm import get_ollama_client, format_prompt, generate_response
7
+ from .core.models import QueryRequest, QueryResponse
8
 
9
  app = FastAPI()
10
 
11
+ # --- Constants ---
12
+ UPLOADS_DIR = "uploads"
13
+ QDRANT_COLLECTION_NAME = "knowledge_base"
14
+ OLLAMA_MODEL = "llama3"
15
+
16
+ # --- Application Startup ---
17
+ # Create uploads directory if it doesn't exist
18
+ if not os.path.exists(UPLOADS_DIR):
19
+ os.makedirs(UPLOADS_DIR)
20
+
21
+ # Load models and clients on startup
22
+ embedding_model = get_embedding_model()
23
+ qdrant_client = get_qdrant_client()
24
+ ollama_client = get_ollama_client()
25
+
26
+ # Get the size of the embeddings from the model
27
+ embedding_size = embedding_model.get_sentence_embedding_dimension()
28
+
29
+ # Create the Qdrant collection if it doesn't exist
30
+ create_collection_if_not_exists(qdrant_client, QDRANT_COLLECTION_NAME, embedding_size)
31
+
32
+ # --- API Endpoints ---
33
+ @app.post("/upload")
34
+ def upload_file(file: UploadFile = File(...)):
35
+ if not file.filename.lower().endswith(".pdf"):
36
+ raise HTTPException(status_code=400, detail="Invalid file type. Only PDFs are supported.")
37
+
38
+ file_path = os.path.join(UPLOADS_DIR, file.filename)
39
+
40
+ try:
41
+ with open(file_path, "wb") as buffer:
42
+ shutil.copyfileobj(file.file, buffer)
43
+ except Exception as e:
44
+ raise HTTPException(status_code=500, detail=f"Error saving file: {e}")
45
+
46
+ try:
47
+ text = parse_pdf(file_path)
48
+ if not text.strip():
49
+ raise HTTPException(status_code=400, detail="Could not extract text from the PDF.")
50
+
51
+ chunks = chunk_text(text)
52
+ embeddings = embedding_model.encode(chunks)
53
+ payloads = [{"text": chunk, "source": file.filename} for chunk in chunks]
54
+
55
+ upsert_vectors(qdrant_client, QDRANT_COLLECTION_NAME, embeddings, payloads)
56
+
57
+ except Exception as e:
58
+ os.remove(file_path)
59
+ raise HTTPException(status_code=500, detail=f"Error processing and storing file: {e}")
60
+ finally:
61
+ if os.path.exists(file_path):
62
+ os.remove(file_path)
63
+
64
+ return {
65
+ "filename": file.filename,
66
+ "message": f"Successfully uploaded, processed, and stored.",
67
+ "num_chunks_stored": len(chunks)
68
+ }
69
+
70
+ @app.post("/query", response_model=QueryResponse)
71
+ def query_knowledge_base(request: QueryRequest):
72
+ try:
73
+ # 1. Embed the user's query
74
+ query_embedding = embedding_model.encode(request.query)
75
+
76
+ # 2. Search for relevant documents in Qdrant
77
+ search_results = search_vectors(
78
+ client=qdrant_client,
79
+ collection_name=QDRANT_COLLECTION_NAME,
80
+ query_vector=query_embedding,
81
+ limit=3 # Retrieve top 3 most relevant chunks
82
+ )
83
+
84
+ # 3. Format the prompt for the LLM
85
+ prompt = format_prompt(request.query, search_results)
86
+
87
+ # 4. Generate a response from the LLM
88
+ answer = generate_response(ollama_client, OLLAMA_MODEL, prompt)
89
+
90
+ # 5. Extract source documents for citation
91
+ source_documents = [
92
+ {
93
+ "source": result.payload["source"],
94
+ "text": result.payload["text"],
95
+ "score": result.score
96
+ }
97
+ for result in search_results
98
+ ]
99
+
100
+ return QueryResponse(answer=answer, source_documents=source_documents)
101
+
102
+ except Exception as e:
103
+ raise HTTPException(status_code=500, detail=f"Error during query: {e}")
104
+
105
  @app.get("/health")
106
  def health_check():
107
  return {"status": "ok"}