Spaces:
Runtime error
Runtime error
Upload main.py
Browse files
main.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# main.py (Final Version with Updated Response Structure)
|
| 2 |
+
import os
|
| 3 |
+
import chromadb
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
import tiktoken
|
| 9 |
+
from groq import Groq
|
| 10 |
+
from googleapiclient.discovery import build
|
| 11 |
+
|
| 12 |
+
# --- INITIALIZATION ---
|
| 13 |
+
load_dotenv()
|
| 14 |
+
app = FastAPI()
|
| 15 |
+
groq_client = Groq()
|
| 16 |
+
|
| 17 |
+
google_api_key = os.getenv("GOOGLE_API_KEY")
|
| 18 |
+
search_engine_id = os.getenv("SEARCH_ENGINE_ID")
|
| 19 |
+
google_search_service = build("customsearch", "v1", developerKey=google_api_key)
|
| 20 |
+
|
| 21 |
+
print("Loading embedding model: 'nomic-ai/nomic-embed-text-v1.5'...")
|
| 22 |
+
model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
|
| 23 |
+
print("Model loaded.")
|
| 24 |
+
|
| 25 |
+
client = chromadb.PersistentClient(path="./chroma_db")
|
| 26 |
+
collection = client.get_collection(name="legal_docs")
|
| 27 |
+
print(f"Connected to ChromaDB. Documents in collection: {collection.count()}")
|
| 28 |
+
|
| 29 |
+
# --- DATA MODELS ---
|
| 30 |
+
class QueryRequest(BaseModel):
|
| 31 |
+
query: str
|
| 32 |
+
|
| 33 |
+
# --- TOKENIZER FUNCTION ---
|
| 34 |
+
def count_tokens(text, model="gpt-4"):
|
| 35 |
+
encoding = tiktoken.encoding_for_model(model)
|
| 36 |
+
return len(encoding.encode(text))
|
| 37 |
+
|
| 38 |
+
# --- API ENDPOINT ---
|
| 39 |
+
@app.post("/api/ai/research/query")
|
| 40 |
+
async def research_query(request: QueryRequest):
|
| 41 |
+
try:
|
| 42 |
+
context = ""
|
| 43 |
+
citations = []
|
| 44 |
+
|
| 45 |
+
# 1. Search local ChromaDB first
|
| 46 |
+
query_embedding = model.encode(request.query).tolist()
|
| 47 |
+
results = collection.query(
|
| 48 |
+
query_embeddings=[query_embedding],
|
| 49 |
+
n_results=4,
|
| 50 |
+
include=['documents', 'metadatas', 'distances']
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# 2. Check if the best result is relevant enough
|
| 54 |
+
SIMILARITY_THRESHOLD = 0.7
|
| 55 |
+
if results['distances'] and results['distances'][0] and results['distances'][0][0] < SIMILARITY_THRESHOLD:
|
| 56 |
+
print("INFO: Found relevant documents in local ChromaDB.")
|
| 57 |
+
context_chunks = results['documents'][0]
|
| 58 |
+
citations = [meta['source'] for meta in results['metadatas'][0]]
|
| 59 |
+
else:
|
| 60 |
+
# 3. If not, fallback to Google Custom Search
|
| 61 |
+
print(f"INFO: No relevant results found locally. Falling back to Google Search.")
|
| 62 |
+
search_results = google_search_service.cse().list(
|
| 63 |
+
q=request.query, cx=search_engine_id, num=4
|
| 64 |
+
).execute()
|
| 65 |
+
|
| 66 |
+
if not search_results.get('items'):
|
| 67 |
+
return {"answer": "I could not find any relevant information to answer your question.", "citations": []}
|
| 68 |
+
|
| 69 |
+
context_chunks = [item.get('snippet', '') for item in search_results.get('items', [])]
|
| 70 |
+
citations = [item.get('link', '') for item in search_results.get('items', [])]
|
| 71 |
+
|
| 72 |
+
prompt_template = """
|
| 73 |
+
You are an expert legal AI assistant. Based ONLY on the following legal context, provide a detailed answer to the user's question. Structure your answer in multiple paragraphs and use bullet points for key points. Each bullet should be clear and informative. Do not use any outside knowledge.
|
| 74 |
+
|
| 75 |
+
CONTEXT:
|
| 76 |
+
{context}
|
| 77 |
+
|
| 78 |
+
QUESTION:
|
| 79 |
+
{query}
|
| 80 |
+
|
| 81 |
+
DETAILED ANSWER (use paragraphs and bullet points):
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
TOKEN_LIMIT = 7000
|
| 85 |
+
for chunk in context_chunks:
|
| 86 |
+
temp_context = context + chunk + "\n\n---\n\n"
|
| 87 |
+
temp_prompt = prompt_template.format(context=temp_context, query=request.query)
|
| 88 |
+
if count_tokens(temp_prompt) <= TOKEN_LIMIT:
|
| 89 |
+
context = temp_context
|
| 90 |
+
else:
|
| 91 |
+
break
|
| 92 |
+
|
| 93 |
+
if not context:
|
| 94 |
+
return {"answer": "I found some information, but it was too long to process. Please try a more specific query.", "citations": []}
|
| 95 |
+
|
| 96 |
+
final_prompt = prompt_template.format(context=context, query=request.query)
|
| 97 |
+
|
| 98 |
+
# 4. Generate answer using the chosen context
|
| 99 |
+
completion = groq_client.chat.completions.create(
|
| 100 |
+
model="openai/gpt-oss-120b",
|
| 101 |
+
messages=[{"role": "user", "content": final_prompt}],
|
| 102 |
+
temperature=0.2
|
| 103 |
+
)
|
| 104 |
+
answer = completion.choices[0].message.content
|
| 105 |
+
|
| 106 |
+
# --- THIS IS THE UPDATED RETURN STATEMENT ---
|
| 107 |
+
# Ensure only 4 unique citations are returned
|
| 108 |
+
citations = list(dict.fromkeys(citations))[:4]
|
| 109 |
+
return { "answer": answer, "citations": citations }
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error during query processing: {e}")
|
| 113 |
+
return {"error": "Failed to process the request."}
|