Kasher13 commited on
Commit
96baa52
·
verified ·
1 Parent(s): 9578bc0

Update inference proxy (Groq/Gemini/CPU priority chain)

Browse files
Files changed (1) hide show
  1. app.py +144 -51
app.py CHANGED
@@ -1,19 +1,14 @@
1
  """
2
- deploy/tgi_space/app.py — Fallback inference server if TGI fails on CPU Space.
3
 
4
- Serves OpenAI-compatible /v1/chat/completions using HuggingFace transformers directly.
5
- Use this Dockerfile instead when TGI has RUNTIME_ERROR on a Space:
 
 
6
 
7
- FROM python:3.11-slim
8
- WORKDIR /app
9
- COPY app.py .
10
- RUN pip install --no-cache-dir fastapi uvicorn transformers accelerate torch --index-url https://download.pytorch.org/whl/cpu
11
- ENV PORT=7860
12
- ENV MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct
13
- ENV HF_HOME=/data
14
- CMD ["python", "app.py"]
15
-
16
- Usage (from setup_spaces.py): set FALLBACK_DOCKERFILE=deploy/tgi_space/Dockerfile.fallback
17
  """
18
  from __future__ import annotations
19
 
@@ -24,38 +19,48 @@ import time
24
  import uuid
25
  from typing import Any, Dict, List, Optional
26
 
27
- import torch
28
- from fastapi import FastAPI, HTTPException
29
  from fastapi.responses import JSONResponse
30
  from pydantic import BaseModel
31
- from transformers import AutoModelForCausalLM, AutoTokenizer
32
 
33
- app = FastAPI(title="Persona Inference Server")
 
 
 
 
34
 
35
  MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
36
  PORT = int(os.environ.get("PORT", 7860))
37
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 600))
38
 
39
- # Load model at startup (float32 for CPU, bfloat16 if available)
 
 
 
 
 
 
40
  _tokenizer = None
41
  _model = None
 
42
 
43
 
44
- def _load_model():
45
- global _tokenizer, _model
46
- print(f"Loading {MODEL_ID} ...")
47
- _tokenizer = AutoTokenizer.from_pretrained(
48
- MODEL_ID, trust_remote_code=True
49
- )
50
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
51
- _model = AutoModelForCausalLM.from_pretrained(
52
- MODEL_ID,
53
- torch_dtype=dtype,
54
- device_map="auto",
55
- trust_remote_code=True,
56
- )
57
- _model.eval()
58
- print(f"Model loaded: {MODEL_ID} dtype={dtype}")
59
 
60
 
61
  # ---------------------------------------------------------------------------
@@ -76,24 +81,77 @@ class ChatRequest(BaseModel):
76
 
77
 
78
  # ---------------------------------------------------------------------------
79
- # Endpoints
80
  # ---------------------------------------------------------------------------
81
 
82
- @app.on_event("startup")
83
- async def startup():
84
- loop = asyncio.get_event_loop()
85
- await loop.run_in_executor(None, _load_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
- @app.get("/health")
89
- async def health():
90
- return {"status": "ok", "model": MODEL_ID, "loaded": _model is not None}
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
- @app.post("/v1/chat/completions")
94
- async def chat_completions(req: ChatRequest):
95
- if _model is None or _tokenizer is None:
96
- raise HTTPException(status_code=503, detail="Model not loaded yet")
 
 
 
97
 
98
  msgs = [{"role": m.role, "content": m.content} for m in req.messages]
99
 
@@ -121,16 +179,51 @@ async def chat_completions(req: ChatRequest):
121
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
122
  "object": "chat.completion",
123
  "created": int(time.time()),
124
- "model": req.model,
125
- "choices": [{
126
- "index": 0,
127
- "message": {"role": "assistant", "content": content},
128
- "finish_reason": "stop",
129
- }],
130
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
131
  }
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  if __name__ == "__main__":
135
  import uvicorn
136
  uvicorn.run(app, host="0.0.0.0", port=PORT)
 
1
  """
2
+ deploy/tgi_space/app.py — Smart inference proxy for persona generation.
3
 
4
+ Priority chain (first available wins):
5
+ 1. GROQ_API_KEY → Groq Cloud (fast, free: 14,400 req/day with llama-3.1-8b-instant)
6
+ 2. GEMINI_API_KEY → Gemini Flash (generous free: 1,500 req/day, 1M tok/day)
7
+ 3. Local CPU → transformers (slow fallback, only for smoke-testing)
8
 
9
+ To activate a fast provider, set the env var in the HF Space settings:
10
+ - Groq: GROQ_API_KEY = gsk_... (free at https://console.groq.com)
11
+ - Gemini: GEMINI_API_KEY = AIza... (free at https://aistudio.google.com)
 
 
 
 
 
 
 
12
  """
13
  from __future__ import annotations
14
 
 
19
  import uuid
20
  from typing import Any, Dict, List, Optional
21
 
22
+ import httpx
23
+ from fastapi import FastAPI, HTTPException, Request
24
  from fastapi.responses import JSONResponse
25
  from pydantic import BaseModel
 
26
 
27
+ app = FastAPI(title="Persona Inference Proxy")
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Config (env vars)
31
+ # ---------------------------------------------------------------------------
32
 
33
  MODEL_ID = os.environ.get("MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct")
34
  PORT = int(os.environ.get("PORT", 7860))
35
  MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 600))
36
 
37
+ GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
38
+ GROQ_MODEL = os.environ.get("GROQ_MODEL", "llama-3.1-8b-instant")
39
+
40
+ GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "")
41
+ GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-1.5-flash")
42
+
43
+ # Local model (loaded lazily — only if no fast provider)
44
  _tokenizer = None
