Tim Luka Horstmann
commited on
Commit
·
6f6e59d
1
Parent(s):
58d2235
Updated to use history
Browse files
app.py
CHANGED
|
@@ -31,7 +31,6 @@ login(token=hf_token)
|
|
| 31 |
|
| 32 |
# Models Configuration
|
| 33 |
sentence_transformer_model = "all-MiniLM-L6-v2"
|
| 34 |
-
# Using the 8B model with Q4_K_M quantization
|
| 35 |
repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-8B-GGUF"
|
| 36 |
filename = "deepcogito_cogito-v1-preview-llama-8B-Q4_K_M.gguf"
|
| 37 |
|
|
@@ -68,7 +67,7 @@ try:
|
|
| 68 |
faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
|
| 69 |
faiss.normalize_L2(faq_embeddings)
|
| 70 |
|
| 71 |
-
# Load the 8B Cogito model
|
| 72 |
logger.info(f"Loading {filename} model")
|
| 73 |
model_path = hf_hub_download(
|
| 74 |
repo_id=repo_id,
|
|
@@ -76,13 +75,13 @@ try:
|
|
| 76 |
local_dir="/app/cache" if os.getenv("HF_HOME") else None,
|
| 77 |
token=hf_token,
|
| 78 |
)
|
| 79 |
-
# Use n_batch=256 for lower first-token latency on CPU
|
| 80 |
generator = Llama(
|
| 81 |
model_path=model_path,
|
| 82 |
-
n_ctx=
|
| 83 |
n_threads=2,
|
| 84 |
-
n_batch=
|
| 85 |
n_gpu_layers=0,
|
|
|
|
| 86 |
verbose=True,
|
| 87 |
)
|
| 88 |
logger.info(f"{filename} model loaded")
|
|
@@ -106,42 +105,42 @@ def retrieve_context(query, top_k=2):
|
|
| 106 |
with open("cv_text.txt", "r", encoding="utf-8") as f:
|
| 107 |
full_cv_text = f.read()
|
| 108 |
|
| 109 |
-
async def stream_response(query):
|
| 110 |
logger.info(f"Processing query: {query}")
|
| 111 |
start_time = time.time()
|
| 112 |
first_token_logged = False
|
| 113 |
|
| 114 |
current_date = datetime.now().strftime("%Y-%m-%d")
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
#
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
#
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
]
|
| 143 |
-
|
| 144 |
-
#
|
| 145 |
async with model_lock:
|
| 146 |
for chunk in generator.create_chat_completion(
|
| 147 |
messages=messages,
|
|
@@ -160,14 +159,14 @@ async def stream_response(query):
|
|
| 160 |
yield "data: [DONE]\n\n"
|
| 161 |
|
| 162 |
class QueryRequest(BaseModel):
|
| 163 |
-
|
|
|
|
| 164 |
|
| 165 |
@app.post("/api/predict")
|
| 166 |
async def predict(request: QueryRequest):
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
query =
|
| 170 |
-
return StreamingResponse(stream_response(query), media_type="text/event-stream")
|
| 171 |
|
| 172 |
@app.get("/health")
|
| 173 |
async def health_check():
|
|
@@ -188,6 +187,7 @@ async def model_info():
|
|
| 188 |
async def warm_up_model():
|
| 189 |
logger.info("Warming up the model...")
|
| 190 |
dummy_query = "Hello"
|
| 191 |
-
|
|
|
|
| 192 |
pass
|
| 193 |
logger.info("Model warm-up completed.")
|
|
|
|
| 31 |
|
| 32 |
# Models Configuration
|
| 33 |
sentence_transformer_model = "all-MiniLM-L6-v2"
|
|
|
|
| 34 |
repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-8B-GGUF"
|
| 35 |
filename = "deepcogito_cogito-v1-preview-llama-8B-Q4_K_M.gguf"
|
| 36 |
|
|
|
|
| 67 |
faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
|
| 68 |
faiss.normalize_L2(faq_embeddings)
|
| 69 |
|
| 70 |
+
# Load the 8B Cogito model with optimized parameters
|
| 71 |
logger.info(f"Loading {filename} model")
|
| 72 |
model_path = hf_hub_download(
|
| 73 |
repo_id=repo_id,
|
|
|
|
| 75 |
local_dir="/app/cache" if os.getenv("HF_HOME") else None,
|
| 76 |
token=hf_token,
|
| 77 |
)
|
|
|
|
| 78 |
generator = Llama(
|
| 79 |
model_path=model_path,
|
| 80 |
+
n_ctx=3072,
|
| 81 |
n_threads=2,
|
| 82 |
+
n_batch=128,
|
| 83 |
n_gpu_layers=0,
|
| 84 |
+
f16_kv=True,
|
| 85 |
verbose=True,
|
| 86 |
)
|
| 87 |
logger.info(f"{filename} model loaded")
|
|
|
|
| 105 |
with open("cv_text.txt", "r", encoding="utf-8") as f:
|
| 106 |
full_cv_text = f.read()
|
| 107 |
|
| 108 |
+
async def stream_response(query, history):
|
| 109 |
logger.info(f"Processing query: {query}")
|
| 110 |
start_time = time.time()
|
| 111 |
first_token_logged = False
|
| 112 |
|
| 113 |
current_date = datetime.now().strftime("%Y-%m-%d")
|
| 114 |
|
| 115 |
+
system_prompt = (
|
| 116 |
+
"You are Tim Luka Horstmann, a Computer Scientist. A user is asking you a question. Respond as yourself, using the first person, in a friendly and concise manner. "
|
| 117 |
+
"For questions about your CV, base your answer *exclusively* on the provided CV information below and do not add any details not explicitly stated. "
|
| 118 |
+
"For casual questions not covered by the CV, respond naturally but limit answers to general truths about yourself (e.g., your current location is Paris, France, or your field is AI) "
|
| 119 |
+
"and say 'I don't have specific details to share about that' if pressed for specifics beyond the CV or FAQs. Do not invent facts, experiences, or opinions not supported by the CV or FAQs. "
|
| 120 |
+
f"Today’s date is {current_date}. "
|
| 121 |
+
f"CV: {full_cv_text}"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Combine system prompt, history, and current query
|
| 125 |
+
messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": query}]
|
| 126 |
+
|
| 127 |
+
# Estimate token counts and truncate history if necessary
|
| 128 |
+
system_tokens = len(generator.tokenize(system_prompt))
|
| 129 |
+
query_tokens = len(generator.tokenize(query))
|
| 130 |
+
history_tokens = [len(generator.tokenize(msg["content"])) for msg in history]
|
| 131 |
+
total_tokens = system_tokens + query_tokens + sum(history_tokens) + len(history) * 10 + 10 # Rough estimate for formatting
|
| 132 |
+
|
| 133 |
+
max_allowed_tokens = generator.n_ctx - 512 - 100 # max_tokens=512, safety_margin=100
|
| 134 |
+
|
| 135 |
+
while total_tokens > max_allowed_tokens and history:
|
| 136 |
+
removed_msg = history.pop(0)
|
| 137 |
+
removed_tokens = len(generator.tokenize(removed_msg["content"]))
|
| 138 |
+
total_tokens -= (removed_tokens + 10)
|
| 139 |
+
|
| 140 |
+
# Reconstruct messages after possible truncation
|
| 141 |
+
messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": query}]
|
| 142 |
+
|
| 143 |
+
# Generate response with lock
|
| 144 |
async with model_lock:
|
| 145 |
for chunk in generator.create_chat_completion(
|
| 146 |
messages=messages,
|
|
|
|
| 159 |
yield "data: [DONE]\n\n"
|
| 160 |
|
| 161 |
class QueryRequest(BaseModel):
|
| 162 |
+
query: str
|
| 163 |
+
history: list[dict]
|
| 164 |
|
| 165 |
@app.post("/api/predict")
|
| 166 |
async def predict(request: QueryRequest):
|
| 167 |
+
query = request.query
|
| 168 |
+
history = request.history
|
| 169 |
+
return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
|
|
|
|
| 170 |
|
| 171 |
@app.get("/health")
|
| 172 |
async def health_check():
|
|
|
|
| 187 |
async def warm_up_model():
|
| 188 |
logger.info("Warming up the model...")
|
| 189 |
dummy_query = "Hello"
|
| 190 |
+
dummy_history = []
|
| 191 |
+
async for _ in stream_response(dummy_query, dummy_history):
|
| 192 |
pass
|
| 193 |
logger.info("Model warm-up completed.")
|