Tim Luka Horstmann commited on
Commit
6f6e59d
·
1 Parent(s): 58d2235

Updated to use history

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -31,7 +31,6 @@ login(token=hf_token)
31
 
32
  # Models Configuration
33
  sentence_transformer_model = "all-MiniLM-L6-v2"
34
- # Using the 8B model with Q4_K_M quantization
35
  repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-8B-GGUF"
36
  filename = "deepcogito_cogito-v1-preview-llama-8B-Q4_K_M.gguf"
37
 
@@ -68,7 +67,7 @@ try:
68
  faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
69
  faiss.normalize_L2(faq_embeddings)
70
 
71
- # Load the 8B Cogito model
72
  logger.info(f"Loading {filename} model")
73
  model_path = hf_hub_download(
74
  repo_id=repo_id,
@@ -76,13 +75,13 @@ try:
76
  local_dir="/app/cache" if os.getenv("HF_HOME") else None,
77
  token=hf_token,
78
  )
79
- # Use n_batch=256 for lower first-token latency on CPU
80
  generator = Llama(
81
  model_path=model_path,
82
- n_ctx=2048,
83
  n_threads=2,
84
- n_batch=256, # Reduced from 512 to improve streaming responsiveness
85
  n_gpu_layers=0,
 
86
  verbose=True,
87
  )
88
  logger.info(f"{filename} model loaded")
@@ -106,42 +105,42 @@ def retrieve_context(query, top_k=2):
106
  with open("cv_text.txt", "r", encoding="utf-8") as f:
107
  full_cv_text = f.read()
108
 
109
- async def stream_response(query):
110
  logger.info(f"Processing query: {query}")
111
  start_time = time.time()
112
  first_token_logged = False
113
 
114
  current_date = datetime.now().strftime("%Y-%m-%d")
115
 
116
- # FAQ check first (keep this as it's fast)
117
- # query_embedding = embedder.encode(query, convert_to_numpy=True).astype("float32")
118
- # query_embedding = query_embedding.reshape(1, -1)
119
- # faiss.normalize_L2(query_embedding)
120
- # similarities = np.dot(faq_embeddings, query_embedding.T).flatten()
121
- # max_sim = np.max(similarities)
122
- # if max_sim > 0.9:
123
- # idx = np.argmax(similarities)
124
- # yield f"data: {faqs[idx]['answer']}\n\n"
125
- # yield "data: [DONE]\n\n"
126
- # return
127
-
128
- # Use full CV instead of retrieved chunks
129
- messages = [
130
- {
131
- "role": "system",
132
- "content": (
133
- "You are Tim Luka Horstmann, a Computer Scientist. A user is asking you a question. Respond as yourself, using the first person, in a friendly and concise manner. "
134
- "For questions about your CV, base your answer *exclusively* on the provided CV information below and do not add any details not explicitly stated. "
135
- "For casual questions not covered by the CV, respond naturally but limit answers to general truths about yourself (e.g., your current location is Paris, France, or your field is AI) "
136
- "and say 'I don't have specific details to share about that' if pressed for specifics beyond the CV or FAQs. Do not invent facts, experiences, or opinions not supported by the CV or FAQs. "
137
- f"Today’s date is {current_date}. "
138
- f"CV: {full_cv_text}"
139
- )
140
- },
141
- {"role": "user", "content": query}
142
- ]
143
-
144
- # Acquire lock to ensure exclusive model access
145
  async with model_lock:
