zhu-mingye commited on
Commit
c2293c5
·
1 Parent(s): a073e71

Switch to CodeT5+ (codet5p-220m) model

Browse files
Files changed (1) hide show
  1. app.py +8 -9
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # 加载 CodeGen 模型(GPT-2架构,兼容性好)
6
- model_name = "Salesforce/codegen-350M-mono"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name)
9
 
10
  def generate_code(prompt: str, max_length: int = 128) -> str:
11
  """代码生成/补全"""
@@ -19,8 +19,7 @@ def generate_code(prompt: str, max_length: int = 128) -> str:
19
  **inputs,
20
  max_length=max_length,
21
  num_beams=4,
22
- early_stopping=True,
23
- pad_token_id=tokenizer.eos_token_id
24
  )
25
 
26
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -37,8 +36,8 @@ demo = gr.Interface(
37
  gr.Slider(32, 512, value=128, step=32, label="Max Length")
38
  ],
39
  outputs=gr.Textbox(label="Generated Code", lines=10),
40
- title="CodeGen Code Generation",
41
- description="基于 Salesforce CodeGen (350M) 的代码生成模型。支持代码补全、代码生成等任务。",
42
  examples=[
43
  ["def fibonacci(n):", 128],
44
  ["# Python function to calculate factorial", 128],
@@ -47,4 +46,4 @@ demo = gr.Interface(
47
  )
48
 
49
  if __name__ == "__main__":
50
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
+ # 加载 CodeT5+ 模型
6
+ model_name = "Salesforce/codet5p-220m"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
  def generate_code(prompt: str, max_length: int = 128) -> str:
11
  """代码生成/补全"""
 
19
  **inputs,
20
  max_length=max_length,
21
  num_beams=4,
22
+ early_stopping=True
 
23
  )
24
 
25
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
36
  gr.Slider(32, 512, value=128, step=32, label="Max Length")
37
  ],
38
  outputs=gr.Textbox(label="Generated Code", lines=10),
39
+ title="CodeT5+ Code Generation",
40
+ description="基于 Salesforce CodeT5+ (220M) 的代码生成模型。支持代码补全、代码生成等任务。",
41
  examples=[
42
  ["def fibonacci(n):", 128],
43
  ["# Python function to calculate factorial", 128],
 
46
  )
47
 
48
  if __name__ == "__main__":
49
+ demo.launch()