KazeStudy commited on
Commit
c46817f
·
1 Parent(s): 71dad29

Update app.py seccondly

Browse files
Files changed (1) hide show
  1. app.py +108 -25
app.py CHANGED
@@ -6,16 +6,15 @@ import torch
6
  app = FastAPI(title="CodeT5+ Backend on HuggingFace")
7
 
8
  # ==== LOAD MODEL ====
9
- base_ckpt = "Salesforce/codet5p-770m"
10
- finetuned_ckpt = "Salesforce/codet5p-770m"
11
 
12
  print("Loading tokenizer + config...")
13
- tokenizer = AutoTokenizer.from_pretrained(base_ckpt)
14
- config = AutoConfig.from_pretrained(base_ckpt)
15
 
16
- print("Loading fine-tuned model weights...")
17
  model = T5ForConditionalGeneration.from_pretrained(
18
- finetuned_ckpt,
19
  config=config
20
  )
21
 
@@ -26,45 +25,129 @@ model = model.to(device)
26
  model.eval()
27
 
28
  # ==== REQUEST / RESPONSE MODELS ====
 
29
  class GenerateRequest(BaseModel):
30
- prompt: str
31
  language: str | None = "Python"
32
- task: str = "generate"
33
  max_new_tokens: int = 128
34
  num_beams: int = 4
35
  temperature: float = 0.7
36
 
37
- class GenerateResponse(BaseModel):
38
- output: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
- def build_prompt(req: GenerateRequest):
42
- if req.task == "generate":
43
- return f"Generate {req.language} code:\n{req.prompt}"
44
- elif req.task == "fix":
45
- return f"Fix the bug in the following {req.language} code:\n{req.prompt}\n\nCorrected code:"
46
- else:
47
- return req.prompt
48
 
49
 
50
- @app.post("/generate", response_model=GenerateResponse)
51
- def generate(req: GenerateRequest):
52
 
53
- prompt = build_prompt(req)
 
 
 
54
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
55
 
56
  with torch.no_grad():
57
  outputs = model.generate(
58
  **inputs,
59
- max_new_tokens=req.max_new_tokens,
60
- num_beams=req.num_beams,
61
- temperature=req.temperature,
62
- early_stopping=True
63
  )
64
 
65
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
- return GenerateResponse(output=text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
68
 
69
  @app.get("/")
70
  def root():
 
6
  app = FastAPI(title="CodeT5+ Backend on HuggingFace")
7
 
8
  # ==== LOAD MODEL ====
9
+ model_name = "Salesforce/codet5p-770m" # model đa ngôn ngữ, không fine-tune Python-only
 
10
 
11
  print("Loading tokenizer + config...")
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ config = AutoConfig.from_pretrained(model_name)
14
 
15
+ print("Loading model weights...")
16
  model = T5ForConditionalGeneration.from_pretrained(
17
+ model_name,
18
  config=config
19
  )
20
 
 
25
  model.eval()
26
 
27
  # ==== REQUEST / RESPONSE MODELS ====
28
+
29
  class GenerateRequest(BaseModel):
30
+ prompt: str # mô tả cần sinh code
31
  language: str | None = "Python"
 
32
  max_new_tokens: int = 128
33
  num_beams: int = 4
34
  temperature: float = 0.7
35
 
36
+
37
+ class FixRequest(BaseModel):
38
+ code: str # code bị lỗi
39
+ language: str | None = "Python"
40
+ max_new_tokens: int = 128
41
+ num_beams: int = 4
42
+ temperature: float = 0.3 # thấp để sửa lỗi ổn định hơn
43
+
44
+
45
+ class CompleteRequest(BaseModel):
46
+ prefix: str # code phía trước con trỏ
47
+ suffix: str = "" # code phía sau con trỏ (nếu có)
48
+ language: str | None = "Python"
49
+ max_new_tokens: int = 64 # completion thường ngắn
50
+ num_beams: int = 4
51
+ temperature: float = 0.7
52
 
53
 
54
+ class CodeResponse(BaseModel):
55
+ output: str
 
 
 
 
 
56
 
57
 
58
+ # ==== TIỆN ÍCH DÙNG CHUNG ====
 
59
 
60
+ def run_model(prompt: str,
61
+ max_new_tokens: int,
62
+ num_beams: int,
63
+ temperature: float) -> str:
64
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
65
 
66
  with torch.no_grad():
67
  outputs = model.generate(
68
  **inputs,
69
+ max_new_tokens=max_new_tokens,
70
+ num_beams=num_beams,
71
+ temperature=temperature,
72
+ early_stopping=True,
73
  )
74
 
75
  text = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ return text
77
+
78
+
79
+ # ==== ENDPOINT 1: TẠO CODE TỪ PROMPT ====
80
+
81
+ @app.post("/generate-code", response_model=CodeResponse)
82
+ def generate_code(req: GenerateRequest):
83
+ lang = req.language or "Python"
84
+
85
+ prompt = (
86
+ f"Generate {lang} code ONLY.\n"
87
+ f"Do NOT use any other programming language.\n\n"
88
+ f"Task:\n{req.prompt}\n\n"
89
+ f"{lang} code:\n"
90
+ )
91
+
92
+ output = run_model(
93
+ prompt,
94
+ max_new_tokens=req.max_new_tokens,
95
+ num_beams=req.num_beams,
96
+ temperature=req.temperature,
97
+ )
98
+
99
+ return CodeResponse(output=output)
100
+
101
+
102
+ # ==== ENDPOINT 2: SỬA LỖI CODE ====
103
+
104
+ @app.post("/fix-code", response_model=CodeResponse)
105
+ def fix_code(req: FixRequest):
106
+ lang = req.language or "Python"
107
+
108
+ prompt = (
109
+ f"The following {lang} code contains bugs.\n"
110
+ f"Fix all bugs and return ONLY the corrected {lang} code.\n\n"
111
+ f"Buggy {lang} code:\n{req.code}\n\n"
112
+ f"Corrected {lang} code:\n"
113
+ )
114
+
115
+ output = run_model(
116
+ prompt,
117
+ max_new_tokens=req.max_new_tokens,
118
+ num_beams=req.num_beams,
119
+ temperature=req.temperature,
120
+ )
121
+
122
+ return CodeResponse(output=output)
123
+
124
+
125
+ # ==== ENDPOINT 3: GỢI Ý CODE (COMPLETION) ====
126
+
127
+ @app.post("/complete-code", response_model=CodeResponse)
128
+ def complete_code(req: CompleteRequest):
129
+ lang = req.language or "Python"
130
+
131
+ # prefix + suffix giống kiểu Copilot completion
132
+ prompt = (
133
+ f"Complete the following {lang} code.\n"
134
+ f"Only generate the missing code between the prefix and suffix.\n\n"
135
+ f"Prefix:\n{req.prefix}\n\n"
136
+ f"Suffix:\n{req.suffix}\n\n"
137
+ f"Missing {lang} code:\n"
138
+ )
139
+
140
+ output = run_model(
141
+ prompt,
142
+ max_new_tokens=req.max_new_tokens,
143
+ num_beams=req.num_beams,
144
+ temperature=req.temperature,
145
+ )
146
+
147
+ return CodeResponse(output=output)
148
+
149
 
150
+ # ==== HEALTHCHECK ====
151
 
152
  @app.get("/")
153
  def root():