File size: 2,773 Bytes
18ccd52
 
464bae0
 
 
 
 
1452e7a
464bae0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1452e7a
464bae0
 
 
18ccd52
 
464bae0
 
 
 
 
 
 
 
 
7578d16
 
464bae0
 
7578d16
464bae0
 
 
 
 
 
 
 
 
 
 
 
 
 
18ccd52
 
464bae0
 
 
 
 
7578d16
 
18ccd52
1452e7a
d74d9aa
464bae0
 
 
18ccd52
464bae0
18ccd52
464bae0
 
 
18ccd52
464bae0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import time

from schemas import CodeXRequest, CodeXResponse, CodeTaskType
from code_router import detect_task_type
from prompt_builder import build_prompt
from model_client import model_client
from response_formatter import build_response, build_error_response
from code_retriever import retrieve_code_evidence
from config import settings


def should_use_retrieval(task_type: CodeTaskType, request: CodeXRequest) -> bool:
    if not request.use_retrieval:
        return False

    if task_type == CodeTaskType.FIX and settings.ENABLE_RETRIEVAL_FOR_FIX:
        return True

    if task_type == CodeTaskType.GENERATE and settings.ENABLE_RETRIEVAL_FOR_GENERATE:
        return True

    if task_type == CodeTaskType.EXPLAIN and settings.ENABLE_RETRIEVAL_FOR_EXPLAIN:
        return True

    return False


def get_retrieved_evidence(task_type: CodeTaskType, request: CodeXRequest):
    return retrieve_code_evidence(task_type, request)


def process_codex_request(request: CodeXRequest) -> CodeXResponse:
    start_time = time.perf_counter()

    try:
        task_type = detect_task_type(
            message=request.message,
            code=request.code,
            error_message=request.error_message,
            mode_hint=request.mode,
        )

        evidence_list = []
        retrieval_used = False

        if should_use_retrieval(task_type, request):
            evidence_list = get_retrieved_evidence(task_type, request)
            retrieval_used = len(evidence_list) > 0

        prompt = build_prompt(
            task_type=task_type,
            message=request.message,
            code=request.code,
            error_message=request.error_message,
            language=request.language,
            framework=request.framework,
            previous_context=request.previous_context,
            evidence_list=evidence_list,
        )

        model_output, model_used, used_fallback = model_client.generate(prompt)

        processing_time_ms = int((time.perf_counter() - start_time) * 1000)

        return build_response(
            task_type=task_type,
            model_output=model_output,
            model_used=model_used,
            used_fallback=used_fallback,
            retrieval_used=retrieval_used,
            source_count=len(evidence_list),
            processing_time_ms=processing_time_ms,
            original_code=request.code,
            sources=evidence_list,
        )

    except Exception as e:
        processing_time_ms = int((time.perf_counter() - start_time) * 1000)
        fallback_task = request.mode if request.mode else CodeTaskType.UNKNOWN

        return build_error_response(
            task_type=fallback_task,
            error_message=str(e),
            processing_time_ms=processing_time_ms,
        )