Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import chromadb
|
|
| 5 |
import tiktoken
|
| 6 |
from groq import Groq
|
| 7 |
from googleapiclient.discovery import build
|
| 8 |
-
from
|
| 9 |
|
| 10 |
# --- INITIALIZATION ---
|
| 11 |
app = FastAPI()
|
|
@@ -14,95 +14,44 @@ groq_client = Groq()
|
|
| 14 |
# Environment variables
|
| 15 |
google_api_key = os.getenv("GOOGLE_API_KEY")
|
| 16 |
search_engine_id = os.getenv("SEARCH_ENGINE_ID")
|
| 17 |
-
NOMIC_API_KEY = os.getenv("NOMIC_API_KEY")
|
| 18 |
|
| 19 |
# Google Custom Search setup
|
| 20 |
google_search_service = build("customsearch", "v1", developerKey=google_api_key)
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
print("
|
| 25 |
|
| 26 |
# Connect to local ChromaDB
|
| 27 |
client = chromadb.PersistentClient(path="./chroma_db")
|
| 28 |
collection = client.get_collection(name="legal_docs")
|
| 29 |
-
print(f"Connected to ChromaDB. Documents in collection: {collection.count()}")
|
| 30 |
|
| 31 |
-
# ---
|
| 32 |
class QueryRequest(BaseModel):
|
| 33 |
query: str
|
| 34 |
|
| 35 |
-
# ---
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
# 2. Check if the best result is relevant enough
|
| 56 |
-
SIMILARITY_THRESHOLD = 0.7
|
| 57 |
-
if results['distances'] and results['distances'][0] and results['distances'][0][0] < SIMILARITY_THRESHOLD:
|
| 58 |
-
print("INFO: Found relevant documents in local ChromaDB.")
|
| 59 |
-
context_chunks = results['documents'][0]
|
| 60 |
-
citations = [meta['source'] for meta in results['metadatas'][0]]
|
| 61 |
-
else:
|
| 62 |
-
# 3. Fallback to Google Custom Search
|
| 63 |
-
print("INFO: No relevant results locally. Using Google Search.")
|
| 64 |
-
search_results = google_search_service.cse().list(
|
| 65 |
-
q=request.query, cx=search_engine_id, num=3
|
| 66 |
-
).execute()
|
| 67 |
-
|
| 68 |
-
if not search_results.get('items'):
|
| 69 |
-
return {"answer": "I could not find any relevant information.", "citations": []}
|
| 70 |
-
|
| 71 |
-
context_chunks = [item.get('snippet', '') for item in search_results.get('items', [])]
|
| 72 |
-
citations = [item.get('link', '') for item in search_results.get('items', [])]
|
| 73 |
-
|
| 74 |
-
# 4. Build context respecting token limits
|
| 75 |
-
prompt_template = """
|
| 76 |
-
You are an expert legal AI assistant. Based ONLY on the following legal context, provide a concise answer.
|
| 77 |
-
CONTEXT: {context}
|
| 78 |
-
QUESTION: {query}
|
| 79 |
-
CONCISE ANSWER:
|
| 80 |
-
"""
|
| 81 |
-
TOKEN_LIMIT = 7000
|
| 82 |
-
for chunk in context_chunks:
|
| 83 |
-
temp_context = context + chunk + "\n\n---\n\n"
|
| 84 |
-
temp_prompt = prompt_template.format(context=temp_context, query=request.query)
|
| 85 |
-
if count_tokens(temp_prompt) <= TOKEN_LIMIT:
|
| 86 |
-
context = temp_context
|
| 87 |
-
else:
|
| 88 |
-
break
|
| 89 |
-
|
| 90 |
-
if not context:
|
| 91 |
-
return {"answer": "Information found is too long to process. Try a more specific query.", "citations": []}
|
| 92 |
-
|
| 93 |
-
final_prompt = prompt_template.format(context=context, query=request.query)
|
| 94 |
-
|
| 95 |
-
# 5. Generate answer with Groq AI
|
| 96 |
-
completion = groq_client.chat.completions.create(
|
| 97 |
-
model="openai/gpt-oss-120b",
|
| 98 |
-
messages=[{"role": "user", "content": final_prompt}],
|
| 99 |
-
temperature=0.2
|
| 100 |
)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
return {"answer": answer, "citations": list(set(citations))}
|
| 105 |
|
| 106 |
-
|
| 107 |
-
print(f"Error during query processing: {e}")
|
| 108 |
-
return {"error": "Failed to process the request."}
|
|
|
|
| 5 |
import tiktoken
|
| 6 |
from groq import Groq
|
| 7 |
from googleapiclient.discovery import build
|
| 8 |
+
from sentence_transformers import SentenceTransformer
|
| 9 |
|
| 10 |
# --- INITIALIZATION ---
|
| 11 |
app = FastAPI()
|
|
|
|
| 14 |
# Environment variables
|
| 15 |
google_api_key = os.getenv("GOOGLE_API_KEY")
|
| 16 |
search_engine_id = os.getenv("SEARCH_ENGINE_ID")
|
|
|
|
| 17 |
|
| 18 |
# Google Custom Search setup
|
| 19 |
google_search_service = build("customsearch", "v1", developerKey=google_api_key)
|
| 20 |
|
| 21 |
+
# SentenceTransformer model for embeddings (runs locally in HF Space)
|
| 22 |
+
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 23 |
+
print("Loaded SentenceTransformer embeddings.")
|
| 24 |
|
| 25 |
# Connect to local ChromaDB
|
| 26 |
client = chromadb.PersistentClient(path="./chroma_db")
|
| 27 |
collection = client.get_collection(name="legal_docs")
|
|
|
|
| 28 |
|
| 29 |
+
# --- REQUEST MODEL ---
|
| 30 |
class QueryRequest(BaseModel):
|
| 31 |
query: str
|
| 32 |
|
| 33 |
+
# --- API ROUTES ---
|
| 34 |
+
@app.post("/query")
|
| 35 |
+
async def query_api(request: QueryRequest):
|
| 36 |
+
# Create embedding for query
|
| 37 |
+
query_embedding = embedder.encode(request.query).tolist()
|
| 38 |
+
|
| 39 |
+
# Search in ChromaDB
|
| 40 |
+
results = collection.query(query_embeddings=[query_embedding], n_results=3)
|
| 41 |
+
|
| 42 |
+
if results and results["documents"]:
|
| 43 |
+
context_docs = [doc for sublist in results["documents"] for doc in sublist]
|
| 44 |
+
response_text = f"Relevant documents found:\n{context_docs}"
|
| 45 |
+
else:
|
| 46 |
+
# fallback to Google search
|
| 47 |
+
response_text = "No relevant local docs found. Searching externally..."
|
| 48 |
+
google_results = (
|
| 49 |
+
google_search_service.cse()
|
| 50 |
+
.list(q=request.query, cx=search_engine_id, num=3)
|
| 51 |
+
.execute()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
)
|
| 53 |
+
response_text += "\n\nExternal sources:\n"
|
| 54 |
+
for item in google_results.get("items", []):
|
| 55 |
+
response_text += f"- {item['title']}: {item['link']}\n"
|
|
|
|
| 56 |
|
| 57 |
+
return {"answer": response_text}
|
|
|
|
|
|