d3evil4 commited on
Commit
e536cd5
·
1 Parent(s): 35424c3

feat: huh

Browse files
Files changed (3) hide show
  1. Dockerfile +4 -3
  2. main.py +19 -177
  3. start.sh +46 -0
Dockerfile CHANGED
@@ -2,7 +2,7 @@ FROM ghcr.io/ggml-org/llama.cpp:full
2
 
3
  WORKDIR /app
4
 
5
- RUN apt update && apt install -y python3 python3-pip python3-venv
6
  RUN python3 -m venv /opt/venv
7
  ENV PATH="/opt/venv/bin:$PATH"
8
 
@@ -19,7 +19,8 @@ RUN python3 -c 'from huggingface_hub import hf_hub_download; \
19
 
20
  ENV HF_TOKEN=""
21
 
22
- COPY main.py /app/main.py
 
23
 
24
  ENTRYPOINT []
25
- CMD uvicorn main:app --host 0.0.0.0 --port 7860
 
2
 
3
  WORKDIR /app
4
 
5
+ RUN apt update && apt install -y python3 python3-pip python3-venv curl
6
  RUN python3 -m venv /opt/venv
7
  ENV PATH="/opt/venv/bin:$PATH"
8
 
 
19
 
20
  ENV HF_TOKEN=""
21
 
22
+ COPY main.py start.sh /app/
23
+ RUN chmod +x /app/start.sh
24
 
25
  ENTRYPOINT []
26
+ CMD ["/app/start.sh"]
main.py CHANGED
@@ -1,13 +1,7 @@
1
  from __future__ import annotations
2
 
3
- import asyncio
4
  import logging
5
- import os
6
- import shutil
7
- import subprocess
8
  import sys
9
- import time
10
- from contextlib import asynccontextmanager
11
  from typing import Any, AsyncIterator
12
 
13
  import httpx
@@ -23,92 +17,9 @@ logging.basicConfig(
23
  )
24
  logger = logging.getLogger("gemma4")
25
 
 
26
 
27
- def _find_llama_server() -> str:
28
- candidates = [
29
- "llama-server",
30
- "/llama-server",
31
- "/usr/local/bin/llama-server",
32
- "/usr/bin/llama-server",
33
- ]
34
- for c in candidates:
35
- found = shutil.which(c)
36
- if found:
37
- return found
38
- if os.path.isfile(c) and os.access(c, os.X_OK):
39
- return c
40
- raise RuntimeError(f"llama-server binary not found; searched: {candidates}")
41
-
42
-
43
- LLAMA_BASE = "http://localhost:8080"
44
- LLAMA_CMD = [
45
- _find_llama_server(),
46
- "-m", "/app/gemma-4-E2B-it-UD-Q5_K_XL.gguf",
47
- "--mmproj", "/app/mmproj-BF16.gguf",
48
- "--host", "0.0.0.0",
49
- "--port", "8080",
50
- "-t", "2",
51
- "--cache-type-k", "q8_0",
52
- "--cache-type-v", "iq4_nl",
53
- "-c", "128000",
54
- "-n", "38912",
55
- ]
56
- HEALTH_TIMEOUT = 300
57
- HEALTH_POLL_INTERVAL = 2
58
-
59
- _llama_proc: subprocess.Popen[bytes] | None = None
60
- _http_client: httpx.AsyncClient | None = None
61
-
62
-
63
- async def _wait_for_llama() -> None:
64
- assert _http_client is not None
65
- deadline = time.monotonic() + HEALTH_TIMEOUT
66
- while time.monotonic() < deadline:
67
- try:
68
- resp = await _http_client.get(f"{LLAMA_BASE}/health", timeout=5.0)
69
- if resp.status_code == 200:
70
- logger.info("llama.cpp server is healthy")
71
- return
72
- except httpx.TransportError:
73
- pass
74
- await asyncio.sleep(HEALTH_POLL_INTERVAL)
75
- raise RuntimeError("llama.cpp server did not become healthy within timeout")
76
-
77
-
78
- @asynccontextmanager
79
- async def lifespan(app: FastAPI) -> AsyncIterator[None]:
80
- global _llama_proc, _http_client
81
-
82
- _http_client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0))
83
-
84
- logger.info("Starting llama.cpp server: %s", " ".join(LLAMA_CMD))
85
- _llama_proc = subprocess.Popen(
86
- LLAMA_CMD,
87
- stdout=subprocess.PIPE,
88
- stderr=subprocess.STDOUT,
89
- )
90
-
91
- try:
92
- await _wait_for_llama()
93
- except RuntimeError:
94
- _llama_proc.terminate()
95
- await _http_client.aclose()
96
- raise
97
-
98
- yield
99
-
100
- logger.info("Shutting down llama.cpp server")
101
- if _llama_proc and _llama_proc.poll() is None:
102
- _llama_proc.terminate()
103
- try:
104
- _llama_proc.wait(timeout=10)
105
- except subprocess.TimeoutExpired:
106
- _llama_proc.kill()
107
-
108
- await _http_client.aclose()
109
-
110
-
111
- app = FastAPI(title="Gemma 4 API", version="1.0.0", lifespan=lifespan)
112
 
