psychologyphd commited on
Commit
dfe8fbc
·
verified ·
1 Parent(s): 876125b

Delete how_to_use.ipynb

Browse files
Files changed (1) hide show
  1. how_to_use.ipynb +0 -244
how_to_use.ipynb DELETED
@@ -1,244 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 1,
6
- "id": "95945cf4",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "# Cell 1: 安装依赖 (如果还没装)\n",
11
- "# !pip install -q transformers torch accelerate"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 2,
17
- "id": "f19c9969",
18
- "metadata": {},
19
- "outputs": [
20
- {
21
- "name": "stdout",
22
- "output_type": "stream",
23
- "text": [
24
- "🚀 正在使用设备: mps\n",
25
- "⏳ 正在加载完整模型: psychologyphd/CodeLlama-7b-Text-to-SQL-mps-finetuned ...\n"
26
- ]
27
- },
28
- {
29
- "data": {
30
- "application/vnd.jupyter.widget-view+json": {
31
- "model_id": "2335b0f37e43423b85663aec1af82f0c",
32
- "version_major": 2,
33
- "version_minor": 0
34
- },
35
- "text/plain": [
36
- "config.json: 0%| | 0.00/726 [00:00<?, ?B/s]"
37
- ]
38
- },
39
- "metadata": {},
40
- "output_type": "display_data"
41
- },
42
- {
43
- "data": {
44
- "application/vnd.jupyter.widget-view+json": {
45
- "model_id": "de531a01c38e4f2f9220814189e2110e",
46
- "version_major": 2,
47
- "version_minor": 0
48
- },
49
- "text/plain": [
50
- "model.safetensors.index.json: 0.00B [00:00, ?B/s]"
51
- ]
52
- },
53
- "metadata": {},
54
- "output_type": "display_data"
55
- },
56
- {
57
- "data": {
58
- "application/vnd.jupyter.widget-view+json": {
59
- "model_id": "e1458e002c884af6bc17dc101b779880",
60
- "version_major": 2,
61
- "version_minor": 0
62
- },
63
- "text/plain": [
64
- "Downloading shards: 0%| | 0/3 [00:00<?, ?it/s]"
65
- ]
66
- },
67
- "metadata": {},
68
- "output_type": "display_data"
69
- },
70
- {
71
- "data": {
72
- "application/vnd.jupyter.widget-view+json": {
73
- "model_id": "aee236b369eb400ca5dc4604f0817926",
74
- "version_major": 2,
75
- "version_minor": 0
76
- },
77
- "text/plain": [
78
- "model-00001-of-00003.safetensors: 0%| | 0.00/4.94G [00:00<?, ?B/s]"
79
- ]
80
- },
81
- "metadata": {},
82
- "output_type": "display_data"
83
- },
84
- {
85
- "ename": "KeyboardInterrupt",
86
- "evalue": "",
87
- "output_type": "error",
88
- "traceback": [
89
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
90
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
91
- "Cell \u001b[0;32mIn[2], line 22\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m⏳ 正在加载完整模型: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m ...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# 3. 直接加载完整模型 (不需要 PeftModel,因为你已经 Merge 过了)\u001b[39;00m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;66;03m# 如果网络不好下载断了,可以加 force_download=True 重试,或者手动下载后填本地路径\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m model \u001b[38;5;241m=\u001b[39m AutoModelForCausalLM\u001b[38;5;241m.\u001b[39mfrom_pretrained(\n\u001b[1;32m 23\u001b[0m model_id,\n\u001b[1;32m 24\u001b[0m device_map\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m: device}, \n\u001b[1;32m 25\u001b[0m torch_dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat16 \u001b[38;5;28;01mif\u001b[39;00m device \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mfloat32,\n\u001b[1;32m 26\u001b[0m )\n\u001b[1;32m 28\u001b[0m \u001b[38;5;66;03m# 4. 加载 Tokenizer\u001b[39;00m\n\u001b[1;32m 29\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(model_id)\n",
92
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py:564\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(config) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 563\u001b[0m model_class \u001b[38;5;241m=\u001b[39m _get_model_class(config, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping)\n\u001b[0;32m--> 564\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model_class\u001b[38;5;241m.\u001b[39mfrom_pretrained(\n\u001b[1;32m 565\u001b[0m pretrained_model_name_or_path, \u001b[38;5;241m*\u001b[39mmodel_args, config\u001b[38;5;241m=\u001b[39mconfig, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mhub_kwargs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m 566\u001b[0m )\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 568\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 569\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(c\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 570\u001b[0m )\n",
93
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/modeling_utils.py:3944\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 3941\u001b[0m \u001b[38;5;66;03m# We'll need to download and cache each checkpoint shard if the checkpoint is sharded.\u001b[39;00m\n\u001b[1;32m 3942\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_sharded:\n\u001b[1;32m 3943\u001b[0m \u001b[38;5;66;03m# resolved_archive_file becomes a list of files that point to the different checkpoint shards in this case.\u001b[39;00m\n\u001b[0;32m-> 3944\u001b[0m resolved_archive_file, sharded_metadata \u001b[38;5;241m=\u001b[39m get_checkpoint_shard_files(\n\u001b[1;32m 3945\u001b[0m pretrained_model_name_or_path,\n\u001b[1;32m 3946\u001b[0m resolved_archive_file,\n\u001b[1;32m 3947\u001b[0m cache_dir\u001b[38;5;241m=\u001b[39mcache_dir,\n\u001b[1;32m 3948\u001b[0m force_download\u001b[38;5;241m=\u001b[39mforce_download,\n\u001b[1;32m 3949\u001b[0m proxies\u001b[38;5;241m=\u001b[39mproxies,\n\u001b[1;32m 3950\u001b[0m resume_download\u001b[38;5;241m=\u001b[39mresume_download,\n\u001b[1;32m 3951\u001b[0m local_files_only\u001b[38;5;241m=\u001b[39mlocal_files_only,\n\u001b[1;32m 3952\u001b[0m token\u001b[38;5;241m=\u001b[39mtoken,\n\u001b[1;32m 3953\u001b[0m user_agent\u001b[38;5;241m=\u001b[39muser_agent,\n\u001b[1;32m 3954\u001b[0m revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[1;32m 3955\u001b[0m subfolder\u001b[38;5;241m=\u001b[39msubfolder,\n\u001b[1;32m 3956\u001b[0m _commit_hash\u001b[38;5;241m=\u001b[39mcommit_hash,\n\u001b[1;32m 3957\u001b[0m )\n\u001b[1;32m 3959\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 3960\u001b[0m is_safetensors_available()\n\u001b[1;32m 3961\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(resolved_archive_file, \u001b[38;5;28mstr\u001b[39m)\n\u001b[1;32m 3962\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m resolved_archive_file\u001b[38;5;241m.\u001b[39mendswith(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.safetensors\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 3963\u001b[0m ):\n\u001b[1;32m 3964\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m safe_open(resolved_archive_file, framework\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n",
94
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/utils/hub.py:1098\u001b[0m, in \u001b[0;36mget_checkpoint_shard_files\u001b[0;34m(pretrained_model_name_or_path, index_filename, cache_dir, force_download, proxies, resume_download, local_files_only, token, user_agent, revision, subfolder, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_filename \u001b[38;5;129;01min\u001b[39;00m tqdm(shard_filenames, desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownloading shards\u001b[39m\u001b[38;5;124m\"\u001b[39m, disable\u001b[38;5;241m=\u001b[39m\u001b[38;5;129;01mnot\u001b[39;00m show_progress_bar):\n\u001b[1;32m 1096\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;66;03m# Load from URL\u001b[39;00m\n\u001b[0;32m-> 1098\u001b[0m cached_filename \u001b[38;5;241m=\u001b[39m cached_file(\n\u001b[1;32m 1099\u001b[0m pretrained_model_name_or_path,\n\u001b[1;32m 1100\u001b[0m shard_filename,\n\u001b[1;32m 1101\u001b[0m cache_dir\u001b[38;5;241m=\u001b[39mcache_dir,\n\u001b[1;32m 1102\u001b[0m force_download\u001b[38;5;241m=\u001b[39mforce_download,\n\u001b[1;32m 1103\u001b[0m proxies\u001b[38;5;241m=\u001b[39mproxies,\n\u001b[1;32m 1104\u001b[0m resume_download\u001b[38;5;241m=\u001b[39mresume_download,\n\u001b[1;32m 1105\u001b[0m local_files_only\u001b[38;5;241m=\u001b[39mlocal_files_only,\n\u001b[1;32m 1106\u001b[0m token\u001b[38;5;241m=\u001b[39mtoken,\n\u001b[1;32m 1107\u001b[0m user_agent\u001b[38;5;241m=\u001b[39muser_agent,\n\u001b[1;32m 1108\u001b[0m revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[1;32m 1109\u001b[0m subfolder\u001b[38;5;241m=\u001b[39msubfolder,\n\u001b[1;32m 1110\u001b[0m _commit_hash\u001b[38;5;241m=\u001b[39m_commit_hash,\n\u001b[1;32m 1111\u001b[0m )\n\u001b[1;32m 1112\u001b[0m \u001b[38;5;66;03m# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so\u001b[39;00m\n\u001b[1;32m 1113\u001b[0m \u001b[38;5;66;03m# we don't have to catch them here.\u001b[39;00m\n\u001b[1;32m 1114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m EntryNotFoundError:\n",
95
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/transformers/utils/hub.py:403\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m 400\u001b[0m user_agent \u001b[38;5;241m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m 401\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 402\u001b[0m \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 403\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m hf_hub_download(\n\u001b[1;32m 404\u001b[0m path_or_repo_id,\n\u001b[1;32m 405\u001b[0m filename,\n\u001b[1;32m 406\u001b[0m subfolder\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(subfolder) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m subfolder,\n\u001b[1;32m 407\u001b[0m repo_type\u001b[38;5;241m=\u001b[39mrepo_type,\n\u001b[1;32m 408\u001b[0m revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[1;32m 409\u001b[0m cache_dir\u001b[38;5;241m=\u001b[39mcache_dir,\n\u001b[1;32m 410\u001b[0m user_agent\u001b[38;5;241m=\u001b[39muser_agent,\n\u001b[1;32m 411\u001b[0m force_download\u001b[38;5;241m=\u001b[39mforce_download,\n\u001b[1;32m 412\u001b[0m proxies\u001b[38;5;241m=\u001b[39mproxies,\n\u001b[1;32m 413\u001b[0m resume_download\u001b[38;5;241m=\u001b[39mresume_download,\n\u001b[1;32m 414\u001b[0m token\u001b[38;5;241m=\u001b[39mtoken,\n\u001b[1;32m 415\u001b[0m local_files_only\u001b[38;5;241m=\u001b[39mlocal_files_only,\n\u001b[1;32m 416\u001b[0m )\n\u001b[1;32m 417\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GatedRepoError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 418\u001b[0m resolved_file \u001b[38;5;241m=\u001b[39m _get_cache_file_to_return(path_or_repo_id, full_filename, cache_dir, revision)\n",
96
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/huggingface_hub-0.27.1-py3.8.egg/huggingface_hub/utils/_validators.py:114\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m 112\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 114\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fn(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
97
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/huggingface_hub-0.27.1-py3.8.egg/huggingface_hub/file_download.py:860\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, library_name, library_version, cache_dir, local_dir, user_agent, force_download, proxies, etag_timeout, token, local_files_only, headers, endpoint, resume_download, force_filename, local_dir_use_symlinks)\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _hf_hub_download_to_local_dir(\n\u001b[1;32m 841\u001b[0m \u001b[38;5;66;03m# Destination\u001b[39;00m\n\u001b[1;32m 842\u001b[0m local_dir\u001b[38;5;241m=\u001b[39mlocal_dir,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 857\u001b[0m local_files_only\u001b[38;5;241m=\u001b[39mlocal_files_only,\n\u001b[1;32m 858\u001b[0m )\n\u001b[1;32m 859\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 860\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _hf_hub_download_to_cache_dir(\n\u001b[1;32m 861\u001b[0m \u001b[38;5;66;03m# Destination\u001b[39;00m\n\u001b[1;32m 862\u001b[0m cache_dir\u001b[38;5;241m=\u001b[39mcache_dir,\n\u001b[1;32m 863\u001b[0m \u001b[38;5;66;03m# File info\u001b[39;00m\n\u001b[1;32m 864\u001b[0m repo_id\u001b[38;5;241m=\u001b[39mrepo_id,\n\u001b[1;32m 865\u001b[0m filename\u001b[38;5;241m=\u001b[39mfilename,\n\u001b[1;32m 866\u001b[0m repo_type\u001b[38;5;241m=\u001b[39mrepo_type,\n\u001b[1;32m 867\u001b[0m revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[1;32m 868\u001b[0m \u001b[38;5;66;03m# HTTP info\u001b[39;00m\n\u001b[1;32m 869\u001b[0m endpoint\u001b[38;5;241m=\u001b[39mendpoint,\n\u001b[1;32m 870\u001b[0m etag_timeout\u001b[38;5;241m=\u001b[39metag_timeout,\n\u001b[1;32m 871\u001b[0m headers\u001b[38;5;241m=\u001b[39mhf_headers,\n\u001b[1;32m 872\u001b[0m proxies\u001b[38;5;241m=\u001b[39mproxies,\n\u001b[1;32m 873\u001b[0m token\u001b[38;5;241m=\u001b[39mtoken,\n\u001b[1;32m 874\u001b[0m \u001b[38;5;66;03m# Additional options\u001b[39;00m\n\u001b[1;32m 875\u001b[0m local_files_only\u001b[38;5;241m=\u001b[39mlocal_files_only,\n\u001b[1;32m 876\u001b[0m force_download\u001b[38;5;241m=\u001b[39mforce_download,\n\u001b[1;32m 877\u001b[0m )\n",
98
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/huggingface_hub-0.27.1-py3.8.egg/huggingface_hub/file_download.py:1009\u001b[0m, in \u001b[0;36m_hf_hub_download_to_cache_dir\u001b[0;34m(cache_dir, repo_id, filename, repo_type, revision, endpoint, etag_timeout, headers, proxies, token, local_files_only, force_download)\u001b[0m\n\u001b[1;32m 1007\u001b[0m Path(lock_path)\u001b[38;5;241m.\u001b[39mparent\u001b[38;5;241m.\u001b[39mmkdir(parents\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, exist_ok\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 1008\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m WeakFileLock(lock_path):\n\u001b[0;32m-> 1009\u001b[0m _download_to_tmp_and_move(\n\u001b[1;32m 1010\u001b[0m incomplete_path\u001b[38;5;241m=\u001b[39mPath(blob_path \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.incomplete\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 1011\u001b[0m destination_path\u001b[38;5;241m=\u001b[39mPath(blob_path),\n\u001b[1;32m 1012\u001b[0m url_to_download\u001b[38;5;241m=\u001b[39murl_to_download,\n\u001b[1;32m 1013\u001b[0m proxies\u001b[38;5;241m=\u001b[39mproxies,\n\u001b[1;32m 1014\u001b[0m headers\u001b[38;5;241m=\u001b[39mheaders,\n\u001b[1;32m 1015\u001b[0m expected_size\u001b[38;5;241m=\u001b[39mexpected_size,\n\u001b[1;32m 1016\u001b[0m filename\u001b[38;5;241m=\u001b[39mfilename,\n\u001b[1;32m 1017\u001b[0m force_download\u001b[38;5;241m=\u001b[39mforce_download,\n\u001b[1;32m 1018\u001b[0m )\n\u001b[1;32m 1019\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mexists(pointer_path):\n\u001b[1;32m 1020\u001b[0m _create_symlink(blob_path, pointer_path, new_blob\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
99
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/huggingface_hub-0.27.1-py3.8.egg/huggingface_hub/file_download.py:1543\u001b[0m, in \u001b[0;36m_download_to_tmp_and_move\u001b[0;34m(incomplete_path, destination_path, url_to_download, proxies, headers, expected_size, filename, force_download)\u001b[0m\n\u001b[1;32m 1540\u001b[0m _check_disk_space(expected_size, incomplete_path\u001b[38;5;241m.\u001b[39mparent)\n\u001b[1;32m 1541\u001b[0m _check_disk_space(expected_size, destination_path\u001b[38;5;241m.\u001b[39mparent)\n\u001b[0;32m-> 1543\u001b[0m http_get(\n\u001b[1;32m 1544\u001b[0m url_to_download,\n\u001b[1;32m 1545\u001b[0m f,\n\u001b[1;32m 1546\u001b[0m proxies\u001b[38;5;241m=\u001b[39mproxies,\n\u001b[1;32m 1547\u001b[0m resume_size\u001b[38;5;241m=\u001b[39mresume_size,\n\u001b[1;32m 1548\u001b[0m headers\u001b[38;5;241m=\u001b[39mheaders,\n\u001b[1;32m 1549\u001b[0m expected_size\u001b[38;5;241m=\u001b[39mexpected_size,\n\u001b[1;32m 1550\u001b[0m )\n\u001b[1;32m 1552\u001b[0m logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDownload complete. Moving file to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdestination_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1553\u001b[0m _chmod_and_move(incomplete_path, destination_path)\n",
100
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/huggingface_hub-0.27.1-py3.8.egg/huggingface_hub/file_download.py:452\u001b[0m, in \u001b[0;36mhttp_get\u001b[0;34m(url, temp_file, proxies, resume_size, headers, expected_size, displayed_filename, _nb_retries, _tqdm_bar)\u001b[0m\n\u001b[1;32m 450\u001b[0m new_resume_size \u001b[38;5;241m=\u001b[39m resume_size\n\u001b[1;32m 451\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 452\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m chunk \u001b[38;5;129;01min\u001b[39;00m r\u001b[38;5;241m.\u001b[39miter_content(chunk_size\u001b[38;5;241m=\u001b[39mconstants\u001b[38;5;241m.\u001b[39mDOWNLOAD_CHUNK_SIZE):\n\u001b[1;32m 453\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m chunk: \u001b[38;5;66;03m# filter out keep-alive new chunks\u001b[39;00m\n\u001b[1;32m 454\u001b[0m progress\u001b[38;5;241m.\u001b[39mupdate(\u001b[38;5;28mlen\u001b[39m(chunk))\n",
101
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/requests/models.py:820\u001b[0m, in \u001b[0;36mResponse.iter_content.<locals>.generate\u001b[0;34m()\u001b[0m\n\u001b[1;32m 818\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstream\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 819\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 820\u001b[0m \u001b[38;5;28;01myield from\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mraw\u001b[38;5;241m.\u001b[39mstream(chunk_size, decode_content\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m ProtocolError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 822\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m ChunkedEncodingError(e)\n",
102
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/urllib3/response.py:1066\u001b[0m, in \u001b[0;36mHTTPResponse.stream\u001b[0;34m(self, amt, decode_content)\u001b[0m\n\u001b[1;32m 1064\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1065\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_fp_closed(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decoded_buffer) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 1066\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mread(amt\u001b[38;5;241m=\u001b[39mamt, decode_content\u001b[38;5;241m=\u001b[39mdecode_content)\n\u001b[1;32m 1068\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data:\n\u001b[1;32m 1069\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m data\n",
103
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/urllib3/response.py:955\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt, decode_content, cache_content)\u001b[0m\n\u001b[1;32m 952\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decoded_buffer) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m amt:\n\u001b[1;32m 953\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decoded_buffer\u001b[38;5;241m.\u001b[39mget(amt)\n\u001b[0;32m--> 955\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_raw_read(amt)\n\u001b[1;32m 957\u001b[0m flush_decoder \u001b[38;5;241m=\u001b[39m amt \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m (amt \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m data)\n\u001b[1;32m 959\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m data \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_decoded_buffer) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
104
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/urllib3/response.py:879\u001b[0m, in \u001b[0;36mHTTPResponse._raw_read\u001b[0;34m(self, amt, read1)\u001b[0m\n\u001b[1;32m 876\u001b[0m fp_closed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclosed\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 878\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_error_catcher():\n\u001b[0;32m--> 879\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp_read(amt, read1\u001b[38;5;241m=\u001b[39mread1) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m fp_closed \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124mb\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 880\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m amt \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m amt \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m data:\n\u001b[1;32m 881\u001b[0m \u001b[38;5;66;03m# Platform-specific: Buggy versions of Python.\u001b[39;00m\n\u001b[1;32m 882\u001b[0m \u001b[38;5;66;03m# Close the connection when no data is returned\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 887\u001b[0m \u001b[38;5;66;03m# not properly close the connection in all cases. There is\u001b[39;00m\n\u001b[1;32m 888\u001b[0m \u001b[38;5;66;03m# no harm in redundantly calling close.\u001b[39;00m\n\u001b[1;32m 889\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp\u001b[38;5;241m.\u001b[39mclose()\n",
105
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/site-packages/urllib3/response.py:862\u001b[0m, in \u001b[0;36mHTTPResponse._fp_read\u001b[0;34m(self, amt, read1)\u001b[0m\n\u001b[1;32m 859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp\u001b[38;5;241m.\u001b[39mread1(amt) \u001b[38;5;28;01mif\u001b[39;00m amt \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp\u001b[38;5;241m.\u001b[39mread1()\n\u001b[1;32m 860\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 861\u001b[0m \u001b[38;5;66;03m# StringIO doesn't like amt=None\u001b[39;00m\n\u001b[0;32m--> 862\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp\u001b[38;5;241m.\u001b[39mread(amt) \u001b[38;5;28;01mif\u001b[39;00m amt \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fp\u001b[38;5;241m.\u001b[39mread()\n",
106
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/http/client.py:479\u001b[0m, in \u001b[0;36mHTTPResponse.read\u001b[0;34m(self, amt)\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlength \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m amt \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlength:\n\u001b[1;32m 477\u001b[0m \u001b[38;5;66;03m# clip the read to the \"end of response\"\u001b[39;00m\n\u001b[1;32m 478\u001b[0m amt \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlength\n\u001b[0;32m--> 479\u001b[0m s \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfp\u001b[38;5;241m.\u001b[39mread(amt)\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m s \u001b[38;5;129;01mand\u001b[39;00m amt:\n\u001b[1;32m 481\u001b[0m \u001b[38;5;66;03m# Ideally, we would raise IncompleteRead if the content-length\u001b[39;00m\n\u001b[1;32m 482\u001b[0m \u001b[38;5;66;03m# wasn't satisfied, but it might break compatibility.\u001b[39;00m\n\u001b[1;32m 483\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_close_conn()\n",
107
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/socket.py:707\u001b[0m, in \u001b[0;36mSocketIO.readinto\u001b[0;34m(self, b)\u001b[0m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[1;32m 706\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 707\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sock\u001b[38;5;241m.\u001b[39mrecv_into(b)\n\u001b[1;32m 708\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m timeout:\n\u001b[1;32m 709\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_timeout_occurred \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
108
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/ssl.py:1252\u001b[0m, in \u001b[0;36mSSLSocket.recv_into\u001b[0;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[1;32m 1248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flags \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1249\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 1250\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-zero flags not allowed in calls to recv_into() on \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m\n\u001b[1;32m 1251\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m)\n\u001b[0;32m-> 1252\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mread(nbytes, buffer)\n\u001b[1;32m 1253\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1254\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39mrecv_into(buffer, nbytes, flags)\n",
109
- "File \u001b[0;32m~/miniconda3/envs/py312/lib/python3.12/ssl.py:1104\u001b[0m, in \u001b[0;36mSSLSocket.read\u001b[0;34m(self, len, buffer)\u001b[0m\n\u001b[1;32m 1102\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m buffer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m-> 1104\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sslobj\u001b[38;5;241m.\u001b[39mread(\u001b[38;5;28mlen\u001b[39m, buffer)\n\u001b[1;32m 1105\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sslobj\u001b[38;5;241m.\u001b[39mread(\u001b[38;5;28mlen\u001b[39m)\n",
110
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
111
- ]
112
- }
113
- ],
114
- "source": [
115
- "# Cell 2: 加载你的模型 (Merged Model)\n",
116
- "\n",
117
- "import torch\n",
118
- "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
119
- "\n",
120
- "# 1. 自动检测设备\n",
121
- "if torch.cuda.is_available():\n",
122
- " device = \"cuda\"\n",
123
- "elif torch.backends.mps.is_available():\n",
124
- " device = \"mps\"\n",
125
- "else:\n",
126
- " device = \"cpu\"\n",
127
- "print(f\"🚀 正在使用设备: {device}\")\n",
128
- "\n",
129
- "# 2. 你的 Hugging Face 模型 ID (已更新)\n",
130
- "model_id = \"psychologyphd/CodeLlama-7b-Text-to-SQL-mps-finetuned\"\n",
131
- "\n",
132
- "print(f\"⏳ 正在加载完整模型: {model_id} ...\")\n",
133
- "\n",
134
- "# 3. 直接加载完整模型 (不需要 PeftModel,因为你已经 Merge 过了)\n",
135
- "# 如果网络不好下载断了,可以加 force_download=True 重试,或者手动下载后填本地路径\n",
136
- "model = AutoModelForCausalLM.from_pretrained(\n",
137
- " model_id,\n",
138
- " device_map={\"\": device}, \n",
139
- " torch_dtype=torch.float16 if device != \"cpu\" else torch.float32,\n",
140
- ")\n",
141
- "\n",
142
- "# 4. 加载 Tokenizer\n",
143
- "tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
144
- "tokenizer.padding_side = 'right'\n",
145
- "if tokenizer.pad_token is None:\n",
146
- " tokenizer.pad_token = tokenizer.eos_token\n",
147
- "\n",
148
- "# =====================================================\n",
149
- "# 🛠️ 必须手动注入 Chat Template 🛠️\n",
150
- "# 因为 Base Model 没有自带这个,不加这个效果会很差\n",
151
- "# =====================================================\n",
152
- "chat_template = \"\"\"{% for message in messages %}\n",
153
- "{%- if message['role'] == 'system' -%}\n",
154
- "### System: {{ message['content'] }}\n",
155
- "{%- elif message['role'] == 'user' -%}\n",
156
- "### Human: {{ message['content'] }}\n",
157
- "{%- else -%}\n",
158
- "### Assistant: {{ message['content'] }}\n",
159
- "{%- endif %}\n",
160
- "{% endfor %}\"\"\"\n",
161
- "\n",
162
- "tokenizer.chat_template = chat_template\n",
163
- "# =====================================================\n",
164
- "\n",
165
- "print(\"✅ 模型加载完成!\")"
166
- ]
167
- },
168
- {
169
- "cell_type": "code",
170
- "execution_count": null,
171
- "id": "0ab4bec1",
172
- "metadata": {},
173
- "outputs": [],
174
- "source": [
175
- "# Cell 3: 定义推理函数\n",
176
- "\n",
177
- "def generate_sql(question, schema):\n",
178
- " \n",
179
- " # 构造符合训练格式的输入\n",
180
- " messages = [\n",
181
- " {\n",
182
- " \"role\": \"system\", \n",
183
- " \"content\": f\"You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.\\nSCHEMA:\\n{schema}\"\n",
184
- " },\n",
185
- " {\n",
186
- " \"role\": \"user\", \n",
187
- " \"content\": question\n",
188
- " }\n",
189
- " ]\n",
190
- " \n",
191
- " # 应用模板\n",
192
- " prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
193
- " \n",
194
- " print(f\"\\n🔍 Prompt:\\n{prompt.strip()}\\n{'='*20}\")\n",
195
- " \n",
196
- " inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
197
- " \n",
198
- " with torch.no_grad():\n",
199
- " outputs = model.generate(\n",
200
- " **inputs,\n",
201
- " max_new_tokens=100, \n",
202
- " do_sample=False, # 贪婪搜索 (Greedy Search)\n",
203
- " pad_token_id=tokenizer.pad_token_id,\n",
204
- " eos_token_id=tokenizer.eos_token_id\n",
205
- " )\n",
206
- " \n",
207
- " full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
208
- " # 提取 \"### Assistant: \" 之后的内容\n",
209
- " sql_result = full_text[len(prompt):].strip()\n",
210
- " \n",
211
- " return sql_result\n",
212
- "\n",
213
- "# === 测试 ===\n",
214
- "print(\"\\n🤖 开始测试...\")\n",
215
- "test_schema = \"CREATE TABLE employees (id INTEGER, name VARCHAR, department VARCHAR, salary INTEGER)\"\n",
216
- "test_question = \"Find the names of employees in the 'Sales' department who earn more than 50000.\"\n",
217
- "\n",
218
- "sql = generate_sql(test_question, test_schema)\n",
219
- "print(f\"\\n💡 生成的 SQL:\\n{sql}\")"
220
- ]
221
- }
222
- ],
223
- "metadata": {
224
- "kernelspec": {
225
- "display_name": "Python 3.12.2",
226
- "language": "python",
227
- "name": "py312"
228
- },
229
- "language_info": {
230
- "codemirror_mode": {
231
- "name": "ipython",
232
- "version": 3
233
- },
234
- "file_extension": ".py",
235
- "mimetype": "text/x-python",
236
- "name": "python",
237
- "nbconvert_exporter": "python",
238
- "pygments_lexer": "ipython3",
239
- "version": "3.12.2"
240
- }
241
- },
242
- "nbformat": 4,
243
- "nbformat_minor": 5
244
- }