Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig | |
| import torch | |
| app = FastAPI(title="CodeT5+ Backend on HuggingFace") | |
| # ==== LOAD MODEL ==== | |
| model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only | |
| print("Loading tokenizer + config...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| config = AutoConfig.from_pretrained(model_name) | |
| print("Loading model weights...") | |
| model = T5ForConditionalGeneration.from_pretrained( | |
| model_name, | |
| config=config | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("Running on:", device) | |
| model = model.to(device) | |
| model.eval() | |
| # ==== REQUEST / RESPONSE MODELS ==== | |
| class GenerateRequest(BaseModel): | |
| prompt: str # mô tả cần sinh code (nên gửi tiếng Anh) | |
| language: str | None = "Python" | |
| max_new_tokens: int = 128 | |
| num_beams: int = 1 # ít beam cho ổn định | |
| temperature: float = 0.3 # giảm randomness | |
| class FixRequest(BaseModel): | |
| code: str # code bị lỗi | |
| language: str | None = "Python" | |
| max_new_tokens: int = 128 | |
| num_beams: int = 1 | |
| temperature: float = 0.2 # thấp để sửa lỗi ổn định hơn | |
| class CompleteRequest(BaseModel): | |
| prefix: str # code phía trước con trỏ | |
| suffix: str = "" # code phía sau con trỏ (chưa dùng nhiều, vì Codet5 không phải infill) | |
| language: str | None = "Python" | |
| max_new_tokens: int = 64 # completion thường ngắn | |
| num_beams: int = 1 | |
| temperature: float = 0.3 | |
| class CodeResponse(BaseModel): | |
| output: str | |
| # ==== TIỆN ÍCH DÙNG CHUNG ==== | |
| def run_model(prompt: str, | |
| max_new_tokens: int, | |
| num_beams: int, | |
| temperature: float) -> str: | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| temperature=temperature, | |
| early_stopping=True, | |
| repetition_penalty=1.05, # nhẹ để giảm lặp | |
| ) | |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return text.strip() | |
| # ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ==== | |
| def generate_code(req: GenerateRequest): | |
| """ | |
| Sinh code từ mô tả. | |
| Lưu ý: Codet5+ “thích” prompt ngắn, dạng pattern. | |
| """ | |
| lang = req.language or "Python" | |
| # Prompt cực ngắn, đúng style CodeT5 (tránh essay dài) | |
| # Ví dụ: "Python code:\n# Task: Create a function that prints numbers from 1 to 10.\n" | |
| prompt = f"{lang} code:\n# Task: {req.prompt}\n" | |
| output = run_model( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| num_beams=req.num_beams, | |
| temperature=req.temperature, | |
| ) | |
| return CodeResponse(output=output) | |
| # ==== ENDPOINT 2: SỬA LỖI CODE ==== | |
| def fix_code(req: FixRequest): | |
| """ | |
| Sửa lỗi code: input là code sai, output là code đúng. | |
| """ | |
| lang = req.language or "Python" | |
| # Cũng giữ prompt thật đơn giản | |
| prompt = ( | |
| f"Fix the following {lang} code:\n" | |
| f"{req.code}\n\n" | |
| f"Fixed {lang} code:\n" | |
| ) | |
| output = run_model( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| num_beams=req.num_beams, | |
| temperature=req.temperature, | |
| ) | |
| return CodeResponse(output=output) | |
| # ==== ENDPOINT 3: GỢI Ý CODE (KIỂU CURSOR – DÙ CHỈ DÙNG PREFIX) ==== | |
| def complete_code(req: CompleteRequest): | |
| """ | |
| Gợi ý code tiếp theo dựa trên prefix. | |
| Lưu ý: Codet5p-770m không phải model infill thực sự, | |
| nên suffix ít tác dụng. Ở đây ta dùng chủ yếu prefix. | |
| """ | |
| lang = req.language or "Python" | |
| # Dùng prefix làm context, để model tiếp tục code. | |
| # Suffix có thể dùng để hiển thị phía client, còn model chủ yếu nhìn prefix. | |
| prompt = f"{lang} code:\n{req.prefix}" | |
| output = run_model( | |
| prompt, | |
| max_new_tokens=req.max_new_tokens, | |
| num_beams=req.num_beams, | |
| temperature=req.temperature, | |
| ) | |
| return CodeResponse(output=output) | |
| # ==== HEALTHCHECK ==== | |
| def root(): | |
| return {"status": "CodeT5+ backend is running 🚀"} | |