146
  for chunk in generator.create_chat_completion(
147
  messages=messages,
@@ -160,14 +159,14 @@ async def stream_response(query):
160
  yield "data: [DONE]\n\n"
161
 
162
  class QueryRequest(BaseModel):
163
- data: list
 
164
 
165
  @app.post("/api/predict")
166
  async def predict(request: QueryRequest):
167
- if not request.data or not isinstance(request.data, list) or len(request.data) < 1:
168
- raise HTTPException(status_code=400, detail="Invalid input: 'data' must be a non-empty list")
169
- query = request.data[0]
170
- return StreamingResponse(stream_response(query), media_type="text/event-stream")
171
 
172
  @app.get("/health")
173
  async def health_check():
@@ -188,6 +187,7 @@ async def model_info():
188
  async def warm_up_model():
189
  logger.info("Warming up the model...")
190
  dummy_query = "Hello"
191
- async for _ in stream_response(dummy_query):
 
192
  pass
193
  logger.info("Model warm-up completed.")
 
31
 
32
  # Models Configuration
33
  sentence_transformer_model = "all-MiniLM-L6-v2"
 
34
  repo_id = "bartowski/deepcogito_cogito-v1-preview-llama-8B-GGUF"
35
  filename = "deepcogito_cogito-v1-preview-llama-8B-Q4_K_M.gguf"
36
 
 
67
  faq_embeddings = embedder.encode(faq_questions, convert_to_numpy=True).astype("float32")
68
  faiss.normalize_L2(faq_embeddings)
69
 
70
+ # Load the 8B Cogito model with optimized parameters
71
  logger.info(f"Loading {filename} model")
72
  model_path = hf_hub_download(
73
  repo_id=repo_id,
 
75
  local_dir="/app/cache" if os.getenv("HF_HOME") else None,
76
  token=hf_token,
77
  )
 
78
  generator = Llama(
79
  model_path=model_path,
80
+ n_ctx=3072,
81
  n_threads=2,
82
+ n_batch=128,
83
  n_gpu_layers=0,
84
+ f16_kv=True,
85
  verbose=True,
86
  )
87
  logger.info(f"{filename} model loaded")
 
105
  with open("cv_text.txt", "r", encoding="utf-8") as f:
106
  full_cv_text = f.read()
107
 
108
+ async def stream_response(query, history):
109
  logger.info(f"Processing query: {query}")
110
  start_time = time.time()
111
  first_token_logged = False
112
 
113
  current_date = datetime.now().strftime("%Y-%m-%d")
114
 
115
+ system_prompt = (
116
+ "You are Tim Luka Horstmann, a Computer Scientist. A user is asking you a question. Respond as yourself, using the first person, in a friendly and concise manner. "
117
+ "For questions about your CV, base your answer *exclusively* on the provided CV information below and do not add any details not explicitly stated. "
118
+ "For casual questions not covered by the CV, respond naturally but limit answers to general truths about yourself (e.g., your current location is Paris, France, or your field is AI) "
119
+ "and say 'I don't have specific details to share about that' if pressed for specifics beyond the CV or FAQs. Do not invent facts, experiences, or opinions not supported by the CV or FAQs. "
120
+ f"Today’s date is {current_date}. "
121
+ f"CV: {full_cv_text}"
122
+ )
123
+
124
+ # Combine system prompt, history, and current query
125
+ messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": query}]
126
+
127
+ # Estimate token counts and truncate history if necessary
128
+ system_tokens = len(generator.tokenize(system_prompt))
129
+ query_tokens = len(generator.tokenize(query))
130
+ history_tokens = [len(generator.tokenize(msg["content"])) for msg in history]
131
+ total_tokens = system_tokens + query_tokens + sum(history_tokens) + len(history) * 10 + 10 # Rough estimate for formatting
132
+
133
+ max_allowed_tokens = generator.n_ctx - 512 - 100 # max_tokens=512, safety_margin=100
134
+
135
+ while total_tokens > max_allowed_tokens and history:
136
+ removed_msg = history.pop(0)
137
+ removed_tokens = len(generator.tokenize(removed_msg["content"]))
138
+ total_tokens -= (removed_tokens + 10)
139
+
140
+ # Reconstruct messages after possible truncation
141
+ messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": query}]
142
+
143
+ # Generate response with lock
144
  async with model_lock:
145
  for chunk in generator.create_chat_completion(
146
  messages=messages,
 
159
  yield "data: [DONE]\n\n"
160
 
161
  class QueryRequest(BaseModel):
162
+ query: str
163
+ history: list[dict]
164
 
165
  @app.post("/api/predict")
166
  async def predict(request: QueryRequest):
167
+ query = request.query
168
+ history = request.history
169
+ return StreamingResponse(stream_response(query, history), media_type="text/event-stream")
 
170
 
171
  @app.get("/health")
172
  async def health_check():
 
187
  async def warm_up_model():
188
  logger.info("Warming up the model...")
189
  dummy_query = "Hello"
190
+ dummy_history = []
191
+ async for _ in stream_response(dummy_query, dummy_history):
192
  pass
193
  logger.info("Model warm-up completed.")