mtyrrell's picture
ts file upload sources
201e72b
from utils import detect_file_type, convert_context_to_list, merge_state, getconfig
from models import GraphState
from datetime import datetime
import tempfile
import os
from gradio_client import Client, file
import logging
import dotenv
import httpx
import json
from typing import Generator, Optional
dotenv.load_dotenv()
logger = logging.getLogger(__name__)
# Load config once at module level
config = getconfig("params.cfg")
RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
MAX_CONTEXT_CHARS = int(config.get("general", "MAX_CONTEXT_CHARS"))
#----------------------------------------
# LANGGRAPH NODE FUNCTIONS
#----------------------------------------
def detect_file_type_node(state: GraphState) -> GraphState:
"""Detect file type and determine workflow"""
file_type = "unknown"
workflow_type = "standard"
if state.get("file_content") and state.get("filename"):
file_type = detect_file_type(state["filename"], state["file_content"])
workflow_type = "geojson_direct" if file_type == "geojson" else "standard"
metadata = state.get("metadata", {})
metadata.update({
"file_type": file_type,
"workflow_type": workflow_type
})
return {
"file_type": file_type,
"workflow_type": workflow_type,
"metadata": metadata
}
def ingest_node(state: GraphState) -> GraphState:
"""Process file through appropriate ingestor based on file type"""
start_time = datetime.now()
if not state.get("file_content") or not state.get("filename"):
logger.info("No file provided, skipping ingestion")
return {"ingestor_context": "", "metadata": state.get("metadata", {})}
file_type = state.get("file_type", "unknown")
logger.info(f"Ingesting {file_type} file: {state['filename']}")
try:
# Choose ingestor based on file type
ingestor_url = GEOJSON_INGESTOR if file_type == "geojson" else INGESTOR
logger.info(f"Using ingestor: {ingestor_url}")
client = Client(ingestor_url, hf_token=os.getenv("HF_TOKEN"))
# Create temporary file for upload
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
tmp_file.write(state["file_content"])
tmp_file_path = tmp_file.name
try:
ingestor_context = client.predict(file(tmp_file_path), api_name="/ingest")
logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
if isinstance(ingestor_context, str) and ingestor_context.startswith("Error:"):
raise Exception(ingestor_context)
finally:
os.unlink(tmp_file_path)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
"ingestion_success": True,
"ingestor_used": ingestor_url
})
return {"ingestor_context": ingestor_context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Ingestion failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"ingestion_duration": duration,
"ingestion_success": False,
"ingestion_error": str(e)
})
return {"ingestor_context": "", "metadata": metadata}
def geojson_direct_result_node(state: GraphState) -> GraphState:
"""For GeoJSON files, return ingestor results directly"""
logger.info("Processing GeoJSON file - returning direct results")
ingestor_context = state.get("ingestor_context", "")
result = ingestor_context if ingestor_context else "No results from GeoJSON processing."
metadata = state.get("metadata", {})
metadata.update({
"processing_type": "geojson_direct",
"result_length": len(result)
})
return {"result": result, "metadata": metadata}
def retrieve_node(state: GraphState) -> GraphState:
"""Retrieve relevant context from vector store"""
start_time = datetime.now()
logger.info(f"Retrieval: {state['query'][:50]}...")
try:
client = Client(RETRIEVER, hf_token=os.getenv("HF_TOKEN"))
context = client.predict(
query=state["query"],
reports_filter=state.get("reports_filter", ""),
sources_filter=state.get("sources_filter", ""),
subtype_filter=state.get("subtype_filter", ""),
year_filter=state.get("year_filter", ""),
api_name="/retrieve"
)
duration = (datetime.now() - start_time).total_seconds()
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"context_length": len(context) if context else 0,
"retrieval_success": True
})
return {"context": context, "metadata": metadata}
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Retrieval failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"retrieval_duration": duration,
"retrieval_success": False,
"retrieval_error": str(e)
})
return {"context": "", "metadata": metadata}
async def generate_node_streaming(state: GraphState) -> Generator[GraphState, None, None]:
"""Streaming generation using generator's FastAPI endpoint"""
start_time = datetime.now()
logger.info(f"Generation (streaming): {state['query'][:50]}...")
try:
# Combine contexts
retrieved_context = state.get("context", "")
ingestor_context = state.get("ingestor_context", "")
logger.info(f"Context lengths - Ingestor: {len(ingestor_context)}, Retrieved: {len(retrieved_context)}")
# Build context list with truncation
context_list = []
total_context_chars = 0
if ingestor_context:
truncated_ingestor = (
ingestor_context[:MAX_CONTEXT_CHARS] + "...\n[Content truncated due to length]"
if len(ingestor_context) > MAX_CONTEXT_CHARS
else ingestor_context
)
context_list.append({
"answer": truncated_ingestor,
"answer_metadata": {
"filename": state.get("filename", "Uploaded Document"),
"page": "Unknown",
"year": "Unknown",
"source": "Ingestor"
}
})
total_context_chars += len(truncated_ingestor)
if retrieved_context and total_context_chars < MAX_CONTEXT_CHARS:
retrieved_list = convert_context_to_list(retrieved_context)
remaining_chars = MAX_CONTEXT_CHARS - total_context_chars
for item in retrieved_list:
item_text = item.get("answer", "")
if len(item_text) <= remaining_chars:
context_list.append(item)
remaining_chars -= len(item_text)
else:
if remaining_chars > 100:
item["answer"] = item_text[:remaining_chars-50] + "...\n[Content truncated]"
context_list.append(item)
break
final_context_size = sum(len(item.get("answer", "")) for item in context_list)
logger.info(f"Final context size: {final_context_size} characters (limit: {MAX_CONTEXT_CHARS})")
payload = {"query": state["query"], "context": context_list}
# ===== COMPREHENSIVE LOGGING OF PAYLOAD BEING SENT TO GENERATOR =====
logger.info("=" * 80)
logger.info("PAYLOAD BEING SENT TO GENERATOR")
logger.info("=" * 80)
logger.info(f"Query: {state['query']}")
logger.info(f"Number of context items: {len(context_list)}")
logger.info(f"Total context size: {final_context_size} characters")
# Log each context item in detail
for i, context_item in enumerate(context_list):
logger.info(f"Context Item {i+1}:")
logger.info(f" - Source: {context_item.get('answer_metadata', {}).get('source', 'Unknown')}")
logger.info(f" - Filename: {context_item.get('answer_metadata', {}).get('filename', 'Unknown')}")
logger.info(f" - Page: {context_item.get('answer_metadata', {}).get('page', 'Unknown')}")
logger.info(f" - Year: {context_item.get('answer_metadata', {}).get('year', 'Unknown')}")
answer_text = context_item.get('answer', '')
logger.info(f" - Content length: {len(answer_text)} characters")
logger.info(f" - Content preview: {answer_text[:200]}{'...' if len(answer_text) > 200 else ''}")
logger.info(" " + "-" * 50)
# Log the complete payload structure
logger.info("Complete payload structure:")
logger.info(f" - Query: {payload['query']}")
logger.info(f" - Context items: {len(payload['context'])}")
for i, item in enumerate(payload['context']):
logger.info(f" Item {i+1}: {len(item.get('answer', ''))} chars from {item.get('answer_metadata', {}).get('source', 'Unknown')}")
# Normalize generator URL
generator_url = GENERATOR
if not generator_url.startswith('http'):
space_name = generator_url.replace('/', '-').replace('_', '-')
generator_url = f"https://{space_name}.hf.space"
logger.info(f"Sending request to generator URL: {generator_url}/generate/stream")
logger.info("=" * 80)
# Stream from generator
async with httpx.AsyncClient(timeout=300.0, verify=False) as client:
async with client.stream(
"POST",
f"{generator_url}/generate/stream",
json=payload,
headers={"Content-Type": "application/json"}
) as response:
if response.status_code != 200:
raise Exception(f"Generator returned status {response.status_code}")
current_text = ""
sources = None
event_type = None
async for line in response.aiter_lines():
if not line.strip():
continue
if line.startswith("event: "):
event_type = line[7:].strip()
continue
elif line.startswith("data: "):
data_content = line[6:].strip()
if event_type == "data":
try:
chunk = json.loads(data_content)
if isinstance(chunk, str):
current_text += chunk
except json.JSONDecodeError:
current_text += data_content
chunk = data_content
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": (datetime.now() - start_time).total_seconds(),
"result_length": len(current_text),
"generation_success": True,
"streaming": True,
"context_chars_used": final_context_size
})
yield {"result": chunk, "metadata": metadata}
elif event_type == "sources":
try:
sources_data = json.loads(data_content)
sources = sources_data.get("sources", [])
metadata = state.get("metadata", {})
metadata.update({
"sources_received": True,
"sources_count": len(sources)
})
yield {"sources": sources, "metadata": metadata}
except json.JSONDecodeError:
logger.warning(f"Failed to parse sources: {data_content}")
elif event_type == "end":
logger.info("Generator stream ended")
break
elif event_type == "error":
try:
error_data = json.loads(data_content)
raise Exception(error_data.get("error", "Unknown error"))
except json.JSONDecodeError:
raise Exception(data_content)
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
logger.error(f"Streaming generation failed: {str(e)}")
metadata = state.get("metadata", {})
metadata.update({
"generation_duration": duration,
"generation_success": False,
"generation_error": str(e),
"streaming": True
})
yield {"result": f"Error: {str(e)}", "metadata": metadata}
def route_workflow(state: GraphState) -> str:
"""Conditional routing based on workflow type"""
return state.get("workflow_type", "standard")
#----------------------------------------
# UNIFIED STREAMING PROCESSOR
#----------------------------------------
async def process_query_streaming(
query: str,
file_upload=None,
file_content: Optional[bytes] = None,
filename: Optional[str] = None,
reports_filter: str = "",
sources_filter: str = "",
subtype_filter: str = "",
year_filter: str = "",
output_format: str = "structured"
):
"""
Unified streaming function supporting both file objects and raw content.
Args:
query: User query string
file_upload: File object (optional)
file_content: Raw file bytes (optional, alternative to file_upload)
filename: Filename for raw content (required if file_content provided)
output_format: "structured" returns dicts, "gradio" returns accumulated text
"""
# Handle file_upload if provided
if file_upload is not None:
try:
with open(file_upload.name, 'rb') as f:
file_content = f.read()
filename = os.path.basename(file_upload.name)
logger.info(f"File uploaded: {filename}, size: {len(file_content)} bytes")
except Exception as e:
logger.error(f"Error reading uploaded file: {str(e)}")
if output_format == "structured":
yield {"type": "error", "content": f"Error reading file: {str(e)}"}
else:
yield f"Error reading file: {str(e)}"
return
start_time = datetime.now()
session_id = f"stream_{start_time.strftime('%Y%m%d_%H%M%S')}"
try:
# Build initial state
initial_state = {
"query": query,
"context": "",
"ingestor_context": "",
"result": "",
"sources": [],
"reports_filter": reports_filter or "",
"sources_filter": sources_filter or "",
"subtype_filter": subtype_filter or "",
"year_filter": year_filter or "",
"file_content": file_content,
"filename": filename,
"file_type": "unknown",
"workflow_type": "standard",
"metadata": {
"session_id": session_id,
"start_time": start_time.isoformat(),
"has_file_attachment": file_content is not None
}
}
# Execute workflow nodes
state = merge_state(initial_state, detect_file_type_node(initial_state))
state = merge_state(state, ingest_node(state))
workflow_type = route_workflow(state)
if workflow_type == "geojson_direct":
final_state = geojson_direct_result_node(state)
if output_format == "structured":
yield {"type": "data", "content": final_state["result"]}
yield {"type": "end", "content": ""}
else:
yield final_state["result"]
else:
state = merge_state(state, retrieve_node(state))
sources_collected = None
accumulated_response = "" if output_format == "gradio" else None
async for partial_state in generate_node_streaming(state):
if "result" in partial_state:
if output_format == "structured":
yield {"type": "data", "content": partial_state["result"]}
else:
accumulated_response += partial_state["result"]
yield accumulated_response
if "sources" in partial_state:
sources_collected = partial_state["sources"]
# Format and yield sources
if sources_collected:
if output_format == "structured":
yield {"type": "sources", "content": sources_collected}
else:
sources_text = "\n\n**Sources:**\n"
for i, source in enumerate(sources_collected, 1):
if isinstance(source, dict):
title = source.get('title', 'Unknown')
link = source.get('link', '#')
sources_text += f"{i}. [{title}]({link})\n"
else:
sources_text += f"{i}. {source}\n"
accumulated_response += sources_text
yield accumulated_response
if output_format == "structured":
yield {"type": "end", "content": ""}
except Exception as e:
logger.error(f"Streaming pipeline failed: {str(e)}")
if output_format == "structured":
yield {"type": "error", "content": f"Error: {str(e)}"}
else:
yield f"Error: {str(e)}"