CooLLaMACEO commited on
Commit
5ec5b09
·
verified ·
1 Parent(s): 33f722d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -11
app.py CHANGED
@@ -1,15 +1,27 @@
1
  import os
2
  import torch
3
- import gradio as gr
 
 
 
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
5
 
6
- # Point to the local folder created in the Dockerfile
7
  MODEL_PATH = "/app/model"
 
 
8
 
9
- print("Loading Overflow-111.7B from Local Docker Storage...")
 
 
10
 
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
12
 
 
 
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_PATH,
15
  trust_remote_code=True,
@@ -18,12 +30,45 @@ model = AutoModelForCausalLM.from_pretrained(
18
  low_cpu_mem_usage=True
19
  )
20
 
21
- def respond(message, history):
22
- inputs = tokenizer(message, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  with torch.no_grad():
24
- output_tokens = model.generate(**inputs, max_new_tokens=30)
25
- return tokenizer.decode(output_tokens[0], skip_special_tokens=True)
 
 
26
 
27
- demo = gr.ChatInterface(respond)
28
- if __name__ == "__main__":
29
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
  import torch
3
+ import secrets
4
+ import time
5
+ from fastapi import FastAPI, HTTPException, Security, Depends
6
+ from fastapi.security.api_key import APIKeyHeader
7
+ from pydantic import BaseModel
8
  from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from starlette.status import HTTP_403_FORBIDDEN
10
 
11
+ # --- CONFIGURATION ---
12
  MODEL_PATH = "/app/model"
13
+ API_KEY_NAME = "X-API-Key"
14
+ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
15
 
16
+ # In-memory storage for keys.
17
+ # Note: These will reset if the Space restarts unless you use Persistent Storage.
18
+ generated_keys = {}
19
 
20
+ app = FastAPI(title="Overflow-111.7B API Manager")
21
 
22
+ # --- MODEL LOADING ---
23
+ print("Loading Overflow-111.7B Engine...")
24
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  MODEL_PATH,
27
  trust_remote_code=True,
 
30
  low_cpu_mem_usage=True
31
  )
32
 
33
+ class Query(BaseModel):
34
+ prompt: str
35
+ max_tokens: int = 50
36
+
37
+ # --- API KEY GENERATION ---
38
+ @app.get("/api/generate")
39
+ async def create_new_key():
40
+ """Generates a new of_sk key for the user."""
41
+ # Generate a random string of 24 characters
42
+ random_hex = secrets.token_hex(12)
43
+ new_key = f"of_sk-{random_hex}"
44
+
45
+ # Store with a timestamp
46
+ generated_keys[new_key] = {"created_at": time.time()}
47
+
48
+ return {
49
+ "status": "success",
50
+ "api_key": new_key,
51
+ "instructions": f"Include this key in your request header as '{API_KEY_NAME}'"
52
+ }
53
+
54
+ # --- SECURITY CHECK ---
55
+ async def get_api_key(api_key_header: str = Depends(api_key_header)):
56
+ if api_key_header in generated_keys:
57
+ return api_key_header
58
+ raise HTTPException(
59
+ status_code=HTTP_403_FORBIDDEN,
60
+ detail="Invalid or expired API Key. Generate one at /api/generate"
61
+ )
62
+
63
+ @app.post("/v1/generate")
64
+ async def generate(query: Query, api_key: str = Depends(get_api_key)):
65
+ inputs = tokenizer(query.prompt, return_tensors="pt")
66
  with torch.no_grad():
67
+ output_tokens = model.generate(**inputs, max_new_tokens=query.max_tokens)
68
+
69
+ response = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
70
+ return {"text": response}
71
 
72
+ @app.get("/")
73
+ def home():
74
+ return {"message": "Welcome to Overflow-111.7B. Go to /api/generate to get a key."}