| """ | |
| 1. 这里用公网Qwen API来代替本地的ChatGLM模型。为了在Huggingface上演示。 | |
| 1. 使用确定了列名作为SQL语句的变量名,可以有效解决模型生成的SQL语句中变量名准确的问题。 | |
| """ | |
| ##TODO: | |
| import requests | |
| import os | |
| from rich import print | |
| import os | |
| import sys | |
| import time | |
| import pandas as pd | |
| import numpy as np | |
| import sys | |
| import time | |
| from typing import Any | |
| import requests | |
| import csv | |
| import os | |
| from rich import print | |
| import pandas | |
| import io | |
| from io import StringIO | |
| import re | |
| from langchain.llms.utils import enforce_stop_tokens | |
| import json | |
| from transformers import AutoModel, AutoTokenizer | |
| import mdtex2html | |
| import qwen_response | |
| ''' Start: Environment settings. ''' | |
| os.environ['SENTENCE_TRANSFORMERS_HOME'] = '/Users/yunshi/Downloads/chatGLM/My_LocalKB_Project/' | |
| os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
| import torch | |
| mps_device = torch.device("mps") ## 在mac机器上需要加上这句。必须要有这句,否则会报错。 | |
| ### 在langchain中定义chatGLM作为LLM。 | |
| from typing import Any, List, Mapping, Optional | |
| from langchain.callbacks.manager import CallbackManagerForLLMRun | |
| from langchain.llms.base import LLM | |
| from transformers import AutoTokenizer, AutoModel | |
| # llm_filepath = str("/Users/yunshi/Downloads/chatGLM/ChatGLM3-6B/6B") ## 第三代chatGLM 6B W/ code-interpreter | |
| # ## API模式启动ChatGLM | |
| # ## 配置ChatGLM的类与后端api server对应。 | |
| # class ChatGLM(LLM): | |
| # max_token: int = 2048 | |
| # temperature: float = 0.1 | |
| # top_p = 0.9 | |
| # history = [] | |
| # def __init__(self): | |
| # super().__init__() | |
| # @property | |
| # def _llm_type(self) -> str: | |
| # return "ChatGLM" | |
| # def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
| # # headers中添加上content-type这个参数,指定为json格式 | |
| # headers = {'Content-Type': 'application/json'} | |
| # data=json.dumps({ | |
| # 'prompt':prompt, | |
| # 'temperature':self.temperature, | |
| # 'history':self.history, | |
| # 'max_length':self.max_token | |
| # }) | |
| # print("ChatGLM prompt:",prompt) | |
| # # 调用api | |
| # # response = requests.post("http://0.0.0.0:8000",headers=headers,data=data) ##working。 | |
| # response = requests.post("http://127.0.0.1:8000",headers=headers,data=data) ##working。 | |
| # print("ChatGLM resp:", response) | |
| # if response.status_code!=200: | |
| # return "查询结果错误" | |
| # resp = response.json() | |
| # if stop is not None: | |
| # response = enforce_stop_tokens(response, stop) | |
| # self.history = self.history+[[None, resp['response']]] ##original | |
| # return resp['response'] ##original. | |
| # llm = ChatGLM() ## 启动一个实例。orignal working。 | |
| # import asyncio | |
| # llm = ChatGLM() ## 启动一个实例。 | |
| ''' End: Environment settings. ''' | |
| ### 我会用中文或者英文双引号(即:“ ”," ")来告知你变量的名称。 长度","宽度","价格","产品ID","比率","类别","*" | |
| ### 用ChatGLM构建一个只返回SQL语句的模型。 | |
| def main(prompt): | |
| full_reponse = [] | |
| sys_prompt = """ | |
| 1. 你是一个将文字转换成SQL语句的人工智能。 | |
| 2. 你需要注意:你只需要用纯文本回复代码的内容,即你不允许回复代码以外的任何信息。 | |
| 3. SQL变量默认是中文,而且只能从如下的名称列表中选择,你不可以使用这些名字以外的变量名:"长度","宽度","价格","产品ID","比率","类别","*" | |
| 4. 你不能写IF, THEN的SQL语句,需要使用CASE。 | |
| 5. 我需要你转换的文字如下:""" | |
| total_prompt = sys_prompt + "在数据表格table01中," + prompt | |
| print('total prompt now:',total_prompt) | |
| # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。 | |
| # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(prompt)): ## 这里保留了所有的chat history在input_prompt中。 | |
| # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(total_prompt)): ## 这里保留了所有的chat history在input_prompt中。 | |
| # # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0])): ## 从用langchain的自定义方式来做。 | |
| # # for response, history in chatglm.model.stream_chat(chatglm.tokenizer, query=str(input_prompt[-1][0]), history=input_prompt, max_length=max_tokens, top_p=top_p, temperature=temperature): ## 从用langchain的自定义方式来做。 | |
| # if response != "<br>": | |
| # # print('response of model:', response) | |
| # # input_prompt[-1][1] = response ## working. | |
| # # input_prompt[-1][1] = response | |
| # # yield input_prompt | |
| # full_reponse.append(response) | |
| # ## 得到一个非stream格式的答复。非API模式。 | |
| # response, history = chatglm.model.chat(chatglm.tokenizer, query=str(total_prompt), temperature=0.1) ## 这里保留了所有的chat history在input_prompt中。 | |
| ###TODO:API模式,需要先启动API服务器。 | |
| # llm = ChatGLM() ##!! 重要说明:每次都需要实例化一次!!!否则会报错content error。实际上是应该在每次函数调用的时候都要实例化一次! | |
| # response = llm(total_prompt) ## 这里是本地的ChatGLM来作为大模型输出基座。 | |
| response = qwen_response.call_with_messages(total_prompt) | |
| print('response of model:', response) | |
| ## 用regex来提取纯SQL语句。需要构建多个正则式pattern | |
| pattern_1 = r"(?:`sql\n|\n`)" | |
| pattern_2 = r"(?:```|``)" | |
| pattern_3 = r"(?s)(.*?SQL语句示例.*?:).*?\n" | |
| pattern_4 = r"(?:`{3}|`{2}|`)" | |
| # pattern_5 = r"[\u4e00-\u9FFF]" ## 匹配中文。 | |
| # pattern_6 = r"^[\u4e00-\u9fa5]{5,}" ## 首行中包含5个中文汉字的。 | |
| pattern_7 = r"^.{0,2}([\u4e00-\u9fa5]{5,}).*" ## 首行中包含5个中文汉字的。 | |
| pattern_8 = r'^"|"$' ## 去除一句话开始或者末尾的英文双引号 | |
| pattern_list = [pattern_1, pattern_2, pattern_3, pattern_4, pattern_7, pattern_8] | |
| ## 遍历所有的pattern,逐个去除。 | |
| full_reponse = response | |
| for p in pattern_list: | |
| full_reponse = re.sub(p, "", full_reponse) | |
| # final_response = re.sub(pattern_1, "", response) ## 逐步匹配。 | |
| # final_response = re.sub(pattern_1, "", response) ## 逐步匹配。 | |
| # final_response = re.sub(pattern_2, "", final_response) ## 逐步匹配。 | |
| return full_reponse | |
| # prompt = "你给我一段复杂的SQL语句示例。" | |
| # prompt = "你给我一段SQL语句,用来完成如下工作:查询年龄大于30岁,男性,收入超过2万元的员工。" | |
| # res = main(prompt=prompt) | |
| # print(res) | |