Sixparticle commited on
Commit
b51e859
·
1 Parent(s): a983386

Sanitize added_tokens before tokenizer load

Browse files
Files changed (2) hide show
  1. app.py +30 -6
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,15 +1,39 @@
1
  import gradio as gr
 
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
  # 加载 CodeT5+ 模型
6
  model_name = "Salesforce/codet5p-220m"
7
- try:
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)
9
- except TypeError:
10
- # Some tokenizer repos expose added_tokens metadata that breaks fast tokenizer init.
11
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, trust_remote_code=True)
12
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_code(prompt: str, max_length: int = 128) -> str:
15
  """代码生成/补全"""
 
1
  import gradio as gr
2
+ import json
3
+ import os
4
+ from huggingface_hub import snapshot_download
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
7
 
8
  # 加载 CodeT5+ 模型
9
  model_name = "Salesforce/codet5p-220m"
10
+
11
+
12
+ def prepare_local_model(repo_id: str, local_dir: str = "./model_cache") -> str:
13
+ snapshot_download(repo_id=repo_id, local_dir=local_dir)
14
+
15
+ added_tokens_file = os.path.join(local_dir, "added_tokens.json")
16
+ if os.path.exists(added_tokens_file):
17
+ with open(added_tokens_file, "r", encoding="utf-8") as f:
18
+ data = json.load(f)
19
+
20
+ # Ensure the file is a plain token list for compatibility with tokenizers.add_tokens.
21
+ if isinstance(data, dict):
22
+ normalized = list(data.keys())
23
+ elif isinstance(data, list):
24
+ normalized = [str(item) for item in data]
25
+ else:
26
+ normalized = []
27
+
28
+ with open(added_tokens_file, "w", encoding="utf-8") as f:
29
+ json.dump(normalized, f, ensure_ascii=False)
30
+
31
+ return local_dir
32
+
33
+
34
+ local_model_dir = prepare_local_model(model_name)
35
+ tokenizer = AutoTokenizer.from_pretrained(local_model_dir, use_fast=True, trust_remote_code=True)
36
+ model = AutoModelForSeq2SeqLM.from_pretrained(local_model_dir, trust_remote_code=True)
37
 
38
  def generate_code(prompt: str, max_length: int = 128) -> str:
39
  """代码生成/补全"""
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  transformers>=4.40.0
 
2
  torch>=2.0.0
3
  sentencepiece>=0.1.96
4
  accelerate>=0.20.0
 
1
  transformers>=4.40.0
2
+ huggingface_hub>=0.23.0
3
  torch>=2.0.0
4
  sentencepiece>=0.1.96
5
  accelerate>=0.20.0