|
|
import uvicorn |
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
|
|
from typing import List, Dict, Any, Optional |
|
|
import json |
|
|
import os |
|
|
from transformers import AutoProcessor, AutoModel, AutoTokenizer,AutoModelForSequenceClassification,AutoModelForCausalLM |
|
|
import torch, torchaudio |
|
|
import os |
|
|
import openai |
|
|
import copy |
|
|
import numpy as np |
|
|
from rapidfuzz import process, fuzz |
|
|
from pypinyin import pinyin, Style |
|
|
from copy import deepcopy |
|
|
def get_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=20): |
|
|
pinyin_dict = {} |
|
|
for location in location_dict: |
|
|
pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)]) |
|
|
pinyin_dict[pinyin_name] = location |
|
|
|
|
|
user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)]) |
|
|
|
|
|
best_match_pinyin = process.extractOne( |
|
|
query=user_pinyin_sentence, |
|
|
choices=list(pinyin_dict.keys()), |
|
|
scorer=fuzz.token_set_ratio, |
|
|
score_cutoff=score_cutoff |
|
|
) |
|
|
|
|
|
if best_match_pinyin and best_match_pinyin[0] in pinyin_dict: |
|
|
return pinyin_dict[best_match_pinyin[0]] |
|
|
else: |
|
|
return "" |
|
|
|
|
|
|
|
|
class InferenceClass: |
|
|
def __init__(self,base_url="http://localhost:8087/v1"): |
|
|
self.model_id = "/root/merged" |
|
|
|
|
|
self.client = openai.OpenAI(base_url=base_url, api_key="EMPTY") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_bert_class(self,query): |
|
|
return "SLM" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_gpt(self,messages): |
|
|
prompt = self.apply_template(messages) |
|
|
print('input prompt is : ') |
|
|
print(prompt) |
|
|
print('------'*30) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
response = self.client.completions.create( |
|
|
model=self.model_id, prompt=prompt, max_tokens=8192,temperature=0.001, |
|
|
top_p=0.001, |
|
|
) |
|
|
response_data = { |
|
|
"content": response.choices[0].text, |
|
|
"finish_reason": response.choices[0].finish_reason, |
|
|
} |
|
|
return response_data['content'] |
|
|
except: |
|
|
print('call_gpt error','input',messages) |
|
|
return "" |
|
|
def build_prompt(self,messages): |
|
|
role_to_tag = { |
|
|
"user": "user", |
|
|
"assistant": "assistant", |
|
|
"function": "assistant", |
|
|
"observation": "ipython", |
|
|
"tool": "ipyton", |
|
|
"system": "system" |
|
|
} |
|
|
|
|
|
prompt = "" |
|
|
|
|
|
for i, msg in enumerate(messages): |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
|
|
|
tag = role_to_tag.get(role, "user") |
|
|
|
|
|
prompt += f"<|start_header_id|>{tag}<|end_header_id|>\n\n{content}<|eot_id|>" |
|
|
|
|
|
last_role = messages[-1]['role'] |
|
|
if last_role in {"user", "observation", "tool"}: |
|
|
prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
|
|
|
return prompt |
|
|
def apply_template(self, messages): |
|
|
gpt_out_template = """{RAG_Keywords}<Next step>{intent}; |
|
|
<Text Output>{output} |
|
|
""" |
|
|
human_in_template = """{RAG_data}{Cloud_data}<User Query>{query}{LOC} |
|
|
""" |
|
|
short_messages = [] |
|
|
pre = "" |
|
|
for conv in messages: |
|
|
if conv['from']=='human' and ("旅運" in conv['value'] or "市民" in conv['value']) and "開始" in conv['value'] and len(conv['value'])==7\ |
|
|
and '-' in conv['value'].replace("旅運",'').replace("市民",'').replace("開始",'') \ |
|
|
and pre<conv['value'].replace("旅運",'').replace("市民",'').replace("開始",''): |
|
|
pre = conv['value'].replace("旅運",'').replace("市民",'').replace("開始",'') |
|
|
short_messages = [] |
|
|
short_messages.append(conv) |
|
|
msg = [] |
|
|
for c in short_messages: |
|
|
obj = {'role': None} |
|
|
if c['from'] == 'human': |
|
|
tmp = deepcopy(human_in_template) |
|
|
rag_data="" |
|
|
cloud_data = "" |
|
|
if 'RAG' in c and c['RAG']: |
|
|
rag_data = "<RAG data>" |
|
|
for k in c['RAG']: |
|
|
if type(c['RAG'][k])==type(None):continue |
|
|
rag_data+=k+":"+c['RAG'][k]['name']+", content:"+str(c['RAG'][k]['content'])+';\n' |
|
|
if 'Cloud' in c and c['Cloud']: |
|
|
cloud_data="<Cloud data>"+c['Cloud']+';\n' |
|
|
Loc_info = "" |
|
|
if "Loc" in c and c["Loc"]: |
|
|
Loc_info = ", current location:"+c["Loc"] |
|
|
tmp=tmp.format(RAG_data=rag_data,Cloud_data=cloud_data,query=c['value'],LOC=Loc_info) |
|
|
obj.update({'role': 'user', 'content': tmp}) |
|
|
elif c['from'] == 'gpt': |
|
|
tmp = deepcopy(gpt_out_template) |
|
|
rag_in = "<keywords>"+str(c['RAG_KEYWORDS'])+";\n" if 'RAG_KEYWORDS' in c and c['RAG_KEYWORDS'] else "" |
|
|
intent = "False" if 'Next' in c and not c['Next'] else "True" |
|
|
tmp=tmp.format(RAG_Keywords=rag_in,intent=intent,output=c['value']) |
|
|
obj.update({'role': 'assistant', 'content': tmp}) |
|
|
msg.append(obj) |
|
|
input_text = self.build_prompt(msg) |
|
|
|
|
|
return input_text |
|
|
|
|
|
def generate(self,chat_history): |
|
|
''' |
|
|
input: |
|
|
chat_history : dict |
|
|
stage : str |
|
|
return: |
|
|
model_output : dict |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bert_class = "SLM" |
|
|
print('Bert out : ',bert_class) |
|
|
|
|
|
if bert_class=="SLM": |
|
|
model_output = self.call_gpt(chat_history) |
|
|
model_output = model_output.replace("<RAG data>","<keywords>").replace('晚上晚上','晚上').replace("ShoushanLover'sObservatory","Shoushan Lover's Observatory").replace('SizihwanBay','Sizihwan Bay') |
|
|
|
|
|
print(model_output) |
|
|
tmp = {'from': 'gpt','Next':False} |
|
|
outs = model_output.split(';\n') |
|
|
print(outs) |
|
|
if '<keywords>' in outs[0]: |
|
|
tmp['RAG_KEYWORDS'] = outs[0].replace('<keywords>','').replace(' ','') |
|
|
if '<Next step>' in outs[1]: |
|
|
tmp['Next'] = True if 'True' in outs[1] else False |
|
|
elif '<Next step>' in outs[0]: |
|
|
tmp['Next'] = True if 'True' in outs[0] else False |
|
|
tmp['value'] = outs[-1].replace('<Text Output>','').replace('\n<|eot_id|>','') |
|
|
|
|
|
chat_history.append(tmp) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
chat_history.append({'from': 'gpt', 'value': "","Next":True,'quick':True}) |
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
pipeline = InferenceClass() |
|
|
app = FastAPI( |
|
|
title="Audio LLM API", |
|
|
description="An API that accepts an audio file and a list of dictionaries.", |
|
|
) |
|
|
|
|
|
|
|
|
@app.post("/service_LLM/") |
|
|
async def generate( |
|
|
data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.") |
|
|
) -> List[Dict[str, Any]]: |
|
|
|
|
|
try: |
|
|
input_data_list = json.loads(data) |
|
|
if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list): |
|
|
raise ValueError("The provided data is not a list of dictionaries.") |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
raise HTTPException( |
|
|
status_code=422, |
|
|
detail="Invalid JSON format for 'data' field. Please provide a valid JSON string." |
|
|
) |
|
|
except ValueError as e: |
|
|
raise HTTPException( |
|
|
status_code=422, |
|
|
detail=str(e) |
|
|
) |
|
|
|
|
|
output_data = pipeline.generate(input_data_list) |
|
|
print(output_data) |
|
|
return output_data |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8088, reload=True) |