Forol commited on
Commit
381d980
Β·
verified Β·
1 Parent(s): fff978d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, time, uuid
2
+ from threading import Thread
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+ from fastapi import FastAPI, HTTPException, Depends
7
+ from fastapi.responses import StreamingResponse, JSONResponse
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
+ from pydantic import BaseModel
10
+ from typing import List, Optional
11
+
12
+ # ── Config ────────────────────────────────────────────────────────────────────
13
+ MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-1.5B-Instruct")
14
+ API_KEY = os.getenv("API_KEY", "my-secret-key") # set this in Space secrets
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
17
+
18
+ print(f"Loading {MODEL_ID} on {DEVICE} ...")
19
+
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_ID,
23
+ torch_dtype=DTYPE,
24
+ device_map="auto",
25
+ trust_remote_code=True,
26
+ )
27
+ model.eval()
28
+ print("Model ready.")
29
+
30
+ # ── FastAPI app ───────────────────────────────────────────────────────────────
31
+ app = FastAPI(title="LLM API", version="1.0")
32
+ bearer = HTTPBearer(auto_error=False)
33
+
34
+ def verify_key(creds: Optional[HTTPAuthorizationCredentials] = Depends(bearer)):
35
+ if API_KEY and (creds is None or creds.credentials != API_KEY):
36
+ raise HTTPException(status_code=401, detail="Invalid or missing API key")
37
+
38
+ # ── Schemas ───────────────────────────────────────────────────────────────────
39
+ class Message(BaseModel):
40
+ role: str
41
+ content: str
42
+
43
+ class ChatRequest(BaseModel):
44
+ model: Optional[str] = MODEL_ID
45
+ messages: List[Message]
46
+ max_tokens: Optional[int] = 512
47
+ temperature: Optional[float]= 0.7
48
+ stream: Optional[bool] = False
49
+
50
+ # ── Routes ────────────────────────────────────────────────────────────────────
51
+ @app.get("/")
52
+ def root():
53
+ return {"status": "ok", "model": MODEL_ID, "device": DEVICE}
54
+
55
+ @app.get("/v1/models")
56
+ def list_models(_=Depends(verify_key)):
57
+ return {
58
+ "object": "list",
59
+ "data": [{"id": MODEL_ID, "object": "model", "owned_by": "user"}]
60
+ }
61
+
62
+ @app.post("/v1/chat/completions")
63
+ def chat_completions(req: ChatRequest, _=Depends(verify_key)):
64
+ msgs = [{"role": m.role, "content": m.content} for m in req.messages]
65
+
66
+ # Apply chat template
67
+ text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
68
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
69
+
70
+ gen_kwargs = dict(
71
+ **inputs,
72
+ max_new_tokens = req.max_tokens,
73
+ temperature = req.temperature,
74
+ do_sample = req.temperature > 0,
75
+ pad_token_id = tokenizer.eos_token_id,
76
+ )
77
+
78
+ cid = f"chatcmpl-{uuid.uuid4().hex[:12]}"
79
+
80
+ # ── Streaming ──────────────────────────────────────────────────────────────
81
+ if req.stream:
82
+ streamer = TextIteratorStreamer(
83
+ tokenizer, skip_special_tokens=True, skip_prompt=True
84
+ )
85
+ gen_kwargs["streamer"] = streamer
86
+ Thread(target=model.generate, kwargs=gen_kwargs, daemon=True).start()
87
+
88
+ def event_stream():
89
+ for token in streamer:
90
+ chunk = {
91
+ "id": cid, "object": "chat.completion.chunk",
92
+ "created": int(time.time()), "model": MODEL_ID,
93
+ "choices": [{"delta": {"content": token}, "index": 0, "finish_reason": None}]
94
+ }
95
+ yield f"data: {json.dumps(chunk)}\n\n"
96
+ done = {
97
+ "id": cid, "object": "chat.completion.chunk",
98
+ "created": int(time.time()), "model": MODEL_ID,
99
+ "choices": [{"delta": {}, "index": 0, "finish_reason": "stop"}]
100
+ }
101
+ yield f"data: {json.dumps(done)}\n\n"
102
+ yield "data: [DONE]\n\n"
103
+
104
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
105
+
106
+ # ── Non-streaming ──────────────────────────────────────────────────────────
107
+ with torch.no_grad():
108
+ output = model.generate(**gen_kwargs)
109
+
110
+ prompt_len = inputs.input_ids.shape[1]
111
+ reply = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
112
+ comp_tokens = output.shape[1] - prompt_len
113
+
114
+ return JSONResponse({
115
+ "id": cid, "object": "chat.completion",
116
+ "created": int(time.time()), "model": MODEL_ID,
117
+ "choices": [{
118
+ "index": 0,
119
+ "message": {"role": "assistant", "content": reply},
120
+ "finish_reason": "stop"
121
+ }],
122
+ "usage": {
123
+ "prompt_tokens": prompt_len,
124
+ "completion_tokens": comp_tokens,
125
+ "total_tokens": prompt_len + comp_tokens,
126
+ }
127
+ })