JDVariadic commited on
Commit
467a232
·
1 Parent(s): b819dd2

add token limit

Browse files
Files changed (1) hide show
  1. main.py +4 -2
main.py CHANGED
@@ -7,7 +7,9 @@ import torch
7
 
8
  app = FastAPI()
9
 
10
- async def generate_text(title, max_length=1000, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
 
 
11
  model = AutoModelForCausalLM.from_pretrained(model_dir)
12
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
13
  input_text = f"[TITLE] {title} [/TITLE]"
@@ -25,7 +27,7 @@ async def generate_text(title, max_length=1000, top_k=50, model_dir="./model/cus
25
 
26
  class RequestParams(BaseModel):
27
  title: str
28
- max_length: int = 1000
29
  top_k: int = 50
30
 
31
  @app.post("/generate-article")
 
7
 
8
  app = FastAPI()
9
 
10
+ async def generate_text(title, max_length=400, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
11
+ if max_length > 400:
12
+ return {"error": "Limit for max_length is 400 tokens."}
13
  model = AutoModelForCausalLM.from_pretrained(model_dir)
14
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
15
  input_text = f"[TITLE] {title} [/TITLE]"
 
27
 
28
  class RequestParams(BaseModel):
29
  title: str
30
+ max_length: int = 400
31
  top_k: int = 50
32
 
33
  @app.post("/generate-article")