Spaces:
Configuration error
Configuration error
Commit ·
16fe227
1
Parent(s): 467a232
set max_length and top_k limits
Browse files
main.py
CHANGED
|
@@ -6,10 +6,13 @@ import torch
|
|
| 6 |
#Credits to https://www.kaggle.com/datasets/fabiochiusano/medium-articles for the dataset
|
| 7 |
|
| 8 |
app = FastAPI()
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
| 12 |
return {"error": "Limit for max_length is 400 tokens."}
|
|
|
|
|
|
|
| 13 |
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
| 15 |
input_text = f"[TITLE] {title} [/TITLE]"
|
|
@@ -27,8 +30,8 @@ async def generate_text(title, max_length=400, top_k=50, model_dir="./model/cust
|
|
| 27 |
|
| 28 |
class RequestParams(BaseModel):
|
| 29 |
title: str
|
| 30 |
-
max_length: int =
|
| 31 |
-
top_k: int =
|
| 32 |
|
| 33 |
@app.post("/generate-article")
|
| 34 |
async def handle_request(request: RequestParams):
|
|
|
|
| 6 |
#Credits to https://www.kaggle.com/datasets/fabiochiusano/medium-articles for the dataset
|
| 7 |
|
| 8 |
app = FastAPI()
|
| 9 |
+
TOKEN_LIMIT = 400
|
| 10 |
+
TOP_K_LIMIT = 50
|
| 11 |
+
async def generate_text(title, max_length=TOKEN_LIMIT, top_k=TOP_K_LIMIT, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
|
| 12 |
+
if max_length > TOKEN_LIMIT:
|
| 13 |
return {"error": "Limit for max_length is 400 tokens."}
|
| 14 |
+
if top_k > TOP_K_LIMIT:
|
| 15 |
+
return {"error": "Limit for top_k is 50."}
|
| 16 |
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
| 17 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
| 18 |
input_text = f"[TITLE] {title} [/TITLE]"
|
|
|
|
| 30 |
|
| 31 |
class RequestParams(BaseModel):
|
| 32 |
title: str
|
| 33 |
+
max_length: int = TOKEN_LIMIT
|
| 34 |
+
top_k: int = TOP_K_LIMIT
|
| 35 |
|
| 36 |
@app.post("/generate-article")
|
| 37 |
async def handle_request(request: RequestParams):
|