rewgwrth's picture
ge3
04dec4e
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import numpy as np
assert np.__version__.startswith('1.'), f"Несовместимая версия NumPy: {np.__version__}"
app = FastAPI()
class RequestData(BaseModel):
prompt: str
max_tokens: int = 256
# MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODEL_NAME = "gpt2"
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
except Exception as e:
print(f"Ошибка загрузки модели: {str(e)}")
generator = None
@app.post("/generate")
async def generate_text(request: RequestData):
if not generator:
raise HTTPException(status_code=503, detail="Модель не загружена")
try:
output = generator(
request.prompt,
max_new_tokens=min(request.max_tokens, 10000),
do_sample=False,
num_beams=1
)
return {"response": output[0]["generated_text"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "ok" if generator else "unavailable"}