Update code
Browse files- code_retriever.py +38 -10
code_retriever.py
CHANGED
|
@@ -54,36 +54,64 @@ def sort_evidence(evidence_list: List[RetrievedEvidence]) -> List[RetrievedEvide
|
|
| 54 |
def trim_evidence(evidence_list: List[RetrievedEvidence]) -> List[RetrievedEvidence]:
|
| 55 |
stack_items = [item for item in evidence_list if item.source_type == SourceType.STACKOVERFLOW]
|
| 56 |
github_items = [item for item in evidence_list if item.source_type == SourceType.GITHUB]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
stack_items = stack_items[:3]
|
| 59 |
-
github_items = github_items[:
|
| 60 |
|
| 61 |
-
combined = stack_items + github_items
|
| 62 |
combined = sort_evidence(combined)
|
| 63 |
|
| 64 |
return combined[: settings.MAX_RETRIEVED_ITEMS]
|
| 65 |
|
| 66 |
|
| 67 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
if not settings.ENABLE_GITHUB_SEARCH:
|
| 69 |
return False
|
| 70 |
|
| 71 |
-
if
|
| 72 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
if
|
| 75 |
-
return
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
return False
|
| 78 |
|
| 79 |
|
| 80 |
def retrieve_code_evidence(task_type: CodeTaskType, request: CodeXRequest) -> List[RetrievedEvidence]:
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
return []
|
| 83 |
|
| 84 |
collected: List[RetrievedEvidence] = []
|
| 85 |
|
| 86 |
-
if
|
| 87 |
stack_results = search_stackoverflow(
|
| 88 |
message=request.message,
|
| 89 |
error_message=request.error_message,
|
|
@@ -94,7 +122,7 @@ def retrieve_code_evidence(task_type: CodeTaskType, request: CodeXRequest) -> Li
|
|
| 94 |
)
|
| 95 |
collected.extend(stack_results)
|
| 96 |
|
| 97 |
-
if should_use_github(request):
|
| 98 |
github_results = search_github(
|
| 99 |
message=request.message,
|
| 100 |
error_message=request.error_message,
|
|
|
|
| 54 |
def trim_evidence(evidence_list: List[RetrievedEvidence]) -> List[RetrievedEvidence]:
|
| 55 |
stack_items = [item for item in evidence_list if item.source_type == SourceType.STACKOVERFLOW]
|
| 56 |
github_items = [item for item in evidence_list if item.source_type == SourceType.GITHUB]
|
| 57 |
+
other_items = [
|
| 58 |
+
item for item in evidence_list
|
| 59 |
+
if item.source_type not in {SourceType.STACKOVERFLOW, SourceType.GITHUB}
|
| 60 |
+
]
|
| 61 |
|
| 62 |
stack_items = stack_items[:3]
|
| 63 |
+
github_items = github_items[:3]
|
| 64 |
|
| 65 |
+
combined = stack_items + github_items + other_items
|
| 66 |
combined = sort_evidence(combined)
|
| 67 |
|
| 68 |
return combined[: settings.MAX_RETRIEVED_ITEMS]
|
| 69 |
|
| 70 |
|
| 71 |
+
def should_use_stack(task_type: CodeTaskType) -> bool:
|
| 72 |
+
if not settings.ENABLE_STACK_SEARCH:
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
return task_type in {
|
| 76 |
+
CodeTaskType.FIX,
|
| 77 |
+
CodeTaskType.REVIEW,
|
| 78 |
+
CodeTaskType.REFACTOR,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def should_use_github(task_type: CodeTaskType, request: CodeXRequest) -> bool:
|
| 83 |
if not settings.ENABLE_GITHUB_SEARCH:
|
| 84 |
return False
|
| 85 |
|
| 86 |
+
if task_type == CodeTaskType.FIX:
|
| 87 |
+
return bool(
|
| 88 |
+
(request.framework and request.framework.strip())
|
| 89 |
+
or (request.error_message and request.error_message.strip())
|
| 90 |
+
or (request.language and request.language.strip())
|
| 91 |
+
)
|
| 92 |
|
| 93 |
+
if task_type in {CodeTaskType.REVIEW, CodeTaskType.REFACTOR}:
|
| 94 |
+
return bool(
|
| 95 |
+
(request.framework and request.framework.strip())
|
| 96 |
+
or (request.language and request.language.strip())
|
| 97 |
+
)
|
| 98 |
|
| 99 |
return False
|
| 100 |
|
| 101 |
|
| 102 |
def retrieve_code_evidence(task_type: CodeTaskType, request: CodeXRequest) -> List[RetrievedEvidence]:
|
| 103 |
+
supported_tasks = {
|
| 104 |
+
CodeTaskType.FIX,
|
| 105 |
+
CodeTaskType.REVIEW,
|
| 106 |
+
CodeTaskType.REFACTOR,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if task_type not in supported_tasks:
|
| 110 |
return []
|
| 111 |
|
| 112 |
collected: List[RetrievedEvidence] = []
|
| 113 |
|
| 114 |
+
if should_use_stack(task_type):
|
| 115 |
stack_results = search_stackoverflow(
|
| 116 |
message=request.message,
|
| 117 |
error_message=request.error_message,
|
|
|
|
| 122 |
)
|
| 123 |
collected.extend(stack_results)
|
| 124 |
|
| 125 |
+
if should_use_github(task_type, request):
|
| 126 |
github_results = search_github(
|
| 127 |
message=request.message,
|
| 128 |
error_message=request.error_message,
|