zhu-mingye commited on
Commit
769cc52
·
1 Parent(s): 7d46282

Switch to CodeT5+ for better transformers compatibility

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -2,10 +2,10 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
- # 加载 CodeT5 模型
6
- model_name = "Salesforce/codet5-base"
7
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
8
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
  def generate_code(prompt: str, max_length: int = 128) -> str:
11
  """代码生成/补全"""
@@ -36,8 +36,8 @@ demo = gr.Interface(
36
  gr.Slider(32, 512, value=128, step=32, label="Max Length")
37
  ],
38
  outputs=gr.Textbox(label="Generated Code", lines=10),
39
- title="CodeT5 Code Generation",
40
- description="基于 Salesforce CodeT5 的代码生成模型。支持代码补全、代码生成等任务。",
41
  examples=[
42
  ["def fibonacci(n):", 128],
43
  ["# Python function to calculate factorial", 128],
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
+ # 加载 CodeT5+ 模型(兼容性更好)
6
+ model_name = "Salesforce/codet5p-220m"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
9
 
10
  def generate_code(prompt: str, max_length: int = 128) -> str:
11
  """代码生成/补全"""
 
36
  gr.Slider(32, 512, value=128, step=32, label="Max Length")
37
  ],
38
  outputs=gr.Textbox(label="Generated Code", lines=10),
39
+ title="CodeT5+ Code Generation",
40
+ description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。",
41
  examples=[
42
  ["def fibonacci(n):", 128],
43
  ["# Python function to calculate factorial", 128],