File size: 9,390 Bytes
4c1ba5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import uvicorn
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from typing import List, Dict, Any, Optional
import json
import os
from transformers import AutoProcessor, AutoModel
import torch, torchaudio
import os

import copy
from rapidfuzz import process, fuzz
from pypinyin import pinyin, Style

def correct_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=50):
    pinyin_dict = {}
    for location in location_dict:
        pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)])
        pinyin_dict[pinyin_name] = location

    user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)])

    best_match_pinyin = process.extractOne(
        query=user_pinyin_sentence,
        choices=list(pinyin_dict.keys()), # 傳入拼音作為搜尋目標
        scorer=fuzz.token_set_ratio,
        score_cutoff=score_cutoff
    )

    if best_match_pinyin:
        best_pinyin_name = best_match_pinyin[0]
        corrected_location_name = pinyin_dict[best_pinyin_name]

        best_user_substring = None
        max_substring_score = 0
        
        for i in range(len(user_input_sentence)):
            for j in range(i + 2, min(i + 16, len(user_input_sentence) + 1)):
                substring = user_input_sentence[i:j]
                
                score = fuzz.ratio(substring, corrected_location_name)
                
                if score > max_substring_score:
                    max_substring_score = score
                    best_user_substring = substring
        
        if best_user_substring and max_substring_score > score_cutoff:
            return user_input_sentence.replace(best_user_substring, corrected_location_name, 1)
        else:
            return user_input_sentence
    return user_input_sentence

class InferenceClass:
    def __init__(self,model_id):
        self.model = AutoModel.from_pretrained(
            model_id, device_map="cuda", 
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            attn_implementation="eager"
        ).eval()

        self.processor = AutoProcessor.from_pretrained(
            model_id, trust_remote_code=True
        )
        self.remove_words_signs = lambda x:x.replace('User transcribe is :','').replace('GPT output is :','').replace('\n','').\
                            replace(' ','').replace('?','').replace('?','').replace('!','').replace('。','').\
                            replace('!','')
    def call_gpt(self,inputs_tensor):
        with torch.inference_mode():
            inputs = {k:inputs_tensor[k].to('cuda') for k in inputs_tensor}
            generate_ids = self.model.generate(**inputs, max_new_tokens=128, do_sample=False)
            generate_ids = generate_ids[:, inputs['input_ids'].shape[1] :]
            model_output = self.processor.batch_decode(
                generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]
        return model_output
    def call_function_fake(self,messages=[],obs=""):
        messages.append({'from': 'observation', 'value': obs})
        return messages
    def generate(self,chat_history,tools="",audio_path=None):
        '''
        input:
            audio_path : str
            chat_history : dict
        return:
            model_output : dict
        '''
        chat_history = copy.deepcopy(chat_history)
        if type(audio_path)!=type(None):
            chat_history.append({'from': 'human',
                                'value': [{'type': 'audio',
                                            'audio': audio_path}]})
        words_from_poi = []
        for hist in chat_history:
            if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:
                tmp = json.loads(hist['value'])
                for i,poi in enumerate(tmp['poi']):
                    words_from_poi.append(poi['name'])
        for hist in chat_history:
            if hist['from']=='human' and type(hist['value'])==str:
                hist['value'] = correct_sentence_with_pinyin(hist['value'],words_from_poi)
            elif hist['from']=='function_call' and "arguments" in hist['value'] and 'keyword' in hist['value']["arguments"]:
                hist['value']["arguments"] = eval(hist['value']["arguments"])
                if 'keyword' in hist['value']["arguments"]:
                    hist['value']["arguments"]['keyword'] = correct_sentence_with_pinyin(hist['value']["arguments"]['keyword'],words_from_poi)
                hist['value']["arguments"] = str(hist['value']["arguments"])
        # model_input_history = copy.deepcopy(chat_history)
        # num2ch = {1:'一',2:'二',3:'三',4:'四',5:'五',6:'六'}
        # for hist in model_input_history:
        #     if hist['from']=='observation' and '地點查詢成功' in hist['value'] and 'poi' in hist['value']:
        #         tmp = json.loads(hist['value'])
        #         new_poi = []
        #         for i,poi in enumerate(tmp['poi']):
        #             new_poi.append('第{}個 : '.format(num2ch[i+1])+str(poi))
        #         tmp['poi'] = new_poi
        #         hist['value'] = json.dumps(tmp, ensure_ascii=False)

        inputs_text = self.processor.apply_chat_template(
            chat_history, add_generation_prompt=True, tokenize=False,
            return_dict=True, return_tensors="pt", tools=json.loads(tools)
        )
        inputs_tensor = self.processor(text=inputs_text, 
                        audio=[torchaudio.load(audio_path)[0]] if type(audio_path)!=type(None) else None, 
                        add_special_tokens=False, 
                        return_tensors='pt'
                    )
        model_output = self.call_gpt(inputs_tensor)
        if chat_history[-1]['from']=='observation':
            chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(model_output,words_from_poi)})
            return chat_history
        if ((not ';\n' in model_output) or (not 'User transcribe is :' in model_output) or (not 'GPT output is :' in model_output)\
            or len(model_output.split(';\n'))<2 ):
            if chat_history[-1]['value']!="抱歉我聽不清楚 能麻煩您再說一次嗎":
                chat_history.append({'from': 'human',
                'value': 'HUMAN_VOICE_IS_NOT_RECOGNIZED'}),
                chat_history.append({'from': 'gpt', 'value': '抱歉我聽不清楚 能麻煩您再說一次嗎'})
            return chat_history
        output_t,output_o = model_output.split(';\n')[:2]
        output_t,output_o = self.remove_words_signs(output_t),self.remove_words_signs(output_o)
        chat_history[-1]['value'] = correct_sentence_with_pinyin(output_t,words_from_poi)
        if 'Action:' in output_o and 'ActionInput:' in output_o: # function calling
            function_name,function_arg = output_o.split('ActionInput:')
            function_name = function_name.replace('Action:','')
            if "keyword" in function_arg:
                function_arg = json.loads(function_arg)
                if "keyword" in function_arg:
                    function_arg["keyword"] = correct_sentence_with_pinyin(function_arg["keyword"],words_from_poi)
            chat_history.append({'from': 'function_call', 'value': {"name": function_name, "arguments": str(function_arg)}})
        else: # gpt response
            chat_history.append({'from': 'gpt', 'value': correct_sentence_with_pinyin(output_o,words_from_poi)})
        return chat_history



