|
|
import uvicorn |
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
|
|
from typing import List, Dict, Any, Optional |
|
|
import json |
|
|
import os |
|
|
from transformers import AutoProcessor, AutoModel |
|
|
import torch, torchaudio |
|
|
import os |
|
|
|
|
|
import copy |
|
|
from rapidfuzz import process, fuzz |
|
|
from pypinyin import pinyin, Style |
|
|
|
|
|
def correct_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=50): |
|
|
pinyin_dict = {} |
|
|
for location in location_dict: |
|
|
pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)]) |
|
|
pinyin_dict[pinyin_name] = location |
|
|
|
|
|
user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)]) |
|
|
|
|
|
best_match_pinyin = process.extractOne( |
|
|
query=user_pinyin_sentence, |
|
|
choices=list(pinyin_dict.keys()), |
|
|
scorer=fuzz.token_set_ratio, |
|
|
score_cutoff=score_cutoff |
|
|
) |
|
|
|
|
|
if best_match_pinyin: |
|
|
best_pinyin_name = best_match_pinyin[0] |
|
|
corrected_location_name = pinyin_dict[best_pinyin_name] |
|
|
|
|
|
best_user_substring = None |
|
|
max_substring_score = 0 |
|
|
|
|
|
for i in range(len(user_input_sentence)): |
|
|
for j in range(i + 2, min(i + 16, len(user_input_sentence) + 1)): |
|
|
substring = user_input_sentence[i:j] |
|
|
|
|
|
score = fuzz.ratio(substring, corrected_location_name) |
|
|
|
|
|
if score > max_substring_score: |
|
|
max_substring_score = score |
|
|
best_user_substring = substring |
|
|
|
|
|
if best_user_substring and max_substring_score > score_cutoff: |
|
|
return user_input_sentence.replace(best_user_substring, corrected_location_name, 1) |
|
|
else: |
|
|
return user_input_sentence |
|
|
return user_input_sentence |
|
|
|
|
|
class InferenceClass: |
|
|
def __init__(self,model_id): |
|
|
self.model = AutoModel.from_pretrained( |
|
|
model_id, device_map="cuda", |
|
|
torch_dtype=torch.bfloat16, |
|
|
trust_remote_code=True, |
|
|
attn_implementation="eager" |
|
|
).eval() |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
|
model_id, trust_remote_code=True |
|
|
) |
|
|
self.remove_words_signs = lambda x:x.replace('User transcribe is :','').replace('GPT output is :','').replace('\n','').\ |
|
|
replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\ |
|
|
replace('!','') |
|
|
def call_gpt(self,inputs_tensor): |
|
|
with torch.inference_mode(): |
|
|
inputs = {k:inputs_tensor[k].to('cuda') for k in inputs_tensor} |
|
|
generate_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False) |
|
|
generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :] |
|
|
model_output = self.processor.batch_decode( |
|
|
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
return model_output |
|
|
def call_function_fake(self,messages=[],obs=""): |
|
|
messages.append({'from': 'observation', 'value': obs}) |
|
|
return messages |
|
|
def generate(self,chat_history,tools="",audio_path=None): |
|
|
''' |
|
|
input: |
|
|
audio_path : str |
|
|
chat_history : dict |
|
|
return: |
|
|
model_output : dict |
|
|
''' |
|
|
chat_history = copy.deepcopy(chat_history) |
|
|
if type(audio_path)!=type(None): |
|
|
chat_history.append({'from': 'human', |
|
|
'value': [{'type': 'audio', |
|
|
'audio': audio_path}]}) |
|
|
words_from_poi = [] |
|
|
for hist in chat_history: |
|
|
if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']: |
|
|
tmp = json.loads(hist['value']) |
|
|
for i,poi in enumerate(tmp['poi']): |
|
|
words_from_poi.append(poi['name']) |
|
|
for hist in chat_history: |
|
|
if hist['from']=='human' and type(hist['value'])==str: |
|
|
hist['value'] = correct_sentence_with_pinyin(hist['value'],words_from_poi) |
|
|
elif hist['from']=='function_call' and "arguments" in hist['value'] and 'keyword' in hist['value']["arguments"]: |
|
|
hist['value']["arguments"] = eval(hist['value']["arguments"]) |
|
|
if 'keyword' in hist['value']["arguments"]: |
|
|
hist['value']["arguments"]['keyword'] = correct_sentence_with_pinyin(hist['value']["arguments"]['keyword'],words_from_poi) |
|
|
hist['value']["arguments"] = str(hist['value']["arguments"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs_text = self.processor.apply_chat_template( |
|
|
chat_history, add_generation_prompt=True, tokenize=False, |
|
|
return_dict=True, return_tensors="pt", tools=json.loads(tools) |
|
|
) |
|
|
inputs_tensor = self.processor(text=inputs_text, |
|
|
audio=[torchaudio.load(audio_path)[0]] if type(audio_path)!=type(None) else None, |
|
|
add_special_tokens=False, |
|
|
return_tensors='pt' |
|
|
) |
|
|
model_output = self.call_gpt(inputs_tensor) |
|
|
if chat_history[-1]['from']=='observation': |
|
|
chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(model_output,words_from_poi)}) |
|
|
return chat_history |
|
|
if ((not ';\n' in model_output) or (not 'User transcribe is :' in model_output) or (not 'GPT output is :' in model_output)\ |
|
|
or len(model_output.split(';\n'))<2 ): |
|
|
if chat_history[-1]['value']!="抱歉我聽不清楚 能麻煩您再說一次嗎": |
|
|
chat_history.append({'from': 'human', |
|
|
'value': 'HUMAN_VOICE_IS_NOT_RECOGNIZED'}), |
|
|
chat_history.append({'from': 'gpt', 'value': '抱歉我聽不清楚 能麻煩您再說一次嗎'}) |
|
|
return chat_history |
|
|
output_t,output_o = model_output.split(';\n')[:2] |
|
|
output_t,output_o = self.remove_words_signs(output_t),self.remove_words_signs(output_o) |
|
|
chat_history[-1]['value'] = correct_sentence_with_pinyin(output_t,words_from_poi) |
|
|
if 'Action:' in output_o and 'ActionInput:' in output_o: |
|
|
function_name,function_arg = output_o.split('ActionInput:') |
|
|
function_name = function_name.replace('Action:','') |
|
|
if "keyword" in function_arg: |
|
|
function_arg = json.loads(function_arg) |
|
|
if "keyword" in function_arg: |
|
|
function_arg["keyword"] = correct_sentence_with_pinyin(function_arg["keyword"],words_from_poi) |
|
|
chat_history.append({'from': 'function_call', 'value': {"name": function_name, "arguments": str(function_arg)}}) |
|
|
else: |
|
|
chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(output_o,words_from_poi)}) |
|
|
return chat_history |
|
|
|
|
|
|
|
|
|
|
|
model_id = "/home/jeff/jeff/codes/llm/InCar/gemma-3-4b-it-omni" |
|
|
pipeline = InferenceClass(model_id) |
|
|
app = FastAPI( |
|
|
title="Audio LLM API", |
|
|
description="An API that accepts an audio file and a list of dictionaries.", |
|
|
) |
|
|
import json |
|
|
dataset = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json')) |
|
|
tools = dataset[0]['tools'] |
|
|
|
|
|
@app.post("/audio_llm/") |
|
|
async def process_audio_and_data( |
|
|
audio_file: Optional[UploadFile] = File(None, description="The audio file to be processed."), |
|
|
data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.") |
|
|
) -> List[Dict[str, Any]]: |
|
|
|
|
|
try: |
|
|
input_data_list = json.loads(data) |
|
|
if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list): |
|
|
raise ValueError("The provided data is not a list of dictionaries.") |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
raise HTTPException( |
|
|
status_code=422, |
|
|
detail="Invalid JSON format for 'data' field. Please provide a valid JSON string." |
|
|
) |
|
|
except ValueError as e: |
|
|
raise HTTPException( |
|
|
status_code=422, |
|
|
detail=str(e) |
|
|
) |
|
|
temp_file_path=None |
|
|
if audio_file: |
|
|
temp_file_path = f"./audio_path/temp_{audio_file.filename}" |
|
|
with open(temp_file_path, "wb") as buffer: |
|
|
buffer.write(await audio_file.read()) |
|
|
print(f"Audio file saved to {temp_file_path}") |
|
|
|
|
|
output_data = pipeline.generate(input_data_list,tools=tools,audio_path=temp_file_path) |
|
|
print(output_data) |
|
|
return output_data |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) |