Spaces:
Runtime error
Runtime error
remove extra download
Browse files
app.py
CHANGED
|
@@ -3,37 +3,364 @@ import torch
|
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
from peft import PeftModel
|
| 5 |
from modelscope.hub.snapshot_download import snapshot_download
|
|
|
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
#
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
# 4. 加载基础模型和 Tokenizer
|
| 18 |
-
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
|
| 19 |
-
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
-
# (
|
| 26 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
model.eval()
|
| 29 |
|
| 30 |
-
#
|
| 31 |
def inference(text_input):
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
iface.launch()
|
|
|
|
| 3 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 4 |
from peft import PeftModel
|
| 5 |
from modelscope.hub.snapshot_download import snapshot_download
|
| 6 |
+
import os
|
| 7 |
|
| 8 |
+
# --- 自动检测设备 ---
|
| 9 |
+
# 检查是否有可用的 GPU (NVIDIA CUDA)
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
device = "cuda"
|
| 12 |
+
torch_dtype = torch.float16 # GPU 使用 float16
|
| 13 |
+
device_map = "auto" # 自动分配到 GPU
|
| 14 |
+
print("检测到 CUDA (GPU),将使用 GPU。")
|
| 15 |
+
# 检查是否有可用的 Apple Silicon (MPS)
|
| 16 |
+
elif torch.backends.mps.is_available():
|
| 17 |
+
device = "mps"
|
| 18 |
+
torch_dtype = torch.float16 # Apple Silicon GPU 也可以用 float16
|
| 19 |
+
device_map = "mps" # 明确指定 mps
|
| 20 |
+
print("检测到 MPS (Apple Silicon),将使用 MPS。")
|
| 21 |
+
# 否则回退到 CPU
|
| 22 |
+
else:
|
| 23 |
+
device = "cpu"
|
| 24 |
+
torch_dtype = torch.float32 # CPU 使用 float32 (float16 在 CPU 上很慢或不支持)
|
| 25 |
+
device_map = "cpu" # 明确指定 cpu
|
| 26 |
+
print("未检测到 GPU,将使用 CPU。")
|
| 27 |
+
# ---------------------
|
| 28 |
|
| 29 |
+
|
| 30 |
+
# --- 关键配置 ---
|
| 31 |
+
|
| 32 |
+
# 1. 您的 LoRA 模型 ID
|
| 33 |
+
# LORA_MODEL_ID = "/mnt/DataFlow/ydw/model/risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data/"
|
| 34 |
+
# lora_model_dir = LORA_MODEL_ID
|
| 35 |
LORA_MODEL_ID = "risemds/UniVectorSQL-7B-LoRA-all_steps_1030_new_data"
|
| 36 |
|
| 37 |
+
# 2. 基础模型 ID (!!!)
|
| 38 |
+
# 您必须找到这个 LoRA 对应的基础模型是什么。
|
| 39 |
+
# BASE_MODEL_ID = "/mnt/DataFlow/ydw/model/seeklhy/OmniSQL-7B/"
|
| 40 |
+
BASE_MODEL_ID = "seeklhy/OmniSQL-7B"
|
| 41 |
+
|
| 42 |
+
# -----------------
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# 使用 ignore_patterns 避免下载不需要的 checkpoint/文件
|
| 46 |
+
# 我们将忽略 .md 文件和 training_args.bin 文件
|
| 47 |
+
print(f"开始下载 LoRA 适配器: {LORA_MODEL_ID}")
|
| 48 |
+
lora_model_dir = snapshot_download(
|
| 49 |
+
LORA_MODEL_ID,
|
| 50 |
+
revision='master',
|
| 51 |
+
ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"]
|
| 52 |
+
)
|
| 53 |
+
print(f"LoRA 适配器下载完成,路径: {lora_model_dir}")
|
| 54 |
+
|
| 55 |
+
# 1. 加载 Tokenizer
|
| 56 |
+
# 优先使用 LoRA 仓库中的 Tokenizer,因为它可能已更新(例如添加了新 token)
|
| 57 |
+
try:
|
| 58 |
+
tokenizer = AutoTokenizer.from_pretrained(lora_model_dir, trust_remote_code=True)
|
| 59 |
+
print("从 LoRA 目录加载 Tokenizer 成功。")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
print(f"从 LoRA 目录加载 Tokenizer 失败 ({e}), 尝试从基础模型加载。")
|
| 62 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, trust_remote_code=True)
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
# 2. 加载基础模型
|
| 66 |
+
print(f"开始加载基础模型: {BASE_MODEL_ID}")
|
| 67 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 68 |
+
BASE_MODEL_ID,
|
| 69 |
+
device_map=device_map, # <--- 修改:使用自动检测的 device_map
|
| 70 |
+
torch_dtype=torch_dtype, # <--- 修改:使用自动检测的 torch_dtype
|
| 71 |
+
trust_remote_code=True
|
| 72 |
+
)
|
| 73 |
+
print("基础模型加载完成。")
|
| 74 |
|
| 75 |
+
# --- (新增) 修复 Qwen 模型的 pad_token_id ---
|
| 76 |
+
# Qwen1.5 基础模型可能没有设置 pad_token_id,这会在推理时产生警告
|
| 77 |
+
# 我们将其设置为 eos_token_id
|
| 78 |
+
if tokenizer.pad_token_id is None:
|
| 79 |
+
print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。")
|
| 80 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 81 |
+
|
| 82 |
+
if model.config.pad_token_id is None:
|
| 83 |
+
print("Model config 未设置 pad_token_id,将其设置为 eos_token_id。")
|
| 84 |
+
model.config.pad_token_id = tokenizer.eos_token_id
|
| 85 |
+
# ----------------------------------------
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# 3. 加载并融合 LoRA 适配器
|
| 89 |
+
print(f"开始加载 LoRA 适配器到基础模型...")
|
| 90 |
+
model = PeftModel.from_pretrained(
|
| 91 |
+
model,
|
| 92 |
+
lora_model_dir,
|
| 93 |
+
device_map=device_map # <--- 修改:传递自动检测的 device_map
|
| 94 |
+
)
|
| 95 |
+
print("LoRA 适配器加载成功。")
|
| 96 |
+
|
| 97 |
+
# (可选) 合并权重以加快推理速度,但这会占用更多内存
|
| 98 |
+
# 注意:在 CPU 上合并可能非常慢
|
| 99 |
+
print("正在合并 LoRA 权重...")
|
| 100 |
+
model = model.merge_and_unload()
|
| 101 |
+
print("权重合并完成。")
|
| 102 |
|
| 103 |
model.eval()
|
| 104 |
|
| 105 |
+
# 4. 定义推理函数 (*** 已修改为使用 Qwen 对话模板 ***)
|
| 106 |
def inference(text_input):
|
| 107 |
+
print(f"收到输入: {text_input}")
|
| 108 |
+
try:
|
| 109 |
+
# 1. 构建 Qwen (ChatML) 对话格式
|
| 110 |
+
# 假设 text_input 是完整的用户提示 (包含指令、Schema等)
|
| 111 |
+
messages = [
|
| 112 |
+
{"role": "user", "content": text_input}
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
# 2. 应用对话模板
|
| 116 |
+
# tokenizer.apply_chat_template 会自动处理特殊 tokens (例如 <|im_start|>)
|
| 117 |
+
# 并且 add_generation_prompt=True 会添加 <|im_start|>assistant\n
|
| 118 |
+
# 这会告诉模型开始生成回复
|
| 119 |
+
inputs = tokenizer.apply_chat_template(
|
| 120 |
+
messages,
|
| 121 |
+
return_tensors="pt",
|
| 122 |
+
add_generation_prompt=True
|
| 123 |
+
).to(model.device)
|
| 124 |
+
|
| 125 |
+
# 3. 获取输入 token 的长度
|
| 126 |
+
# 这里的 inputs 是 token IDs tensor, shape [1, sequence_length]
|
| 127 |
+
input_len = inputs.shape[1]
|
| 128 |
+
|
| 129 |
+
# 4. 获取 EOT (End of Text) token IDs
|
| 130 |
+
# Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643)
|
| 131 |
+
# tokenizer.eos_token_id 应该已经正确设置 (来自 Qwen1.5 基础模型)
|
| 132 |
+
eot_ids = [tokenizer.eos_token_id]
|
| 133 |
+
|
| 134 |
+
# 额外检查 <|im_end|>,确保它也在停止列表中
|
| 135 |
+
im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 136 |
+
if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids:
|
| 137 |
+
eot_ids.append(im_end_token_id)
|
| 138 |
+
|
| 139 |
+
print(f"使用 EOT token IDs: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})")
|
| 140 |
+
|
| 141 |
+
# 5. 生成
|
| 142 |
+
# 使用 eos_token_id 列表来确保模型在 <|im_end|> 处停止
|
| 143 |
+
outputs = model.generate(
|
| 144 |
+
input_ids=inputs,
|
| 145 |
+
max_new_tokens=2048,
|
| 146 |
+
eos_token_id=eot_ids, # <--- 关键:告诉模型何时停止
|
| 147 |
+
pad_token_id=tokenizer.eos_token_id # 避免 HF warning
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# 6. 从输出的 token 序列中,只选择新生成的部分
|
| 151 |
+
# outputs[0] 包含 (输入 + 生成) 的所有 tokens
|
| 152 |
+
new_tokens = outputs[0][input_len:]
|
| 153 |
+
|
| 154 |
+
# 7. 只解码新生成的 token
|
| 155 |
+
# skip_special_tokens=True 会移除任何 <|im_end|>
|
| 156 |
+
result = tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 157 |
+
|
| 158 |
+
print(f"生成结果 (仅新内容): {result}")
|
| 159 |
+
return result
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"推理时出错: {e}")
|
| 162 |
+
return f"错误: {e}"
|
| 163 |
+
|
| 164 |
+
# ----------------------------------------------------
|
| 165 |
+
|
| 166 |
+
example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context.
|
| 167 |
+
|
| 168 |
+
## INSTRUCTIONS
|
| 169 |
+
1. **Backend Adherence**: The query MUST be written for the `clickhouse` database backend. This is a strict requirement.
|
| 170 |
+
2. **Follow Special Notes**: You MUST strictly follow all syntax, functions, or constraints described in the [Database Backend Notes]. Pay extremely close attention to this section, as it contains critical, non-standard rules.
|
| 171 |
+
3. **Schema Integrity**: The query MUST ONLY use the tables and columns provided in the [Database Schema]. Do not invent or guess table or column names.
|
| 172 |
+
4. **Answer the Question**: The query must directly and accurately answer the [Natural Language Question].
|
| 173 |
+
5. **Output Format**: Enclose the final SQL query in a single Markdown code block formatted for SQL (` ```sql ... ``` `).
|
| 174 |
+
6. **Embedding Match**: If the [EMBEDDING_MODEL_NAME] parameter is a valid string (e.g., 'intfloat/E5-Mistral-7B-Instruct'), you MUST generate a query that includes the WHERE [EMBEDDING_COLUMN_NAME] MATCH lembed(...) clause for vector search. Otherwise, if embedding model name below the [EMBEDDING MODEL NAME] is None, , you MUST generate a standard SQL query that OMITS the entire MATCH lembed(...) clause. The query should not perform any vector search.
|
| 175 |
+
7. **Embedding Name**: If a value is provided for the parameter `[EMBEDDING_MODEL_NAME]`, your generated query must contain a `lembed` function call. The first parameter to the `lembed` function MUST be the exact value of `[EMBEDDING_MODEL_NAME]`, formatted as a string literal (enclosed in single quotes). For example, if `[EMBEDDING_MODEL_NAME]` is `laion/CLIP-ViT-B-32-laion2B-s34B-b79K`, the generated SQL must include `MATCH lembed('laion/CLIP-ViT-B-32-laion2B-s34B-b79K', ...)`.
|
| 176 |
+
|
| 177 |
+
## DATABASE CONTEXT
|
| 178 |
+
|
| 179 |
+
[DATABASE BACKEND]:
|
| 180 |
+
clickhouse
|
| 181 |
+
|
| 182 |
+
[DATABASE SCHEMA]:
|
| 183 |
+
CREATE TABLE CAMPAIGN_RESULTS (
|
| 184 |
+
`result_id` Nullable(Int64),
|
| 185 |
+
`campaign_id` Nullable(Int64),
|
| 186 |
+
`territory_id` Nullable(Int64),
|
| 187 |
+
`menu_item_id` Nullable(Int64),
|
| 188 |
+
`sales_increase_percentage` Nullable(Float64),
|
| 189 |
+
`customer_engagement_score` Nullable(Float64),
|
| 190 |
+
`feedback_improvement` Nullable(Float64)
|
| 191 |
+
);
|
| 192 |
+
CREATE TABLE CUSTOMERS (
|
| 193 |
+
`customer_id` Nullable(Int64),
|
| 194 |
+
`customer_name` Nullable(String),
|
| 195 |
+
`email` Nullable(String),
|
| 196 |
+
`phone_number` Nullable(String),
|
| 197 |
+
`loyalty_points` Nullable(Int64)
|
| 198 |
+
);
|
| 199 |
+
CREATE TABLE CUSTOMER_FEEDBACK (
|
| 200 |
+
`feedback_id` Nullable(Int64),
|
| 201 |
+
`menu_item_id` Nullable(Int64),
|
| 202 |
+
`customer_id` Nullable(Int64),
|
| 203 |
+
`feedback_date` Nullable(String),
|
| 204 |
+
`rating` Nullable(Float64),
|
| 205 |
+
`comments` Nullable(String),
|
| 206 |
+
`feedback_type` Nullable(String),
|
| 207 |
+
`comments_embedding` Array(Float32)
|
| 208 |
+
);
|
| 209 |
+
CREATE TABLE INGREDIENTS (
|
| 210 |
+
`ingredient_id` Nullable(Int64),
|
| 211 |
+
`name` Nullable(String),
|
| 212 |
+
`description` Nullable(String),
|
| 213 |
+
`supplier_id` Nullable(Int64),
|
| 214 |
+
`cost_per_unit` Nullable(Float64),
|
| 215 |
+
`description_embedding` Array(Float32)
|
| 216 |
+
);
|
| 217 |
+
CREATE TABLE MARKETING_CAMPAIGNS (
|
| 218 |
+
`campaign_id` Nullable(Int64),
|
| 219 |
+
`campaign_name` Nullable(String),
|
| 220 |
+
`start_date` Nullable(String),
|
| 221 |
+
`end_date` Nullable(String),
|
| 222 |
+
`budget` Nullable(Float64),
|
| 223 |
+
`objective` Nullable(String),
|
| 224 |
+
`territory_id` Nullable(Int64)
|
| 225 |
+
);
|
| 226 |
+
CREATE TABLE MENU_CATEGORIES (
|
| 227 |
+
`category_id` Nullable(Int64),
|
| 228 |
+
`category_name` Nullable(String),
|
| 229 |
+
`description` Nullable(String),
|
| 230 |
+
`parent_category_id` Nullable(Int64),
|
| 231 |
+
`description_embedding` Array(Float32)
|
| 232 |
+
);
|
| 233 |
+
CREATE TABLE MENU_ITEMS (
|
| 234 |
+
`menu_item_id` Nullable(Int64),
|
| 235 |
+
`territory_id` Nullable(Int64),
|
| 236 |
+
`name` Nullable(String),
|
| 237 |
+
`price_inr` Nullable(Float64),
|
| 238 |
+
`price_usd` Nullable(Float64),
|
| 239 |
+
`price_eur` Nullable(Float64),
|
| 240 |
+
`category_id` Nullable(Int64),
|
| 241 |
+
`menu_type` Nullable(String),
|
| 242 |
+
`calories` Nullable(Int64),
|
| 243 |
+
`is_vegetarian` Nullable(Int64),
|
| 244 |
+
`promotion_id` Nullable(Int64)
|
| 245 |
+
);
|
| 246 |
+
CREATE TABLE MENU_ITEM_INGREDIENTS (
|
| 247 |
+
`menu_item_id` Nullable(Int64),
|
| 248 |
+
`ingredient_id` Nullable(Int64),
|
| 249 |
+
`quantity` Nullable(Float64),
|
| 250 |
+
`unit_of_measurement` Nullable(String)
|
| 251 |
+
);
|
| 252 |
+
CREATE TABLE PRICING_STRATEGIES (
|
| 253 |
+
`strategy_id` Nullable(Int64),
|
| 254 |
+
`strategy_name` Nullable(String),
|
| 255 |
+
`description` Nullable(String),
|
| 256 |
+
`territory_id` Nullable(Int64),
|
| 257 |
+
`effective_date` Nullable(String),
|
| 258 |
+
`end_date` Nullable(String),
|
| 259 |
+
`description_embedding` Array(Float32)
|
| 260 |
+
);
|
| 261 |
+
CREATE TABLE PROMOTIONS (
|
| 262 |
+
`promotion_id` Nullable(Int64),
|
| 263 |
+
`promotion_name` Nullable(String),
|
| 264 |
+
`start_date` Nullable(String),
|
| 265 |
+
`end_date` Nullable(String),
|
| 266 |
+
`discount_percentage` Nullable(Float64),
|
| 267 |
+
`category_id` Nullable(Int64),
|
| 268 |
+
`territory_id` Nullable(Int64)
|
| 269 |
+
);
|
| 270 |
+
CREATE TABLE SALES_DATA (
|
| 271 |
+
`sale_id` Nullable(Int64),
|
| 272 |
+
`menu_item_id` Nullable(Int64),
|
| 273 |
+
`territory_id` Nullable(Int64),
|
| 274 |
+
`sale_date` Nullable(String),
|
| 275 |
+
`quantity_sold` Nullable(Int64),
|
| 276 |
+
`total_revenue` Nullable(Float64),
|
| 277 |
+
`discount_applied` Nullable(Float64),
|
| 278 |
+
`customer_id` Nullable(Int64)
|
| 279 |
+
);
|
| 280 |
+
CREATE TABLE SALES_FORECAST (
|
| 281 |
+
`forecast_id` Nullable(Int64),
|
| 282 |
+
`menu_item_id` Nullable(Int64),
|
| 283 |
+
`territory_id` Nullable(Int64),
|
| 284 |
+
`forecast_date` Nullable(String),
|
| 285 |
+
`forecast_quantity` Nullable(Int64),
|
| 286 |
+
`forecast_revenue` Nullable(Float64),
|
| 287 |
+
`prediction_accuracy` Nullable(Float64)
|
| 288 |
+
);
|
| 289 |
+
CREATE TABLE SUPPLIERS (
|
| 290 |
+
`supplier_id` Nullable(Int64),
|
| 291 |
+
`supplier_name` Nullable(String),
|
| 292 |
+
`contact_email` Nullable(String),
|
| 293 |
+
`phone_number` Nullable(String),
|
| 294 |
+
`address` Nullable(String)
|
| 295 |
+
);
|
| 296 |
+
CREATE TABLE TERRITORIES (
|
| 297 |
+
`territory_id` Nullable(Int64),
|
| 298 |
+
`territory_name` Nullable(String),
|
| 299 |
+
`region` Nullable(String),
|
| 300 |
+
`contact_email` Nullable(String),
|
| 301 |
+
`local_tax_rate` Nullable(Float64),
|
| 302 |
+
`currency_code` Nullable(String)
|
| 303 |
+
);
|
| 304 |
+
CREATE TABLE USERS (
|
| 305 |
+
`user_id` Nullable(Int64),
|
| 306 |
+
`user_name` Nullable(String),
|
| 307 |
+
`email` Nullable(String),
|
| 308 |
+
`role_id` Nullable(Int64),
|
| 309 |
+
`territory_id` Nullable(Int64)
|
| 310 |
+
);
|
| 311 |
+
CREATE TABLE USER_ROLES (
|
| 312 |
+
`role_id` Nullable(Int64),
|
| 313 |
+
`role_name` Nullable(String),
|
| 314 |
+
`description` Nullable(String),
|
| 315 |
+
`permissions` Nullable(String),
|
| 316 |
+
`description_embedding` Array(Float32)
|
| 317 |
+
);
|
| 318 |
+
|
| 319 |
+
[DATABASE BACKEND NOTES]:
|
| 320 |
+
There are a few requirements you should comply with in addition:
|
| 321 |
+
1. When generating SQL queries, you should prioritize utilizing K-Nearest Neighbor (KNN) searches whenever contextually appropriate. However, you must avoid unnecessary/forced KNN implementations for:
|
| 322 |
+
-- Traditional relational data queries (especially for columns like: id, age, price).
|
| 323 |
+
-- Cases where standard SQL operators (equality, range, or aggregation functions) are more efficient and semantically appropriate.
|
| 324 |
+
2. Only columns with a vector type (like: Array(Float32)) support KNN queries. The names of these vector columns often end with "_embedding". You can perform KNN searches when the column name you need to query ends with "_embedding" or is otherwise identified as a vector column.
|
| 325 |
+
3. In ClickHouse, vector similarity search is performed using distance functions. You must explicitly calculate the distance in the SELECT clause using a function like L2Distance and give it an alias, typically "AS distance". This distance alias will not be implicitly generated.
|
| 326 |
+
4. The lembed function is used to transform a string into a semantic vector. This function should be used within a WITH clause to define the reference vector. The lembed function has two parameters: the first is the name of the embedding model used (default value: 'intfloat/E5-Mistral-7B-Instruct'), and the second is the string content to embed. The resulting vector should be given an alias in the WITH clause.
|
| 327 |
+
5. You must generate plausible and semantically relevant words or sentences for the second parameter of the lembed function based on the column's name, type, and comment. For example, if a column is named product_description_embedding and its comment is "Embedding of the product's features and marketing text", you could generate text like "durable and waterproof outdoor adventure camera".
|
| 328 |
+
6. Every KNN search query MUST conclude with "ORDER BY distance LIMIT N" to retrieve the top-N most similar results. The LIMIT clause is mandatory for performing a KNN search and ensuring predictable performance.
|
| 329 |
+
7. When combining a vector search with JOIN operations, the standard WHERE clause should be used to apply filters from any of the joined tables. The ORDER BY distance LIMIT N clause is applied after all filtering and joins are resolved.
|
| 330 |
+
8. A SELECT statement should typically be ordered by a single distance calculation to perform one primary KNN search. However, subqueries can perform their own independent KNN searches, each with its own WITH clause, distance calculation, and ORDER BY distance LIMIT N clause.
|
| 331 |
+
|
| 332 |
+
## Example of a ClickHouse KNN Query
|
| 333 |
+
DB Schema: Some table on articles with a column content_embedding Array(Float32).
|
| 334 |
+
Query Task: Identify the article ID of the single most relevant article discussing innovative algorithms in graph theory.
|
| 335 |
+
Generated SQL:
|
| 336 |
+
```sql
|
| 337 |
+
WITH\n lembed('intfloat/E5-Mistral-7B-Instruct', 'innovative algorithms in graph theory.') AS ref_vec_0\n\nSELECT id, L2Distance(articles.abstract_embedding, ref_vec_0) AS distance\nFROM articles\nORDER BY distance\nLIMIT 1;
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
[EMBEDDING MODEL NAME]:
|
| 341 |
+
intfloat/E5-Mistral-7B-Instruct
|
| 342 |
+
|
| 343 |
+
## NATURAL LANGUAGE QUESTION
|
| 344 |
+
In vector searches, the `MATCH` operator performs an approximate nearest neighbor (ANN) search, which identifies items based on their similarity to a given vector. The `lembed()` function is used to convert text phrases into vector representations using a specific model, in this case, `intfloat/E5-Mistral-7B-Instruct`. This helps in finding items that align closely with the concept of "Popular menu items based on sales." The parameter `k = 5` specifies that only the top 5 categories, which are most similar in terms of the embedding, should be considered. The similarity is determined by calculating the Euclidean distance between vectors, where a smaller distance indicates higher similarity.
|
| 345 |
+
Can you unveil the crown jewel of our vegetarian delights, the one that has soared to the top of the sales charts from the elite circle of our most cherished categories this year?
|
| 346 |
+
|
| 347 |
+
Let's think step by step!
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
examples = [example]
|
| 351 |
+
|
| 352 |
+
# 6. 创建 Gradio 界面 (原第5步)
|
| 353 |
+
iface = gr.Interface(
|
| 354 |
+
fn=inference,
|
| 355 |
+
# 增加了行数以便容纳 Schema
|
| 356 |
+
inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"),
|
| 357 |
+
outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"),
|
| 358 |
+
title="UniVectorSQL-7B-LoRA 推理",
|
| 359 |
+
# description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。",
|
| 360 |
+
examples=examples # <--- 添加示例
|
| 361 |
+
)
|
| 362 |
+
# ---------------------
|
| 363 |
|
| 364 |
+
print("启动 Gradio 界面...")
|
| 365 |
+
# 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的
|
| 366 |
iface.launch()
|