45
  _model = None
46
+ _local_loaded = False
47
 
48
 
49
+ def _infer_mode() -> str:
50
+ if GROQ_API_KEY:
51
+ return "groq"
52
+ if GEMINI_API_KEY:
53
+ return "gemini"
54
+ return "local-cpu"
55
+
56
+
57
+ def _active_model() -> str:
58
+ mode = _infer_mode()
59
+ if mode == "groq":
60
+ return GROQ_MODEL
61
+ if mode == "gemini":
62
+ return GEMINI_MODEL
63
+ return MODEL_ID
64
 
65
 
66
  # ---------------------------------------------------------------------------
 
81
 
82
 
83
  # ---------------------------------------------------------------------------
84
+ # Provider implementations
85
  # ---------------------------------------------------------------------------
86
 
87
+ async def _call_groq(req: ChatRequest) -> dict:
88
+ payload = {
89
+ "model": GROQ_MODEL,
90
+ "messages": [{"role": m.role, "content": m.content} for m in req.messages],
91
+ "max_tokens": min(req.max_tokens, MAX_NEW_TOKENS),
92
+ "temperature": req.temperature,
93
+ "top_p": req.top_p,
94
+ }
95
+ async with httpx.AsyncClient(timeout=90.0) as client:
96
+ r = await client.post(
97
+ "https://api.groq.com/openai/v1/chat/completions",
98
+ json=payload,
99
+ headers={
100
+ "Authorization": f"Bearer {GROQ_API_KEY}",
101
+ "Content-Type": "application/json",
102
+ },
103
+ )
104
+ if r.status_code != 200:
105
+ raise HTTPException(status_code=r.status_code, detail=f"Groq error: {r.text[:200]}")
106
+ return r.json()
107
+
108
+
109
+ async def _call_gemini(req: ChatRequest) -> dict:
110
+ """Call Gemini via its OpenAI-compatible endpoint."""
111
+ payload = {
112
+ "model": GEMINI_MODEL,
113
+ "messages": [{"role": m.role, "content": m.content} for m in req.messages],
114
+ "max_tokens": min(req.max_tokens, MAX_NEW_TOKENS),
115
+ "temperature": req.temperature,
116
+ }
117
+ async with httpx.AsyncClient(timeout=90.0) as client:
118
+ r = await client.post(
119
+ "https://generativelanguage.googleapis.com/v1beta/openai/chat/completions",
120
+ json=payload,
121
+ headers={
122
+ "Authorization": f"Bearer {GEMINI_API_KEY}",
123
+ "Content-Type": "application/json",
124
+ },
125
+ )
126
+ if r.status_code != 200:
127
+ raise HTTPException(status_code=r.status_code, detail=f"Gemini error: {r.text[:200]}")
128
+ return r.json()
129
 
130
 
131
+ def _load_local_model():
132
+ global _tokenizer, _model, _local_loaded
133
+ import torch
134
+ from transformers import AutoModelForCausalLM, AutoTokenizer
135
+ print(f"Loading local model {MODEL_ID} on CPU ...")
136
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
137
+ _model = AutoModelForCausalLM.from_pretrained(
138
+ MODEL_ID,
139
+ torch_dtype=torch.float32,
140
+ device_map="auto",
141
+ trust_remote_code=True,
142
+ )
143
+ _model.eval()
144
+ _local_loaded = True
145
+ print(f"Local model loaded: {MODEL_ID}")
146
 
147
 
148
+ async def _call_local(req: ChatRequest) -> dict:
149
+ global _local_loaded
150
+ if not _local_loaded:
151
+ loop = asyncio.get_event_loop()
152
+ await loop.run_in_executor(None, _load_local_model)
153
+
154
+ import torch
155
 
156
  msgs = [{"role": m.role, "content": m.content} for m in req.messages]
157
 
 
179
  "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
180
  "object": "chat.completion",
181
  "created": int(time.time()),
182
+ "model": MODEL_ID,
183
+ "choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}],
 
 
 
 
184
  "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
185
  }
186
 
187
 
188
+ # ---------------------------------------------------------------------------
189
+ # Startup: eagerly load local model only if no fast provider configured
190
+ # ---------------------------------------------------------------------------
191
+
192
+ @app.on_event("startup")
193
+ async def startup():
194
+ mode = _infer_mode()
195
+ print(f"Inference mode: {mode} active_model: {_active_model()}")
196
+ if mode == "local-cpu":
197
+ loop = asyncio.get_event_loop()
198
+ await loop.run_in_executor(None, _load_local_model)
199
+
200
+
201
+ # ---------------------------------------------------------------------------
202
+ # Endpoints
203
+ # ---------------------------------------------------------------------------
204
+
205
+ @app.get("/health")
206
+ async def health():
207
+ mode = _infer_mode()
208
+ loaded = True if mode != "local-cpu" else _local_loaded
209
+ return {
210
+ "status": "ok",
211
+ "mode": mode,
212
+ "model": _active_model(),
213
+ "loaded": loaded,
214
+ }
215
+
216
+
217
+ @app.post("/v1/chat/completions")
218
+ async def chat_completions(req: ChatRequest):
219
+ mode = _infer_mode()
220
+ if mode == "groq":
221
+ return await _call_groq(req)
222
+ if mode == "gemini":
223
+ return await _call_gemini(req)
224
+ return await _call_local(req)
225
+
226
+
227
  if __name__ == "__main__":
228
  import uvicorn
229
  uvicorn.run(app, host="0.0.0.0", port=PORT)