File size: 2,268 Bytes
7838fc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import re
from schemas import CodeTaskType


GENERATE_PATTERNS = [
    r"\bcreate\b",
    r"\bgenerate\b",
    r"\bbuild\b",
    r"\bwrite\b",
    r"\bmake\b",
    r"\bdevelop\b",
    r"\bimplement\b",
]

FIX_PATTERNS = [
    r"\bfix\b",
    r"\bsolve\b",
    r"\bcorrect\b",
    r"\brepair\b",
    r"\bnot working\b",
    r"\bissue\b",
    r"\bbug\b",
    r"\berror\b",
    r"\bexception\b",
    r"\bcrash\b",
    r"\bfailed\b",
    r"\bproblem\b",
]

EXPLAIN_PATTERNS = [
    r"\bexplain\b",
    r"\bwhat does this do\b",
    r"\bhow does this work\b",
    r"\bdescribe\b",
    r"\bmeaning\b",
    r"\bunderstand\b",
]

REFACTOR_PATTERNS = [
    r"\brefactor\b",
    r"\bclean\b",
    r"\bimprove\b",
    r"\boptimize\b",
    r"\bbetter\b",
    r"\bmake this cleaner\b",
    r"\bmake this better\b",
]

REVIEW_PATTERNS = [
    r"\breview\b",
    r"\bcheck this code\b",
    r"\baudit\b",
    r"\binspect\b",
    r"\bcode review\b",
]


def normalize_text(text: str) -> str:
    text = text or ""
    text = text.strip().lower()
    text = re.sub(r"\s+", " ", text)
    return text


def contains_pattern(text: str, patterns: list[str]) -> bool:
    return any(re.search(pattern, text) for pattern in patterns)


def detect_task_type(
    message: str,
    code: str | None = None,
    error_message: str | None = None,
    mode_hint: CodeTaskType | None = None,
) -> CodeTaskType:
    if mode_hint and mode_hint != CodeTaskType.UNKNOWN:
        return mode_hint

    normalized_message = normalize_text(message)
    has_code = bool(code and code.strip())
    has_error = bool(error_message and error_message.strip())

    if has_error:
        return CodeTaskType.FIX

    if contains_pattern(normalized_message, FIX_PATTERNS):
        return CodeTaskType.FIX

    if contains_pattern(normalized_message, EXPLAIN_PATTERNS):
        return CodeTaskType.EXPLAIN

    if contains_pattern(normalized_message, REFACTOR_PATTERNS):
        return CodeTaskType.REFACTOR

    if contains_pattern(normalized_message, REVIEW_PATTERNS):
        return CodeTaskType.REVIEW

    if contains_pattern(normalized_message, GENERATE_PATTERNS):
        return CodeTaskType.GENERATE

    if has_code and not has_error:
        return CodeTaskType.EXPLAIN

    return CodeTaskType.UNKNOWN