Batrdj commited on
Commit
72d67bb
·
verified ·
1 Parent(s): f7cfbba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -22
app.py CHANGED
@@ -1,52 +1,77 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
 
5
 
6
  app = FastAPI()
7
 
8
- # Ultra-tiny model (SAFE for free CPU)
9
- MODEL_NAME = "sshleifer/tiny-gpt2"
 
 
 
10
 
11
- # Load tokenizer & model once at startup
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_NAME,
15
- torch_dtype=torch.float32
 
16
  )
17
  model.eval()
18
 
19
- # Request schema
20
  class Prompt(BaseModel):
21
  message: str
22
 
23
- # Health check
 
 
 
 
 
 
 
 
 
 
 
24
  @app.get("/")
25
  def root():
26
  return {"status": "TinyLLM API is running"}
27
 
28
- # Chat endpoint
29
  @app.post("/chat")
30
- def chat(prompt: Prompt):
31
- inputs = tokenizer(
32
- prompt.message,
33
- return_tensors="pt",
34
- truncation=True,
35
- max_length=128
 
 
 
 
 
 
 
 
36
  )
37
 
38
  with torch.no_grad():
39
- outputs = model.generate(
40
- **inputs,
41
- max_new_tokens=50,
42
- do_sample=True,
43
  temperature=0.7,
44
- top_p=0.9
 
 
45
  )
46
 
47
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
48
 
49
  return {
50
- "input": prompt.message,
51
- "response": response
52
  }
 
1
+ from fastapi import FastAPI, Header, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
+ from typing import Optional
6
 
7
  app = FastAPI()
8
 
9
+ # 🔐 CHANGE THIS TO YOUR REAL SECRET KEY
10
+ API_KEY = "sk-tinyllm-9f3a2c7e8b4d1a6c0e52f91d"
11
+
12
+ # ✅ Best FREE CPU chat model
13
+ MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat"
14
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_NAME,
18
+ torch_dtype=torch.float32,
19
+ device_map="cpu"
20
  )
21
  model.eval()
22
 
 
23
  class Prompt(BaseModel):
24
  message: str
25
 
26
+ # 🔐 API KEY CHECK (OpenAI style)
27
+ def check_api_key(authorization: Optional[str]):
28
+ if authorization is None:
29
+ raise HTTPException(status_code=401, detail="Missing API key")
30
+
31
+ if not authorization.startswith("Bearer "):
32
+ raise HTTPException(status_code=401, detail="Invalid API key format")
33
+
34
+ token = authorization.replace("Bearer ", "").strip()
35
+ if token != API_KEY:
36
+ raise HTTPException(status_code=401, detail="Invalid API key")
37
+
38
  @app.get("/")
39
  def root():
40
  return {"status": "TinyLLM API is running"}
41
 
 
42
  @app.post("/chat")
43
+ def chat(
44
+ prompt: Prompt,
45
+ authorization: Optional[str] = Header(None)
46
+ ):
47
+ check_api_key(authorization)
48
+
49
+ messages = [
50
+ {"role": "system", "content": "You are a helpful AI assistant."},
51
+ {"role": "user", "content": prompt.message}
52
+ ]
53
+
54
+ input_ids = tokenizer.apply_chat_template(
55
+ messages,
56
+ return_tensors="pt"
57
  )
58
 
59
  with torch.no_grad():
60
+ output_ids = model.generate(
61
+ input_ids,
62
+ max_new_tokens=200,
 
63
  temperature=0.7,
64
+ top_p=0.9,
65
+ repetition_penalty=1.1,
66
+ do_sample=True
67
  )
68
 
69
+ response = tokenizer.decode(
70
+ output_ids[0][input_ids.shape[-1]:],
71
+ skip_special_tokens=True
72
+ )
73
 
74
  return {
75
+ "model": MODEL_NAME,
76
+ "response": response.strip()
77
  }