CooLLaMACEO commited on
Commit
3320b3e
·
verified ·
1 Parent(s): ad08817

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -42
app.py CHANGED
@@ -1,17 +1,13 @@
1
- import os
2
  from fastapi import FastAPI, Request, HTTPException, Depends
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
  from fastapi.responses import JSONResponse
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
- import torch
8
  import uvicorn
9
 
10
- # -------------------------------
11
- # FastAPI setup
12
- # -------------------------------
13
- app = FastAPI(title="ChatMPT API (Transformers)")
14
 
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -19,58 +15,47 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
 
22
  security = HTTPBearer()
23
- MY_API_KEY = os.environ.get("API_KEY", "my-secret-key-456")
 
 
 
 
 
 
 
 
24
 
25
  def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
26
  if credentials.credentials != MY_API_KEY:
27
  raise HTTPException(status_code=403, detail="Unauthorized")
28
  return credentials.credentials
29
 
30
- # -------------------------------
31
- # Load model with Transformers
32
- # -------------------------------
33
- MODEL_PATH = "./mpt-7b-q2.gguf" # path to downloaded model
34
-
35
- print("Loading tokenizer and model...")
36
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- MODEL_PATH,
39
- device_map="auto", # will use GPU if available, CPU otherwise
40
- torch_dtype=torch.float16 # use float16 if possible for efficiency
41
- )
42
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
43
-
44
- # -------------------------------
45
- # Chat Endpoint
46
- # -------------------------------
47
  @app.post("/v1/chat")
48
  async def chat(request: Request, _ = Depends(verify_token)):
49
  try:
50
  data = await request.json()
51
- user_input = data.get("prompt", "").strip()
52
- if not user_input:
53
  return JSONResponse(status_code=400, content={"error": "No prompt provided"})
54
 
55
- # Generate response
56
- output = generator(user_input, do_sample=True, temperature=0.7)
57
- reply = output[0]["generated_text"]
 
 
 
 
 
 
 
58
 
 
59
  return JSONResponse(content={"reply": reply})
60
 
61
  except Exception as e:
62
  return JSONResponse(status_code=500, content={"error": str(e)})
63
 
64
- # -------------------------------
65
- # Health Check
66
- # -------------------------------
67
- @app.get("/health")
68
- async def health():
69
- return {"status": "ok"}
70
-
71
- # -------------------------------
72
- # Run app
73
- # -------------------------------
74
  if __name__ == "__main__":
75
- port = int(os.environ.get("PORT", 8080))
76
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
1
  from fastapi import FastAPI, Request, HTTPException, Depends
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4
  from fastapi.responses import JSONResponse
5
+ from llama_cpp import Llama
 
6
  import uvicorn
7
 
8
+ app = FastAPI()
 
 
 
9
 
10
+ # Allow all origins (for frontend access)
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
 
15
  allow_headers=["*"],
16
  )
17
 
18
+ # Simple API key auth
19
  security = HTTPBearer()
20
+ MY_API_KEY = "my-secret-key-456"
21
+
22
+ # Load GGUF model (CPU only, small threads for Spaces)
23
+ llm = Llama(
24
+ model_path="./mpt-7b-chat.gguf", # Make sure this is a tokenizer-included GGUF
25
+ n_ctx=2048,
26
+ n_threads=2, # Reduce for free tier
27
+ n_gpu_layers=0 # Force CPU
28
+ )
29
 
30
  def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
31
  if credentials.credentials != MY_API_KEY:
32
  raise HTTPException(status_code=403, detail="Unauthorized")
33
  return credentials.credentials
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  @app.post("/v1/chat")
36
  async def chat(request: Request, _ = Depends(verify_token)):
37
  try:
38
  data = await request.json()
39
+ user_prompt = data.get("prompt", "").strip()
40
+ if not user_prompt:
41
  return JSONResponse(status_code=400, content={"error": "No prompt provided"})
42
 
43
+ # MPT chat format
44
+ prompt = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
45
+
46
+ output = llm(
47
+ prompt,
48
+ max_tokens=512,
49
+ temperature=0.7,
50
+ stop=["<|im_end|>", "<|im_start|>"],
51
+ echo=False
52
+ )
53
 
54
+ reply = output["choices"][0]["text"].strip()
55
  return JSONResponse(content={"reply": reply})
56
 
57
  except Exception as e:
58
  return JSONResponse(status_code=500, content={"error": str(e)})
59
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
+ uvicorn.run(app, host="0.0.0.0", port=7860)