File size: 1,530 Bytes
f439658
 
 
 
 
 
 
 
 
 
 
 
 
c6f59a4
f439658
 
a40b465
04dec4e
f439658
 
 
 
 
c31212e
f439658
cee8daf
f439658
 
 
 
 
 
 
2c01a1c
f439658
 
 
 
 
 
 
 
 
 
 
 
 
311a3c8
f439658
a700966
f439658
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import numpy as np

assert np.__version__.startswith('1.'), f"Несовместимая версия NumPy: {np.__version__}"

app = FastAPI()


class RequestData(BaseModel):
    prompt: str
    max_tokens: int = 256


# MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
MODEL_NAME = "gpt2"

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float32,
        device_map="auto",
        low_cpu_mem_usage=True
    )

    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer
    )

except Exception as e:
    print(f"Ошибка загрузки модели: {str(e)}")
    generator = None


@app.post("/generate")
async def generate_text(request: RequestData):
    if not generator:
        raise HTTPException(status_code=503, detail="Модель не загружена")

    try:
        output = generator(
            request.prompt,
            max_new_tokens=min(request.max_tokens, 10000),
            do_sample=False,
            num_beams=1
        )
        return {"response": output[0]["generated_text"]}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
async def health_check():
    return {"status": "ok" if generator else "unavailable"}