KazeStudy commited on
Commit
e7ed4e6
·
1 Parent(s): 0c7656e

Add application file

Browse files
Files changed (2) hide show
  1. Dockerfile +14 -0
  2. app.py +71 -0
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+
14
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoConfig
4
+ import torch
5
+
6
+ app = FastAPI(title="CodeT5+ Backend on HuggingFace")
7
+
8
+ # ==== LOAD MODEL ====
9
+ base_ckpt = "Salesforce/codet5p-770m"
10
+ finetuned_ckpt = "OSS-forge/codet5p-770m-pyresbugs"
11
+
12
+ print("Loading tokenizer + config...")
13
+ tokenizer = AutoTokenizer.from_pretrained(base_ckpt)
14
+ config = AutoConfig.from_pretrained(base_ckpt)
15
+
16
+ print("Loading fine-tuned model weights...")
17
+ model = T5ForConditionalGeneration.from_pretrained(
18
+ finetuned_ckpt,
19
+ config=config
20
+ )
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ print("Running on:", device)
24
+
25
+ model = model.to(device)
26
+ model.eval()
27
+
28
+ # ==== REQUEST / RESPONSE MODELS ====
29
+ class GenerateRequest(BaseModel):
30
+ prompt: str
31
+ language: str | None = "Python"
32
+ task: str = "generate"
33
+ max_new_tokens: int = 128
34
+ num_beams: int = 4
35
+ temperature: float = 0.7
36
+
37
+ class GenerateResponse(BaseModel):
38
+ output: str
39
+
40
+
41
+ def build_prompt(req: GenerateRequest):
42
+ if req.task == "generate":
43
+ return f"Generate {req.language} code:\n{req.prompt}"
44
+ elif req.task == "fix":
45
+ return f"Fix the bug in the following {req.language} code:\n{req.prompt}\n\nCorrected code:"
46
+ else:
47
+ return req.prompt
48
+
49
+
50
+ @app.post("/generate", response_model=GenerateResponse)
51
+ def generate(req: GenerateRequest):
52
+
53
+ prompt = build_prompt(req)
54
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
55
+
56
+ with torch.no_grad():
57
+ outputs = model.generate(
58
+ **inputs,
59
+ max_new_tokens=req.max_new_tokens,
60
+ num_beams=req.num_beams,
61
+ temperature=req.temperature,
62
+ early_stopping=True
63
+ )
64
+
65
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+ return GenerateResponse(output=text)
67
+
68
+
69
+ @app.get("/")
70
+ def root():
71
+ return {"status": "CodeT5+ backend is running 🚀"}