mjpsm commited on
Commit
891c78f
·
verified ·
1 Parent(s): d7c0849

Upload 8 files

Browse files
Files changed (8) hide show
  1. Dockerfile +13 -0
  2. __init__.py +3 -0
  3. config.py +2 -0
  4. main.py +26 -0
  5. model_loader.py +23 -0
  6. requirements.txt +7 -0
  7. routes.py +105 -0
  8. schemas.py +12 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+
7
+ RUN pip install --no-cache-dir -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ EXPOSE 8000
12
+
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from app.main import app
2
+
3
+ __all__ = ["app"]
config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ MODEL_NAME = "mjpsm/qwen3-0.6-bash-experiment-model-final-merged"
2
+ MAX_NEW_TOKENS = 128
main.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+
3
+ from fastapi import FastAPI
4
+
5
+ from app.model_loader import load_model
6
+ from app.routes import router
7
+
8
+
9
+ @asynccontextmanager
10
+ async def lifespan(app: FastAPI):
11
+ tokenizer, model = load_model()
12
+ app.state.tokenizer = tokenizer
13
+ app.state.model = model
14
+
15
+ yield
16
+
17
+ app.state.tokenizer = None
18
+ app.state.model = None
19
+
20
+
21
+ app = FastAPI(
22
+ title="Qwen Bash Tool Calling API",
23
+ lifespan=lifespan,
24
+ )
25
+
26
+ app.include_router(router)
model_loader.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.config import MODEL_NAME
2
+
3
+
4
+ def load_model():
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ print("Loading tokenizer...")
8
+ tokenizer = AutoTokenizer.from_pretrained(
9
+ MODEL_NAME,
10
+ trust_remote_code=True,
11
+ extra_special_tokens={},
12
+ )
13
+
14
+ print("Loading model...")
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ MODEL_NAME,
17
+ trust_remote_code=True,
18
+ low_cpu_mem_usage=True,
19
+ )
20
+ model.eval()
21
+
22
+ print("Model loaded.")
23
+ return tokenizer, model
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ accelerate
6
+ sentencepiece
7
+ safetensors
routes.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+
4
+ from fastapi import APIRouter, Request
5
+
6
+ from app.config import MAX_NEW_TOKENS, MODEL_NAME
7
+ from app.schemas import PredictionResponse, PromptRequest
8
+
9
+
10
+ router = APIRouter()
11
+
12
+ COMMAND_PATTERN = re.compile(
13
+ r'"command"\s*:\s*"([^"]+)"',
14
+ )
15
+
16
+
17
+ @router.get("/")
18
+ def root():
19
+ return {
20
+ "status": "running",
21
+ }
22
+
23
+
24
+ @router.get("/health")
25
+ def health(request: Request):
26
+ model_loaded = (
27
+ hasattr(request.app.state, "model")
28
+ and hasattr(request.app.state, "tokenizer")
29
+ and request.app.state.model is not None
30
+ and request.app.state.tokenizer is not None
31
+ )
32
+
33
+ return {
34
+ "status": "healthy",
35
+ "model_loaded": model_loaded,
36
+ "model_name": MODEL_NAME,
37
+ }
38
+
39
+
40
+ @router.get("/model-info")
41
+ def model_info():
42
+ return {
43
+ "model_name": MODEL_NAME,
44
+ }
45
+
46
+
47
+ @router.post("/predict", response_model=PredictionResponse)
48
+ def predict(payload: PromptRequest, request: Request):
49
+ import torch
50
+
51
+ start_time = time.time()
52
+
53
+ tokenizer = request.app.state.tokenizer
54
+ model = request.app.state.model
55
+
56
+ messages = [
57
+ {
58
+ "role": "user",
59
+ "content": payload.prompt,
60
+ }
61
+ ]
62
+
63
+ text = tokenizer.apply_chat_template(
64
+ messages,
65
+ tokenize=False,
66
+ add_generation_prompt=True,
67
+ )
68
+
69
+ inputs = tokenizer(
70
+ text,
71
+ return_tensors="pt",
72
+ ).to(model.device)
73
+
74
+ with torch.inference_mode():
75
+ output = model.generate(
76
+ **inputs,
77
+ max_new_tokens=MAX_NEW_TOKENS,
78
+ do_sample=False,
79
+ )
80
+
81
+ prompt_token_count = inputs["input_ids"].shape[1]
82
+ generated_tokens = output[0][prompt_token_count:]
83
+
84
+ response = tokenizer.decode(
85
+ generated_tokens,
86
+ skip_special_tokens=True,
87
+ )
88
+
89
+ command = None
90
+ match = COMMAND_PATTERN.search(response)
91
+
92
+ if match:
93
+ command = match.group(1)
94
+
95
+ latency_seconds = round(
96
+ time.time() - start_time,
97
+ 3,
98
+ )
99
+
100
+ return PredictionResponse(
101
+ prompt=payload.prompt,
102
+ command=command,
103
+ raw_output=response,
104
+ latency_seconds=latency_seconds,
105
+ )
schemas.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+
4
+ class PromptRequest(BaseModel):
5
+ prompt: str
6
+
7
+
8
+ class PredictionResponse(BaseModel):
9
+ prompt: str
10
+ command: str | None
11
+ raw_output: str
12
+ latency_seconds: float