File size: 4,996 Bytes
4e93ece 6e43b59 4e93ece 6e43b59 4e93ece 6e43b59 4e93ece 6e43b59 4e93ece 6e43b59 4e93ece 6e43b59 4e93ece | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from typing import Dict, Optional
from schemas import CodeTaskType, CodeXRequest
FOLLOW_UP_PATTERNS = [
"update this",
"modify this",
"change this",
"improve this",
"refactor this",
"fix this",
"explain this",
"review this",
"add validation",
"input validation",
"add input validation",
"add error handling",
"add comments",
"add logging",
"add docstring",
"add type hints",
"optimize this",
"continue this",
"continue from above",
"same code",
"this code",
"that code",
"the above code",
"use the previous code",
"based on the previous code",
"now add",
"now update",
"now change",
"now remove",
"now use",
]
class SessionArtifact:
def __init__(
self,
code: Optional[str] = None,
task_type: Optional[CodeTaskType] = None,
language: Optional[str] = None,
framework: Optional[str] = None,
last_user_goal: Optional[str] = None,
):
self.code = code
self.task_type = task_type
self.language = language
self.framework = framework
self.last_user_goal = last_user_goal
class ContextManager:
def __init__(self):
self._store: Dict[str, SessionArtifact] = {}
def normalize_session_id(self, session_id: Optional[str]) -> str:
if session_id and str(session_id).strip():
return str(session_id).strip()
return "default"
def get_artifact(self, session_id: Optional[str]) -> Optional[SessionArtifact]:
normalized_session_id = self.normalize_session_id(session_id)
return self._store.get(normalized_session_id)
def save_artifact(
self,
session_id: Optional[str],
code: Optional[str],
task_type: Optional[CodeTaskType] = None,
language: Optional[str] = None,
framework: Optional[str] = None,
last_user_goal: Optional[str] = None,
) -> None:
if not code or not str(code).strip():
return
normalized_session_id = self.normalize_session_id(session_id)
self._store[normalized_session_id] = SessionArtifact(
code=code.strip(),
task_type=task_type,
language=language.strip() if language else None,
framework=framework.strip() if framework else None,
last_user_goal=last_user_goal.strip() if last_user_goal else None,
)
def has_artifact(self, session_id: Optional[str]) -> bool:
artifact = self.get_artifact(session_id)
return bool(artifact and artifact.code and artifact.code.strip())
def is_follow_up_request(self, request: CodeXRequest) -> bool:
if request.code and request.code.strip():
return False
message = (request.message or "").strip().lower()
if not message:
return False
if any(pattern in message for pattern in FOLLOW_UP_PATTERNS):
return True
loose_follow_up_signals = [
"add ",
"update ",
"change ",
"modify ",
"remove ",
"replace ",
"rename ",
"use ",
"include ",
]
has_session = bool(request.session_id and str(request.session_id).strip())
if has_session and any(signal in message for signal in loose_follow_up_signals):
return True
return False
def enrich_request_with_context(
self,
request: CodeXRequest,
session_id: Optional[str],
) -> CodeXRequest:
if request.code and request.code.strip():
return request
artifact = self.get_artifact(session_id)
if not artifact or not artifact.code:
return request
if not self.is_follow_up_request(request):
return request
enriched_language = request.language or artifact.language
enriched_framework = request.framework or artifact.framework
previous_context_parts = []
if artifact.last_user_goal:
previous_context_parts.append(f"Previous User Goal:\n{artifact.last_user_goal}")
if request.previous_context and request.previous_context.strip():
previous_context_parts.append(request.previous_context.strip())
merged_previous_context = (
"\n\n".join(previous_context_parts)
if previous_context_parts
else request.previous_context
)
return CodeXRequest(
message=request.message,
session_id=request.session_id,
mode=request.mode,
language=enriched_language,
code=artifact.code,
error_message=request.error_message,
framework=enriched_framework,
file_name=request.file_name,
previous_context=merged_previous_context,
use_retrieval=request.use_retrieval,
)
context_manager = ContextManager() |