jiminaa commited on
Commit
6482d6b
·
1 Parent(s): 32fdadf

openai compatible endpoint

Browse files
Files changed (2) hide show
  1. Dockerfile +0 -1
  2. main.py +75 -2
Dockerfile CHANGED
@@ -12,6 +12,5 @@ RUN pip uninstall -y gradio gradio-client || true \
12
 
13
  USER user
14
 
15
-
16
  # Start the FastAPI app on port 7860, the default port expected by Spaces
17
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
12
 
13
  USER user
14
 
 
15
  # Start the FastAPI app on port 7860, the default port expected by Spaces
16
  CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py CHANGED
@@ -7,8 +7,10 @@ from fastapi import FastAPI
7
  from fastapi.responses import StreamingResponse, RedirectResponse
8
  from pydantic import BaseModel
9
  import json
10
- from typing import List, Literal
11
  import os
 
 
12
 
13
 
14
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -123,6 +125,14 @@ class GenerateRequest(BaseModel):
123
  max_length: int = 256
124
  temperature: float = 0.7
125
 
 
 
 
 
 
 
 
 
126
  # fastAPI endpoints
127
 
128
  # return information about the API
@@ -177,7 +187,70 @@ async def generate_stream_api(request: GenerateRequest):
177
  headers={
178
  "Cache-Control": "no-cache", # Don't cache streaming responses
179
  "Connection": "keep-alive", # Keep connection open
180
- "X-Accel-Buffering": "no",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  }
182
  )
183
 
 
7
  from fastapi.responses import StreamingResponse, RedirectResponse
8
  from pydantic import BaseModel
9
  import json
10
+ from typing import List, Literal, Optional
11
  import os
12
+ import uuid
13
+ import time
14
 
15
 
16
  HF_TOKEN = os.getenv("HF_TOKEN")
 
125
  max_length: int = 256
126
  temperature: float = 0.7
127
 
128
+ # OpenAI-compatible request format for InferenceClient
129
+ class ChatCompletionRequest(BaseModel):
130
+ model: str = "default"
131
+ messages: List[Message]
132
+ max_tokens: Optional[int] = 256
133
+ temperature: Optional[float] = 0.7
134
+ stream: Optional[bool] = True
135
+
136
  # fastAPI endpoints
137
 
138
  # return information about the API
 
187
  headers={
188
  "Cache-Control": "no-cache", # Don't cache streaming responses
189
  "Connection": "keep-alive", # Keep connection open
190
+ "X-Accel-Buffering": "no",
191
+ }
192
+ )
193
+
194
+ # OpenAI-compatible endpoint for HuggingFace InferenceClient
195
+ # Pass language via the `model` field (e.g., "English", "Spanish", "Korean")
196
+ @app.post("/v1/chat/completions")
197
+ async def chat_completions(request: ChatCompletionRequest):
198
+
199
+ messages_dicts = [{"role": msg.role, "content": msg.content} for msg in request.messages]
200
+
201
+ # Use model field as language selector, default to English if invalid
202
+ language = request.model if request.model in adapter_paths else "English"
203
+
204
+ chat_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
205
+ created = int(time.time())
206
+
207
+ def event_generator():
208
+ try:
209
+ for token in generate_text_stream(
210
+ messages_dicts,
211
+ language,
212
+ request.max_tokens or 256,
213
+ request.temperature or 0.7
214
+ ):
215
+ chunk = {
216
+ "id": chat_id,
217
+ "object": "chat.completion.chunk",
218
+ "created": created,
219
+ "model": language,
220
+ "choices": [{
221
+ "index": 0,
222
+ "delta": {"content": token},
223
+ "finish_reason": None
224
+ }]
225
+ }
226
+ yield f"data: {json.dumps(chunk)}\n\n"
227
+
228
+ # Final chunk with finish_reason
229
+ final_chunk = {
230
+ "id": chat_id,
231
+ "object": "chat.completion.chunk",
232
+ "created": created,
233
+ "model": language,
234
+ "choices": [{
235
+ "index": 0,
236
+ "delta": {},
237
+ "finish_reason": "stop"
238
+ }]
239
+ }
240
+ yield f"data: {json.dumps(final_chunk)}\n\n"
241
+ yield "data: [DONE]\n\n"
242
+
243
+ except Exception as e:
244
+ error_chunk = {"error": {"message": str(e), "type": "server_error"}}
245
+ yield f"data: {json.dumps(error_chunk)}\n\n"
246
+
247
+ return StreamingResponse(
248
+ event_generator(),
249
+ media_type="text/event-stream",
250
+ headers={
251
+ "Cache-Control": "no-cache",
252
+ "Connection": "keep-alive",
253
+ "X-Accel-Buffering": "no",
254
  }
255
  )
256