model_id =  "/home/jeff/jeff/codes/llm/InCar/gemma-3-4b-it-omni"
pipeline = InferenceClass(model_id)
app = FastAPI(
    title="Audio LLM API",
    description="An API that accepts an audio file and a list of dictionaries.",
)
import json
dataset = json.load(open('/home/jeff/jeff/codes/llm/InCar/data/test_data/nav_0730_noisy.json'))
tools = dataset[0]['tools']

@app.post("/audio_llm/")
async def process_audio_and_data(
    audio_file: Optional[UploadFile] = File(None, description="The audio file to be processed."),
    data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.")
) -> List[Dict[str, Any]]:

    try:
        input_data_list = json.loads(data)
        if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list):
            raise ValueError("The provided data is not a list of dictionaries.")

    except json.JSONDecodeError:
        raise HTTPException(
            status_code=422,
            detail="Invalid JSON format for 'data' field. Please provide a valid JSON string."
        )
    except ValueError as e:
        raise HTTPException(
            status_code=422,
            detail=str(e)
        )
    temp_file_path=None
    if audio_file:
        temp_file_path = f"./audio_path/temp_{audio_file.filename}"
        with open(temp_file_path, "wb") as buffer:
            buffer.write(await audio_file.read())
        print(f"Audio file saved to {temp_file_path}")

    output_data = pipeline.generate(input_data_list,tools=tools,audio_path=temp_file_path)
    print(output_data)
    return output_data

# uvicorn main:app --host 0.0.0.0 --port 8087 --log-level info --workers 1 >> ./log.txt
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)