Devaholic commited on
Commit
2addc01
·
0 Parent(s):

First commit

Browse files
Files changed (9) hide show
  1. .gitignore +14 -0
  2. .python-version +1 -0
  3. Dockerfile +16 -0
  4. README.md +110 -0
  5. main.py +107 -0
  6. pyproject.toml +20 -0
  7. requirements.txt +62 -0
  8. train.py +121 -0
  9. uv.lock +0 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ results/
13
+
14
+ __huggingface_repos__.json
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install uv
4
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
5
+
6
+ WORKDIR /app
7
+
8
+ COPY requirements.txt .
9
+ RUN uv pip install --system --no-cache -r requirements.txt
10
+
11
+ COPY main.py .
12
+ COPY results ./results
13
+
14
+ EXPOSE 8000
15
+
16
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sentiment API
2
+
3
+ Fine-tune **DistilBERT** on the SST-2 dataset and serve it as a REST API with FastAPI.
4
+
5
+ ## Overview
6
+
7
+ | | |
8
+ |---|---|
9
+ | **Model** | `distilbert-base-uncased` fine-tuned on SST-2 |
10
+ | **Task** | Binary sentiment classification (POSITIVE / NEGATIVE) |
11
+ | **Dataset** | [GLUE SST-2](https://huggingface.co/datasets/glue) — Stanford Movie Reviews |
12
+ | **Serving** | FastAPI + Uvicorn |
13
+ | **Packaging** | Docker |
14
+ | **Deps** | [uv](https://github.com/astral-sh/uv) |
15
+
16
+ ## Project structure
17
+
18
+ ```
19
+ .
20
+ ├── main.py # FastAPI inference server
21
+ ├── train.py # Training script (fine-tunes DistilBERT, saves to results/)
22
+ ├── Dockerfile # Production container
23
+ ├── pyproject.toml # Project metadata and dependencies
24
+ ├── requirements.txt # Pinned requirements for Docker
25
+ └── results/ # Training output — gitignored
26
+ └── best_model/ # Saved model loaded by the API
27
+ ```
28
+
29
+ ## Quickstart
30
+
31
+ ### 1. Install dependencies
32
+
33
+ ```bash
34
+ uv sync
35
+ ```
36
+
37
+ ### 2. Train the model
38
+
39
+ ```bash
40
+ uv run python train.py
41
+ ```
42
+
43
+ This downloads `distilbert-base-uncased` and the SST-2 dataset from HuggingFace, fine-tunes the model, and saves the best checkpoint to `results/best_model/`.
44
+
45
+ ### 3. Run the API
46
+
47
+ ```bash
48
+ uv run python -m uvicorn main:app --reload
49
+ ```
50
+
51
+ > **Note (Windows):** `fastapi dev` / `uvicorn` trampolines are broken in some uv versions on Windows. Use `python -m uvicorn` instead.
52
+
53
+ The API is available at `http://localhost:8000`. Interactive docs at `http://localhost:8000/docs`.
54
+
55
+ ## API endpoints
56
+
57
+ ### `GET /`
58
+ Health check.
59
+ ```json
60
+ {"status": "ok", "model": "./results/best_model"}
61
+ ```
62
+
63
+ ### `POST /predict`
64
+ Single text prediction.
65
+
66
+ **Request:**
67
+ ```json
68
+ {"text": "This movie was absolutely fantastic!"}
69
+ ```
70
+
71
+ **Response:**
72
+ ```json
73
+ {
74
+ "text": "This movie was absolutely fantastic!",
75
+ "label": "POSITIVE",
76
+ "score": 0.9987,
77
+ "latency_ms": 12.4
78
+ }
79
+ ```
80
+
81
+ ### `POST /predict/batch`
82
+ Batch prediction (up to 32 texts).
83
+
84
+ **Request:**
85
+ ```json
86
+ {"texts": ["Great film!", "Terrible waste of time."]}
87
+ ```
88
+
89
+ **Response:**
90
+ ```json
91
+ {
92
+ "results": [
93
+ {"text": "Great film!", "label": "POSITIVE", "score": 0.9981, "latency_ms": 6.1},
94
+ {"text": "Terrible waste of time.", "label": "NEGATIVE", "score": 0.9973, "latency_ms": 6.1}
95
+ ],
96
+ "total_latency_ms": 12.3
97
+ }
98
+ ```
99
+
100
+ ## Docker
101
+
102
+ ```bash
103
+ # Build
104
+ docker build -t ml-api .
105
+
106
+ # Run
107
+ docker run -p 8000:8000 ml-api
108
+ ```
109
+
110
+ The Dockerfile installs dependencies via uv and serves the API on port 8000.
main.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from contextlib import asynccontextmanager
3
+
4
+ from fastapi import FastAPI, HTTPException
5
+ from pydantic import BaseModel, Field
6
+ from transformers import pipeline
7
+
8
+ MODEL_PATH = "./results/best_model"
9
+
10
+ ml: dict = {}
11
+
12
+
13
+ @asynccontextmanager
14
+ async def lifespan(app: FastAPI):
15
+ print(f"Loading model from {MODEL_PATH} ...")
16
+ ml["pipe"] = pipeline(
17
+ "text-classification",
18
+ model=MODEL_PATH,
19
+ tokenizer=MODEL_PATH,
20
+ truncation=True,
21
+ max_length=128,
22
+ )
23
+ print("Model is ready")
24
+ yield
25
+ ml.clear()
26
+
27
+
28
+ app = FastAPI(
29
+ title="Sentiment API",
30
+ description="DistilBERT fine-tuned on SST-2 — binary sentiment classification",
31
+ version="1.0.0",
32
+ lifespan=lifespan,
33
+ )
34
+
35
+
36
+ class PredictRequest(BaseModel):
37
+ text: str = Field(
38
+ ...,
39
+ min_length=1,
40
+ max_length=512,
41
+ example="This movie was absolutely fantastic!",
42
+ )
43
+
44
+
45
+ class PredictResponse(BaseModel):
46
+ text: str
47
+ label: str # "POSITIVE" | "NEGATIVE"
48
+ score: float # confidence 0–1
49
+ latency_ms: float
50
+
51
+
52
+ class BatchRequest(BaseModel):
53
+ texts: list[str] = Field(
54
+ ...,
55
+ min_length=1,
56
+ max_length=32,
57
+ example=["Great film!", "Terrible waste of time."],
58
+ )
59
+
60
+
61
+ class BatchResponse(BaseModel):
62
+ results: list[PredictResponse]
63
+ total_latency_ms: float
64
+
65
+
66
+ @app.get("/", tags=["health"])
67
+ def health():
68
+ return {"status": "ok", "model": MODEL_PATH}
69
+
70
+
71
+ @app.post("/predict", response_model=PredictResponse, tags=["inference"])
72
+ def predict(req: PredictRequest):
73
+ if "pipe" not in ml:
74
+ raise HTTPException(status_code=503, detail="Model not loaded")
75
+
76
+ t0 = time.perf_counter()
77
+ result = ml["pipe"](req.text)[0]
78
+ latency = (time.perf_counter() - t0) * 1000
79
+
80
+ return PredictResponse(
81
+ text=req.text,
82
+ label=result["label"],
83
+ score=round(result["score"], 4),
84
+ latency_ms=round(latency, 2),
85
+ )
86
+
87
+
88
+ @app.post("/predict/batch", response_model=BatchResponse, tags=["inference"])
89
+ def predict_batch(req: BatchRequest):
90
+ if "pipe" not in ml:
91
+ raise HTTPException(status_code=503, detail="Model not loaded")
92
+
93
+ t0 = time.perf_counter()
94
+ raw = ml["pipe"](req.texts)
95
+ total_latency = (time.perf_counter() - t0) * 1000
96
+
97
+ results = [
98
+ PredictResponse(
99
+ text=text,
100
+ label=r["label"],
101
+ score=round(r["score"], 4),
102
+ latency_ms=round(total_latency / len(req.texts), 2),
103
+ )
104
+ for text, r in zip(req.texts, raw)
105
+ ]
106
+
107
+ return BatchResponse(results=results, total_latency_ms=round(total_latency, 2))
pyproject.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "hf-training"
3
+ version = "0.1.0"
4
+ description = "Fine-tune DistilBERT on SST-2 sentiment classification"
5
+ requires-python = ">=3.10"
6
+ dependencies = [
7
+ "torch>=2.2.0",
8
+ "transformers>=4.40.0",
9
+ "datasets>=2.19.0",
10
+ "evaluate>=0.4.1",
11
+ "accelerate>=0.29.0",
12
+ "scikit-learn>=1.4.0",
13
+ "fastapi>=0.111.0",
14
+ "uvicorn[standard]>=0.29.0",
15
+ ]
16
+
17
+ [dependency-groups]
18
+ dev = [
19
+ "ipykernel>=6.29.0",
20
+ ]
requirements.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-doc==0.0.4
2
+ annotated-types==0.7.0
3
+ anyio==4.12.1
4
+ certifi==2026.2.25
5
+ click==8.3.1
6
+ colorama==0.4.6
7
+ dnspython==2.8.0
8
+ email-validator==2.3.0
9
+ fastapi==0.135.1
10
+ fastapi-cli==0.0.24
11
+ fastapi-cloud-cli==0.14.0
12
+ fastar==0.8.0
13
+ filelock==3.25.0
14
+ fsspec==2026.2.0
15
+ h11==0.16.0
16
+ hf-xet==1.3.2
17
+ httpcore==1.0.9
18
+ httptools==0.7.1
19
+ httpx==0.28.1
20
+ huggingface-hub==1.6.0
21
+ idna==3.11
22
+ jinja2==3.1.6
23
+ joblib==1.5.3
24
+ markdown-it-py==4.0.0
25
+ markupsafe==3.0.3
26
+ mdurl==0.1.2
27
+ mpmath==1.3.0
28
+ networkx==3.6.1
29
+ numpy==2.4.2
30
+ packaging==26.0
31
+ pydantic==2.12.5
32
+ pydantic-core==2.41.5
33
+ pydantic-extra-types==2.11.0
34
+ pydantic-settings==2.13.1
35
+ pygments==2.19.2
36
+ python-dotenv==1.2.2
37
+ python-multipart==0.0.22
38
+ pyyaml==6.0.3
39
+ regex==2026.2.28
40
+ rich==14.3.3
41
+ rich-toolkit==0.19.7
42
+ rignore==0.7.6
43
+ safetensors==0.7.0
44
+ scikit-learn==1.8.0
45
+ scipy==1.17.1
46
+ sentry-sdk==2.54.0
47
+ setuptools==82.0.0
48
+ shellingham==1.5.4
49
+ starlette==0.52.1
50
+ sympy==1.14.0
51
+ threadpoolctl==3.6.0
52
+ tokenizers==0.22.2
53
+ torch==2.10.0
54
+ tqdm==4.67.3
55
+ transformers==5.3.0
56
+ typer==0.24.1
57
+ typing-extensions==4.15.0
58
+ typing-inspection==0.4.2
59
+ urllib3==2.6.3
60
+ uvicorn==0.41.0
61
+ watchfiles==1.1.1
62
+ websockets==16.0
train.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from contextlib import asynccontextmanager
3
+ from typing import Optional
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, Field
7
+ from transformers import pipeline
8
+
9
+ # ── Config ────────────────────────────────────────────────────────────────────
10
+
11
+ MODEL_PATH = "./results/best_model" # produced by train.py
12
+
13
+
14
+ # ── Lifespan (load model once on startup) ─────────────────────────────────────
15
+
16
+ ml: dict = {}
17
+
18
+
19
+ @asynccontextmanager
20
+ async def lifespan(app: FastAPI):
21
+ print(f"Loading model from {MODEL_PATH} ...")
22
+ ml["pipe"] = pipeline(
23
+ "text-classification",
24
+ model=MODEL_PATH,
25
+ tokenizer=MODEL_PATH,
26
+ truncation=True,
27
+ max_length=128,
28
+ )
29
+ print("Model ready ✅")
30
+ yield
31
+ ml.clear()
32
+
33
+
34
+ # ── App ───────────────────────────────────────────────────────────────────────
35
+
36
+ app = FastAPI(
37
+ title="Sentiment API",
38
+ description="DistilBERT fine-tuned on SST-2 — binary sentiment classification",
39
+ version="1.0.0",
40
+ lifespan=lifespan,
41
+ )
42
+
43
+
44
+ # ── Schemas ───────────────────────────────────────────────────────────────────
45
+
46
+
47
+ class PredictRequest(BaseModel):
48
+ text: str = Field(
49
+ ...,
50
+ min_length=1,
51
+ max_length=512,
52
+ example="This movie was absolutely fantastic!",
53
+ )
54
+
55
+
56
+ class PredictResponse(BaseModel):
57
+ text: str
58
+ label: str # "POSITIVE" | "NEGATIVE"
59
+ score: float # confidence 0–1
60
+ latency_ms: float
61
+
62
+
63
+ class BatchRequest(BaseModel):
64
+ texts: list[str] = Field(
65
+ ...,
66
+ min_length=1,
67
+ max_length=32,
68
+ example=["Great film!", "Terrible waste of time."],
69
+ )
70
+
71
+
72
+ class BatchResponse(BaseModel):
73
+ results: list[PredictResponse]
74
+ total_latency_ms: float
75
+
76
+
77
+ # ── Routes ────────────────────────────────────────────────────────────────────
78
+
79
+
80
+ @app.get("/", tags=["health"])
81
+ def health():
82
+ return {"status": "ok", "model": MODEL_PATH}
83
+
84
+
85
+ @app.post("/predict", response_model=PredictResponse, tags=["inference"])
86
+ def predict(req: PredictRequest):
87
+ if "pipe" not in ml:
88
+ raise HTTPException(status_code=503, detail="Model not loaded")
89
+
90
+ t0 = time.perf_counter()
91
+ result = ml["pipe"](req.text)[0]
92
+ latency = (time.perf_counter() - t0) * 1000
93
+
94
+ return PredictResponse(
95
+ text=req.text,
96
+ label=result["label"],
97
+ score=round(result["score"], 4),
98
+ latency_ms=round(latency, 2),
99
+ )
100
+
101
+
102
+ @app.post("/predict/batch", response_model=BatchResponse, tags=["inference"])
103
+ def predict_batch(req: BatchRequest):
104
+ if "pipe" not in ml:
105
+ raise HTTPException(status_code=503, detail="Model not loaded")
106
+
107
+ t0 = time.perf_counter()
108
+ raw = ml["pipe"](req.texts)
109
+ total_latency = (time.perf_counter() - t0) * 1000
110
+
111
+ results = [
112
+ PredictResponse(
113
+ text=text,
114
+ label=r["label"],
115
+ score=round(r["score"], 4),
116
+ latency_ms=round(total_latency / len(req.texts), 2),
117
+ )
118
+ for text, r in zip(req.texts, raw)
119
+ ]
120
+
121
+ return BatchResponse(results=results, total_latency_ms=round(total_latency, 2))
uv.lock ADDED
The diff for this file is too large to render. See raw diff