fugthchat commited on
Commit
5ae1757
·
verified ·
1 Parent(s): 14f464e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -63
app.py CHANGED
@@ -1,66 +1,168 @@
1
- from flask import Flask, request, jsonify
2
- from llama_cpp import Llama
3
  import os
4
-
5
- app = Flask(__name__)
6
-
7
- MODEL_URLS = {
8
- "light": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q3_K_S.gguf",
9
- "medium": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q4_K_M.gguf",
10
- "heavy": "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q5_0.gguf"
11
- }
12
-
13
- MODEL_PATHS = {
14
- k: f"{k}.gguf" for k in MODEL_URLS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
16
 
17
- current_model = None
18
- llm = None
19
-
20
- def ensure_model(model_choice):
21
- global llm, current_model
22
- model_path = MODEL_PATHS[model_choice]
23
- url = MODEL_URLS[model_choice]
24
-
25
- if not os.path.exists(model_path):
26
- print(f"Downloading {model_choice} model...")
27
- os.system(f"wget -O {model_path} {url}")
28
-
29
- if current_model != model_choice:
30
- print(f"Loading {model_choice} model...")
31
- llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, use_mlock=False)
32
- current_model = model_choice
33
- return llm
34
-
35
-
36
- @app.route("/status")
37
- def status():
38
- return jsonify({"status": "ok" if llm else "not_loaded", "model": current_model})
39
-
40
-
41
- @app.route("/generate", methods=["POST"])
42
- def generate():
43
- data = request.get_json(force=True)
44
- model_choice = data.get("model_choice", "light")
45
- prompt = data.get("prompt", "")
46
- story_memory = data.get("story_memory", "")
47
- feedback = data.get("feedback", "")
48
-
49
- llm = ensure_model(model_choice)
50
-
51
- full_prompt = story_memory + "\n\n" + prompt
52
- if feedback:
53
- full_prompt += f"\n\nUser feedback: {feedback}\n"
54
-
55
- result = llm(full_prompt, max_tokens=512, temperature=0.8)
56
- text = result["choices"][0]["text"].strip()
57
- return jsonify({"response": text})
58
-
59
-
60
- @app.route("/")
61
- def root():
62
- return "StableLM Zephyr GGUF API running!"
63
-
64
-
65
- if __name__ == "__main__":
66
- app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import JSONResponse
4
+ from pydantic import BaseModel
5
+ from llama_cpp import Llama
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from huggingface_hub import hf_hub_download
8
+ import logging
9
+ import threading
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+ # --- MODEL MAP ---
15
+ # This maps the "light", "medium", "heavy" keys from your frontend
16
+ # to the actual model files on Hugging Face.
17
+ MODEL_MAP = {
18
+ "light": {
19
+ "repo_id": "TheBloke/stablelm-zephyr-3b-GGUF",
20
+ "filename": "stablelm-zephyr-3b.Q3_K_S.gguf" # 1.25 GB
21
+ },
22
+ "medium": {
23
+ "repo_id": "TheBloke/stablelm-zephyr-3b-GGUF",
24
+ "filename": "stablelm-zephyr-3b.Q4_K_M.gguf" # 1.71 GB
25
+ },
26
+ "heavy": {
27
+ "repo_id": "TheBloke/stablelm-zephyr-3b-GGUF",
28
+ "filename": "stablelm-zephyr-3b.Q5_K_M.gguf" # 2.03 GB
29
+ }
30
  }
31
 
