Spaces:
Sleeping
Sleeping
File size: 1,135 Bytes
73dae50 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import requests
app = FastAPI(title="LuxAI GPT-2 Backend")
# ===== GPT-2 =====
MODEL_NAME = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.eval()
# ===== Request =====
class GenerateRequest(BaseModel):
user_input: str
model: str = "gpt2"
@app.post("/generate")
def generate(req: GenerateRequest):
if req.model != "gpt2":
raise HTTPException(400, "Tento backend podporuje pouze gpt2")
prompt = f"User: {req.user_input}\nBot:"
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.8,
do_sample=True,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
response = text[len(prompt):].strip()
return {"response": response}
|