113
  app.add_middleware(
114
  CORSMiddleware,
@@ -118,6 +29,8 @@ app.add_middleware(
118
  allow_headers=["*"],
119
  )
120
 
 
 
121
 
122
  @app.middleware("http")
123
  async def log_requests(request: Request, call_next):
@@ -127,17 +40,10 @@ async def log_requests(request: Request, call_next):
127
  return response
128
 
129
 
130
- def _client() -> httpx.AsyncClient:
131
- if _http_client is None:
132
- raise HTTPException(status_code=503, detail="Service not initialized")
133
- return _http_client
134
-
135
-
136
  async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[bytes]:
137
- async with _client().stream("POST", url, json=payload) as resp:
138
  if resp.status_code != 200:
139
- body = await resp.aread()
140
- yield body
141
  return
142
  async for chunk in resp.aiter_bytes():
143
  yield chunk
@@ -147,24 +53,6 @@ async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[byte
147
  # Pydantic models
148
  # ---------------------------------------------------------------------------
149
 
150
- class MessageContent(BaseModel):
151
- role: str
152
- content: Any
153
-
154
-
155
- class ChatRequest(BaseModel):
156
- model: str | None = None
157
- messages: list[dict[str, Any]]
158
- max_tokens: int | None = Field(default=None, alias="max_tokens")
159
- temperature: float | None = None
160
- top_p: float | None = None
161
- stream: bool = False
162
- stop: list[str] | None = None
163
- extra: dict[str, Any] = Field(default_factory=dict)
164
-
165
- model_config = {"extra": "allow", "populate_by_name": True}
166
-
167
-
168
  class SimpleChatRequest(BaseModel):
169
  messages: list[dict[str, Any]]
170
  max_tokens: int = 2048
@@ -193,24 +81,17 @@ class VisionRequest(BaseModel):
193
  @app.get("/health")
194
  async def health() -> dict[str, Any]:
195
  try:
196
- resp = await _client().get(f"{LLAMA_BASE}/health", timeout=5.0)
197
  llama_status = resp.json() if resp.status_code == 200 else {"status": "error"}
198
  except httpx.TransportError:
199
  raise HTTPException(status_code=503, detail="llama.cpp server unreachable")
200
-
201
- try:
202
- models_resp = await _client().get(f"{LLAMA_BASE}/v1/models", timeout=5.0)
203
- models = models_resp.json() if models_resp.status_code == 200 else {}
204
- except httpx.TransportError:
205
- models = {}
206
-
207
- return {"status": "ok", "llama": llama_status, "models": models}
208
 
209
 
210
  @app.get("/v1/models")
211
  async def list_models() -> Any:
212
  try:
213
- resp = await _client().get(f"{LLAMA_BASE}/v1/models", timeout=10.0)
214
  except httpx.TransportError as exc:
215
  raise HTTPException(status_code=503, detail=str(exc))
216
  return resp.json()
@@ -219,26 +100,17 @@ async def list_models() -> Any:
219
  @app.post("/v1/chat/completions")
220
  async def chat_completions(request: Request) -> Any:
221
  payload = await request.json()
222
- stream = payload.get("stream", False)
223
-
224
- if stream:
225
  return StreamingResponse(
226
  _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
227
  media_type="text/event-stream",
228
  )
229
-
230
  try:
231
- resp = await _client().post(
232
- f"{LLAMA_BASE}/v1/chat/completions",
233
- json=payload,
234
- timeout=300.0,
235
- )
236
  except httpx.TransportError as exc:
237
  raise HTTPException(status_code=503, detail=str(exc))
238
-
239
  if resp.status_code != 200:
240
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
241
-
242
  return resp.json()
243
 
244
 
@@ -250,25 +122,17 @@ async def chat(req: SimpleChatRequest) -> Any:
250
  "temperature": req.temperature,
251
  "stream": req.stream,
252
  }
253
-
254
  if req.stream:
255
  return StreamingResponse(
256
  _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
257
  media_type="text/event-stream",
258
  )
259
-
260
  try:
261
- resp = await _client().post(
262
- f"{LLAMA_BASE}/v1/chat/completions",
263
- json=payload,
264
- timeout=300.0,
265
- )
266
  except httpx.TransportError as exc:
267
  raise HTTPException(status_code=503, detail=str(exc))
268
-
269
  if resp.status_code != 200:
270
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
271
-
272
  return resp.json()
273
 
274
 
@@ -280,64 +144,42 @@ async def generate(req: GenerateRequest) -> Any:
280
  "temperature": req.temperature,
281
  "stream": req.stream,
282
  }
