hashan-7 commited on
Commit
4e93ece
·
verified ·
1 Parent(s): 73a75a4

add the code

Browse files
Files changed (1) hide show
  1. context_manager.py +133 -0
context_manager.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+
3
+ from schemas import CodeTaskType, CodeXRequest
4
+
5
+
6
+ FOLLOW_UP_PATTERNS = [
7
+ "update this",
8
+ "modify this",
9
+ "change this",
10
+ "improve this",
11
+ "refactor this",
12
+ "fix this",
13
+ "explain this",
14
+ "review this",
15
+ "add validation",
16
+ "add error handling",
17
+ "add comments",
18
+ "optimize this",
19
+ "continue this",
20
+ "continue from above",
21
+ "same code",
22
+ "this code",
23
+ "that code",
24
+ "the above code",
25
+ "use the previous code",
26
+ "based on the previous code",
27
+ ]
28
+
29
+
30
+ class SessionArtifact:
31
+ def __init__(
32
+ self,
33
+ code: Optional[str] = None,
34
+ task_type: Optional[CodeTaskType] = None,
35
+ language: Optional[str] = None,
36
+ framework: Optional[str] = None,
37
+ last_user_goal: Optional[str] = None,
38
+ ):
39
+ self.code = code
40
+ self.task_type = task_type
41
+ self.language = language
42
+ self.framework = framework
43
+ self.last_user_goal = last_user_goal
44
+
45
+
46
+ class ContextManager:
47
+ def __init__(self):
48
+ self._store: Dict[str, SessionArtifact] = {}
49
+
50
+ def normalize_session_id(self, session_id: Optional[str]) -> str:
51
+ if session_id and str(session_id).strip():
52
+ return str(session_id).strip()
53
+ return "default"
54
+
55
+ def get_artifact(self, session_id: Optional[str]) -> Optional[SessionArtifact]:
56
+ normalized_session_id = self.normalize_session_id(session_id)
57
+ return self._store.get(normalized_session_id)
58
+
59
+ def save_artifact(
60
+ self,
61
+ session_id: Optional[str],
62
+ code: Optional[str],
63
+ task_type: Optional[CodeTaskType] = None,
64
+ language: Optional[str] = None,
65
+ framework: Optional[str] = None,
66
+ last_user_goal: Optional[str] = None,
67
+ ) -> None:
68
+ if not code or not str(code).strip():
69
+ return
70
+
71
+ normalized_session_id = self.normalize_session_id(session_id)
72
+ self._store[normalized_session_id] = SessionArtifact(
73
+ code=code.strip(),
74
+ task_type=task_type,
75
+ language=language.strip() if language else None,
76
+ framework=framework.strip() if framework else None,
77
+ last_user_goal=last_user_goal.strip() if last_user_goal else None,
78
+ )
79
+
80
+ def has_artifact(self, session_id: Optional[str]) -> bool:
81
+ artifact = self.get_artifact(session_id)
82
+ return bool(artifact and artifact.code and artifact.code.strip())
83
+
84
+ def is_follow_up_request(self, request: CodeXRequest) -> bool:
85
+ if request.code and request.code.strip():
86
+ return False
87
+
88
+ message = (request.message or "").strip().lower()
89
+ if not message:
90
+ return False
91
+
92
+ return any(pattern in message for pattern in FOLLOW_UP_PATTERNS)
93
+
94
+ def enrich_request_with_context(
95
+ self,
96
+ request: CodeXRequest,
97
+ session_id: Optional[str],
98
+ ) -> CodeXRequest:
99
+ if request.code and request.code.strip():
100
+ return request
101
+
102
+ artifact = self.get_artifact(session_id)
103
+ if not artifact or not artifact.code:
104
+ return request
105
+
106
+ if not self.is_follow_up_request(request):
107
+ return request
108
+
109
+ enriched_language = request.language or artifact.language
110
+ enriched_framework = request.framework or artifact.framework
111
+
112
+ previous_context_parts = []
113
+ if artifact.last_user_goal:
114
+ previous_context_parts.append(f"Previous User Goal:\n{artifact.last_user_goal}")
115
+ if request.previous_context and request.previous_context.strip():
116
+ previous_context_parts.append(request.previous_context.strip())
117
+
118
+ merged_previous_context = "\n\n".join(previous_context_parts) if previous_context_parts else request.previous_context
119
+
120
+ return CodeXRequest(
121
+ message=request.message,
122
+ mode=request.mode,
123
+ language=enriched_language,
124
+ code=artifact.code,
125
+ error_message=request.error_message,
126
+ framework=enriched_framework,
127
+ file_name=request.file_name,
128
+ previous_context=merged_previous_context,
129
+ use_retrieval=request.use_retrieval,
130
+ )
131
+
132
+
133
+ context_manager = ContextManager()