codeT5Extension / app.py
KazeStudy's picture
Update app.py seccondly
c46817f
raw
history blame
4.13 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
import torch
app = FastAPI(title="CodeT5+ Backend on HuggingFace")
# ==== LOAD MODEL ====
model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only
print("Loading tokenizer + config...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
print("Loading model weights...")
model = T5ForConditionalGeneration.from_pretrained(
model_name,
config=config
)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)
model = model.to(device)
model.eval()
# ==== REQUEST / RESPONSE MODELS ====
class GenerateRequest(BaseModel):
prompt: str # mô tả cần sinh code
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 4
temperature: float = 0.7
class FixRequest(BaseModel):
code: str # code bị lỗi
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 4
temperature: float = 0.3 # thấp để sửa lỗi ổn định hơn
class CompleteRequest(BaseModel):
prefix: str # code phía trước con trỏ
suffix: str = "" # code phía sau con trỏ (nếu có)
language: str | None = "Python"
max_new_tokens: int = 64 # completion thường ngắn
num_beams: int = 4
temperature: float = 0.7
class CodeResponse(BaseModel):
output: str
# ==== TIỆN ÍCH DÙNG CHUNG ====
def run_model(prompt: str,
max_new_tokens: int,
num_beams: int,
temperature: float) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
early_stopping=True,
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text
# ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====
@app.post("/generate-code", response_model=CodeResponse)
def generate_code(req: GenerateRequest):
lang = req.language or "Python"
prompt = (
f"Generate {lang} code ONLY.\n"
f"Do NOT use any other programming language.\n\n"
f"Task:\n{req.prompt}\n\n"
f"{lang} code:\n"
)
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 2: SỬA LỖI CODE ====
@app.post("/fix-code", response_model=CodeResponse)
def fix_code(req: FixRequest):
lang = req.language or "Python"
prompt = (
f"The following {lang} code contains bugs.\n"
f"Fix all bugs and return ONLY the corrected {lang} code.\n\n"
f"Buggy {lang} code:\n{req.code}\n\n"
f"Corrected {lang} code:\n"
)
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 3: GỢI Ý CODE (COMPLETION) ====
@app.post("/complete-code", response_model=CodeResponse)
def complete_code(req: CompleteRequest):
lang = req.language or "Python"
# prefix + suffix giống kiểu Copilot completion
prompt = (
f"Complete the following {lang} code.\n"
f"Only generate the missing code between the prefix and suffix.\n\n"
f"Prefix:\n{req.prefix}\n\n"
f"Suffix:\n{req.suffix}\n\n"
f"Missing {lang} code:\n"
)
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== HEALTHCHECK ====
@app.get("/")
def root():
return {"status": "CodeT5+ backend is running 🚀"}