File size: 2,493 Bytes
72d67bb
454c6b3
 
f7cfbba
72d67bb
454c6b3
286d07a
 
 
454c6b3
 
433f86f
72d67bb
 
5a79cf4
 
454c6b3
 
f7cfbba
 
5a79cf4
f7cfbba
 
454c6b3
286d07a
 
 
454c6b3
 
 
286d07a
 
 
72d67bb
 
 
286d07a
72d67bb
 
 
 
 
 
 
286d07a
 
 
454c6b3
 
433f86f
454c6b3
286d07a
433f86f
286d07a
454c6b3
72d67bb
 
 
 
 
 
5a79cf4
 
 
433f86f
 
 
 
 
 
 
5a79cf4
 
 
 
 
f7cfbba
433f86f
72d67bb
433f86f
5a79cf4
 
 
 
 
72d67bb
433f86f
f7cfbba
 
72d67bb
433f86f
72d67bb
286d07a
f7cfbba
ed3a83e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from fastapi import FastAPI, Header, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import Optional

# -----------------------
# App
# -----------------------
app = FastAPI()

# πŸ” API KEY (keep same)
API_KEY = "sk-tinyllm-9f3a2c7e8b4d1a6c0e52f91d"

# βœ… Lightweight CPU model (NLP engine only)
MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.float32
)
model.eval()

# -----------------------
# Request schema
# -----------------------
class Prompt(BaseModel):
    message: str

# -----------------------
# API key verification
# -----------------------
def check_api_key(authorization: Optional[str]):
    if authorization is None:
        raise HTTPException(status_code=401, detail="Missing API key")

    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid API key format")

    token = authorization.replace("Bearer ", "").strip()
    if token != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API key")

# -----------------------
# Health check
# -----------------------
@app.get("/")
def root():
    return {"status": "TinyLLM RAG NLP API running"}

# -----------------------
# Chat endpoint (RAG-safe)
# -----------------------
@app.post("/chat")
def chat(
    prompt: Prompt,
    authorization: Optional[str] = Header(None)
):
    check_api_key(authorization)

    # 🚫 IMPORTANT:
    # DO NOT inject system identity here.
    # Your RAG prompt already contains ALL rules.
    messages = [
        {
            "role": "user",
            "content": prompt.message
        }
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True
    )

    with torch.no_grad():
        output_ids = model.generate(
            input_ids,
            max_new_tokens=220,        # controlled output
            temperature=0.0,           # πŸ”₯ NO hallucination
            top_p=0.7,
            top_k=20,
            do_sample=False,           # deterministic
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id
        )

    response = tokenizer.decode(
        output_ids[0][input_ids.shape[-1]:],
        skip_special_tokens=True
    ).strip()

    return {
        "response": response
    }