JDVariadic commited on
Commit
9ad5796
·
0 Parent(s):

add main api file

Browse files
Files changed (1) hide show
  1. main.py +43 -0
main.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Here’s the exam:
3
+ 1. Select a Causal language Model
4
+ 2. ⁠You can freely train/fine-tune/or use it outside the box into what use-case you prefer
5
+ 3. ⁠Deploy that to heroku, render, or any free deployment platforms (free only) using Fast API.
6
+ 4. ⁠Must be able to do post requests remotely.
7
+ 5. Upload it to github with a short readme on how to install and infer on your endpoint
8
+ """
9
+ from fastapi import FastAPI
10
+ from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, BitsAndBytesConfig
11
+ from pydantic import BaseModel
12
+ import torch
13
+
14
+ #Credits to https://www.kaggle.com/datasets/fabiochiusano/medium-articles for the dataset
15
+
16
+ app = FastAPI()
17
+
18
+ async def generate_text(title, max_length=1000, top_k=50, model_dir="./model/custom-gpt2-model", tokenizer_dir="./model/custom-gpt2-tokenizer"):
19
+ model = AutoModelForCausalLM.from_pretrained(model_dir)
20
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
21
+ input_text = f"[TITLE] {title} [/TITLE]"
22
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
23
+ with torch.no_grad():
24
+ output_sequences = model.generate(
25
+ input_ids=input_ids,
26
+ pad_token_id=tokenizer.pad_token_id,
27
+ max_length=max_length,
28
+ do_sample=True,
29
+ top_k=top_k,
30
+ early_stopping=True,
31
+ )
32
+ generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
33
+ return generated_text
34
+
35
+ class RequestParams(BaseModel):
36
+ title: str
37
+ max_length: int = 1000
38
+ top_k: int = 50
39
+
40
+ @app.post("/generate-article")
41
+ async def handle_request(request: RequestParams):
42
+ generated_article = await generate_text(request.title, request.max_length, request.top_k)
43
+ return {"generated_article": generated_article}