File size: 6,277 Bytes
e7ed4e6
 
 
 
 
 
 
 
c46817f
e7ed4e6
 
c46817f
 
e7ed4e6
c46817f
e7ed4e6
c46817f
e7ed4e6
 
 
 
 
 
 
 
 
 
c46817f
e9cc738
e7ed4e6
c46817f
e7ed4e6
 
d60814d
 
e7ed4e6
c46817f
 
 
 
 
d60814d
 
c46817f
 
 
 
 
 
 
e9cc738
 
e7ed4e6
 
c46817f
 
e7ed4e6
 
c46817f
e7ed4e6
e9cc738
c46817f
 
 
 
e7ed4e6
 
 
 
 
c46817f
 
 
 
e9cc738
e7ed4e6
 
 
e9cc738
c46817f
 
d60814d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46817f
 
e9cc738
c46817f
 
 
 
e9cc738
 
 
 
 
d60814d
 
e9cc738
 
 
 
d60814d
e9cc738
c46817f
 
 
 
 
 
 
 
d60814d
 
c46817f
 
 
 
 
e9cc738
c46817f
 
 
 
e9cc738
 
 
 
d60814d
e9cc738
 
 
 
 
 
c46817f
 
 
 
 
 
 
 
d60814d
 
c46817f
 
 
e9cc738
 
c46817f
 
 
 
 
e9cc738
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46817f
 
 
 
 
 
 
 
d60814d
 
c46817f
e7ed4e6
c46817f
e7ed4e6
e9cc738
e7ed4e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
import torch

app = FastAPI(title="CodeT5+ Backend on HuggingFace")

# ==== LOAD MODEL ====
model_name = "Salesforce/codet5p-770m"  # model đa ngôn ngữ, không fine-tune Python-only

print("Loading tokenizer + config...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

print("Loading model weights...")
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    config=config
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

model = model.to(device)
model.eval()

# ==== REQUEST / RESPONSE MODELS ====


class GenerateRequest(BaseModel):
    prompt: str                     # mô tả cần sinh code
    language: str | None = "Python"
    max_new_tokens: int = 128
    num_beams: int = 1              # ít beam hơn cho ổn định
    temperature: float = 0.3        # giảm randomness


class FixRequest(BaseModel):
    code: str                       # code bị lỗi
    language: str | None = "Python"
    max_new_tokens: int = 128
    num_beams: int = 1
    temperature: float = 0.2        # thấp để sửa lỗi ổn định hơn


class CompleteRequest(BaseModel):
    prefix: str                     # code phía trước con trỏ
    suffix: str = ""                # code phía sau con trỏ (nếu có)
    language: str | None = "Python"
    max_new_tokens: int = 64        # completion thường ngắn
    num_beams: int = 1              # completion kiểu Cursor thường để 1 beam
    temperature: float = 0.3        # ổn định hơn


class CodeResponse(BaseModel):
    output: str


# ==== TIỆN ÍCH DÙNG CHUNG ====


def run_model(prompt: str,
              max_new_tokens: int,
              num_beams: int,
              temperature: float) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            temperature=temperature,
            early_stopping=True,
            repetition_penalty=1.05,   # nhẹ để giảm lặp
        )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text.strip()


def clean_code(raw: str, lang: str) -> str:
    """
    Dọn mấy dòng rác đầu output (vd: ':', 'program:', ...) cho ra code “sạch” hơn.
    Không đụng gì phần giữa & cuối.
    """
    lines = raw.splitlines()
    if not lines:
        return raw.strip()

    lang_low = (lang or "").lower()

    def looks_like_code(s: str) -> bool:
        s = s.strip()
        if not s:
            return False

        if lang_low == "python":
            # thường bắt đầu bằng import/def/class/# comment
            prefixes = ("def ", "class ", "import ", "from ", "#", "@")
            return s.startswith(prefixes)
        elif lang_low in ("c", "c++", "cpp"):
            prefixes = ("#include", "int ", "void ", "char ", "float ",
                        "double ", "struct ", "typedef ")
            return s.startswith(prefixes)
        else:
            # fallback cho ngôn ngữ khác
            return any(ch in s for ch in (";", "{", "}", "=", "function ", "public ", "private "))

    start = 0
    for i, line in enumerate(lines):
        if looks_like_code(line):
            start = i
            break

    cleaned = "\n".join(lines[start:]).strip()
    return cleaned if cleaned else raw.strip()


# ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====


@app.post("/generate-code", response_model=CodeResponse)
def generate_code(req: GenerateRequest):
    lang = req.language or "Python"

    prompt = f"""
You are a helpful coding assistant.

Generate ONLY valid {lang} source code for the task below.
Do NOT add any explanations, comments in natural language, or markdown.
Do NOT repeat the task description.
Return only raw {lang} code that can be run.

Task:
{req.prompt}

Begin {lang} code now:
""".strip()

    output = run_model(
        prompt,
        max_new_tokens=req.max_new_tokens,
        num_beams=req.num_beams,
        temperature=req.temperature,
    )

    output = clean_code(output, lang)

    return CodeResponse(output=output)


# ==== ENDPOINT 2: SỬA LỖI CODE ====


@app.post("/fix-code", response_model=CodeResponse)
def fix_code(req: FixRequest):
    lang = req.language or "Python"

    prompt = f"""
The following {lang} code contains bugs.
Fix all bugs and return ONLY the corrected {lang} code.
Do NOT add any explanations or comments in natural language.
Do NOT change the language or rewrite the task.

Buggy {lang} code:
{req.code}

Corrected {lang} code:
""".strip()

    output = run_model(
        prompt,
        max_new_tokens=req.max_new_tokens,
        num_beams=req.num_beams,
        temperature=req.temperature,
    )

    output = clean_code(output, lang)

    return CodeResponse(output=output)


# ==== ENDPOINT 3: GỢI Ý CODE KIỂU CURSOR (COMPLETION) ====


@app.post("/complete-code", response_model=CodeResponse)
def complete_code(req: CompleteRequest):
    lang = req.language or "Python"

    prompt = f"""
You are an AI code completion engine like Cursor or GitHub Copilot.

You will be given the prefix and suffix of a {lang} file.
Your task is to generate ONLY the missing {lang} code between them.

Rules:
- DO NOT repeat the prefix.
- DO NOT repeat the suffix.
- DO NOT add any explanations, natural language text, or markdown.
- DO NOT add imports/includes if they already appear in the prefix.
- Return ONLY raw {lang} code that can be directly inserted at the cursor.

Prefix:
{req.prefix}

<CURSOR HERE>

Suffix:
{req.suffix}

Missing {lang} code:
""".strip()

    output = run_model(
        prompt,
        max_new_tokens=req.max_new_tokens,
        num_beams=req.num_beams,
        temperature=req.temperature,
    )

    # completion thường là snippet ngắn, không clean để tránh cắt nhầm
    return CodeResponse(output=output.strip())


# ==== HEALTHCHECK ====


@app.get("/")
def root():
    return {"status": "CodeT5+ backend is running 🚀"}