Aditya20040422 commited on
Commit
f802a5e
·
verified ·
1 Parent(s): baf004f

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +113 -0
main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py (Final Version with Updated Response Structure)
2
+ import os
3
+ import chromadb
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from dotenv import load_dotenv
7
+ from sentence_transformers import SentenceTransformer
8
+ import tiktoken
9
+ from groq import Groq
10
+ from googleapiclient.discovery import build
11
+
12
+ # --- INITIALIZATION ---
13
+ load_dotenv()
14
+ app = FastAPI()
15
+ groq_client = Groq()
16
+
17
+ google_api_key = os.getenv("GOOGLE_API_KEY")
18
+ search_engine_id = os.getenv("SEARCH_ENGINE_ID")
19
+ google_search_service = build("customsearch", "v1", developerKey=google_api_key)
20
+
21
+ print("Loading embedding model: 'nomic-ai/nomic-embed-text-v1.5'...")
22
+ model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
23
+ print("Model loaded.")
24
+
25
+ client = chromadb.PersistentClient(path="./chroma_db")
26
+ collection = client.get_collection(name="legal_docs")
27
+ print(f"Connected to ChromaDB. Documents in collection: {collection.count()}")
28
+
29
+ # --- DATA MODELS ---
30
+ class QueryRequest(BaseModel):
31
+ query: str
32
+
33
+ # --- TOKENIZER FUNCTION ---
34
+ def count_tokens(text, model="gpt-4"):
35
+ encoding = tiktoken.encoding_for_model(model)
36
+ return len(encoding.encode(text))
37
+
38
+ # --- API ENDPOINT ---
39
+ @app.post("/api/ai/research/query")
40
+ async def research_query(request: QueryRequest):
41
+ try:
42
+ context = ""
43
+ citations = []
44
+
45
+ # 1. Search local ChromaDB first
46
+ query_embedding = model.encode(request.query).tolist()
47
+ results = collection.query(
48
+ query_embeddings=[query_embedding],
49
+ n_results=4,
50
+ include=['documents', 'metadatas', 'distances']
51
+ )
52
+
53
+ # 2. Check if the best result is relevant enough
54
+ SIMILARITY_THRESHOLD = 0.7
55
+ if results['distances'] and results['distances'][0] and results['distances'][0][0] < SIMILARITY_THRESHOLD:
56
+ print("INFO: Found relevant documents in local ChromaDB.")
57
+ context_chunks = results['documents'][0]
58
+ citations = [meta['source'] for meta in results['metadatas'][0]]
59
+ else:
60
+ # 3. If not, fallback to Google Custom Search
61
+ print(f"INFO: No relevant results found locally. Falling back to Google Search.")
62
+ search_results = google_search_service.cse().list(
63
+ q=request.query, cx=search_engine_id, num=4
64
+ ).execute()
65
+
66
+ if not search_results.get('items'):
67
+ return {"answer": "I could not find any relevant information to answer your question.", "citations": []}
68
+
69
+ context_chunks = [item.get('snippet', '') for item in search_results.get('items', [])]
70
+ citations = [item.get('link', '') for item in search_results.get('items', [])]
71
+
72
+ prompt_template = """
73
+ You are an expert legal AI assistant. Based ONLY on the following legal context, provide a detailed answer to the user's question. Structure your answer in multiple paragraphs and use bullet points for key points. Each bullet should be clear and informative. Do not use any outside knowledge.
74
+
75
+ CONTEXT:
76
+ {context}
77
+
78
+ QUESTION:
79
+ {query}
80
+
81
+ DETAILED ANSWER (use paragraphs and bullet points):
82
+ """
83
+
84
+ TOKEN_LIMIT = 7000
85
+ for chunk in context_chunks:
86
+ temp_context = context + chunk + "\n\n---\n\n"
87
+ temp_prompt = prompt_template.format(context=temp_context, query=request.query)
88
+ if count_tokens(temp_prompt) <= TOKEN_LIMIT:
89
+ context = temp_context
90
+ else:
91
+ break
92
+
93
+ if not context:
94
+ return {"answer": "I found some information, but it was too long to process. Please try a more specific query.", "citations": []}
95
+
96
+ final_prompt = prompt_template.format(context=context, query=request.query)
97
+
98
+ # 4. Generate answer using the chosen context
99
+ completion = groq_client.chat.completions.create(
100
+ model="openai/gpt-oss-120b",
101
+ messages=[{"role": "user", "content": final_prompt}],
102
+ temperature=0.2
103
+ )
104
+ answer = completion.choices[0].message.content
105
+
106
+ # --- THIS IS THE UPDATED RETURN STATEMENT ---
107
+ # Ensure only 4 unique citations are returned
108
+ citations = list(dict.fromkeys(citations))[:4]
109
+ return { "answer": answer, "citations": citations }
110
+
111
+ except Exception as e:
112
+ print(f"Error during query processing: {e}")
113
+ return {"error": "Failed to process the request."}