Spaces:
Configuration error
Configuration error
File size: 1,547 Bytes
9ad5796 86f55a2 9ad5796 16fe227 467a232 16fe227 9ad5796 16fe227 9ad5796 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
from pydantic import BaseModel
import torch
#Credits to https://www.kaggle.com/datasets/fabiochiusano/medium-articles for the dataset
app = FastAPI()
TOKEN_LIMIT = 400
TOP_K_LIMIT = 50
async def generate_text(title, max_length=TOKEN_LIMIT, top_k=TOP_K_LIMIT, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
if max_length > TOKEN_LIMIT:
return {"error": "Limit for max_length is 400 tokens."}
if top_k > TOP_K_LIMIT:
return {"error": "Limit for top_k is 50."}
model = AutoModelForCausalLM.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
input_text = f"[TITLE] {title} [/TITLE]"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
with torch.no_grad():
output_sequences = model.generate(
input_ids=input_ids,
pad_token_id=tokenizer.pad_token_id,
max_length=max_length,
do_sample=True,
top_k=top_k,
)
generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
return generated_text
class RequestParams(BaseModel):
title: str
max_length: int = TOKEN_LIMIT
top_k: int = TOP_K_LIMIT
@app.post("/generate-article")
async def handle_request(request: RequestParams):
generated_article = await generate_text(request.title, request.max_length, request.top_k)
return {"generated_article": generated_article} |