|
|
from fastapi import FastAPI, Header, HTTPException |
|
|
from pydantic import BaseModel |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
API_KEY = "sk-tinyllm-9f3a2c7e8b4d1a6c0e52f91d" |
|
|
|
|
|
|
|
|
MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_NAME, |
|
|
dtype=torch.float32 |
|
|
) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Prompt(BaseModel): |
|
|
message: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_api_key(authorization: Optional[str]): |
|
|
if authorization is None: |
|
|
raise HTTPException(status_code=401, detail="Missing API key") |
|
|
|
|
|
if not authorization.startswith("Bearer "): |
|
|
raise HTTPException(status_code=401, detail="Invalid API key format") |
|
|
|
|
|
token = authorization.replace("Bearer ", "").strip() |
|
|
if token != API_KEY: |
|
|
raise HTTPException(status_code=401, detail="Invalid API key") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"status": "TinyLLM RAG NLP API running"} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/chat") |
|
|
def chat( |
|
|
prompt: Prompt, |
|
|
authorization: Optional[str] = Header(None) |
|
|
): |
|
|
check_api_key(authorization) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt.message |
|
|
} |
|
|
] |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
return_tensors="pt", |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output_ids = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=220, |
|
|
temperature=0.0, |
|
|
top_p=0.7, |
|
|
top_k=20, |
|
|
do_sample=False, |
|
|
repetition_penalty=1.1, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = tokenizer.decode( |
|
|
output_ids[0][input_ids.shape[-1]:], |
|
|
skip_special_tokens=True |
|
|
).strip() |
|
|
|
|
|
return { |
|
|
"response": response |
|
|
} |
|
|
|