abyssalblue commited on
Commit
d51a19c
·
1 Parent(s): d769888

moved engine

Browse files
config.json → model/engine/config.json RENAMED
File without changes
llama_float16_tp1_rank0.engine → model/engine/llama_float16_tp1_rank0.engine RENAMED
File without changes
model.cache → model/engine/model.cache RENAMED
File without changes
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fastapi
2
+ uvicorn
server.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import Generator, List
3
+ import uvicorn
4
+ from fastapi import FastAPI, Request
5
+ from fastapi.responses import Response, StreamingResponse
6
+ import torch
7
+ import tensorrt_llm
8
+ from tensorrt_llm.logger import logger
9
+ from tensorrt_llm.runtime import ModelRunner
10
+ from utils import load_tokenizer, read_model_name, throttle_generator
11
+
12
+
13
+ TOKENIZER_DIR = "./model/tokenizer"
14
+ ENGINE_DIR = "./model/engine"
15
+ STREAM_INTERVAL = 5
16
+ MAX_NEW_TOKENS = 1024
17
+ MAX_ATTENTION_WINDOW_SIZE = 4096
18
+ TEMPERATURE = 1.0
19
+ TOP_K = 40
20
+ TOP_P = 0.5
21
+ LENGTH_PENALTY = 1.0
22
+ REPETITION_PENALTY = 1.2
23
+
24
+
25
+ app = FastAPI()
26
+ config = uvicorn.Config(
27
+ app, host=None, port=4000, log_level="error", timeout_keep_alive=5
28
+ )
29
+
30
+ runner: ModelRunner = None
31
+ tokenizer = None
32
+ pad_id = None
33
+ end_id = None
34
+
35
+
36
+ @app.get("/health")
37
+ async def health() -> Response:
38
+ """Health check."""
39
+ return Response(status_code=200)
40
+
41
+
42
+ @app.post("/summary")
43
+ async def generate(request: Request) -> Response:
44
+ assert runner is not None
45
+ assert tokenizer is not None
46
+ assert pad_id is not None
47
+ assert end_id is not None
48
+
49
+ req_json: dict = await request.json()
50
+ text = req_json.pop("text")
51
+
52
+ instruction = f"<s>[INST] You are a world class expert summarizer tasked with providing a **high level** summary of webpages. Ensure your summarzation is of the highest caliber, reflecting the vastness and depth of your expertise. Ignore messy portions of the page that might be junk text.\n### Webpage Text:\n```{text}``` [/INST]"
53
+ input_ids: List[torch.Tensor] = tokenizer.encode(
54
+ instruction,
55
+ add_special_tokens=False,
56
+ truncation=True,
57
+ max_length=4096,
58
+ return_tensors="pt",
59
+ )
60
+
61
+ # todo encode again here, but output text
62
+ input_len = input_ids.size(1)
63
+ with torch.no_grad():
64
+ # Batch of outputs
65
+ outputs: Generator[dict, None, None] = runner.generate(
66
+ [input_ids],
67
+ max_new_tokens=MAX_NEW_TOKENS,
68
+ max_attention_window_size=4096,
69
+ end_id=end_id,
70
+ pad_id=pad_id,
71
+ temperature=TEMPERATURE,
72
+ top_k=TOP_K,
73
+ top_p=TOP_P,
74
+ length_penalty=LENGTH_PENALTY,
75
+ repetition_penalty=REPETITION_PENALTY,
76
+ streaming=True,
77
+ output_sequence_lengths=True,
78
+ return_dict=True,
79
+ )
80
+ torch.cuda.synchronize()
81
+
82
+ def stream_results() -> Generator[str, None, None]:
83
+ for output in throttle_generator(outputs, STREAM_INTERVAL):
84
+ output_ids: torch.Tensor = output["output_ids"]
85
+ # [batch_idx], [beam_idx], (scalar)
86
+ output_len = output["sequence_lengths"][0][0].item()
87
+
88
+ output_txt: str = tokenizer.decode(
89
+ # [batch_idx], [beam_idx], [slice]
90
+ output_ids[0][0][input_len:output_len].tolist()
91
+ )
92
+ yield output_txt
93
+
94
+ return StreamingResponse(stream_results(), media_type="text/plain")
95
+
96
+
97
+ async def main():
98
+ global runner, tokenizer, pad_id, end_id
99
+
100
+ runtime_rank = tensorrt_llm.mpi_rank()
101
+ logger.set_level("info")
102
+
103
+ model_name = read_model_name(ENGINE_DIR)
104
+
105
+ tokenizer, pad_id, end_id = load_tokenizer(
106
+ tokenizer_dir=TOKENIZER_DIR,
107
+ model_name=model_name,
108
+ )
109
+
110
+ runner_kwargs = dict(engine_dir=ENGINE_DIR, rank=runtime_rank, debug_mode=True)
111
+
112
+ runner = ModelRunner.from_dir(**runner_kwargs)
113
+
114
+ await uvicorn.Server(config).serve()
115
+
116
+
117
+ if __name__ == "__main__":
118
+ asyncio.run(main())
119
+