Spaces:
Configuration error
Configuration error
Commit ·
467a232
1
Parent(s): b819dd2
add token limit
Browse files
main.py
CHANGED
|
@@ -7,7 +7,9 @@ import torch
|
|
| 7 |
|
| 8 |
app = FastAPI()
|
| 9 |
|
| 10 |
-
async def generate_text(title, max_length=
|
|
|
|
|
|
|
| 11 |
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
| 13 |
input_text = f"[TITLE] {title} [/TITLE]"
|
|
@@ -25,7 +27,7 @@ async def generate_text(title, max_length=1000, top_k=50, model_dir="./model/cus
|
|
| 25 |
|
| 26 |
class RequestParams(BaseModel):
|
| 27 |
title: str
|
| 28 |
-
max_length: int =
|
| 29 |
top_k: int = 50
|
| 30 |
|
| 31 |
@app.post("/generate-article")
|
|
|
|
| 7 |
|
| 8 |
app = FastAPI()
|
| 9 |
|
| 10 |
+
async def generate_text(title, max_length=400, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
|
| 11 |
+
if max_length > 400:
|
| 12 |
+
return {"error": "Limit for max_length is 400 tokens."}
|
| 13 |
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
|
| 15 |
input_text = f"[TITLE] {title} [/TITLE]"
|
|
|
|
| 27 |
|
| 28 |
class RequestParams(BaseModel):
|
| 29 |
title: str
|
| 30 |
+
max_length: int = 400
|
| 31 |
top_k: int = 50
|
| 32 |
|
| 33 |
@app.post("/generate-article")
|