from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig import torch app = FastAPI(title="CodeT5+ Backend on HuggingFace") # ==== LOAD MODEL ==== model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only print("Loading tokenizer + config...") tokenizer = AutoTokenizer.from_pretrained(model_name) config = AutoConfig.from_pretrained(model_name) print("Loading model weights...") model = T5ForConditionalGeneration.from_pretrained( model_name, config=config ) device = "cuda" if torch.cuda.is_available() else "cpu" print("Running on:", device) model = model.to(device) model.eval() # ==== REQUEST / RESPONSE MODELS ==== class GenerateRequest(BaseModel): prompt: str # mô tả cần sinh code language: str | None = "Python" max_new_tokens: int = 128 num_beams: int = 1 # ít beam hơn cho ổn định temperature: float = 0.3 # giảm randomness class FixRequest(BaseModel): code: str # code bị lỗi language: str | None = "Python" max_new_tokens: int = 128 num_beams: int = 1 temperature: float = 0.2 # thấp để sửa lỗi ổn định hơn class CompleteRequest(BaseModel): prefix: str # code phía trước con trỏ suffix: str = "" # code phía sau con trỏ (nếu có) language: str | None = "Python" max_new_tokens: int = 64 # completion thường ngắn num_beams: int = 1 # completion kiểu Cursor thường để 1 beam temperature: float = 0.3 # ổn định hơn class CodeResponse(BaseModel): output: str # ==== TIỆN ÍCH DÙNG CHUNG ==== def run_model(prompt: str, max_new_tokens: int, num_beams: int, temperature: float) -> str: inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, num_beams=num_beams, temperature=temperature, early_stopping=True, repetition_penalty=1.05, # nhẹ để giảm lặp ) text = tokenizer.decode(outputs[0], skip_special_tokens=True) return text.strip() def clean_code(raw: str, lang: str) -> str: """ Dọn mấy dòng rác đầu output (vd: ':', 'program:', ...) cho ra code “sạch” hơn. Không đụng gì phần giữa & cuối. """ lines = raw.splitlines() if not lines: return raw.strip() lang_low = (lang or "").lower() def looks_like_code(s: str) -> bool: s = s.strip() if not s: return False if lang_low == "python": # thường bắt đầu bằng import/def/class/# comment prefixes = ("def ", "class ", "import ", "from ", "#", "@") return s.startswith(prefixes) elif lang_low in ("c", "c++", "cpp"): prefixes = ("#include", "int ", "void ", "char ", "float ", "double ", "struct ", "typedef ") return s.startswith(prefixes) else: # fallback cho ngôn ngữ khác return any(ch in s for ch in (";", "{", "}", "=", "function ", "public ", "private ")) start = 0 for i, line in enumerate(lines): if looks_like_code(line): start = i break cleaned = "\n".join(lines[start:]).strip() return cleaned if cleaned else raw.strip() # ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ==== @app.post("/generate-code", response_model=CodeResponse) def generate_code(req: GenerateRequest): lang = req.language or "Python" prompt = f""" You are a helpful coding assistant. Generate ONLY valid {lang} source code for the task below. Do NOT add any explanations, comments in natural language, or markdown. Do NOT repeat the task description. Return only raw {lang} code that can be run. Task: {req.prompt} Begin {lang} code now: """.strip() output = run_model( prompt, max_new_tokens=req.max_new_tokens, num_beams=req.num_beams, temperature=req.temperature, ) output = clean_code(output, lang) return CodeResponse(output=output) # ==== ENDPOINT 2: SỬA LỖI CODE ==== @app.post("/fix-code", response_model=CodeResponse) def fix_code(req: FixRequest): lang = req.language or "Python" prompt = f""" The following {lang} code contains bugs. Fix all bugs and return ONLY the corrected {lang} code. Do NOT add any explanations or comments in natural language. Do NOT change the language or rewrite the task. Buggy {lang} code: {req.code} Corrected {lang} code: """.strip() output = run_model( prompt, max_new_tokens=req.max_new_tokens, num_beams=req.num_beams, temperature=req.temperature, ) output = clean_code(output, lang) return CodeResponse(output=output) # ==== ENDPOINT 3: GỢI Ý CODE KIỂU CURSOR (COMPLETION) ==== @app.post("/complete-code", response_model=CodeResponse) def complete_code(req: CompleteRequest): lang = req.language or "Python" prompt = f""" You are an AI code completion engine like Cursor or GitHub Copilot. You will be given the prefix and suffix of a {lang} file. Your task is to generate ONLY the missing {lang} code between them. Rules: - DO NOT repeat the prefix. - DO NOT repeat the suffix. - DO NOT add any explanations, natural language text, or markdown. - DO NOT add imports/includes if they already appear in the prefix. - Return ONLY raw {lang} code that can be directly inserted at the cursor. Prefix: {req.prefix} Suffix: {req.suffix} Missing {lang} code: """.strip() output = run_model( prompt, max_new_tokens=req.max_new_tokens, num_beams=req.num_beams, temperature=req.temperature, ) # completion thường là snippet ngắn, không clean để tránh cắt nhầm return CodeResponse(output=output.strip()) # ==== HEALTHCHECK ==== @app.get("/") def root(): return {"status": "CodeT5+ backend is running 🚀"}