32
+ # --- GLOBAL CACHE & LOCK ---
33
+ llm_cache = {} # Caches loaded models
34
+ model_lock = threading.Lock() # Prevents two requests from using the model at once
35
+
36
+ app = FastAPI()
37
+
38
+ # --- CORS ---
39
+ app.add_middleware(
40
+ CORSMiddleware,
41
+ allow_origins=["*"], # Allows your GitHub Pages site to connect
42
+ allow_credentials=True,
43
+ allow_methods=["*"],
44
+ allow_headers=["*"],
45
+ )
46
+
47
+ # --- Helper Function to Load Model ---
48
+ def get_llm_instance(choice: str) -> Llama:
49
+ """
50
+ Loads a model based on the choice.
51
+ Uses hf_hub_download.
52
+ Caches the loaded model in memory.
53
+ """
54
+ if choice not in MODEL_MAP:
55
+ logging.error(f"Invalid model choice: {choice}")
56
+ return None
57
+
58
+ if choice in llm_cache:
59
+ logging.info(f"Using cached model: {choice}")
60
+ return llm_cache[choice]
61
+
62
+ model_info = MODEL_MAP[choice]
63
+ repo_id = model_info["repo_id"]
64
+ filename = model_info["filename"]
65
+
66
+ try:
67
+ logging.info(f"Downloading model: {filename} from {repo_id}")
68
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
69
+ logging.info(f"Model downloaded to: {model_path}")
70
+
71
+ logging.info("Loading model into memory...")
72
+ llm = Llama(
73
+ model_path=model_path,
74
+ n_ctx=4096, # Max context
75
+ n_threads=2, # For free HF CPU
76
+ n_gpu_layers=0, # Force CPU
77
+ verbose=True
78
+ )
79
+
80
+ llm_cache.clear()
81
+ llm_cache[choice] = llm
82
+ logging.info(f"Model {choice} loaded successfully.")
83
+ return llm
84
+
85
+ except Exception as e:
86
+ logging.critical(f"CRITICAL ERROR: Failed to download/load model {filename}. Error: {e}", exc_info=True)
87
+ return None
88
+
89
+ # --- API Data Models ---
90
+ class StoryPrompt(BaseModel):
91
+ prompt: str
92
+ feedback: str
93
+ story_memory: str
94
+ model_choice: str
95
+
96
+ # --- API Endpoints ---
97
+
98
+ @app.on_event("startup")
99
+ async def startup_event():
100
+ """
101
+ This runs when your Space starts.
102
+ It pre-loads the 'light' model so the app is ready faster.
103
+ """
104
+ logging.info("Server starting up... Acquiring lock to pre-load model.")
105
+ with model_lock:
106
+ get_llm_instance("light")
107
+ logging.info("Server is ready and 'light' model is loaded.")
108
+
109
+ @app.get("/")
110
+ def get_status():
111
+ """
112
+ Health check endpoint.
113
+ This is what your frontend pings.
114
+ """
115
+ loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
116
+ return {
117
+ "status": "AI server is online",
118
+ "model_loaded": loaded_model,
119
+ "models": list(MODEL_MAP.keys()) # <-- This is the CRUCIAL line for your frontend
120
+ }
121
+
122
+ @app.post("/generate")
123
+ async def generate_story(prompt: StoryPrompt):
124
+ """
125
+ Main generation endpoint.
126
+ Uses the thread lock to ensure stability.
127
+ """
128
+ logging.info("Request received. Waiting to acquire model lock...")
129
+ with model_lock:
130
+ logging.info("Lock acquired. Processing request.")
131
+ try:
132
+ llm = get_llm_instance(prompt.model_choice)
133
+ if llm is None:
134
+ logging.error(f"Failed to get model for choice: {prompt.model_choice}")
135
+ return JSONResponse(status_code=503, content={"error": "The AI model is not available or failed to load."})
136
+
137
+ # Format the prompt (Zephyr/ChatML format)
138
+ final_prompt = f"""<|user|>
139
+ Here is the story so far:
140
+ {prompt.story_memory}
141
+
142
+ Here is the part I just wrote or want to continue from:
143
+ {prompt.prompt}
144
+
145
+ Please use this feedback to guide the next chapter:
146
+ {prompt.feedback}
147
+
148
+ Generate the next part of the story.<|endoftext|>
149
+ <|assistant|>"""
150
+
151
+ logging.info(f"Generating with {prompt.model_choice}...")
152
+ output = llm(
153
+ final_prompt,
154
+ max_tokens=512,
155
+ stop=["<|user|>", "<|endoftext|>"],
156
+ echo=False
157
+ )
158
+
159
+ generated_text = output["choices"][0]["text"].strip()
160
+ logging.info("Generation complete.")
161
+
162
+ return {"story_text": generated_text}
163
+
164
+ except Exception as e:
165
+ logging.error(f"An internal error occurred during generation: {e}", exc_info=True)
166
+ return JSONResponse(status_code=500, content={"error": "An unexpected error occurred."})
167
+ finally:
168
+ logging.info("Releasing model lock.")