from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch import requests app = FastAPI(title="LuxAI GPT-2 Backend") # ===== GPT-2 ===== MODEL_NAME = "distilgpt2" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) model.eval() # ===== Request ===== class GenerateRequest(BaseModel): user_input: str model: str = "gpt2" @app.post("/generate") def generate(req: GenerateRequest): if req.model != "gpt2": raise HTTPException(400, "Tento backend podporuje pouze gpt2") prompt = f"User: {req.user_input}\nBot:" inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=120, temperature=0.8, do_sample=True, top_p=0.95, pad_token_id=tokenizer.eos_token_id ) text = tokenizer.decode(output[0], skip_special_tokens=True) response = text[len(prompt):].strip() return {"response": response}