gpt2-backend / app.py
luxopes's picture
Create app.py
73dae50 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
app = FastAPI(title="LuxAI GPT-2 Backend")
# ===== GPT-2 =====
MODEL_NAME = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.eval()
# ===== Request =====
class GenerateRequest(BaseModel):
user_input: str
model: str = "gpt2"
@app.post("/generate")
def generate(req: GenerateRequest):
if req.model != "gpt2":
raise HTTPException(400, "Tento backend podporuje pouze gpt2")
prompt = f"User: {req.user_input}\nBot:"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.8,
do_sample=True,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
response = text[len(prompt):].strip()
return {"response": response}