CooLLaMACEO commited on
Commit
c828ba8
·
verified ·
1 Parent(s): 2257bba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -24
app.py CHANGED
@@ -3,15 +3,15 @@ from fastapi import FastAPI, Request, HTTPException, Depends
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
  from fastapi.responses import JSONResponse
6
- from llama_cpp import Llama
 
7
  import uvicorn
8
 
9
  # -------------------------------
10
  # FastAPI setup
11
  # -------------------------------
12
- app = FastAPI(title="ChatMPT API")
13
 
14
- # Enable CORS for all origins (adjust for production)
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -19,9 +19,8 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
22
- # API key security
23
  security = HTTPBearer()
24
- MY_API_KEY = os.environ.get("API_KEY", "my-secret-key-456") # can override with env variable
25
 
26
  def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
27
  if credentials.credentials != MY_API_KEY:
@@ -29,14 +28,18 @@ def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
29
  return credentials.credentials
30
 
31
  # -------------------------------
32
- # Load MPT-7B Model
33
  # -------------------------------
34
- llm = Llama(
35
- model_path="./mpt-7b-q2.gguf", # downloaded in Dockerfile
36
- n_ctx=2048,
37
- n_threads=4, # adjust for CPU cores
38
- n_gpu_layers=0 # force CPU (change if GPU available)
 
 
 
39
  )
 
40
 
41
  # -------------------------------
42
  # Chat Endpoint
@@ -46,23 +49,13 @@ async def chat(request: Request, _ = Depends(verify_token)):
46
  try:
47
  data = await request.json()
48
  user_input = data.get("prompt", "").strip()
49
-
50
  if not user_input:
51
  return JSONResponse(status_code=400, content={"error": "No prompt provided"})
52
 
53
- # Format for MPT-Chat
54
- prompt = f"<|im_start|>user\n{user_input}<|im_end|>\n<|im_start|>assistant\n"
55
-
56
  # Generate response
57
- output = llm(
58
- prompt,
59
- max_tokens=512,
60
- temperature=0.7,
61
- stop=["<|im_end|>", "<|im_start|>"],
62
- echo=False
63
- )
64
 
65
- reply = output["choices"][0]["text"].strip()
66
  return JSONResponse(content={"reply": reply})
67
 
68
  except Exception as e:
@@ -79,5 +72,5 @@ async def health():
79
  # Run app
80
  # -------------------------------
81
  if __name__ == "__main__":
82
- port = int(os.environ.get("PORT", 8080)) # Hugging Face Spaces sets PORT
83
  uvicorn.run(app, host="0.0.0.0", port=port)
 
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
5
  from fastapi.responses import JSONResponse
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
+ import torch
8
  import uvicorn
9
 
10
  # -------------------------------
11
  # FastAPI setup
12
  # -------------------------------
13
+ app = FastAPI(title="ChatMPT API (Transformers)")
14
 
 
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
 
19
  allow_headers=["*"],
20
  )
21
 
 
22
  security = HTTPBearer()
23
+ MY_API_KEY = os.environ.get("API_KEY", "my-secret-key-456")
24
 
25
  def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
26
  if credentials.credentials != MY_API_KEY:
 
28
  return credentials.credentials
29
 
30
  # -------------------------------
31
+ # Load model with Transformers
32
  # -------------------------------
33
+ MODEL_PATH = "./mpt-7b-q2.gguf" # path to downloaded model
34
+
35
+ print("Loading tokenizer and model...")
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ MODEL_PATH,
39
+ device_map="auto", # will use GPU if available, CPU otherwise
40
+ torch_dtype=torch.float16 # use float16 if possible for efficiency
41
  )
42
+ generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
43
 
44
  # -------------------------------
45
  # Chat Endpoint
 
49
  try:
50
  data = await request.json()
51
  user_input = data.get("prompt", "").strip()
 
52
  if not user_input:
53
  return JSONResponse(status_code=400, content={"error": "No prompt provided"})
54
 
 
 
 
55
  # Generate response
56
+ output = generator(user_input, do_sample=True, temperature=0.7)
57
+ reply = output[0]["generated_text"]
 
 
 
 
 
58
 
 
59
  return JSONResponse(content={"reply": reply})
60
 
61
  except Exception as e:
 
72
  # Run app
73
  # -------------------------------
74
  if __name__ == "__main__":
75
+ port = int(os.environ.get("PORT", 8080))
76
  uvicorn.run(app, host="0.0.0.0", port=port)