| import uvicorn |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
| from typing import List, Dict, Any, Optional |
| import json |
| import os |
| from transformers import AutoProcessor, AutoModel, AutoTokenizer,AutoModelForSequenceClassification,AutoModelForCausalLM |
| import torch, torchaudio |
| import os |
| import openai |
| import copy |
| import numpy as np |
| from rapidfuzz import process, fuzz |
| from pypinyin import pinyin, Style |
| from copy import deepcopy |
| def get_sentence_with_pinyin(user_input_sentence, location_dict, score_cutoff=20): |
| pinyin_dict = {} |
| for location in location_dict: |
| pinyin_name = ''.join([item[0] for item in pinyin(location, style=Style.NORMAL)]) |
| pinyin_dict[pinyin_name] = location |
|
|
| user_pinyin_sentence = ''.join([item[0] for item in pinyin(user_input_sentence, style=Style.NORMAL)]) |
|
|
| best_match_pinyin = process.extractOne( |
| query=user_pinyin_sentence, |
| choices=list(pinyin_dict.keys()), |
| scorer=fuzz.token_set_ratio, |
| score_cutoff=score_cutoff |
| ) |
|
|
| if best_match_pinyin and best_match_pinyin[0] in pinyin_dict: |
| return pinyin_dict[best_match_pinyin[0]] |
| else: |
| return "" |
|
|
|
|
| class InferenceClass: |
| def __init__(self,base_url="http://localhost:8087/v1"): |
| self.model_id = "/root/merged" |
| |
| self.client = openai.OpenAI(base_url=base_url, api_key="EMPTY") |
| |
| |
| |
| |
|
|
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def call_bert_class(self,query): |
| return "SLM" |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def call_gpt(self,messages): |
| prompt = self.apply_template(messages) |
| print('input prompt is : ') |
| print(prompt) |
| print('------'*30) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| try: |
| response = self.client.completions.create( |
| model=self.model_id, prompt=prompt, max_tokens=8192,temperature=0.6, |
| top_p=0.6, |
| ) |
| response_data = { |
| "content": response.choices[0].text, |
| "finish_reason": response.choices[0].finish_reason, |
| } |
| return response_data['content'] |
| except: |
| print('call_gpt error','input',messages) |
| return "" |
| def build_prompt(self,messages): |
| role_to_tag = { |
| "user": "user", |
| "assistant": "assistant", |
| "function": "assistant", |
| "observation": "ipython", |
| "tool": "ipyton", |
| "system": "system" |
| } |
|
|
| prompt = "" |
|
|
| for i, msg in enumerate(messages): |
| role = msg["role"] |
| content = msg["content"] |
| |
| tag = role_to_tag.get(role, "user") |
|
|
| prompt += f"<|start_header_id|>{tag}<|end_header_id|>\n\n{content}<|eot_id|>" |
|
|
| last_role = messages[-1]['role'] |
| if last_role in {"user", "observation", "tool"}: |
| prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
| return prompt |
| def apply_template(self, messages): |
| gpt_out_template = """{RAG_Keywords}<Next step>{intent}; |
| <Text Output>{output} |
| """ |
| human_in_template = """{RAG_data}{Cloud_data}<User Query>{query}{LOC} |
| """ |
| short_messages = [] |
| pre = "" |
| for conv in messages: |
| if conv['from']=='human' and ("旅運" in conv['value'] or "市民" in conv['value']) and "開始" in conv['value'] and len(conv['value'])==7\ |
| and '-' in conv['value'].replace("旅運",'').replace("市民",'').replace("開始",'') \ |
| and pre<conv['value'].replace("旅運",'').replace("市民",'').replace("開始",''): |
| pre = conv['value'].replace("旅運",'').replace("市民",'').replace("開始",'') |
| short_messages = [] |
| short_messages.append(conv) |
| msg = [] |
| for c in short_messages: |
| obj = {'role': None} |
| if c['from'] == 'human': |
| tmp = deepcopy(human_in_template) |
| rag_data="" |
| cloud_data = "" |
| if 'RAG' in c and c['RAG']: |
| rag_data = "<RAG data>" |
| for k in c['RAG']: |
| if type(c['RAG'][k])==type(None):continue |
| rag_data+=k+":"+c['RAG'][k]['name']+", content:"+str(c['RAG'][k]['content'])+';\n' |
| if 'Cloud' in c and c['Cloud']: |
| cloud_data="<Cloud data>"+c['Cloud']+';\n' |
| Loc_info = "" |
| if "Loc" in c and c["Loc"]: |
| Loc_info = ", current location:"+c["Loc"] |
| tmp=tmp.format(RAG_data=rag_data,Cloud_data=cloud_data,query=c['value'],LOC=Loc_info) |
| obj.update({'role': 'user', 'content': tmp}) |
| elif c['from'] == 'gpt': |
| tmp = deepcopy(gpt_out_template) |
| rag_in = "<keywords>"+str(c['RAG_KEYWORDS'])+";\n" if 'RAG_KEYWORDS' in c and c['RAG_KEYWORDS'] else "" |
| intent = "False" if 'Next' in c and not c['Next'] else "True" |
| tmp=tmp.format(RAG_Keywords=rag_in,intent=intent,output=c['value']) |
| obj.update({'role': 'assistant', 'content': tmp}) |
| msg.append(obj) |
| input_text = self.build_prompt(msg) |
| |
| return input_text |
| |
| def generate(self,chat_history): |
| ''' |
| input: |
| chat_history : dict |
| stage : str |
| return: |
| model_output : dict |
| ''' |
| |
| |
| |
| |
| |
| |
| bert_class = "SLM" |
| print('Bert out : ',bert_class) |
| |
| if bert_class=="SLM": |
| model_output = self.call_gpt(chat_history) |
| model_output = model_output.replace("<RAG data>","<keywords>").replace('晚上晚上','晚上').replace("ShoushanLover'sObservatory","Shoushan Lover's Observatory").replace('SizihwanBay','Sizihwan Bay') |
|
|
| print(model_output) |
| tmp = {'from': 'gpt','Next':False} |
| outs = model_output.split(';\n') |
| print(outs) |
| if '<keywords>' in outs[0]: |
| tmp['RAG_KEYWORDS'] = outs[0].replace('<keywords>','').replace(' ','') |
| if '<Next step>' in outs[1]: |
| tmp['Next'] = True if 'True' in outs[1] else False |
| elif '<Next step>' in outs[0]: |
| tmp['Next'] = True if 'True' in outs[0] else False |
| tmp['value'] = outs[-1].replace('<Text Output>','').replace('\n<|eot_id|>','') |
| |
| chat_history.append(tmp) |
| else: |
| |
| |
| |
| chat_history.append({'from': 'gpt', 'value': "","Next":True,'quick':True}) |
| return chat_history |
|
|
|
|
|
|
| pipeline = InferenceClass() |
| app = FastAPI( |
| title="Audio LLM API", |
| description="An API that accepts an audio file and a list of dictionaries.", |
| ) |
|
|
|
|
| @app.post("/service_LLM/") |
| async def generate( |
| data: str = Form(..., description="A JSON string representing a list of chat history dictionaries.") |
| ) -> List[Dict[str, Any]]: |
|
|
| try: |
| input_data_list = json.loads(data) |
| if not isinstance(input_data_list, list) or not all(isinstance(item, dict) for item in input_data_list): |
| raise ValueError("The provided data is not a list of dictionaries.") |
|
|
| except json.JSONDecodeError: |
| raise HTTPException( |
| status_code=422, |
| detail="Invalid JSON format for 'data' field. Please provide a valid JSON string." |
| ) |
| except ValueError as e: |
| raise HTTPException( |
| status_code=422, |
| detail=str(e) |
| ) |
|
|
| output_data = pipeline.generate(input_data_list) |
| print(output_data) |
| return output_data |
|
|
| |
| if __name__ == "__main__": |
| uvicorn.run("main:app", host="0.0.0.0", port=8088, reload=True) |