mrmadblack commited on
Commit
3c98a54
·
verified ·
1 Parent(s): 583a3d1

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +118 -83
server.py CHANGED
@@ -1,102 +1,69 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from huggingface_hub import hf_hub_download
4
- import requests
5
  import subprocess
 
6
  import uvicorn
7
  import os
8
- import json
 
9
 
10
  app = FastAPI()
11
 
12
- MODELS = {
13
- "tinyllama": "models/tinyllama.gguf"
14
- }
15
-
16
- class ChatRequest(BaseModel):
17
- model: str
18
- messages: list
19
-
20
- class GenerateRequest(BaseModel):
21
- model: str
22
- prompt: str
23
-
24
-
25
  # ---------------------------
26
- # logging
27
  # ---------------------------
28
 
29
- def log(title, data):
30
- print("\n==============================")
31
- print(title)
32
- print(data)
33
- print("==============================\n")
34
 
35
 
36
  # ---------------------------
37
- # prompt builder
38
  # ---------------------------
39
 
40
- def build_prompt(messages):
41
-
42
- prompt = ""
43
-
44
- for m in messages:
45
- role = m.get("role", "user")
46
- content = m.get("content", "")
47
-
48
- if content.strip() == "":
49
- continue
50
-
51
- prompt += f"{role}: {content}\n"
52
-
53
- prompt += "assistant:"
54
 
55
- log("PROMPT", prompt)
56
 
57
- return prompt
 
 
58
 
59
 
60
  # ---------------------------
61
- # download model
62
  # ---------------------------
63
 
64
  os.makedirs("models", exist_ok=True)
65
 
66
- MODEL_FILES = {
67
- "tinyllama": (
68
- "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
69
- "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
70
- )
71
- }
72
-
73
- for name, (repo, file) in MODEL_FILES.items():
74
-
75
- path = f"models/{name}.gguf"
76
 
77
- if not os.path.exists(path):
78
 
79
- print(f"Downloading model {name}")
80
-
81
- downloaded = hf_hub_download(
82
- repo_id=repo,
83
- filename=file
84
- )
85
 
86
- os.system(f"cp {downloaded} {path}")
87
 
88
- print(f"Model ready: {path}")
89
 
90
 
91
  # ---------------------------
92
- # start llama-server
93
  # ---------------------------
94
 
95
  print("Starting llama-server...")
96
 
