codeT5Extension / app.py
KazeStudy's picture
Update app.py thirdly
e9cc738
raw
history blame
4.68 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
import torch
app = FastAPI(title="CodeT5+ Backend on HuggingFace")
# ==== LOAD MODEL ====
model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only
print("Loading tokenizer + config...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
print("Loading model weights...")
model = T5ForConditionalGeneration.from_pretrained(
model_name,
config=config
)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)
model = model.to(device)
model.eval()
# ==== REQUEST / RESPONSE MODELS ====
class GenerateRequest(BaseModel):
prompt: str # mô tả cần sinh code
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 4
temperature: float = 0.7
class FixRequest(BaseModel):
code: str # code bị lỗi
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 4
temperature: float = 0.3 # thấp để sửa lỗi ổn định hơn
class CompleteRequest(BaseModel):
prefix: str # code phía trước con trỏ
suffix: str = "" # code phía sau con trỏ (nếu có)
language: str | None = "Python"
max_new_tokens: int = 64 # completion thường ngắn
num_beams: int = 1 # completion kiểu Cursor thường để 1 beam
temperature: float = 0.3 # ổn định hơn
class CodeResponse(BaseModel):
output: str
# ==== TIỆN ÍCH DÙNG CHUNG ====
def run_model(prompt: str,
max_new_tokens: int,
num_beams: int,
temperature: float) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
early_stopping=True,
repetition_penalty=1.05, # nhẹ để giảm lặp
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text.strip()
# ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====
@app.post("/generate-code", response_model=CodeResponse)
def generate_code(req: GenerateRequest):
lang = req.language or "Python"
prompt = f"""
You are a helpful coding assistant.
Generate ONLY valid {lang} source code for the task below.
Do NOT add any explanations, comments in natural language, or markdown.
Return only raw {lang} code.
Task:
{req.prompt}
{lang} code:
""".strip()
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 2: SỬA LỖI CODE ====
@app.post("/fix-code", response_model=CodeResponse)
def fix_code(req: FixRequest):
lang = req.language or "Python"
prompt = f"""
The following {lang} code contains bugs.
Fix all bugs and return ONLY the corrected {lang} code.
Do NOT add any explanations or comments in natural language.
Buggy {lang} code:
{req.code}
Corrected {lang} code:
""".strip()
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 3: GỢI Ý CODE KIỂU CURSOR (COMPLETION) ====
@app.post("/complete-code", response_model=CodeResponse)
def complete_code(req: CompleteRequest):
lang = req.language or "Python"
prompt = f"""
You are an AI code completion engine like Cursor or GitHub Copilot.
You will be given the prefix and suffix of a {lang} file.
Your task is to generate ONLY the missing {lang} code between them.
Rules:
- DO NOT repeat the prefix.
- DO NOT repeat the suffix.
- DO NOT add any explanations, natural language text, or markdown.
- DO NOT add imports/includes if they already appear in the prefix.
- Return ONLY raw {lang} code that can be directly inserted at the cursor.
Prefix:
{req.prefix}
<CURSOR HERE>
Suffix:
{req.suffix}
Missing {lang} code:
""".strip()
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== HEALTHCHECK ====
@app.get("/")
def root():
return {"status": "CodeT5+ backend is running 🚀"}