File size: 5,281 Bytes
18ccd52
 
22e6a73
464bae0
 
 
 
1452e7a
7f49f70
22e6a73
464bae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c42e84
 
 
 
 
 
464bae0
 
 
 
1452e7a
464bae0
 
22e6a73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464bae0
18ccd52
 
464bae0
07a6927
 
 
 
2472c6c
07a6927
22e6a73
 
 
 
 
 
 
 
464bae0
7f49f70
 
 
 
464bae0
 
 
7578d16
 
07a6927
7f49f70
7578d16
464bae0
 
 
7f49f70
 
 
 
 
 
464bae0
 
 
 
 
18ccd52
 
7f49f70
464bae0
 
 
 
7578d16
 
18ccd52
7f49f70
d74d9aa
464bae0
 
7f49f70
 
 
 
 
 
 
 
 
 
 
 
464bae0
18ccd52
464bae0
18ccd52
464bae0
 
 
18ccd52
464bae0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import time

from schemas import CodeXRequest, CodeXResponse, CodeTaskType, ResponseMeta
from code_router import detect_task_type
from prompt_builder import build_prompt
from model_client import model_client
from response_formatter import build_response, build_error_response
from code_retriever import retrieve_code_evidence
from context_manager import context_manager
from scope_router import decide_scope, ScopeRoute
from config import settings


def should_use_retrieval(task_type: CodeTaskType, request: CodeXRequest) -> bool:
    if not request.use_retrieval:
        return False

    if task_type == CodeTaskType.FIX and settings.ENABLE_RETRIEVAL_FOR_FIX:
        return True

    if task_type == CodeTaskType.GENERATE and settings.ENABLE_RETRIEVAL_FOR_GENERATE:
        return True

    if task_type == CodeTaskType.EXPLAIN and settings.ENABLE_RETRIEVAL_FOR_EXPLAIN:
        return True

    if task_type == CodeTaskType.REVIEW and settings.ENABLE_RETRIEVAL_FOR_REVIEW:
        return True

    if task_type == CodeTaskType.REFACTOR and settings.ENABLE_RETRIEVAL_FOR_REFACTOR:
        return True

    return False


def get_retrieved_evidence(task_type: CodeTaskType, request: CodeXRequest):
    return retrieve_code_evidence(task_type, request)


def build_scope_response(
    route: ScopeRoute,
    message: str,
    processing_time_ms: int,
) -> CodeXResponse:
    warnings = []

    if route == ScopeRoute.RESTRICTED:
        warnings.append("Restricted or unsafe request detected.")
    elif route == ScopeRoute.IMAGE:
        warnings.append("Image-related request detected.")
    elif route == ScopeRoute.NON_CODE:
        warnings.append("Out-of-scope non-code request detected.")
    elif route == ScopeRoute.GREETING:
        warnings.append("Greeting request detected.")
    else:
        warnings.append("Unsupported or unclear request detected.")

    return CodeXResponse(
        answer=message,
        task_type=CodeTaskType.UNKNOWN,
        code_output=None,
        explanation=message,
        warnings=warnings,
        sources=[],
        needs_clarification=False,
        meta=ResponseMeta(
            used_model="scope_router",
            fallback_used=False,
            retrieval_used=False,
            source_count=0,
            processing_time_ms=processing_time_ms,
        ),
    )


def process_codex_request(request: CodeXRequest) -> CodeXResponse:
    start_time = time.perf_counter()

    try:
        enriched_request = context_manager.enrich_request_with_context(
            request=request,
            session_id=request.session_id,
        )

        scope_decision = decide_scope(enriched_request)
        if not scope_decision.should_continue_to_codex:
            processing_time_ms = int((time.perf_counter() - start_time) * 1000)
            return build_scope_response(
                route=scope_decision.route,
                message=scope_decision.message,
                processing_time_ms=processing_time_ms,
            )

        task_type = detect_task_type(
            message=enriched_request.message,
            code=enriched_request.code,
            error_message=enriched_request.error_message,
            mode_hint=enriched_request.mode,
        )

        evidence_list = []
        retrieval_used = False

        if should_use_retrieval(task_type, enriched_request):
            evidence_list = get_retrieved_evidence(task_type, enriched_request)
            retrieval_used = len(evidence_list) > 0

        prompt = build_prompt(
            task_type=task_type,
            message=enriched_request.message,
            code=enriched_request.code,
            error_message=enriched_request.error_message,
            language=enriched_request.language,
            framework=enriched_request.framework,
            previous_context=enriched_request.previous_context,
            evidence_list=evidence_list,
        )

        model_output, model_used, used_fallback = model_client.generate(prompt)

        processing_time_ms = int((time.perf_counter() - start_time) * 1000)

        response = build_response(
            task_type=task_type,
            model_output=model_output,
            model_used=model_used,
            used_fallback=used_fallback,
            retrieval_used=retrieval_used,
            source_count=len(evidence_list),
            processing_time_ms=processing_time_ms,
            original_code=enriched_request.code,
            sources=evidence_list,
        )

        if response.code_output and response.code_output.strip():
            context_manager.save_artifact(
                session_id=enriched_request.session_id,
                code=response.code_output,
                task_type=task_type,
                language=enriched_request.language,
                framework=enriched_request.framework,
                last_user_goal=enriched_request.message,
            )

        return response

    except Exception as e:
        processing_time_ms = int((time.perf_counter() - start_time) * 1000)
        fallback_task = request.mode if request.mode else CodeTaskType.UNKNOWN

        return build_error_response(
            task_type=fallback_task,
            error_message=str(e),
            processing_time_ms=processing_time_ms,
        )