fugthchat commited on
Commit
c8b68fa
·
verified ·
1 Parent(s): cd32c82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -70
app.py CHANGED
@@ -1,102 +1,160 @@
1
- from fastapi import FastAPI, Request
 
 
2
  from fastapi.responses import JSONResponse
3
  from pydantic import BaseModel
4
  from llama_cpp import Llama
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from huggingface_hub import hf_hub_download
7
  import logging
8
- import re
9
  import threading
10
 
11
- # Set up logging to get more detailed output
12
  logging.basicConfig(level=logging.INFO)
 
 
 
13
 
14
- # --- STABLE MODEL CONFIGURATION ---
15
- MODEL_REPO_ID = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
16
- MODEL_FILENAME = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
17
-
18
- # --- GLOBAL MODEL OBJECT & THREAD LOCK ---
19
- llm = None
20
- model_lock = threading.Lock()
21
-
22
- # --- SERVER STARTUP LOGIC ---
23
- logging.info("Server starting...")
24
-
25
- try:
26
- logging.info(f"Downloading single model: {MODEL_FILENAME} from {MODEL_REPO_ID}")
27
- model_path = hf_hub_download(repo_id=MODEL_REPO_ID, filename=MODEL_FILENAME)
28
- logging.info(f"Model downloaded to: {model_path}")
29
-
30
- logging.info("Loading model from local path with optimized settings...")
31
- llm = Llama(
32
- model_path=model_path,
33
- n_ctx=1024,
34
- n_threads=2,
35
- n_gpu_layers=0,
36
- verbose=True
37
- )
38
- logging.info("Model loaded successfully! AI server is ready.")
39
- except Exception as e:
40
- logging.critical(f"CRITICAL ERROR: Failed to load the model. Server will be non-functional. Error: {e}", exc_info=True)
41
 
42
- # --- FASTAPI APP SETUP ---
43
- app = FastAPI()
44
  app.add_middleware(
45
  CORSMiddleware,
46
- allow_origins=["*"],
47
  allow_credentials=True,
48
  allow_methods=["*"],
49
  allow_headers=["*"],
50
  )
51
 
52
- # --- API ENDPOINTS ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  @app.get("/")
54
  def get_status():
55
- """Endpoint to check if the server and model are online."""
 
56
  return {
57
  "status": "AI server is online",
58
- "model_loaded": llm is not None
59
  }
60
 
61
- @app.post("/chat")
62
- async def chat_endpoint(request: Request): # Changed to use the raw Request object
63
- """
64
- Main chat endpoint. This version manually parses the JSON body to
65
- bypass the Pydantic 422 validation error.
66
- """
67
  with model_lock:
68
- if not llm:
69
- logging.error("Chat request received but model is not loaded.")
70
- return JSONResponse(status_code=503, content={"response": "The AI model is not available. Please contact support."})
71
-
72
  try:
73
- # Manually parse the JSON from the request body
74
- data = await request.json()
75
- prompt = data.get("prompt")
76
- quality = data.get("quality", "lite")
 
 
 
 
 
 
 
 
77
 
78
- if not prompt:
79
- return JSONResponse(status_code=400, content={"response": "Error: No prompt was provided in the request."})
80
 
81
- if quality == "high":
82
- max_tokens = 512
83
- logging.info(f"Handling request with HIGH quality setting (max_tokens={max_tokens}).")
84
- else:
85
- max_tokens = 200
86
- logging.info(f"Handling request with LITE quality setting (max_tokens={max_tokens}).")
87
 
88
- output = llm.create_completion(
89
- prompt=prompt,
90
- max_tokens=max_tokens,
91
- temperature=0.7,
92
- stop=["</s>", "<|user|>", "<|system|>"],
93
- stream=False
 
94
  )
95
 
96
- response_text = output['choices'][0]['text'].strip()
97
- logging.info("Successfully generated response.")
98
- return {"response": response_text}
99
-
 
 
100
  except Exception as e:
101
- logging.error(f"An internal error occurred during chat completion: {e}", exc_info=True)
102
- return JSONResponse(status_code=500, content={"response": "An unexpected error occurred while processing your request."})
 
1
+ import os
2
+ import uvicorn
3
+ from fastapi import FastAPI
4
  from fastapi.responses import JSONResponse
5
  from pydantic import BaseModel
6
  from llama_cpp import Llama
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from huggingface_hub import hf_hub_download
9
  import logging
 
10
  import threading
11
 
12
+ # --- Setup ---
13
  logging.basicConfig(level=logging.INFO)
14
+ app = FastAPI()
15
+ model_lock = threading.Lock() # From your old app, this is great for stability
16
+ llm_cache = {} # To store loaded models
17
 
18
+ # --- Model Map (With CORRECT URLs) ---
19
+ # Your frontend can request "light", "medium", or "heavy"
20
+ MODEL_MAP = {
21
+ "light": {
22
+ "repo_id": "TheBloke/stablelm-zephyr-3b-GGUF",
23
+ "filename": "stablelm-zephyr-3b.Q3_K_S.gguf" # 1.25 GB
24
+ },
25
+ "medium": {
26
+ "repo_id": "TheBloke/stablelm-zephyr-3b-GGUF",
27
+ "filename": "stablelm-zephyr-3b.Q4_K_M.gguf" # 1.71 GB
28
+ },
29
+ "heavy": {
30
+ "repo_id": "TheBloke/stablelm-zephyr-3b-GGUF",
31
+ "filename": "stablelm-zephyr-3b.Q5_K_M.gguf" # 2.03 GB
32
+ }
33
+ }
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # --- CORS ---
 
