Rdj1 / app.py
Batrdj's picture
Update app.py
5a79cf4 verified
from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Optional
# -----------------------
# App
# -----------------------
app = FastAPI()
# πŸ” API KEY (keep same)
API_KEY = "sk-tinyllm-9f3a2c7e8b4d1a6c0e52f91d"
# βœ… Lightweight CPU model (NLP engine only)
MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
dtype=torch.float32
)
model.eval()
# -----------------------
# Request schema
# -----------------------
class Prompt(BaseModel):
message: str
# -----------------------
# API key verification
# -----------------------
def check_api_key(authorization: Optional[str]):
if authorization is None:
raise HTTPException(status_code=401, detail="Missing API key")
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid API key format")
token = authorization.replace("Bearer ", "").strip()
if token != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
# -----------------------
# Health check
# -----------------------
@app.get("/")
def root():
return {"status": "TinyLLM RAG NLP API running"}
# -----------------------
# Chat endpoint (RAG-safe)
# -----------------------
@app.post("/chat")
def chat(
prompt: Prompt,
authorization: Optional[str] = Header(None)
):
check_api_key(authorization)
# 🚫 IMPORTANT:
# DO NOT inject system identity here.
# Your RAG prompt already contains ALL rules.
messages = [
{
"role": "user",
"content": prompt.message
}
]
input_ids = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=220, # controlled output
temperature=0.0, # πŸ”₯ NO hallucination
top_p=0.7,
top_k=20,
do_sample=False, # deterministic
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(
output_ids[0][input_ids.shape[-1]:],
skip_special_tokens=True
).strip()
return {
"response": response
}