likhonsheikh commited on
Commit
b1751bb
·
verified ·
1 Parent(s): dce0160

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Anthropic-Compatible API Endpoint
3
+ Lightweight CPU-based implementation for Hugging Face Spaces
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import uuid
9
+ from typing import List, Optional, Union
10
+ from contextlib import asynccontextmanager
11
+
12
+ from fastapi import FastAPI, HTTPException, Header, Request
13
+ from fastapi.responses import StreamingResponse, JSONResponse
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel, Field
16
+ import torch
17
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
18
+ from threading import Thread
19
+ import json
20
+
21
+ # ============== Configuration ==============
22
+ MODEL_ID = "HuggingFaceTB/SmolLM2-135M-Instruct" # Ultra-lightweight 135M model
23
+ MAX_TOKENS_DEFAULT = 1024
24
+ DEVICE = "cpu"
25
+
26
+ # Global model and tokenizer
27
+ model = None
28
+ tokenizer = None
29
+
30
+ @asynccontextmanager
31
+ async def lifespan(app: FastAPI):
32
+ """Load model on startup"""
33
+ global model, tokenizer
34
+ print(f"Loading model: {MODEL_ID}")
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL_ID,
39
+ torch_dtype=torch.float32,
40
+ device_map=DEVICE,
41
+ low_cpu_mem_usage=True
42
+ )
43
+ model.eval()
44
+ print("Model loaded successfully!")
45
+
46
+ yield
47
+
48
+ # Cleanup
49
+ del model, tokenizer
50
+
51
+ app = FastAPI(
52
+ title="Anthropic-Compatible API",
53
+ description="Lightweight CPU-based API with Anthropic Messages API compatibility",
54
+ version="1.0.0",
55
+ lifespan=lifespan
56
+ )
57
+
58
+ # CORS middleware
59
+ app.add_middleware(
60
+ CORSMiddleware,
61
+ allow_origins=["*"],
62
+ allow_credentials=True,
63
+ allow_methods=["*"],
64
+ allow_headers=["*"],
65
+ )
66
+
67
+ # ============== Pydantic Models (Anthropic-Compatible) ==============
68
+
69
+ class ContentBlock(BaseModel):
70
+ type: str = "text"
71
+ text: str
72
+
73
+ class Message(BaseModel):
74
+ role: str
75
+ content: Union[str, List[ContentBlock]]
76
+
77
+ class MessageRequest(BaseModel):
78
+ model: str
79
+ messages: List[Message]
80
+ max_tokens: int = MAX_TOKENS_DEFAULT
81
+ temperature: Optional[float] = 0.7
82
+ top_p: Optional[float] = 0.9
83
+ top_k: Optional[int] = 50
84
+ stream: Optional[bool] = False
85
+ system: Optional[str] = None
86
+ stop_sequences: Optional[List[str]] = None
87
+
88
+ class Usage(BaseModel):
89
+ input_tokens: int
90
+ output_tokens: int
91
+
92
+ class MessageResponse(BaseModel):
93
+ id: str
94
+ type: str = "message"
95
+ role: str = "assistant"
96
+ content: List[ContentBlock]
97
+ model: str
98
+ stop_reason: str = "end_turn"
99
+ stop_sequence: Optional[str] = None
100
+ usage: Usage
101
+
102
+ class ErrorResponse(BaseModel):
103
+ type: str = "error"
104
+ error: dict
105
+
106
+ # ============== Helper Functions ==============
107
+
108
+ def format_messages(messages: List[Message], system: Optional[str] = None) -> str:
109
+ """Format messages into a prompt string"""
110
+ formatted_messages = []
111
+
112
+ if system:
113
+ formatted_messages.append({"role": "system", "content": system})
114
+
115
+ for msg in messages:
116
+ content = msg.content
117
+ if isinstance(content, list):
118
+ content = " ".join([block.text for block in content if block.type == "text"])
119
+ formatted_messages.append({"role": msg.role, "content": content})
120
+
121
+ # Use chat template if available
122
+ if tokenizer.chat_template:
123
+ return tokenizer.apply_chat_template(
124
+ formatted_messages,
125
+ tokenize=False,
126
+ add_generation_prompt=True
127
+ )
128
+
129
+ # Fallback simple format
130
+ prompt = ""
131
+ for msg in formatted_messages:
132
+ role = msg["role"].capitalize()
133
+ prompt += f"{role}: {msg['content']}\n"
134
+ prompt += "Assistant: "
135
+ return prompt
136
+
137
+ def generate_id() -> str:
138
+ """Generate a unique message ID"""
139
+ return f"msg_{uuid.uuid4().hex[:24]}"
140
+
141
+ # ============== API Endpoints ==============
142
+
143
+ @app.get("/")
144
+ async def root():
145
+ """Health check endpoint"""
146
+ return {
147
+ "status": "healthy",
148
+ "model": MODEL_ID,
149
+ "api_version": "2023-06-01",
150
+ "compatibility": "anthropic-messages-api"
151
+ }
152
+
153
+ @app.get("/v1/models")
154
+ async def list_models():
155
+ """List available models (Anthropic-compatible)"""
156
+ return {
157
+ "object": "list",
158
+ "data": [
159
+ {
160
+ "id": "smollm2-135m",
161
+ "object": "model",
162
+ "created": int(time.time()),
163
+ "owned_by": "huggingface",
164
+ "display_name": "SmolLM2 135M Instruct"
165
+ }
166
+ ]
167
+ }
168
+
169
+ @app.post("/v1/messages")
170
+ async def create_message(
171
+ request: MessageRequest,
172
+ x_api_key: Optional[str] = Header(None, alias="x-api-key"),
173
+ anthropic_version: Optional[str] = Header(None, alias="anthropic-version")
174
+ ):
175
+ """
176
+ Create a message (Anthropic Messages API compatible)
177
+ """
178
+ try:
179
+ # Format the prompt
180
+ prompt = format_messages(request.messages, request.system)
181
+
182
+ # Tokenize
183
+ inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
184
+ input_token_count = inputs.input_ids.shape[1]
185
+
186
+ if request.stream:
187
+ return await stream_response(request, inputs, input_token_count)
188
+
189
+ # Generate
190
+ with torch.no_grad():
191
+ outputs = model.generate(
192
+ **inputs,
193
+ max_new_tokens=request.max_tokens,
194
+ temperature=request.temperature if request.temperature > 0 else 1.0,
195
+ top_p=request.top_p,
196
+ top_k=request.top_k,
197
+ do_sample=request.temperature > 0,
198
+ pad_token_id=tokenizer.eos_token_id,
199
+ eos_token_id=tokenizer.eos_token_id,
200
+ )
201
+
202
+ # Decode only new tokens
203
+ generated_tokens = outputs[0][input_token_count:]
204
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
205
+ output_token_count = len(generated_tokens)
206
+
207
+ # Build response
208
+ response = MessageResponse(
209
+ id=generate_id(),
210
+ content=[ContentBlock(type="text", text=generated_text.strip())],
211
+ model=request.model,
212
+ stop_reason="end_turn",
213
+ usage=Usage(
214
+ input_tokens=input_token_count,
215
+ output_tokens=output_token_count
216
+ )
217
+ )
218
+
219
+ return response
220
+
221
+ except Exception as e:
222
+ raise HTTPException(status_code=500, detail=str(e))
223
+
224
+ async def stream_response(request: MessageRequest, inputs, input_token_count: int):
225
+ """Stream response using SSE (Server-Sent Events)"""
226
+
227
+ message_id = generate_id()
228
+
229
+ async def generate():
230
+ # Send message_start event
231
+ start_event = {
232
+ "type": "message_start",
233
+ "message": {
234
+ "id": message_id,
235
+ "type": "message",
236
+ "role": "assistant",
237
+ "content": [],
238
+ "model": request.model,
239
+ "stop_reason": None,
240
+ "stop_sequence": None,
241
+ "usage": {"input_tokens": input_token_count, "output_tokens": 0}
242
+ }
243
+ }
244
+ yield f"event: message_start\ndata: {json.dumps(start_event)}\n\n"
245
+
246
+ # Send content_block_start
247
+ block_start = {
248
+ "type": "content_block_start",
249
+ "index": 0,
250
+ "content_block": {"type": "text", "text": ""}
251
+ }
252
+ yield f"event: content_block_start\ndata: {json.dumps(block_start)}\n\n"
253
+
254
+ # Setup streamer
255
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
256
+
257
+ generation_kwargs = {
258
+ **inputs,
259
+ "max_new_tokens": request.max_tokens,
260
+ "temperature": request.temperature if request.temperature > 0 else 1.0,
261
+ "top_p": request.top_p,
262
+ "top_k": request.top_k,
263
+ "do_sample": request.temperature > 0,
264
+ "pad_token_id": tokenizer.eos_token_id,
265
+ "eos_token_id": tokenizer.eos_token_id,
266
+ "streamer": streamer,
267
+ }
268
+
269
+ # Run generation in a thread
270
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
271
+ thread.start()
272
+
273
+ output_tokens = 0
274
+ for text in streamer:
275
+ if text:
276
+ output_tokens += len(tokenizer.encode(text, add_special_tokens=False))
277
+ delta_event = {
278
+ "type": "content_block_delta",
279
+ "index": 0,
280
+ "delta": {"type": "text_delta", "text": text}
281
+ }
282
+ yield f"event: content_block_delta\ndata: {json.dumps(delta_event)}\n\n"
283
+
284
+ thread.join()
285
+
286
+ # Send content_block_stop
287
+ block_stop = {"type": "content_block_stop", "index": 0}
288
+ yield f"event: content_block_stop\ndata: {json.dumps(block_stop)}\n\n"
289
+
290
+ # Send message_delta
291
+ delta = {
292
+ "type": "message_delta",
293
+ "delta": {"stop_reason": "end_turn", "stop_sequence": None},
294
+ "usage": {"output_tokens": output_tokens}
295
+ }
296
+ yield f"event: message_delta\ndata: {json.dumps(delta)}\n\n"
297
+
298
+ # Send message_stop
299
+ yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'})}\n\n"
300
+
301
+ return StreamingResponse(
302
+ generate(),
303
+ media_type="text/event-stream",
304
+ headers={
305
+ "Cache-Control": "no-cache",
306
+ "Connection": "keep-alive",
307
+ "X-Accel-Buffering": "no"
308
+ }
309
+ )
310
+
311
+ # Token counting endpoint
312
+ @app.post("/v1/messages/count_tokens")
313
+ async def count_tokens(request: MessageRequest):
314
+ """Count tokens for a message request"""
315
+ prompt = format_messages(request.messages, request.system)
316
+ tokens = tokenizer.encode(prompt)
317
+ return {"input_tokens": len(tokens)}
318
+
319
+ # Health check
320
+ @app.get("/health")
321
+ async def health():
322
+ return {"status": "ok", "model_loaded": model is not None}
323
+
324
+ if __name__ == "__main__":
325
+ import uvicorn
326
+ uvicorn.run(app, host="0.0.0.0", port=7860)