go / deploy /main.py
jva96160's picture
Upload 32 files
4c1ba5a verified
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from typing import List, Dict, Any, Optional
import json
import os
from transformers import AutoProcessor, AutoModel
import torch, torchaudio
import os
import copy
from rapidfuzz import process, fuzz
from pypinyin import pinyin, Style
def correct_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=50):
pinyin_dict = {}
for location in location_dict:
pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)])
pinyin_dict[pinyin_name] = location
user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)])
best_match_pinyin = process.extractOne(
query=user_pinyin_sentence,
choices=list(pinyin_dict.keys()), # 傳入拼音作為搜尋目標
scorer=fuzz.token_set_ratio,
score_cutoff=score_cutoff
)
if best_match_pinyin:
best_pinyin_name = best_match_pinyin[0]
corrected_location_name = pinyin_dict[best_pinyin_name]
best_user_substring = None
max_substring_score = 0
for i in range(len(user_input_sentence)):
for j in range(i + 2, min(i + 16, len(user_input_sentence) + 1)):
substring = user_input_sentence[i:j]
score = fuzz.ratio(substring, corrected_location_name)
if score > max_substring_score:
max_substring_score = score
best_user_substring = substring
if best_user_substring and max_substring_score > score_cutoff:
return user_input_sentence.replace(best_user_substring, corrected_location_name, 1)
else:
return user_input_sentence
return user_input_sentence
class InferenceClass:
def __init__(self,model_id):
self.model = AutoModel.from_pretrained(
model_id, device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
attn_implementation="eager"
).eval()
self.processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=True
)
self.remove_words_signs = lambda x:x.replace('User transcribe is :','').replace('GPT output is :','').replace('\n','').\
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
replace('!','')
def call_gpt(self,inputs_tensor):
with torch.inference_mode():
inputs = {k:inputs_tensor[k].to('cuda') for k in inputs_tensor}
generate_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]
model_output = self.processor.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return model_output
def call_function_fake(self,messages=[],obs=""):
messages.append({'from': 'observation', 'value': obs})
return messages
def generate(self,chat_history,tools="",audio_path=None):
'''
input:
audio_path : str
chat_history : dict
return:
model_output : dict
'''
chat_history = copy.deepcopy(chat_history)
if type(audio_path)!=type(None):
chat_history.append({'from': 'human',
'value': [{'type': 'audio',
'audio': audio_path}]})
words_from_poi = []
for hist in chat_history:
if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:
tmp = json.loads(hist['value'])
for i,poi in enumerate(tmp['poi']):
words_from_poi.append(poi['name'])
for hist in chat_history:
if hist['from']=='human' and type(hist['value'])==str:
hist['value'] = correct_sentence_with_pinyin(hist['value'],words_from_poi)
elif hist['from']=='function_call' and "arguments" in hist['value'] and 'keyword' in hist['value']["arguments"]:
hist['value']["arguments"] = eval(hist['value']["arguments"])
if 'keyword' in hist['value']["arguments"]:
hist['value']["arguments"]['keyword'] = correct_sentence_with_pinyin(hist['value']["arguments"]['keyword'],words_from_poi)
hist['value']["arguments"] = str(hist['value']["arguments"])
# model_input_history = copy.deepcopy(chat_history)
# num2ch = {1:'一',2:'二',3:'三',4:'四',5:'五',6:'六'}
# for hist in model_input_history:
# if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:
# tmp = json.loads(hist['value'])
# new_poi = []
# for i,poi in enumerate(tmp['poi']):
# new_poi.append('第{}個 : '.format(num2ch[i+1])+str(poi))
# tmp['poi'] = new_poi
# hist['value'] = json.dumps(tmp, ensure_ascii=False)
inputs_text = self.processor.apply_chat_template(
chat_history, add_generation_prompt=True, tokenize=False,
return_dict=True, return_tensors="pt", tools=json.loads(tools)
)
inputs_tensor = self.processor(text=inputs_text,
audio=[torchaudio.load(audio_path)[0]] if type(audio_path)!=type(None) else None,
add_special_tokens=False,
return_tensors='pt'
)
model_output = self.call_gpt(inputs_tensor)
if chat_history[-1]['from']=='observation':
chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(model_output,words_from_poi)})
return chat_history
if ((not ';\n' in model_output) or (not 'User transcribe is :' in model_output) or (not 'GPT output is :' in model_output)\
or len(model_output.split(';\n'))<2 ):
if chat_history[-1]['value']!="抱歉我聽不清楚 能麻煩您再說一次嗎":
chat_history.append({'from': 'human',
'value': 'HUMAN_VOICE_IS_NOT_RECOGNIZED'}),
chat_history.append({'from': 'gpt', 'value': '抱歉我聽不清楚 能麻煩您再說一次嗎'})
return chat_history
output_t,output_o = model_output.split(';\n')[:2]
output_t,output_o = self.remove_words_signs(output_t),self.remove_words_signs(output_o)
chat_history[-1]['value'] = correct_sentence_with_pinyin(output_t,words_from_poi)
if 'Action:' in output_o and 'ActionInput:' in output_o: # function calling
function_name,function_arg = output_o.split('ActionInput:')
function_name = function_name.replace('Action:','')
if "keyword" in function_arg:
function_arg = json.loads(function_arg)
if "keyword" in function_arg:
function_arg["keyword"] = correct_sentence_with_pinyin(function_arg["keyword"],words_from_poi)
chat_history.append({'from': 'function_call', 'value': {"name": function_name, "arguments": str(function_arg)}})
else: # gpt response
chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(output_o,words_from_poi)})
return chat_history
model_id = "/home/jeff/jeff/codes/llm/InCar/gemma-3-4b-it-omni"
pipeline = InferenceClass(model_id)
app = FastAPI(
title="Audio LLM API",
description="An API that accepts an audio file and a list of dictionaries.",
)
import json
dataset = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json'))
tools = dataset[0]['tools']
@app.post("/audio_llm/")
async def process_audio_and_data(
audio_file: Optional[UploadFile] = File(None, description="The audio file to be processed."),
data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.")
) -> List[Dict[str, Any]]:
try:
input_data_list = json.loads(data)
if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list):
raise ValueError("The provided data is not a list of dictionaries.")
except json.JSONDecodeError:
raise HTTPException(
status_code=422,
detail="Invalid JSON format for 'data' field. Please provide a valid JSON string."
)
except ValueError as e:
raise HTTPException(
status_code=422,
detail=str(e)
)
temp_file_path=None
if audio_file:
temp_file_path = f"./audio_path/temp_{audio_file.filename}"
with open(temp_file_path, "wb") as buffer:
buffer.write(await audio_file.read())
print(f"Audio file saved to {temp_file_path}")
output_data = pipeline.generate(input_data_list,tools=tools,audio_path=temp_file_path)
print(output_data)
return output_data
# uvicorn main:app --host 0.0.0.0 --port 8087 --log-level info --workers 1 >> ./log.txt
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)