Spaces:
Sleeping
Sleeping
File size: 4,641 Bytes
e7ed4e6 c46817f e7ed4e6 c46817f e7ed4e6 c46817f e7ed4e6 c46817f e7ed4e6 c46817f e9cc738 e7ed4e6 3587c1f e7ed4e6 3587c1f d60814d e7ed4e6 c46817f d60814d c46817f 3587c1f c46817f 3587c1f e7ed4e6 c46817f e7ed4e6 c46817f e7ed4e6 e9cc738 c46817f e7ed4e6 c46817f 3587c1f e7ed4e6 e9cc738 c46817f e9cc738 c46817f 3587c1f c46817f 3587c1f c46817f e9cc738 c46817f 3587c1f c46817f 3587c1f c46817f 3587c1f e9cc738 c46817f 3587c1f c46817f 3587c1f c46817f 3587c1f c46817f e7ed4e6 c46817f e7ed4e6 e9cc738 e7ed4e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
import torch
app = FastAPI(title="CodeT5+ Backend on HuggingFace")
# ==== LOAD MODEL ====
model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only
print("Loading tokenizer + config...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
print("Loading model weights...")
model = T5ForConditionalGeneration.from_pretrained(
model_name,
config=config
)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on:", device)
model = model.to(device)
model.eval()
# ==== REQUEST / RESPONSE MODELS ====
class GenerateRequest(BaseModel):
prompt: str # mô tả cần sinh code (nên gửi tiếng Anh)
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 1 # ít beam cho ổn định
temperature: float = 0.3 # giảm randomness
class FixRequest(BaseModel):
code: str # code bị lỗi
language: str | None = "Python"
max_new_tokens: int = 128
num_beams: int = 1
temperature: float = 0.2 # thấp để sửa lỗi ổn định hơn
class CompleteRequest(BaseModel):
prefix: str # code phía trước con trỏ
suffix: str = "" # code phía sau con trỏ (chưa dùng nhiều, vì Codet5 không phải infill)
language: str | None = "Python"
max_new_tokens: int = 64 # completion thường ngắn
num_beams: int = 1
temperature: float = 0.3
class CodeResponse(BaseModel):
output: str
# ==== TIỆN ÍCH DÙNG CHUNG ====
def run_model(prompt: str,
max_new_tokens: int,
num_beams: int,
temperature: float) -> str:
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
temperature=temperature,
early_stopping=True,
repetition_penalty=1.05, # nhẹ để giảm lặp
)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return text.strip()
# ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====
@app.post("/generate-code", response_model=CodeResponse)
def generate_code(req: GenerateRequest):
"""
Sinh code từ mô tả.
Lưu ý: Codet5+ “thích” prompt ngắn, dạng pattern.
"""
lang = req.language or "Python"
# Prompt cực ngắn, đúng style CodeT5 (tránh essay dài)
# Ví dụ: "Python code:\n# Task: Create a function that prints numbers from 1 to 10.\n"
prompt = f"{lang} code:\n# Task: {req.prompt}\n"
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 2: SỬA LỖI CODE ====
@app.post("/fix-code", response_model=CodeResponse)
def fix_code(req: FixRequest):
"""
Sửa lỗi code: input là code sai, output là code đúng.
"""
lang = req.language or "Python"
# Cũng giữ prompt thật đơn giản
prompt = (
f"Fix the following {lang} code:\n"
f"{req.code}\n\n"
f"Fixed {lang} code:\n"
)
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== ENDPOINT 3: GỢI Ý CODE (KIỂU CURSOR – DÙ CHỈ DÙNG PREFIX) ====
@app.post("/complete-code", response_model=CodeResponse)
def complete_code(req: CompleteRequest):
"""
Gợi ý code tiếp theo dựa trên prefix.
Lưu ý: Codet5p-770m không phải model infill thực sự,
nên suffix ít tác dụng. Ở đây ta dùng chủ yếu prefix.
"""
lang = req.language or "Python"
# Dùng prefix làm context, để model tiếp tục code.
# Suffix có thể dùng để hiển thị phía client, còn model chủ yếu nhìn prefix.
prompt = f"{lang} code:\n{req.prefix}"
output = run_model(
prompt,
max_new_tokens=req.max_new_tokens,
num_beams=req.num_beams,
temperature=req.temperature,
)
return CodeResponse(output=output)
# ==== HEALTHCHECK ====
@app.get("/")
def root():
return {"status": "CodeT5+ backend is running 🚀"}
|