audit_assistant / src /gemini /file_search.py
Ara Yeroyan
refactor + add gemini
72eb0bf
raw
history blame
11.5 kB
"""
Gemini File Search Client
Handles interaction with Google Gemini File Search API for RAG.
"""
import os
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
logger = logging.getLogger(__name__)
try:
from google import genai
from google.genai import types
GEMINI_AVAILABLE = True
except ImportError:
GEMINI_AVAILABLE = False
@dataclass
class GeminiFileSearchResult:
"""Result from Gemini File Search query"""
answer: str
sources: List[Dict[str, Any]] # List of document references
grounding_metadata: Optional[Dict[str, Any]] = None
query: str = ""
class GeminiFileSearchClient:
"""Client for interacting with Gemini File Search API"""
def __init__(self, api_key: Optional[str] = None, store_name: Optional[str] = None):
"""
Initialize Gemini File Search client.
Args:
api_key: Gemini API key (defaults to GEMINI_API_KEY env var)
store_name: File search store name (defaults to GEMINI_FILESTORE_NAME env var)
"""
if not GEMINI_AVAILABLE:
raise ImportError("google-genai package not installed. Install with: pip install google-genai")
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
if not self.api_key:
raise ValueError("GEMINI_API_KEY not found. Set it in .env file or pass as argument.")
self.store_name = store_name or os.getenv("GEMINI_FILESTORE_NAME")
if not self.store_name:
raise ValueError("GEMINI_FILESTORE_NAME not found. Set it in .env file or pass as argument.")
self.client = genai.Client(api_key=self.api_key)
self.model = "gemini-2.5-flash" # or "gemini-2.5-pro"
def search(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
model: Optional[str] = None
) -> GeminiFileSearchResult:
"""
Search using Gemini File Search.
Args:
query: User query
filters: Optional filters (year, source, district, etc.)
model: Model to use (defaults to gemini-2.5-flash)
Returns:
GeminiFileSearchResult with answer and sources
"""
model = model or self.model
# Build filter context for the query if filters are provided
# Gemini File Search doesn't support explicit filters in the API,
# so we add them as context in the query
filter_context = ""
if filters:
filter_parts = []
if filters.get("year"):
years = filters["year"] if isinstance(filters["year"], list) else [filters["year"]]
filter_parts.append(f"Year: {', '.join(years)}")
if filters.get("sources"):
sources = filters["sources"] if isinstance(filters["sources"], list) else [filters["sources"]]
filter_parts.append(f"Source: {', '.join(sources)}")
if filters.get("district"):
districts = filters["district"] if isinstance(filters["district"], list) else [filters["district"]]
filter_parts.append(f"District: {', '.join(districts)}")
if filters.get("filenames"):
filenames = filters["filenames"] if isinstance(filters["filenames"], list) else [filters["filenames"]]
filter_parts.append(f"Filename: {', '.join(filenames)}")
if filter_parts:
filter_context = f"\n\nPlease focus on documents matching these criteria: {', '.join(filter_parts)}"
# Combine query with filter context
# Add explicit instruction to only use information from retrieved documents
instruction = "\n\nIMPORTANT: Only use information from the retrieved documents. Do not use information from your training data unless it's explicitly mentioned in the retrieved documents. If the retrieved documents don't contain the requested information, clearly state that.\n\n"
full_query = query + filter_context + instruction
try:
# Generate content with file search
# Based on Gemini API docs: https://ai.google.dev/gemini-api/docs/file-search
try:
# Try the documented format first
response = self.client.models.generate_content(
model=model,
contents=full_query,
config=types.GenerateContentConfig(
tools=[
types.Tool(
file_search=types.FileSearch(
file_search_store_names=[self.store_name]
)
)
]
)
)
except (AttributeError, TypeError) as e:
# Fallback: try alternative format
logger.warning(f"Primary API format failed, trying alternative: {e}")
try:
response = self.client.models.generate_content(
model=model,
contents=full_query,
tools=[{
"file_search": {
"file_search_store_names": [self.store_name]
}
}]
)
except Exception as e2:
raise Exception(f"Failed to call Gemini API: {e2}")
# Extract answer
answer = ""
if hasattr(response, 'text'):
answer = response.text
elif hasattr(response, 'candidates') and response.candidates:
# Try to get text from first candidate
candidate = response.candidates[0]
if hasattr(candidate, 'content') and candidate.content:
if hasattr(candidate.content, 'parts'):
text_parts = []
for part in candidate.content.parts:
if hasattr(part, 'text'):
text_parts.append(part.text)
answer = " ".join(text_parts)
elif isinstance(candidate.content, str):
answer = candidate.content
else:
answer = str(response)
# Extract grounding metadata (document references)
sources = []
grounding_metadata = None
if hasattr(response, 'candidates') and response.candidates:
candidate = response.candidates[0]
# Get grounding metadata
if hasattr(candidate, 'grounding_metadata'):
grounding_metadata = candidate.grounding_metadata
# Extract source documents from grounding metadata
# Handle different response formats
grounding_chunks = None
if hasattr(grounding_metadata, 'grounding_chunks'):
grounding_chunks = grounding_metadata.grounding_chunks
elif isinstance(grounding_metadata, dict) and 'grounding_chunks' in grounding_metadata:
grounding_chunks = grounding_metadata['grounding_chunks']
if grounding_chunks:
for chunk in grounding_chunks:
# Handle both object and dict formats
try:
if isinstance(chunk, dict):
chunk_data = chunk
else:
# Object format - convert to dict-like access
chunk_data = {}
if hasattr(chunk, 'chunk'):
chunk_obj = chunk.chunk
chunk_data['chunk'] = {
'text': getattr(chunk_obj, 'text', ''),
'file_name': getattr(chunk_obj, 'file_name', '')
}
if hasattr(chunk, 'relevance_score'):
score_obj = chunk.relevance_score
chunk_data['relevance_score'] = {
'score': getattr(score_obj, 'score', 0.0)
}
chunk_info = chunk_data.get('chunk', {})
text = chunk_info.get('text', '') if isinstance(chunk_info, dict) else ''
file_name = chunk_info.get('file_name', '') if isinstance(chunk_info, dict) else ''
score_data = chunk_data.get('relevance_score', {})
score = score_data.get('score', 0.0) if isinstance(score_data, dict) else 0.0
if text or file_name: # Only add if we have content
source_info = {
"content": text,
"filename": file_name,
"score": score,
}
sources.append(source_info)
except Exception as e:
logger.warning(f"Error extracting chunk info: {e}")
continue
return GeminiFileSearchResult(
answer=answer,
sources=sources,
grounding_metadata=grounding_metadata,
query=query
)
except Exception as e:
# Return error result
return GeminiFileSearchResult(
answer=f"I apologize, but I encountered an error: {str(e)}",
sources=[],
query=query
)
def format_sources_for_display(self, result: GeminiFileSearchResult) -> List[Any]:
"""
Format Gemini sources to match the format expected by the UI.
Returns list of document-like objects compatible with existing display code.
"""
from langchain.docstore.document import Document
formatted_sources = []
for i, source in enumerate(result.sources):
# Create a Document object compatible with existing code
doc = Document(
page_content=source.get("content", ""),
metadata={
"filename": source.get("filename", "Unknown"),
"source": "Gemini File Search",
"score": source.get("score"),
"chunk_index": i,
# Add default fields that might be expected
"page": None,
"year": None,
"district": None,
}
)
formatted_sources.append(doc)
return formatted_sources