Spaces:
Sleeping
Sleeping
Patryk Studzinski commited on
Commit ·
db4996d
1
Parent(s): d9b1571
increase context size and improve message handling in LlamaCppModel
Browse files- app/models/llama_cpp_model.py +16 -28
app/models/llama_cpp_model.py
CHANGED
|
@@ -21,7 +21,7 @@ class LlamaCppModel(BaseLLM):
|
|
| 21 |
Provides significant speedups on CPU compared to Transformers.
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(self, name: str, model_id: str, model_path: str = None, n_ctx: int =
|
| 25 |
super().__init__(name, model_id)
|
| 26 |
self.model_path = model_path
|
| 27 |
self.n_ctx = n_ctx
|
|
@@ -55,7 +55,7 @@ class LlamaCppModel(BaseLLM):
|
|
| 55 |
)
|
| 56 |
|
| 57 |
self._initialized = True
|
| 58 |
-
print(f"[{self.name}] GGUF Model loaded successfully")
|
| 59 |
|
| 60 |
except Exception as e:
|
| 61 |
print(f"[{self.name}] Failed to load GGUF model: {e}")
|
|
@@ -75,43 +75,31 @@ class LlamaCppModel(BaseLLM):
|
|
| 75 |
if not self._initialized or self.llm is None:
|
| 76 |
raise RuntimeError(f"[{self.name}] Model not initialized")
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
if
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
role = msg.get("role", "user")
|
| 85 |
-
content = msg.get("content", "")
|
| 86 |
-
if role == "system":
|
| 87 |
-
formatted_prompt += f"{content}\n\n"
|
| 88 |
-
elif role == "user":
|
| 89 |
-
formatted_prompt += f"User: {content}\n"
|
| 90 |
-
elif role == "assistant":
|
| 91 |
-
formatted_prompt += f"Assistant: {content}\n"
|
| 92 |
-
formatted_prompt += "Assistant:"
|
| 93 |
-
elif prompt:
|
| 94 |
-
formatted_prompt = prompt
|
| 95 |
-
else:
|
| 96 |
raise ValueError("Either prompt or chat_messages required")
|
| 97 |
|
| 98 |
-
# Cache Check
|
| 99 |
-
|
|
|
|
| 100 |
if cache_key in self._response_cache:
|
| 101 |
return self._response_cache[cache_key]
|
| 102 |
|
| 103 |
-
# Generate
|
| 104 |
output = await asyncio.to_thread(
|
| 105 |
-
self.llm.
|
| 106 |
-
|
| 107 |
max_tokens=max_new_tokens,
|
| 108 |
temperature=temperature,
|
| 109 |
top_p=top_p,
|
| 110 |
-
stop
|
| 111 |
-
echo=False
|
| 112 |
)
|
| 113 |
|
| 114 |
-
response_text = output['choices'][0]['
|
| 115 |
|
| 116 |
# Cache Store
|
| 117 |
if len(self._response_cache) >= self._max_cache_size:
|
|
|
|
| 21 |
Provides significant speedups on CPU compared to Transformers.
|
| 22 |
"""
|
| 23 |
|
| 24 |
+
def __init__(self, name: str, model_id: str, model_path: str = None, n_ctx: int = 8192):
|
| 25 |
super().__init__(name, model_id)
|
| 26 |
self.model_path = model_path
|
| 27 |
self.n_ctx = n_ctx
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
self._initialized = True
|
| 58 |
+
print(f"[{self.name}] GGUF Model loaded successfully (n_ctx={self.n_ctx})")
|
| 59 |
|
| 60 |
except Exception as e:
|
| 61 |
print(f"[{self.name}] Failed to load GGUF model: {e}")
|
|
|
|
| 75 |
if not self._initialized or self.llm is None:
|
| 76 |
raise RuntimeError(f"[{self.name}] Model not initialized")
|
| 77 |
|
| 78 |
+
# Ensure we have a list of messages
|
| 79 |
+
messages = chat_messages
|
| 80 |
+
if not messages and prompt:
|
| 81 |
+
messages = [{"role": "user", "content": prompt}]
|
| 82 |
+
|
| 83 |
+
if not messages:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
raise ValueError("Either prompt or chat_messages required")
|
| 85 |
|
| 86 |
+
# Cache Check - using stringified messages for the key
|
| 87 |
+
import json
|
| 88 |
+
cache_key = f"{json.dumps(messages)}_{max_new_tokens}_{temperature}_{top_p}"
|
| 89 |
if cache_key in self._response_cache:
|
| 90 |
return self._response_cache[cache_key]
|
| 91 |
|
| 92 |
+
# Generate using chat completion to leverage internal templates
|
| 93 |
output = await asyncio.to_thread(
|
| 94 |
+
self.llm.create_chat_completion,
|
| 95 |
+
messages=messages,
|
| 96 |
max_tokens=max_new_tokens,
|
| 97 |
temperature=temperature,
|
| 98 |
top_p=top_p,
|
| 99 |
+
# No manual stop tokens needed usually as template handles them
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
+
response_text = output['choices'][0]['message']['content'].strip()
|
| 103 |
|
| 104 |
# Cache Store
|
| 105 |
if len(self._response_cache) >= self._max_cache_size:
|