| from typing import Dict, Optional |
|
|
| from schemas import CodeTaskType, CodeXRequest |
|
|
|
|
| FOLLOW_UP_PATTERNS = [ |
| "update this", |
| "modify this", |
| "change this", |
| "improve this", |
| "refactor this", |
| "fix this", |
| "explain this", |
| "review this", |
| "add validation", |
| "input validation", |
| "add input validation", |
| "add error handling", |
| "add comments", |
| "add logging", |
| "add docstring", |
| "add type hints", |
| "optimize this", |
| "continue this", |
| "continue from above", |
| "same code", |
| "this code", |
| "that code", |
| "the above code", |
| "use the previous code", |
| "based on the previous code", |
| "now add", |
| "now update", |
| "now change", |
| "now remove", |
| "now use", |
| ] |
|
|
|
|
| class SessionArtifact: |
| def __init__( |
| self, |
| code: Optional[str] = None, |
| task_type: Optional[CodeTaskType] = None, |
| language: Optional[str] = None, |
| framework: Optional[str] = None, |
| last_user_goal: Optional[str] = None, |
| ): |
| self.code = code |
| self.task_type = task_type |
| self.language = language |
| self.framework = framework |
| self.last_user_goal = last_user_goal |
|
|
|
|
| class ContextManager: |
| def __init__(self): |
| self._store: Dict[str, SessionArtifact] = {} |
|
|
| def normalize_session_id(self, session_id: Optional[str]) -> str: |
| if session_id and str(session_id).strip(): |
| return str(session_id).strip() |
| return "default" |
|
|
| def get_artifact(self, session_id: Optional[str]) -> Optional[SessionArtifact]: |
| normalized_session_id = self.normalize_session_id(session_id) |
| return self._store.get(normalized_session_id) |
|
|
| def save_artifact( |
| self, |
| session_id: Optional[str], |
| code: Optional[str], |
| task_type: Optional[CodeTaskType] = None, |
| language: Optional[str] = None, |
| framework: Optional[str] = None, |
| last_user_goal: Optional[str] = None, |
| ) -> None: |
| if not code or not str(code).strip(): |
| return |
|
|
| normalized_session_id = self.normalize_session_id(session_id) |
| self._store[normalized_session_id] = SessionArtifact( |
| code=code.strip(), |
| task_type=task_type, |
| language=language.strip() if language else None, |
| framework=framework.strip() if framework else None, |
| last_user_goal=last_user_goal.strip() if last_user_goal else None, |
| ) |
|
|
| def has_artifact(self, session_id: Optional[str]) -> bool: |
| artifact = self.get_artifact(session_id) |
| return bool(artifact and artifact.code and artifact.code.strip()) |
|
|
| def is_follow_up_request(self, request: CodeXRequest) -> bool: |
| if request.code and request.code.strip(): |
| return False |
|
|
| message = (request.message or "").strip().lower() |
| if not message: |
| return False |
|
|
| if any(pattern in message for pattern in FOLLOW_UP_PATTERNS): |
| return True |
|
|
| loose_follow_up_signals = [ |
| "add ", |
| "update ", |
| "change ", |
| "modify ", |
| "remove ", |
| "replace ", |
| "rename ", |
| "use ", |
| "include ", |
| ] |
|
|
| has_session = bool(request.session_id and str(request.session_id).strip()) |
| if has_session and any(signal in message for signal in loose_follow_up_signals): |
| return True |
|
|
| return False |
|
|
| def enrich_request_with_context( |
| self, |
| request: CodeXRequest, |
| session_id: Optional[str], |
| ) -> CodeXRequest: |
| if request.code and request.code.strip(): |
| return request |
|
|
| artifact = self.get_artifact(session_id) |
| if not artifact or not artifact.code: |
| return request |
|
|
| if not self.is_follow_up_request(request): |
| return request |
|
|
| enriched_language = request.language or artifact.language |
| enriched_framework = request.framework or artifact.framework |
|
|
| previous_context_parts = [] |
| if artifact.last_user_goal: |
| previous_context_parts.append(f"Previous User Goal:\n{artifact.last_user_goal}") |
| if request.previous_context and request.previous_context.strip(): |
| previous_context_parts.append(request.previous_context.strip()) |
|
|
| merged_previous_context = ( |
| "\n\n".join(previous_context_parts) |
| if previous_context_parts |
| else request.previous_context |
| ) |
|
|
| return CodeXRequest( |
| message=request.message, |
| session_id=request.session_id, |
| mode=request.mode, |
| language=enriched_language, |
| code=artifact.code, |
| error_message=request.error_message, |
| framework=enriched_framework, |
| file_name=request.file_name, |
| previous_context=merged_previous_context, |
| use_retrieval=request.use_retrieval, |
| ) |
|
|
|
|
| context_manager = ContextManager() |