36
  app.add_middleware(
37
  CORSMiddleware,
38
+ allow_origins=["*"], # Allow your GitHub Pages frontend
39
  allow_credentials=True,
40
  allow_methods=["*"],
41
  allow_headers=["*"],
42
  )
43
 
44
+ # --- Model Loading Logic ---
45
+ def get_llm_instance(choice: str) -> Llama:
46
+ """
47
+ Downloads, loads, and caches a model.
48
+ This is thread-safe thanks to the lock.
49
+ """
50
+ if choice not in MODEL_MAP:
51
+ logging.error(f"Invalid model choice: {choice}")
52
+ return None
53
+
54
+ # If model is already loaded, just return it
55
+ if choice in llm_cache:
56
+ logging.info(f"Using cached model: {choice}")
57
+ return llm_cache[choice]
58
+
59
+ # If not in cache, download and load
60
+ model_info = MODEL_MAP[choice]
61
+ repo_id = model_info["repo_id"]
62
+ filename = model_info["filename"]
63
+
64
+ try:
65
+ logging.info(f"Downloading model: {filename} from {repo_id}...")
66
+ # Use hf_hub_download (from your old app)
67
+ model_path = hf_hub_download(repo_id=repo_id, filename=filename)
68
+ logging.info(f"Model downloaded to: {model_path}")
69
+
70
+ logging.info("Loading model into memory...")
71
+ llm = Llama(
72
+ model_path=model_path,
73
+ n_ctx=4096, # Max context
74
+ n_threads=2, # Free HF CPU has 2 cores
75
+ n_gpu_layers=0, # Force CPU
76
+ verbose=True
77
+ )
78
+
79
+ llm_cache.clear() # Clear old models to save RAM
80
+ llm_cache[choice] = llm # Cache the new model
81
+ logging.info(f"Model {choice} loaded successfully.")
82
+ return llm
83
+
84
+ except Exception as e:
85
+ logging.critical(f"Failed to download/load model {filename}. Error: {e}", exc_info=True)
86
+ return None
87
+
88
+ # --- API Request Model ---
89
+ class StoryPrompt(BaseModel):
90
+ prompt: str
91
+ feedback: str
92
+ story_memory: str
93
+ model_choice: str
94
+
95
+ # --- App Startup Event ---
96
+ @app.on_event("startup")
97
+ async def startup_event():
98
+ """
99
+ On startup, we acquire the lock and pre-load the default 'light' model.
100
+ This is what runs *after* the build.
101
+ """
102
+ logging.info("Server starting... Acquiring lock to pre-load 'light' model.")
103
+ with model_lock:
104
+ get_llm_instance("light")
105
+ logging.info("Server is ready and 'light' model is loaded.")
106
+
107
+ # --- API Endpoints ---
108
  @app.get("/")
109
  def get_status():
110
+ """Health check endpoint."""
111
+ loaded_model = list(llm_cache.keys())[0] if llm_cache else "None"
112
  return {
113
  "status": "AI server is online",
114
+ "model_loaded": loaded_model
115
  }
116
 
117
+ @app.post("/generate")
118
+ async def generate_story(prompt: StoryPrompt):
119
+ """Main generation endpoint. It's thread-safe."""
120
+ logging.info("Request received. Waiting for model lock...")
 
 
121
  with model_lock:
122
+ logging.info("Lock acquired. Processing.")
 
 
 
123
  try:
124
+ # 1. Get the correct LLM (load if needed)
125
+ llm = get_llm_instance(prompt.model_choice)
126
+ if llm is None:
127
+ return JSONResponse(status_code=503, content={"error": "Model failed to load."})
128
+
129
+ # 2. Format the prompt (Zephyr/ChatML format)
130
+ final_prompt = f"""<|user|>
131
+ Story so far:
132
+ {prompt.story_memory}
133
+
134
+ My new part/instruction:
135
+ {prompt.prompt}
136
 
137
+ Feedback to apply:
138
+ {prompt.feedback}
139
 
140
+ Generate the next part of the story.<|endoftext|>
141
+ <|assistant|>"""
 
 
 
 
142
 
143
+ # 3. Generate
144
+ logging.info(f"Generating with {prompt.model_choice}...")
145
+ output = llm(
146
+ final_prompt,
147
+ max_tokens=512,
148
+ stop=["<|user|>", "<|endoftext|>"],
149
+ echo=False
150
  )
151
 
152
+ generated_text = output["choices"][0]["text"].strip()
153
+ logging.info("Generation complete.")
154
+
155
+ # This matches the key your frontend expects
156
+ return {"story_text": generated_text}
157
+
158
  except Exception as e:
159
+ logging.error(f"Generation error: {e}", exc_info=True)
160
+ return JSONResponse(status_code=500, content={"error": "An unexpected error occurred."})