| | from fastapi import FastAPI, Request |
| | from transformers import AutoTokenizer, AutoModel |
| | import uvicorn, json, datetime |
| | import torch |
| |
|
| | DEVICE = "cuda" |
| | DEVICE_ID = "0" |
| | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE |
| |
|
| |
|
| | def torch_gc(): |
| | if torch.cuda.is_available(): |
| | with torch.cuda.device(CUDA_DEVICE): |
| | torch.cuda.empty_cache() |
| | torch.cuda.ipc_collect() |
| |
|
| |
|
| | app = FastAPI() |
| |
|
| |
|
| | @app.post("/") |
| | async def create_item(request: Request): |
| | global model, tokenizer |
| | json_post_raw = await request.json() |
| | json_post = json.dumps(json_post_raw) |
| | json_post_list = json.loads(json_post) |
| | prompt = json_post_list.get('prompt') |
| | history = json_post_list.get('history') |
| | max_length = json_post_list.get('max_length') |
| | top_p = json_post_list.get('top_p') |
| | temperature = json_post_list.get('temperature') |
| | response, history = model.chat(tokenizer, |
| | prompt, |
| | history=history, |
| | max_length=max_length if max_length else 2048, |
| | top_p=top_p if top_p else 0.7, |
| | temperature=temperature if temperature else 0.95) |
| | now = datetime.datetime.now() |
| | time = now.strftime("%Y-%m-%d %H:%M:%S") |
| | answer = { |
| | "response": response, |
| | "history": history, |
| | "status": 200, |
| | "time": time |
| | } |
| | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' |
| | print(log) |
| | |
| | return answer |
| |
|
| |
|
| | if __name__ == '__main__': |
| | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) |
| | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda() |
| | model.eval() |
| | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) |
| |
|