File size: 4,641 Bytes
e7ed4e6
 
 
 
 
 
 
 
c46817f
e7ed4e6
 
c46817f
 
e7ed4e6
c46817f
e7ed4e6
c46817f
e7ed4e6
 
 
 
 
 
 
 
 
 
c46817f
e9cc738
e7ed4e6
3587c1f
e7ed4e6
 
3587c1f
d60814d
e7ed4e6
c46817f
 
 
 
 
d60814d
 
c46817f
 
 
 
3587c1f
c46817f
 
3587c1f
 
e7ed4e6
 
c46817f
 
e7ed4e6
 
c46817f
e7ed4e6
e9cc738
c46817f
 
 
 
e7ed4e6
 
 
 
 
c46817f
 
 
 
3587c1f
e7ed4e6
 
 
e9cc738
c46817f
 
 
 
e9cc738
c46817f
 
3587c1f
 
 
 
c46817f
 
3587c1f
 
 
c46817f
 
 
 
 
 
 
 
 
 
 
 
 
e9cc738
c46817f
 
3587c1f
 
 
c46817f
 
3587c1f
 
 
 
 
 
c46817f
 
 
 
 
 
 
 
 
 
 
3587c1f
e9cc738
c46817f
 
 
3587c1f
 
 
 
 
c46817f
 
3587c1f
 
 
c46817f
 
 
 
 
 
 
 
3587c1f
c46817f
e7ed4e6
c46817f
e7ed4e6
e9cc738
e7ed4e6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
import torch

app = FastAPI(title="CodeT5+ Backend on HuggingFace")

# ==== LOAD MODEL ====
model_name = "Salesforce/codet5p-770m"  # model đa ngôn ngữ, không fine-tune Python-only

print("Loading tokenizer + config...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)

print("Loading model weights...")
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    config=config
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)

model = model.to(device)
model.eval()

# ==== REQUEST / RESPONSE MODELS ====


class GenerateRequest(BaseModel):
    prompt: str                     # mô tả cần sinh code (nên gửi tiếng Anh)
    language: str | None = "Python"
    max_new_tokens: int = 128
    num_beams: int = 1              # ít beam cho ổn định
    temperature: float = 0.3        # giảm randomness


class FixRequest(BaseModel):
    code: str                       # code bị lỗi
    language: str | None = "Python"
    max_new_tokens: int = 128
    num_beams: int = 1
    temperature: float = 0.2        # thấp để sửa lỗi ổn định hơn


class CompleteRequest(BaseModel):
    prefix: str                     # code phía trước con trỏ
    suffix: str = ""                # code phía sau con trỏ (chưa dùng nhiều, vì Codet5 không phải infill)
    language: str | None = "Python"
    max_new_tokens: int = 64        # completion thường ngắn
    num_beams: int = 1
    temperature: float = 0.3


class CodeResponse(BaseModel):
    output: str


# ==== TIỆN ÍCH DÙNG CHUNG ====


def run_model(prompt: str,
              max_new_tokens: int,
              num_beams: int,
              temperature: float) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            temperature=temperature,
            early_stopping=True,
            repetition_penalty=1.05,  # nhẹ để giảm lặp
        )

    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return text.strip()


# ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====


@app.post("/generate-code", response_model=CodeResponse)
def generate_code(req: GenerateRequest):
    """
    Sinh code từ mô tả.
    Lưu ý: Codet5+ “thích” prompt ngắn, dạng pattern.
    """
    lang = req.language or "Python"

    # Prompt cực ngắn, đúng style CodeT5 (tránh essay dài)
    # Ví dụ: "Python code:\n# Task: Create a function that prints numbers from 1 to 10.\n"
    prompt = f"{lang} code:\n# Task: {req.prompt}\n"

    output = run_model(
        prompt,
        max_new_tokens=req.max_new_tokens,
        num_beams=req.num_beams,
        temperature=req.temperature,
    )

    return CodeResponse(output=output)


# ==== ENDPOINT 2: SỬA LỖI CODE ====


@app.post("/fix-code", response_model=CodeResponse)
def fix_code(req: FixRequest):
    """
    Sửa lỗi code: input là code sai, output là code đúng.
    """
    lang = req.language or "Python"

    # Cũng giữ prompt thật đơn giản
    prompt = (
        f"Fix the following {lang} code:\n"
        f"{req.code}\n\n"
        f"Fixed {lang} code:\n"
    )

    output = run_model(
        prompt,
        max_new_tokens=req.max_new_tokens,
        num_beams=req.num_beams,
        temperature=req.temperature,
    )

    return CodeResponse(output=output)


# ==== ENDPOINT 3: GỢI Ý CODE (KIỂU CURSOR – DÙ CHỈ DÙNG PREFIX) ====


@app.post("/complete-code", response_model=CodeResponse)
def complete_code(req: CompleteRequest):
    """
    Gợi ý code tiếp theo dựa trên prefix.
    Lưu ý: Codet5p-770m không phải model infill thực sự,
    nên suffix ít tác dụng. Ở đây ta dùng chủ yếu prefix.
    """
    lang = req.language or "Python"

    # Dùng prefix làm context, để model tiếp tục code.
    # Suffix có thể dùng để hiển thị phía client, còn model chủ yếu nhìn prefix.
    prompt = f"{lang} code:\n{req.prefix}"

    output = run_model(
        prompt,
        max_new_tokens=req.max_new_tokens,
        num_beams=req.num_beams,
        temperature=req.temperature,
    )

    return CodeResponse(output=output)


# ==== HEALTHCHECK ====


@app.get("/")
def root():
    return {"status": "CodeT5+ backend is running 🚀"}