mtyrrell's picture
refactor
f5bde8f
import configparser
import logging
import os
import ast
import json
from dotenv import load_dotenv
from typing import Optional, Dict, Any, List
from models import GraphState
load_dotenv()
logger = logging.getLogger(__name__)
def getconfig(configfile_path: str):
"""Read the config file"""
config = configparser.ConfigParser()
try:
config.read_file(open(configfile_path))
return config
except:
logging.warning("config file not found")
return None
def get_auth(provider: str) -> dict:
"""Get authentication configuration for different providers"""
auth_configs = {
"huggingface": {"api_key": os.getenv("HF_TOKEN")},
"qdrant": {"api_key": os.getenv("QDRANT_API_KEY")},
}
provider = provider.lower()
if provider not in auth_configs:
raise ValueError(f"Unsupported provider: {provider}")
auth_config = auth_configs[provider]
api_key = auth_config.get("api_key")
if not api_key:
logging.warning(f"No API key found for provider '{provider}'")
auth_config["api_key"] = None
return auth_config
def detect_file_type(filename: str, file_content: bytes = None) -> str:
"""Detect file type based on extension and content"""
if not filename:
return "unknown"
_, ext = os.path.splitext(filename.lower())
file_type_mappings = {
'.geojson': 'geojson',
'.json': 'json',
'.pdf': 'text',
'.docx': 'text',
'.doc': 'text',
'.txt': 'text',
'.md': 'text',
'.csv': 'text',
'.xlsx': 'text',
'.xls': 'text'
}
detected_type = file_type_mappings.get(ext, 'unknown')
# For JSON files, check if it's actually GeoJSON
if detected_type == 'json' and file_content:
try:
content_str = file_content.decode('utf-8')
data = json.loads(content_str)
if isinstance(data, dict) and data.get('type') == 'FeatureCollection':
detected_type = 'geojson'
elif isinstance(data, dict) and data.get('type') in [
'Feature', 'Point', 'LineString', 'Polygon',
'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection'
]:
detected_type = 'geojson'
except:
pass
logger.info(f"Detected file type: {detected_type} for file: {filename}")
return detected_type
def convert_context_to_list(context: str) -> List[Dict[str, Any]]:
"""Convert string context to list format expected by generator"""
try:
if context.startswith('['):
return ast.literal_eval(context)
else:
return [{
"answer": context,
"answer_metadata": {
"filename": "Retrieved Context",
"page": "Unknown",
"year": "Unknown",
"source": "Retriever"
}
}]
except:
return [{
"answer": context,
"answer_metadata": {
"filename": "Retrieved Context",
"page": "Unknown",
"year": "Unknown",
"source": "Retriever"
}
}]
def merge_state(base_state: GraphState, updates: dict) -> GraphState:
"""Helper to merge node updates into base state"""
return {**base_state, **updates}