codeT5Extension / app.py
KazeStudy's picture
Update app.py fithly
3587c1f
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
import torch
app = FastAPI(title="CodeT5+ Backend on HuggingFace")
# ==== LOAD MODEL ====
model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only
print("Loading tokenizer + config...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
print("Loading model weights...")
model = T5ForConditionalGeneration.from_pretrained(
model_name,
config=config
)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)
model = model.to(device)
model.eval()
# ==== REQUEST / RESPONSE MODELS ====
class GenerateRequest(BaseModel):
prompt: str # mô tả cần sinh code (nên gửi tiếng Anh)
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 1 # ít beam cho ổn định
temperature: float = 0.3 # giảm randomness
class FixRequest(BaseModel):
code: str # code bị lỗi
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 1
temperature: float = 0.2 # thấp để sửa lỗi ổn định hơn
class CompleteRequest(BaseModel):
prefix: str # code phía trước con trỏ
suffix: str = "" # code phía sau con trỏ (chưa dùng nhiều, vì Codet5 không phải infill)
language: str | None = "Python"
max_new_tokens: int = 64 # completion thường ngắn
num_beams: int = 1
temperature: float = 0.3
class CodeResponse(BaseModel):
output: str
# ==== TIỆN ÍCH DÙNG CHUNG ====
def run_model(prompt: str,
max_new_tokens: int,
num_beams: int,
temperature: float) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
early_stopping=True,
repetition_penalty=1.05, # nhẹ để giảm lặp
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text.strip()
# ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====
@app.post("/generate-code", response_model=CodeResponse)
def generate_code(req: GenerateRequest):
"""
Sinh code từ mô tả.
Lưu ý: Codet5+ “thích” prompt ngắn, dạng pattern.
"""
lang = req.language or "Python"
# Prompt cực ngắn, đúng style CodeT5 (tránh essay dài)
# Ví dụ: "Python code:\n# Task: Create a function that prints numbers from 1 to 10.\n"
prompt = f"{lang} code:\n# Task: {req.prompt}\n"
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 2: SỬA LỖI CODE ====
@app.post("/fix-code", response_model=CodeResponse)
def fix_code(req: FixRequest):
"""
Sửa lỗi code: input là code sai, output là code đúng.
"""
lang = req.language or "Python"
# Cũng giữ prompt thật đơn giản
prompt = (
f"Fix the following {lang} code:\n"
f"{req.code}\n\n"
f"Fixed {lang} code:\n"
)
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 3: GỢI Ý CODE (KIỂU CURSOR – DÙ CHỈ DÙNG PREFIX) ====
@app.post("/complete-code", response_model=CodeResponse)
def complete_code(req: CompleteRequest):
"""
Gợi ý code tiếp theo dựa trên prefix.
Lưu ý: Codet5p-770m không phải model infill thực sự,
nên suffix ít tác dụng. Ở đây ta dùng chủ yếu prefix.
"""
lang = req.language or "Python"
# Dùng prefix làm context, để model tiếp tục code.
# Suffix có thể dùng để hiển thị phía client, còn model chủ yếu nhìn prefix.
prompt = f"{lang} code:\n{req.prefix}"
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== HEALTHCHECK ====
@app.get("/")
def root():
return {"status": "CodeT5+ backend is running 🚀"}