hashan-7 commited on
Commit
4027d07
·
verified ·
1 Parent(s): 6c42e84

Update code

Browse files
Files changed (1) hide show
  1. code_retriever.py +38 -10
code_retriever.py CHANGED
@@ -54,36 +54,64 @@ def sort_evidence(evidence_list: List[RetrievedEvidence]) -> List[RetrievedEvide
54
  def trim_evidence(evidence_list: List[RetrievedEvidence]) -> List[RetrievedEvidence]:
55
  stack_items = [item for item in evidence_list if item.source_type == SourceType.STACKOVERFLOW]
56
  github_items = [item for item in evidence_list if item.source_type == SourceType.GITHUB]
 
 
 
 
57
 
58
  stack_items = stack_items[:3]
59
- github_items = github_items[:2]
60
 
61
- combined = stack_items + github_items
62
  combined = sort_evidence(combined)
63
 
64
  return combined[: settings.MAX_RETRIEVED_ITEMS]
65
 
66
 
67
- def should_use_github(request: CodeXRequest) -> bool:
 
 
 
 
 
 
 
 
 
 
 
68
  if not settings.ENABLE_GITHUB_SEARCH:
69
  return False
70
 
71
- if request.framework and request.framework.strip():
72
- return True
 
 
 
 
73
 
74
- if request.error_message and request.error_message.strip():
75
- return True
 
 
 
76
 
77
  return False
78
 
79
 
80
  def retrieve_code_evidence(task_type: CodeTaskType, request: CodeXRequest) -> List[RetrievedEvidence]:
81
- if task_type != CodeTaskType.FIX:
 
 
 
 
 
 
82
  return []
83
 
84
  collected: List[RetrievedEvidence] = []
85
 
86
- if settings.ENABLE_STACK_SEARCH:
87
  stack_results = search_stackoverflow(
88
  message=request.message,
89
  error_message=request.error_message,
@@ -94,7 +122,7 @@ def retrieve_code_evidence(task_type: CodeTaskType, request: CodeXRequest) -> Li
94
  )
95
  collected.extend(stack_results)
96
 
97
- if should_use_github(request):
98
  github_results = search_github(
99
  message=request.message,
100
  error_message=request.error_message,
 
54
  def trim_evidence(evidence_list: List[RetrievedEvidence]) -> List[RetrievedEvidence]:
55
  stack_items = [item for item in evidence_list if item.source_type == SourceType.STACKOVERFLOW]
56
  github_items = [item for item in evidence_list if item.source_type == SourceType.GITHUB]
57
+ other_items = [
58
+ item for item in evidence_list
59
+ if item.source_type not in {SourceType.STACKOVERFLOW, SourceType.GITHUB}
60
+ ]
61
 
62
  stack_items = stack_items[:3]
63
+ github_items = github_items[:3]
64
 
65
+ combined = stack_items + github_items + other_items
66
  combined = sort_evidence(combined)
67
 
68
  return combined[: settings.MAX_RETRIEVED_ITEMS]
69
 
70
 
71
+ def should_use_stack(task_type: CodeTaskType) -> bool:
72
+ if not settings.ENABLE_STACK_SEARCH:
73
+ return False
74
+
75
+ return task_type in {
76
+ CodeTaskType.FIX,
77
+ CodeTaskType.REVIEW,
78
+ CodeTaskType.REFACTOR,
79
+ }
80
+
81
+
82
+ def should_use_github(task_type: CodeTaskType, request: CodeXRequest) -> bool:
83
  if not settings.ENABLE_GITHUB_SEARCH:
84
  return False
85
 
86
+ if task_type == CodeTaskType.FIX:
87
+ return bool(
88
+ (request.framework and request.framework.strip())
89
+ or (request.error_message and request.error_message.strip())
90
+ or (request.language and request.language.strip())
91
+ )
92
 
93
+ if task_type in {CodeTaskType.REVIEW, CodeTaskType.REFACTOR}:
94
+ return bool(
95
+ (request.framework and request.framework.strip())
96
+ or (request.language and request.language.strip())
97
+ )
98
 
99
  return False
100
 
101
 
102
  def retrieve_code_evidence(task_type: CodeTaskType, request: CodeXRequest) -> List[RetrievedEvidence]:
103
+ supported_tasks = {
104
+ CodeTaskType.FIX,
105
+ CodeTaskType.REVIEW,
106
+ CodeTaskType.REFACTOR,
107
+ }
108
+
109
+ if task_type not in supported_tasks:
110
  return []
111
 
112
  collected: List[RetrievedEvidence] = []
113
 
114
+ if should_use_stack(task_type):
115
  stack_results = search_stackoverflow(
116
  message=request.message,
117
  error_message=request.error_message,
 
122
  )
123
  collected.extend(stack_results)
124
 
125
+ if should_use_github(task_type, request):
126
  github_results = search_github(
127
  message=request.message,
128
  error_message=request.error_message,