JDVariadic commited on
Commit
ccd6516
·
1 Parent(s): 66a7d6c

change token limit of output

Browse files
Files changed (1) hide show
  1. main.py +2 -2
main.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
 
8
  app = FastAPI()
9
 
10
- async def generate_text(title, max_length=1000, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
11
  model = AutoModelForCausalLM.from_pretrained(model_dir)
12
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
13
  input_text = f"[TITLE] {title} [/TITLE]"
@@ -26,7 +26,7 @@ async def generate_text(title, max_length=1000, top_k=50, model_dir="./model/cus
26
 
27
  class RequestParams(BaseModel):
28
  title: str
29
- max_length: int = 1000
30
  top_k: int = 50
31
 
32
  @app.post("/generate-article")
 
7
 
8
  app = FastAPI()
9
 
10
+ async def generate_text(title, max_length=250, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
11
  model = AutoModelForCausalLM.from_pretrained(model_dir)
12
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
13
  input_text = f"[TITLE] {title} [/TITLE]"
 
26
 
27
  class RequestParams(BaseModel):
28
  title: str
29
+ max_length: int = 250
30
  top_k: int = 50
31
 
32
  @app.post("/generate-article")