283
-
284
  if req.stream:
285
  return StreamingResponse(
286
  _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
287
  media_type="text/event-stream",
288
  )
289
-
290
  try:
291
- resp = await _client().post(
292
- f"{LLAMA_BASE}/v1/chat/completions",
293
- json=payload,
294
- timeout=300.0,
295
- )
296
  except httpx.TransportError as exc:
297
  raise HTTPException(status_code=503, detail=str(exc))
298
-
299
  if resp.status_code != 200:
300
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
301
-
302
  return resp.json()
303
 
304
 
305
  @app.post("/vision")
306
  async def vision(req: VisionRequest) -> Any:
307
- image_content: dict[str, Any]
308
  if req.image.startswith("http://") or req.image.startswith("https://"):
309
- image_content = {"type": "image_url", "image_url": {"url": req.image}}
310
  else:
311
  image_content = {
312
  "type": "image_url",
313
  "image_url": {"url": f"data:image/jpeg;base64,{req.image}"},
314
  }
315
-
316
  payload: dict[str, Any] = {
317
- "messages": [
318
- {
319
- "role": "user",
320
- "content": [
321
- {"type": "text", "text": req.prompt},
322
- image_content,
323
- ],
324
- }
325
- ],
326
  "max_tokens": req.max_tokens,
327
  "temperature": req.temperature,
328
  "stream": False,
329
  }
330
-
331
  try:
332
- resp = await _client().post(
333
- f"{LLAMA_BASE}/v1/chat/completions",
334
- json=payload,
335
- timeout=300.0,
336
- )
337
  except httpx.TransportError as exc:
338
  raise HTTPException(status_code=503, detail=str(exc))
339
-
340
  if resp.status_code != 200:
341
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
342
-
343
  return resp.json()
 
1
  from __future__ import annotations
2
 
 
3
  import logging
 
 
 
4
  import sys
 
 
5
  from typing import Any, AsyncIterator
6
 
7
  import httpx
 
17
  )
18
  logger = logging.getLogger("gemma4")
19
 
20
+ LLAMA_BASE = "http://127.0.0.1:8080"
21
 
22
+ app = FastAPI(title="Gemma 4 API", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  app.add_middleware(
25
  CORSMiddleware,
 
29
  allow_headers=["*"],
30
  )
31
 
32
+ _client = httpx.AsyncClient(timeout=httpx.Timeout(300.0, connect=10.0))
33
+
34
 
35
  @app.middleware("http")
36
  async def log_requests(request: Request, call_next):
 
40
  return response
41
 
42
 
 
 
 
 
 
 
43
  async def _proxy_stream(url: str, payload: dict[str, Any]) -> AsyncIterator[bytes]:
44
+ async with _client.stream("POST", url, json=payload) as resp:
45
  if resp.status_code != 200:
46
+ yield await resp.aread()
 
47
  return
48
  async for chunk in resp.aiter_bytes():
49
  yield chunk
 
53
  # Pydantic models
54
  # ---------------------------------------------------------------------------
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class SimpleChatRequest(BaseModel):
57
  messages: list[dict[str, Any]]
58
  max_tokens: int = 2048
 
81
  @app.get("/health")
82
  async def health() -> dict[str, Any]:
83
  try:
84
+ resp = await _client.get(f"{LLAMA_BASE}/health", timeout=5.0)
85
  llama_status = resp.json() if resp.status_code == 200 else {"status": "error"}
86
  except httpx.TransportError:
87
  raise HTTPException(status_code=503, detail="llama.cpp server unreachable")
88
+ return {"status": "ok", "llama": llama_status}
 
 
 
 
 
 
 
89
 
90
 
91
  @app.get("/v1/models")
92
  async def list_models() -> Any:
93
  try:
94
+ resp = await _client.get(f"{LLAMA_BASE}/v1/models", timeout=10.0)
95
  except httpx.TransportError as exc:
96
  raise HTTPException(status_code=503, detail=str(exc))
97
  return resp.json()
 
100
  @app.post("/v1/chat/completions")
101
  async def chat_completions(request: Request) -> Any:
102
  payload = await request.json()
103
+ if payload.get("stream", False):
 
 
104
  return StreamingResponse(
105
  _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
106
  media_type="text/event-stream",
107
  )
 
108
  try:
109
+ resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
 
 
 
 
110
  except httpx.TransportError as exc:
111
  raise HTTPException(status_code=503, detail=str(exc))
 
112
  if resp.status_code != 200:
113
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
 
114
  return resp.json()
115
 
116
 
 
122
  "temperature": req.temperature,
123
  "stream": req.stream,
124
  }
 
