city / deploy /main.py
jva96160's picture
Upload 14 files
4a8c611 verified
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from typing import List, Dict, Any, Optional
import json
import os
from transformers import AutoProcessor, AutoModel, AutoTokenizer,AutoModelForSequenceClassification,AutoModelForCausalLM
import torch, torchaudio
import os
import openai
import copy
import numpy as np
from rapidfuzz import process, fuzz
from pypinyin import pinyin, Style
from copy import deepcopy
def get_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=20):
pinyin_dict = {}
for location in location_dict:
pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)])
pinyin_dict[pinyin_name] = location
user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)])
best_match_pinyin = process.extractOne(
query=user_pinyin_sentence,
choices=list(pinyin_dict.keys()),
scorer=fuzz.token_set_ratio,
score_cutoff=score_cutoff
)
if best_match_pinyin and best_match_pinyin[0] in pinyin_dict:
return pinyin_dict[best_match_pinyin[0]]
else:
return ""
class InferenceClass:
def __init__(self,base_url="http://localhost:8087/v1"):
self.model_id = "/root/merged"
self.client = openai.OpenAI(base_url=base_url, api_key="EMPTY")
# model_name = "/root/bert_classifier"
# id2label = json.load(open('{}/config.json'.format(model_name)))["id2label"]
# self.id2label = {int(k):id2label[k] for k in id2label}
# self.label2id = {v: k for k, v in id2label.items()}
###############################
# MODEL_NAME = "/root/merged"
# self.llm = AutoModelForCausalLM.from_pretrained(
# MODEL_NAME,
# device_map="auto",
# trust_remote_code=True,
# torch_dtype=torch.bfloat16,
# )
# self.tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# self.llm_processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
###################################
# num_labels = len(self.id2label)
# self.bert_tokenizer = AutoTokenizer.from_pretrained(model_name)
# self.bert_classifier = AutoModelForSequenceClassification.from_pretrained(
# model_name,
# num_labels=num_labels,
# id2label=self.id2label,
# label2id=self.label2id
# ).to('cuda')
# self.quick_response = {'旅1-2-ch':["幫你找到{}的交通路線"]}
# self.quick_reject = {'旅1-2-ch':["此行程無須搭乘或無法搭乘高鐵抵達,請重新給我一個目的地"]}
# self.rag_tempalte = {}
def call_bert_class(self,query):
return "SLM"
# inputs = self.bert_tokenizer(query, padding=True, truncation=True, return_tensors="pt").to('cuda')
# with torch.no_grad():
# outputs = self.bert_classifier(**inputs)
# logits = outputs.logits
# confidence = max(torch.nn.functional.softmax(logits[0]))
# predictions = torch.argmax(logits, dim=-1)
# pred_class = self.id2label[predictions.cpu().numpy()[0]]
# print('Bert Confidence',confidence,pred_class)
# if confidence < 0.92:return "SLM"
# return pred_class
def call_gpt(self,messages):
prompt = self.apply_template(messages)
print('input prompt is : ')
print(prompt)
print('------'*30)
# inp = self.llm_processor(text=prompt,add_special_tokens=True, return_tensors='pt')
# inp = {k:inp[k].to('cuda') for k in inp}
# out = self.llm.generate(**inp, pad_token_id=self.tokenizer.eos_token_id,
# max_new_tokens=1024, # Generate up to 50 new tokens
# temperature=0.01)
# predict = self.llm_processor.decode(out[0][len(inp['input_ids'][0]):])
# print('output is : ')
# print(predict)
# print('------'*30)
# return predict
try:
response = self.client.completions.create(
model=self.model_id, prompt=prompt, max_tokens=8192,temperature=0.001,
top_p=0.001,
)
response_data = {
"content": response.choices[0].text,
"finish_reason": response.choices[0].finish_reason,
}
return response_data['content']
except:
print('call_gpt error','input',messages)
return ""
def build_prompt(self,messages):
role_to_tag = {
"user": "user",
"assistant": "assistant",
"function": "assistant",
"observation": "ipython",
"tool": "ipyton",
"system": "system"
}
prompt = ""
for i, msg in enumerate(messages):
role = msg["role"]
content = msg["content"]
tag = role_to_tag.get(role, "user")
prompt += f"<|start_header_id|>{tag}<|end_header_id|>\n\n{content}<|eot_id|>"
last_role = messages[-1]['role']
if last_role in {"user", "observation", "tool"}:
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n"
return prompt
def apply_template(self, messages):
gpt_out_template = """{RAG_Keywords}<Next step>{intent};
<Text Output>{output}
"""
human_in_template = """{RAG_data}{Cloud_data}<User Query>{query}{LOC}
"""
short_messages = []
pre = ""
for conv in messages:
if conv['from']=='human' and ("旅運" in conv['value'] or "市民" in conv['value']) and "開始" in conv['value'] and len(conv['value'])==7\
and '-' in conv['value'].replace("旅運",'').replace("市民",'').replace("開始",'') \
and pre<conv['value'].replace("旅運",'').replace("市民",'').replace("開始",''):
pre = conv['value'].replace("旅運",'').replace("市民",'').replace("開始",'')
short_messages = []
short_messages.append(conv)
msg = []
for c in short_messages:
obj = {'role': None}
if c['from'] == 'human':
tmp = deepcopy(human_in_template)
rag_data=""
cloud_data = ""
if 'RAG' in c and c['RAG']:
rag_data = "<RAG data>"
for k in c['RAG']:
if type(c['RAG'][k])==type(None):continue
rag_data+=k+":"+c['RAG'][k]['name']+", content:"+str(c['RAG'][k]['content'])+';\n'
if 'Cloud' in c and c['Cloud']:
cloud_data="<Cloud data>"+c['Cloud']+';\n'
Loc_info = ""
if "Loc" in c and c["Loc"]:
Loc_info = ", current location:"+c["Loc"]
tmp=tmp.format(RAG_data=rag_data,Cloud_data=cloud_data,query=c['value'],LOC=Loc_info)
obj.update({'role': 'user', 'content': tmp})
elif c['from'] == 'gpt':
tmp = deepcopy(gpt_out_template)
rag_in = "<keywords>"+str(c['RAG_KEYWORDS'])+";\n" if 'RAG_KEYWORDS' in c and c['RAG_KEYWORDS'] else ""
intent = "False" if 'Next' in c and not c['Next'] else "True"
tmp=tmp.format(RAG_Keywords=rag_in,intent=intent,output=c['value'])
obj.update({'role': 'assistant', 'content': tmp})
msg.append(obj)
input_text = self.build_prompt(msg)
return input_text
def generate(self,chat_history):
'''
input:
chat_history : dict
stage : str
return:
model_output : dict
'''
# if len(chat_history)>1:
# bert_query = chat_history[-2]['value']+"\n"+chat_history[-1]['value']
# bert_class = self.call_bert_class(bert_query)
# print('Bert query : ',bert_query)
# else:
bert_class = "SLM"
print('Bert out : ',bert_class)
if bert_class=="SLM":
model_output = self.call_gpt(chat_history)
model_output = model_output.replace("<RAG data>","<keywords>").replace('晚上晚上','晚上').replace("ShoushanLover'sObservatory","Shoushan Lover's Observatory").replace('SizihwanBay','Sizihwan Bay')
print(model_output)
tmp = {'from': 'gpt','Next':False}
outs = model_output.split(';\n')
print(outs)
if '<keywords>' in outs[0]:
tmp['RAG_KEYWORDS'] = outs[0].replace('<keywords>','').replace(' ','')
if '<Next step>' in outs[1]:
tmp['Next'] = True if 'True' in outs[1] else False
elif '<Next step>' in outs[0]:
tmp['Next'] = True if 'True' in outs[0] else False
tmp['value'] = outs[-1].replace('<Text Output>','').replace('\n<|eot_id|>','')
# if tmp['value']=="":tmp['tag'] = 'quick_response'
chat_history.append(tmp)
else:
# if not stage in self.quick_response:
# print("Error : Current stage can not find in quick response template")
# else:
chat_history.append({'from': 'gpt', 'value': "","Next":True,'quick':True})
return chat_history
pipeline = InferenceClass()
app = FastAPI(
title="Audio LLM API",
description="An API that accepts an audio file and a list of dictionaries.",
)
@app.post("/service_LLM/")
async def generate(
data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.")
) -> List[Dict[str, Any]]:
try:
input_data_list = json.loads(data)
if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list):
raise ValueError("The provided data is not a list of dictionaries.")
except json.JSONDecodeError:
raise HTTPException(
status_code=422,
detail="Invalid JSON format for 'data' field. Please provide a valid JSON string."
)
except ValueError as e:
raise HTTPException(
status_code=422,
detail=str(e)
)
output_data = pipeline.generate(input_data_list)
print(output_data)
return output_data
# uvicorn main:app --host 0.0.0.0 --port 8088 --log-level info --workers 1 >> ./log.txt
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8088, reload=True)