audit_assistant / src /gemini /file_search.py
akryldigital's picture
langchain.docstore Document
a533525 verified
"""
Gemini File Search Client
Handles interaction with Google Gemini File Search API for RAG.
"""
import os
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
logger = logging.getLogger(__name__)
try:
from google import genai
from google.genai import types
GEMINI_AVAILABLE = True
except ImportError:
GEMINI_AVAILABLE = False
try:
from langchain.docstore.document import Document
except:
from langchain_core.documents import Document
@dataclass
class GeminiFileSearchResult:
"""Result from Gemini File Search query"""
answer: str
sources: List[Dict[str, Any]] # List of document references
grounding_metadata: Optional[Dict[str, Any]] = None
query: str = ""
class GeminiFileSearchClient:
"""Client for interacting with Gemini File Search API"""
def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
"""
Initialize Gemini File Search client.
Args:
api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
"""
if not GEMINI_AVAILABLE:
raise ImportError("google-genai package not installed. Install with: pip install google-genai")
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
store_name_raw = store_name or os.getenv("GEMINI_FILESTORE_NAME")
if not store_name_raw:
raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
# Normalize store name: API expects the FULL path format (fileSearchStores/xxx)
# If just the ID is provided, construct the full path
if store_name_raw.startswith("fileSearchStores/"):
self.store_name = store_name_raw # Already full path
else:
# Just the ID provided, construct full path
self.store_name = f"fileSearchStores/{store_name_raw}"
logger.info(f"πŸ“¦ Using file search store: {self.store_name}")
self.client = genai.Client(api_key=self.api_key)
self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
def search(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
model: Optional[str] = None
) -> GeminiFileSearchResult:
"""
Search using Gemini File Search.
Args:
query: User query
filters: Optional filters (year, source, district, etc.)
model: Model to use (defaults to gemini-2.5-flash)
Returns:
GeminiFileSearchResult with answer and sources
"""
model = model or self.model
# Build filter context for the query if filters are provided
# Gemini File Search doesn't support explicit filters in the API,
# so we add them as context in the query
filter_context = ""
if filters:
filter_parts = []
if filters.get("year"):
years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
filter_parts.append(f"Year: {', '.join(years)}")
if filters.get("sources"):
sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
filter_parts.append(f"Source: {', '.join(sources)}")
if filters.get("district"):
districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
filter_parts.append(f"District: {', '.join(districts)}")
if filters.get("filenames"):
filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
filter_parts.append(f"Filename: {', '.join(filenames)}")
if filter_parts:
filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
# Combine query with filter context
# Add comprehensive system instructions similar to multi-agent system
system_instructions = """You are a helpful audit report assistant specialized in analyzing government audit reports from Uganda's Office of the Auditor General.
CRITICAL RULES:
1. **NO HALLUCINATION**: Only use information that is explicitly stated in the retrieved documents. Do not make up facts, numbers, or details.
2. **Document References**: Always cite which documents you're using with [Doc i] references at the end of sentences that use specific information.
3. **Formatting**: Structure your response with clear paragraphs, bullet points, or sections for readability.
4. **Accuracy**: If the retrieved documents don't contain the requested information, explicitly state "The retrieved documents do not contain information about [topic]."
5. **Years and Data**: Pay careful attention to years mentioned in documents. If a user asks about a specific year but documents show different years, explicitly state this.
6. **District/Source Names**: Use the exact district and source names as they appear in the document metadata (e.g., "Kalangala" not "Kalagala").
7. **Financial Data**: When providing financial figures, include the currency (UGX) and be precise about amounts.
8. **Conversational Tone**: Be helpful, clear, and conversational while maintaining accuracy.
IMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents."""
# Combine system instructions with query
full_query = f"{system_instructions}\n\nUser Question: {query}{filter_context}\n\nPlease provide a detailed, well-formatted response with proper document references."
try:
# Generate content with file search
# Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
# Try with full path format first, then fallback to just ID if needed
store_name_to_try = self.store_name
try:
# Try the documented format first with full path
response = self.client.models.generate_content(
model=model,
contents=full_query,
config=types.GenerateContentConfig(
tools=[
types.Tool(
file_search=types.FileSearch(
file_search_store_names=[store_name_to_try]
)
)
]
)
)
except Exception as api_error:
error_str = str(api_error).lower()
# If format error, try with just the ID (without fileSearchStores/ prefix)
if 'format' in error_str or 'invalid' in error_str or 'too long' in error_str:
logger.warning(f"Full path format failed, trying with just store ID: {api_error}")
# Extract just the ID part
if store_name_to_try.startswith("fileSearchStores/"):
store_id = store_name_to_try.split("/", 1)[1]
store_name_to_try = store_id
try:
response = self.client.models.generate_content(
model=model,
contents=full_query,
config=types.GenerateContentConfig(
tools=[
types.Tool(
file_search=types.FileSearch(
file_search_store_names=[store_name_to_try]
)
)
]
)
)
except Exception as e2:
raise Exception(f"Failed to call Gemini API with both formats. Full path error: {api_error}, ID-only error: {e2}")
else:
# Try alternative dict format
logger.warning(f"Primary API format failed, trying alternative: {api_error}")
try:
response = self.client.models.generate_content(
model=model,
contents=full_query,
tools=[{
"file_search": {
"file_search_store_names": [store_name_to_try]
}
}]
)
except Exception as e2:
raise Exception(f"Failed to call Gemini API: {e2}")
# Extract answer
answer = ""
if hasattr(response, 'text'):
answer = response.text
elif hasattr(response, 'candidates') and response.candidates:
# Try to get text from first candidate
candidate = response.candidates[0]
if hasattr(candidate, 'content') and candidate.content:
if hasattr(candidate.content, 'parts'):
text_parts = []
for part in candidate.content.parts:
if hasattr(part, 'text'):
text_parts.append(part.text)
answer = " ".join(text_parts)
elif isinstance(candidate.content, str):
answer = candidate.content
else:
answer = str(response)
# Extract grounding metadata (document references)
sources = []
grounding_metadata = None
logger.info(f"πŸ” Extracting sources from Gemini response...")
if hasattr(response, 'candidates') and response.candidates:
candidate = response.candidates[0]
logger.info(f" Found candidate, checking for grounding_metadata...")
# Get grounding metadata
if hasattr(candidate, 'grounding_metadata'):
grounding_metadata = candidate.grounding_metadata
logger.info(f" Found grounding_metadata: {type(grounding_metadata)}")
# Extract source documents from grounding metadata
# Handle different response formats
grounding_chunks = None
if hasattr(grounding_metadata, 'grounding_chunks'):
grounding_chunks = grounding_metadata.grounding_chunks
logger.info(f" Found grounding_chunks (attr): {len(grounding_chunks) if grounding_chunks else 0}")
elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
grounding_chunks = grounding_metadata['grounding_chunks']
logger.info(f" Found grounding_chunks (dict): {len(grounding_chunks) if grounding_chunks else 0}")
elif hasattr(grounding_metadata, '__dict__'):
# Try to access as object attributes
metadata_dict = grounding_metadata.__dict__
if 'grounding_chunks' in metadata_dict:
grounding_chunks = metadata_dict['grounding_chunks']
logger.info(f" Found grounding_chunks (__dict__): {len(grounding_chunks) if grounding_chunks else 0}")
if grounding_chunks:
logger.info(f" Processing {len(grounding_chunks)} grounding chunks...")
for idx, chunk in enumerate(grounding_chunks):
# Handle both object and dict formats
try:
if isinstance(chunk, dict):
chunk_data = chunk
else:
# Object format - convert to dict-like access
chunk_data = {}
if hasattr(chunk, 'chunk'):
chunk_obj = chunk.chunk
chunk_data['chunk'] = {
'text': getattr(chunk_obj, 'text', ''),
'file_name': getattr(chunk_obj, 'file_name', '')
}
if hasattr(chunk, 'relevance_score'):
score_obj = chunk.relevance_score
chunk_data['relevance_score'] = {
'score': getattr(score_obj, 'score', 0.0)
}
chunk_info = chunk_data.get('chunk', {})
text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
# Try to extract file URI and parse metadata from it
file_uri = chunk_info.get('file_uri', '') if isinstance(chunk_info, dict) else ''
# Also check for 'web' attribute (GroundingChunkData structure)
if hasattr(chunk, 'web') and chunk.web:
web_data = chunk.web
file_uri = getattr(web_data, 'file_uri', '') or file_uri
file_name = getattr(web_data, 'title', '') or getattr(web_data, 'filename', '') or file_name
text = getattr(web_data, 'text', '') or getattr(web_data, 'content', '') or text
# Check retrieved_context - this is where the actual data seems to be!
if hasattr(chunk, 'retrieved_context') and chunk.retrieved_context:
rc = chunk.retrieved_context
# Get text content
if hasattr(rc, 'text'):
text = getattr(rc, 'text', '') or text
# Get document name
if hasattr(rc, 'document_name'):
doc_name = getattr(rc, 'document_name', '')
if doc_name:
file_name = doc_name or file_name
# Fallback: Parse from string representation if we still don't have filename
if not file_name:
chunk_str = str(chunk)
import re
# Look for PDF filenames
pdf_match = re.search(r"([A-Za-z0-9\s_-]+\.pdf)", chunk_str)
if pdf_match:
file_name = pdf_match.group(1)
# Or look for title= pattern
if not file_name and 'title=' in chunk_str:
title_match = re.search(r"title=['\"]([^'\"]+)['\"]", chunk_str)
if title_match:
file_name = title_match.group(1)
if not file_name and file_uri:
# Extract filename from URI if available
file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
score_data = chunk_data.get('relevance_score', {})
score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
if text or file_name: # Only add if we have content
source_info = {
"content": text,
"filename": file_name,
"score": score,
"file_uri": file_uri,
}
sources.append(source_info)
logger.info(f"πŸ“„ Extracted source {idx+1}: {file_name} (score: {score:.3f}, content length: {len(text)})")
except Exception as e:
logger.warning(f"Error extracting chunk {idx+1} info: {e}")
import traceback
logger.debug(traceback.format_exc())
continue
else:
logger.warning(f" No grounding_chunks found in grounding_metadata")
else:
logger.warning(f" Candidate does not have grounding_metadata attribute")
# Also try to get file references from other parts of the response
# Sometimes Gemini includes file references in the response itself
if not sources or len(sources) == 0:
logger.info(f" No sources from grounding_metadata, trying alternative extraction...")
# Check if response has file references in other attributes
if hasattr(candidate, 'content') and candidate.content:
if hasattr(candidate.content, 'parts'):
for part in candidate.content.parts:
if hasattr(part, 'file_data'):
file_data = part.file_data
if hasattr(file_data, 'file_uri') or (isinstance(file_data, dict) and 'file_uri' in file_data):
file_uri = getattr(file_data, 'file_uri', None) or (file_data.get('file_uri') if isinstance(file_data, dict) else None)
if file_uri:
file_name = file_uri.split('/')[-1] if '/' in file_uri else file_uri
sources.append({
"content": "",
"filename": file_name,
"score": 0.0,
"file_uri": file_uri,
})
logger.info(f"πŸ“„ Extracted source from file_data: {file_name}")
logger.info(f"βœ… Total sources extracted: {len(sources)}")
return GeminiFileSearchResult(
answer=answer,
sources=sources,
grounding_metadata=grounding_metadata,
query=query
)
except Exception as e:
# Return error result
return GeminiFileSearchResult(
answer=f"I apologize, but I encountered an error: {str(e)}",
sources=[],
query=query
)
def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
"""
Format Gemini sources to match the format expected by the UI.
Returns list of document-like objects compatible with existing display code.
"""
formatted_sources = []
for i, source in enumerate(result.sources):
filename = source.get("filename", "Unknown")
# Try to extract metadata from filename (e.g., "Kalangala DLG Report of Auditor General 2021.pdf")
year = None
district = None
source_name = "Gemini File Search"
# Parse filename for year
import re
year_match = re.search(r'\b(20\d{2})\b', filename)
if year_match:
year = int(year_match.group(1))
# Parse filename for district/source
if "Kalangala" in filename:
district = "Kalangala"
source_name = "Kalangala DLG"
elif "Gulu" in filename:
district = "Gulu"
source_name = "Gulu DLG"
elif "KCCA" in filename:
district = "Kampala"
source_name = "KCCA"
elif "MAAIF" in filename:
source_name = "MAAIF"
elif "MWTS" in filename:
source_name = "MWTS"
elif "Consolidated" in filename:
source_name = "Consolidated"
# Create a Document object compatible with existing code
doc = Document(
page_content=source.get("content", ""),
metadata={
"filename": filename,
"source": source_name,
"score": source.get("score"),
"chunk_index": i,
"page": None, # Gemini doesn't provide page numbers
"year": year,
"district": district,
"chunk_id": f"gemini_{i}",
"_id": f"gemini_{i}",
}
)
formatted_sources.append(doc)
logger.info(f"πŸ“‹ Formatted source {i+1}: {filename} ({year}, {source_name})")
logger.info(f"βœ… Formatted {len(formatted_sources)} sources for display")
return formatted_sources