125
  if req.stream:
126
  return StreamingResponse(
127
  _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
128
  media_type="text/event-stream",
129
  )
 
130
  try:
131
+ resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
 
 
 
 
132
  except httpx.TransportError as exc:
133
  raise HTTPException(status_code=503, detail=str(exc))
 
134
  if resp.status_code != 200:
135
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
 
136
  return resp.json()
137
 
138
 
 
144
  "temperature": req.temperature,
145
  "stream": req.stream,
146
  }
 
147
  if req.stream:
148
  return StreamingResponse(
149
  _proxy_stream(f"{LLAMA_BASE}/v1/chat/completions", payload),
150
  media_type="text/event-stream",
151
  )
 
152
  try:
153
+ resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
 
 
 
 
154
  except httpx.TransportError as exc:
155
  raise HTTPException(status_code=503, detail=str(exc))
 
156
  if resp.status_code != 200:
157
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
 
158
  return resp.json()
159
 
160
 
161
  @app.post("/vision")
162
  async def vision(req: VisionRequest) -> Any:
 
163
  if req.image.startswith("http://") or req.image.startswith("https://"):
164
+ image_content: dict[str, Any] = {"type": "image_url", "image_url": {"url": req.image}}
165
  else:
166
  image_content = {
167
  "type": "image_url",
168
  "image_url": {"url": f"data:image/jpeg;base64,{req.image}"},
169
  }
 
170
  payload: dict[str, Any] = {
171
+ "messages": [{
172
+ "role": "user",
173
+ "content": [{"type": "text", "text": req.prompt}, image_content],
174
+ }],
 
 
 
 
 
175
  "max_tokens": req.max_tokens,
176
  "temperature": req.temperature,
177
  "stream": False,
178
  }
 
179
  try:
180
+ resp = await _client.post(f"{LLAMA_BASE}/v1/chat/completions", json=payload, timeout=300.0)
 
 
 
 
181
  except httpx.TransportError as exc:
182
  raise HTTPException(status_code=503, detail=str(exc))
 
183
  if resp.status_code != 200:
184
  raise HTTPException(status_code=resp.status_code, detail=resp.text)
 
185
  return resp.json()
start.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Find llama-server binary
5
+ LLAMA_BIN=$(find /usr /app /llama.cpp /usr/local / -maxdepth 6 -name "llama-server" -type f 2>/dev/null | head -1)
6
+
7
+ if [ -z "$LLAMA_BIN" ]; then
8
+ echo "ERROR: llama-server binary not found"
9
+ exit 1
10
+ fi
11
+
12
+ echo "Found llama-server at: $LLAMA_BIN"
13
+
14
+ "$LLAMA_BIN" \
15
+ -m /app/gemma-4-E2B-it-UD-Q5_K_XL.gguf \
16
+ --mmproj /app/mmproj-BF16.gguf \
17
+ --host 127.0.0.1 \
18
+ --port 8080 \
19
+ -t 2 \
20
+ --cache-type-k q8_0 \
21
+ --cache-type-v iq4_nl \
22
+ -c 128000 \
23
+ -n 38912 &
24
+
25
+ LLAMA_PID=$!
26
+ echo "llama-server started (PID $LLAMA_PID)"
27
+
28
+ # Wait up to 5 minutes for llama-server to be healthy
29
+ echo "Waiting for llama-server to be ready..."
30
+ for i in $(seq 1 150); do
31
+ if curl -sf http://127.0.0.1:8080/health > /dev/null 2>&1; then
32
+ echo "llama-server is ready"
33
+ break
34
+ fi
35
+ if ! kill -0 "$LLAMA_PID" 2>/dev/null; then
36
+ echo "ERROR: llama-server process died"
37
+ exit 1
38
+ fi
39
+ if [ "$i" -eq 150 ]; then
40
+ echo "ERROR: llama-server did not become ready in time"
41
+ exit 1
42
+ fi
43
+ sleep 2
44
+ done
45
+
46
+ exec uvicorn main:app --host 0.0.0.0 --port 7860