JDVariadic commited on
Commit
16fe227
·
1 Parent(s): 467a232

set max_length and top_k limits

Browse files
Files changed (1) hide show
  1. main.py +8 -5
main.py CHANGED
@@ -6,10 +6,13 @@ import torch
6
  #Credits to https://www.kaggle.com/datasets/fabiochiusano/medium-articles for the dataset
7
 
8
  app = FastAPI()
9
-
10
- async def generate_text(title, max_length=400, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
11
- if max_length > 400:
 
12
  return {"error": "Limit for max_length is 400 tokens."}
 
 
13
  model = AutoModelForCausalLM.from_pretrained(model_dir)
14
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
15
  input_text = f"[TITLE] {title} [/TITLE]"
@@ -27,8 +30,8 @@ async def generate_text(title, max_length=400, top_k=50, model_dir="./model/cust
27
 
28
  class RequestParams(BaseModel):
29
  title: str
30
- max_length: int = 400
31
- top_k: int = 50
32
 
33
  @app.post("/generate-article")
34
  async def handle_request(request: RequestParams):
 
6
  #Credits to https://www.kaggle.com/datasets/fabiochiusano/medium-articles for the dataset
7
 
8
  app = FastAPI()
9
+ TOKEN_LIMIT = 400
10
+ TOP_K_LIMIT = 50
11
+ async def generate_text(title, max_length=TOKEN_LIMIT, top_k=TOP_K_LIMIT, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
12
+ if max_length > TOKEN_LIMIT:
13
  return {"error": "Limit for max_length is 400 tokens."}
14
+ if top_k > TOP_K_LIMIT:
15
+ return {"error": "Limit for top_k is 50."}
16
  model = AutoModelForCausalLM.from_pretrained(model_dir)
17
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
18
  input_text = f"[TITLE] {title} [/TITLE]"
 
30
 
31
  class RequestParams(BaseModel):
32
  title: str
33
+ max_length: int = TOKEN_LIMIT
34
+ top_k: int = TOP_K_LIMIT
35
 
36
  @app.post("/generate-article")
37
  async def handle_request(request: RequestParams):