97
  subprocess.Popen([
98
  "./llama.cpp/build/bin/llama-server",
99
- "-m", "models/tinyllama.gguf",
100
  "--host", "0.0.0.0",
101
  "--port", "8080",
102
  "-c", "2048"
@@ -104,7 +71,30 @@ subprocess.Popen([
104
 
105
 
106
  # ---------------------------
107
- # root
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # ---------------------------
109
 
110
  @app.get("/")
@@ -113,74 +103,119 @@ def root():
113
 
114
 
115
  # ---------------------------
116
- # model list
117
  # ---------------------------
118
 
119
  @app.get("/api/tags")
120
- def list_models():
 
 
 
 
121
 
122
  return {
123
  "models": [
124
- {"name": "tinyllama"}
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  ]
126
  }
127
 
128
 
129
  # ---------------------------
130
- # chat endpoint
131
  # ---------------------------
132
 
133
- @app.post("/api/chat")
134
- def chat(req: ChatRequest):
135
 
136
- prompt = build_prompt(req.messages)
137
 
138
  response = requests.post(
139
  "http://localhost:8080/completion",
140
  json={
141
- "prompt": prompt,
142
  "n_predict": 200
143
  }
144
  )
145
 
146
- data = response.json()
 
 
147
 
148
  return {
149
  "model": req.model,
150
- "message": {
151
- "role": "assistant",
152
- "content": data["content"]
153
- },
154
- "done": True
 
 
 
 
 
155
  }
156
 
157
 
158
  # ---------------------------
159
- # generate endpoint
160
  # ---------------------------
161
 
162
- @app.post("/api/generate")
163
- def generate(req: GenerateRequest):
 
 
 
 
164
 
165
  response = requests.post(
166
  "http://localhost:8080/completion",
167
  json={
168
- "prompt": req.prompt,
169
  "n_predict": 200
170
  }
171
  )
172
 
173
- data = response.json()
 
 
174
 
175
  return {
176
  "model": req.model,
177
- "response": data["content"],
178
- "done": True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  }
180
 
181
 
182
  # ---------------------------
183
- # start API
184
  # ---------------------------
185
 
186
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from huggingface_hub import hf_hub_download
 
4
  import subprocess
5
+ import requests
6
  import uvicorn
7
  import os
8
+ import time
9
+ import hashlib
10
 
11
  app = FastAPI()
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # ---------------------------
14
+ # MODEL CONFIG
15
  # ---------------------------
16
 
17
+ MODEL_NAME = "tinyllama"
18
+ MODEL_PATH = "models/tinyllama.gguf"
19
+
20
+ MODEL_REPO = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
21
+ MODEL_FILE = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
22
 
23
 
24
  # ---------------------------
25
+ # REQUEST MODELS
26
  # ---------------------------
27
 
28
+ class ChatRequest(BaseModel):
29
+ model: str
30
+ messages: list
 
 
 
 
 
 
 
 
 
 
 
31
 
 
32
 
33
+ class GenerateRequest(BaseModel):
34
+ model: str
35
+ prompt: str
36
 
37
 
38
  # ---------------------------
39
+ # DOWNLOAD MODEL
40
  # ---------------------------
41
 
42
  os.makedirs("models", exist_ok=True)
43
 
44
+ if not os.path.exists(MODEL_PATH):
 
 
 
 
 
 
 
 
 
45
 
46
+ print("Downloading model from HuggingFace")
47
 
48
+ downloaded = hf_hub_download(
49
+ repo_id=MODEL_REPO,
50
+ filename=MODEL_FILE
51
+ )
 
 
52
 
53
+ os.system(f"cp {downloaded} {MODEL_PATH}")
54
 
55
+ print("Model downloaded")
56
 
57
 
58
  # ---------------------------
59
+ # START LLAMA SERVER
60
  # ---------------------------
61
 
62
  print("Starting llama-server...")
63
 
64
  subprocess.Popen([
65
  "./llama.cpp/build/bin/llama-server",
66
+ "-m", MODEL_PATH,
67
  "--host", "0.0.0.0",
68
  "--port", "8080",
69
  "-c", "2048"
 
71
 
72
 
73
  # ---------------------------
74
+ # PROMPT BUILDER
75
+ # ---------------------------
76
+
77
+ def build_prompt(messages):
78
+
79
+ prompt = ""
80
+
81
+ for m in messages:
82
+
83
+ role = m.get("role")
84
+ content = m.get("content", "").strip()
85
+
86
+ if content == "":
87
+ continue
88
+
89
+ prompt += f"{role}: {content}\n"
90
+
91
+ prompt += "assistant:"
92
+
93
+ return prompt
94
+
95
+
96
+ # ---------------------------
97
+ # ROOT
98
  # ---------------------------
99
 
100
  @app.get("/")
 
103
 
104
 
105
  # ---------------------------
106
+ # MODEL LIST (OLLAMA FORMAT)
107
  # ---------------------------
108
 
109
  @app.get("/api/tags")
110
+ def tags():
111
+
112
+ size = os.path.getsize(MODEL_PATH)
113
+
114
+ digest = hashlib.sha256(open(MODEL_PATH, "rb").read()).hexdigest()
115
 
116
  return {
117
  "models": [
118
+ {
119
+ "name": MODEL_NAME,
120
+ "model": MODEL_NAME,
121
+ "modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
122
+ "size": size,
123
+ "digest": digest,
124
+ "details": {
125
+ "format": "gguf",
126
+ "family": "llama",
127
+ "families": ["llama"],
128
+ "parameter_size": "1.1B",
129
+ "quantization_level": "Q4_K_M"
130
+ }
131
+ }
132
  ]
133
  }
134
 
135
 
136
  # ---------------------------
137
+ # GENERATE ENDPOINT
138
  # ---------------------------
139
 
140
+ @app.post("/api/generate")
141
+ def generate(req: GenerateRequest):
142
 
143
+ start = time.time()
144
 
145
  response = requests.post(
146
  "http://localhost:8080/completion",
147
  json={
148
+ "prompt": req.prompt,
149
  "n_predict": 200
150
  }
151
  )
152
 
153
+ text = response.json()["content"].strip()
154
+
155
+ duration = int((time.time() - start) * 1e9)
156
 
157
  return {
158
  "model": req.model,
159
+ "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
160
+ "response": text,
161
+ "done": True,
162
+ "done_reason": "stop",
163
+ "total_duration": duration,
164
+ "load_duration": 0,
165
+ "prompt_eval_count": 0,
166
+ "prompt_eval_duration": 0,
167
+ "eval_count": len(text.split()),
168
+ "eval_duration": duration
169
  }
170
 
171
 
172
  # ---------------------------
173
+ # CHAT ENDPOINT
174
  # ---------------------------
175
 
176
+ @app.post("/api/chat")
177
+ def chat(req: ChatRequest):
178
+
179
+ start = time.time()
180
+
181
+ prompt = build_prompt(req.messages)
182
 
183
  response = requests.post(
184
  "http://localhost:8080/completion",
185
  json={
186
+ "prompt": prompt,
187
  "n_predict": 200
188
  }
189
  )
190
 
191
+ text = response.json()["content"].strip()
192
+
193
+ duration = int((time.time() - start) * 1e9)
194
 
195
  return {
196
  "model": req.model,
197
+ "created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
198
+ "message": {
199
+ "role": "assistant",
200
+ "content": text,
201
+ "thinking": "",
202
+ "tool_calls": [],
203
+ "images": []
204
+ },
205
+ "done": True,
206
+ "done_reason": "stop",
207
+ "total_duration": duration,
208
+ "load_duration": 0,
209
+ "prompt_eval_count": 0,
210
+ "prompt_eval_duration": 0,
211
+ "eval_count": len(text.split()),
212
+ "eval_duration": duration,
213
+ "logprobs": []
214
  }
215
 
216
 
217
  # ---------------------------
218
+ # START API
219
  # ---------------------------
220
 
221
  if __name__ == "__main__":