Spaces:
Running
Running
| """ | |
| Context Service for managing context documents in trace metadata. | |
| """ | |
| import json | |
| import uuid | |
| from datetime import datetime | |
| from typing import List, Dict, Any, Optional | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy.orm.attributes import flag_modified | |
| from backend.database.models import Trace | |
| from backend.models import ( | |
| ContextDocument, | |
| CreateContextRequest, | |
| UpdateContextRequest, | |
| ContextDocumentType | |
| ) | |
| class ContextService: | |
| """Service for managing context documents stored in trace metadata.""" | |
| def __init__(self, db: Session): | |
| self.db = db | |
| def create_context_document( | |
| self, | |
| trace_id: str, | |
| title: str, | |
| document_type: ContextDocumentType, | |
| content: str, | |
| file_name: Optional[str] = None | |
| ) -> ContextDocument: | |
| """Create a new context document for a trace.""" | |
| # Validate document type | |
| self.validate_document_type(document_type) | |
| # Get current context documents | |
| current_docs = self.get_context_documents(trace_id) | |
| # Check limits | |
| if len(current_docs) >= 20: | |
| raise ValueError("Maximum of 20 context documents per trace allowed") | |
| # Check for duplicate titles | |
| if any(doc.title == title for doc in current_docs): | |
| raise ValueError(f"Context document with title '{title}' already exists") | |
| # Create new document | |
| new_doc = ContextDocument( | |
| id=self._generate_context_id(), | |
| title=title, | |
| document_type=document_type, | |
| content=content, | |
| file_name=file_name, | |
| created_at=datetime.utcnow(), | |
| is_active=True | |
| ) | |
| # Add to existing documents | |
| current_docs.append(new_doc) | |
| # Update trace metadata | |
| self._update_trace_metadata(trace_id, current_docs) | |
| return new_doc | |
| def get_context_documents(self, trace_id: str) -> List[ContextDocument]: | |
| """Get all context documents for a trace.""" | |
| trace = self.db.query(Trace).filter(Trace.trace_id == trace_id).first() | |
| if not trace: | |
| raise ValueError(f"Trace {trace_id} not found") | |
| if not trace.trace_metadata or "context_documents" not in trace.trace_metadata: | |
| return [] | |
| # Convert dict data back to ContextDocument objects | |
| docs_data = trace.trace_metadata["context_documents"] | |
| return [ContextDocument.model_validate(doc_data) for doc_data in docs_data] | |
| def update_context_document( | |
| self, | |
| trace_id: str, | |
| context_id: str, | |
| updates: UpdateContextRequest | |
| ) -> ContextDocument: | |
| """Update an existing context document.""" | |
| current_docs = self.get_context_documents(trace_id) | |
| # Find document to update | |
| doc_index = None | |
| for i, doc in enumerate(current_docs): | |
| if doc.id == context_id: | |
| doc_index = i | |
| break | |
| if doc_index is None: | |
| raise ValueError(f"Context document {context_id} not found") | |
| # Update document | |
| doc = current_docs[doc_index] | |
| update_data = updates.dict(exclude_unset=True) | |
| for field, value in update_data.items(): | |
| setattr(doc, field, value) | |
| # Check for duplicate titles (excluding current doc) | |
| other_docs = current_docs[:doc_index] + current_docs[doc_index+1:] | |
| if updates.title and any(other_doc.title == updates.title for other_doc in other_docs): | |
| raise ValueError(f"Context document with title '{updates.title}' already exists") | |
| # Update trace metadata | |
| self._update_trace_metadata(trace_id, current_docs) | |
| return doc | |
| def delete_context_document(self, trace_id: str, context_id: str) -> bool: | |
| """Delete a context document.""" | |
| current_docs = self.get_context_documents(trace_id) | |
| # Find and remove document | |
| updated_docs = [doc for doc in current_docs if doc.id != context_id] | |
| if len(updated_docs) == len(current_docs): | |
| raise ValueError(f"Context document {context_id} not found") | |
| # Update trace metadata | |
| self._update_trace_metadata(trace_id, updated_docs) | |
| return True | |
| def validate_document_type(self, document_type: ContextDocumentType) -> bool: | |
| """Validate document type enum.""" | |
| valid_types = [item.value for item in ContextDocumentType] | |
| if document_type not in valid_types: | |
| raise ValueError(f"Invalid document type. Must be one of: {valid_types}") | |
| return True | |
| def process_uploaded_file( | |
| self, | |
| file_content: str, | |
| trace_id: str, | |
| title: str, | |
| document_type: ContextDocumentType, | |
| file_name: str | |
| ) -> ContextDocument: | |
| """Process an uploaded file as a context document.""" | |
| # Validate content length | |
| if len(file_content) > 100000: | |
| raise ValueError("File content exceeds maximum length of 100,000 characters") | |
| return self.create_context_document( | |
| trace_id=trace_id, | |
| title=title, | |
| document_type=document_type, | |
| content=file_content, | |
| file_name=file_name | |
| ) | |
| def _update_trace_metadata(self, trace_id: str, context_documents: List[ContextDocument]) -> None: | |
| """Update trace metadata with context documents.""" | |
| trace = self.db.query(Trace).filter(Trace.trace_id == trace_id).first() | |
| if not trace: | |
| raise ValueError(f"Trace {trace_id} not found") | |
| # Ensure trace_metadata exists | |
| if not trace.trace_metadata: | |
| trace.trace_metadata = {} | |
| # Convert ContextDocument objects to dict for JSON storage | |
| # Use mode='json' to ensure datetime objects are serialized as strings | |
| docs_data = [doc.model_dump(mode='json') for doc in context_documents] | |
| trace.trace_metadata["context_documents"] = docs_data | |
| # Mark as modified for SQLAlchemy - use flag_modified for JSON fields | |
| flag_modified(trace, "trace_metadata") | |
| self.db.commit() | |
| def _generate_context_id(self) -> str: | |
| """Generate a unique ID for context documents.""" | |
| return str(uuid.uuid4()) |