from fastapi import FastAPI, Header, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM import torch from typing import Optional # ----------------------- # App # ----------------------- app = FastAPI() # 🔐 API KEY (keep same) API_KEY = "sk-tinyllm-9f3a2c7e8b4d1a6c0e52f91d" # ✅ Lightweight CPU model (NLP engine only) MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, dtype=torch.float32 ) model.eval() # ----------------------- # Request schema # ----------------------- class Prompt(BaseModel): message: str # ----------------------- # API key verification # ----------------------- def check_api_key(authorization: Optional[str]): if authorization is None: raise HTTPException(status_code=401, detail="Missing API key") if not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="Invalid API key format") token = authorization.replace("Bearer ", "").strip() if token != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") # ----------------------- # Health check # ----------------------- @app.get("/") def root(): return {"status": "TinyLLM RAG NLP API running"} # ----------------------- # Chat endpoint (RAG-safe) # ----------------------- @app.post("/chat") def chat( prompt: Prompt, authorization: Optional[str] = Header(None) ): check_api_key(authorization) # 🚫 IMPORTANT: # DO NOT inject system identity here. # Your RAG prompt already contains ALL rules. messages = [ { "role": "user", "content": prompt.message } ] input_ids = tokenizer.apply_chat_template( messages, return_tensors="pt", add_generation_prompt=True ) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=220, # controlled output temperature=0.0, # 🔥 NO hallucination top_p=0.7, top_k=20, do_sample=False, # deterministic repetition_penalty=1.1, eos_token_id=tokenizer.eos_token_id ) response = tokenizer.decode( output_ids[0][input_ids.shape[-1]:], skip_special_tokens=True ).strip() return { "response": response }