Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,7 @@ from llama_cpp_agent.tools import WebSearchTool
|
|
| 19 |
from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
|
| 20 |
from langchain_community.llms import HuggingFaceHub
|
| 21 |
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings, LlmStructuredOutputType
|
| 22 |
-
from pydantic import BaseModel
|
| 23 |
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputType
|
| 24 |
|
| 25 |
print("Available LlmStructuredOutputType options:")
|
|
@@ -103,10 +103,11 @@ examples = [
|
|
| 103 |
|
| 104 |
class CustomLLMSettings(BaseModel):
|
| 105 |
structured_output: LlmStructuredOutputSettings
|
| 106 |
-
temperature: float
|
| 107 |
-
top_p: float
|
| 108 |
-
repetition_penalty: float
|
| 109 |
-
|
|
|
|
| 110 |
class HuggingFaceHubWrapper:
|
| 111 |
def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
|
| 112 |
self.model = HuggingFaceHub(
|
|
@@ -117,7 +118,7 @@ class HuggingFaceHubWrapper:
|
|
| 117 |
self.temperature = model_kwargs.get('temperature', 0.7)
|
| 118 |
self.top_p = model_kwargs.get('top_p', 0.95)
|
| 119 |
self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
|
| 120 |
-
|
| 121 |
|
| 122 |
def get_provider_default_settings(self):
|
| 123 |
return CustomLLMSettings(
|
|
@@ -129,7 +130,8 @@ class HuggingFaceHubWrapper:
|
|
| 129 |
),
|
| 130 |
temperature=self.temperature,
|
| 131 |
top_p=self.top_p,
|
| 132 |
-
repetition_penalty=self.repetition_penalty
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
def get_provider_identifier(self):
|
|
@@ -170,13 +172,14 @@ class CitingSources(BaseModel):
|
|
| 170 |
)
|
| 171 |
|
| 172 |
# Model function
|
| 173 |
-
def get_model(temperature, top_p, repetition_penalty):
|
| 174 |
return HuggingFaceHubWrapper(
|
| 175 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
| 176 |
model_kwargs={
|
| 177 |
"temperature": temperature,
|
| 178 |
"top_p": top_p,
|
| 179 |
"repetition_penalty": repetition_penalty,
|
|
|
|
| 180 |
"max_length": 1000
|
| 181 |
},
|
| 182 |
huggingfacehub_api_token=huggingface_token
|
|
@@ -204,8 +207,9 @@ def respond(
|
|
| 204 |
temperature,
|
| 205 |
top_p,
|
| 206 |
repeat_penalty,
|
|
|
|
| 207 |
):
|
| 208 |
-
model = get_model(temperature, top_p, repeat_penalty)
|
| 209 |
|
| 210 |
chat_template = MessagesFormatterType.MISTRAL
|
| 211 |
|
|
@@ -258,6 +262,7 @@ demo = gr.ChatInterface(
|
|
| 258 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
|
| 259 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
|
| 260 |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
|
|
|
|
| 261 |
],
|
| 262 |
theme=gr.themes.Soft(
|
| 263 |
primary_hue="orange",
|
|
|
|
| 19 |
from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
|
| 20 |
from langchain_community.llms import HuggingFaceHub
|
| 21 |
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings, LlmStructuredOutputType
|
| 22 |
+
from pydantic import BaseModel, Field
|
| 23 |
from llama_cpp_agent.llm_output_settings import LlmStructuredOutputType
|
| 24 |
|
| 25 |
print("Available LlmStructuredOutputType options:")
|
|
|
|
| 103 |
|
| 104 |
class CustomLLMSettings(BaseModel):
|
| 105 |
structured_output: LlmStructuredOutputSettings
|
| 106 |
+
temperature: float = Field(default=0.7)
|
| 107 |
+
top_p: float = Field(default=0.95)
|
| 108 |
+
repetition_penalty: float = Field(default=1.1)
|
| 109 |
+
top_k: int = Field(default=50) # Added top_k parameter
|
| 110 |
+
|
| 111 |
class HuggingFaceHubWrapper:
|
| 112 |
def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
|
| 113 |
self.model = HuggingFaceHub(
|
|
|
|
| 118 |
self.temperature = model_kwargs.get('temperature', 0.7)
|
| 119 |
self.top_p = model_kwargs.get('top_p', 0.95)
|
| 120 |
self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
|
| 121 |
+
self.top_k = model_kwargs.get('top_k', 50) # Added top_k
|
| 122 |
|
| 123 |
def get_provider_default_settings(self):
|
| 124 |
return CustomLLMSettings(
|
|
|
|
| 130 |
),
|
| 131 |
temperature=self.temperature,
|
| 132 |
top_p=self.top_p,
|
| 133 |
+
repetition_penalty=self.repetition_penalty,
|
| 134 |
+
top_k=self.top_k # Added top_k
|
| 135 |
)
|
| 136 |
|
| 137 |
def get_provider_identifier(self):
|
|
|
|
| 172 |
)
|
| 173 |
|
| 174 |
# Model function
|
| 175 |
+
def get_model(temperature, top_p, repetition_penalty, top_k=50):
|
| 176 |
return HuggingFaceHubWrapper(
|
| 177 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
| 178 |
model_kwargs={
|
| 179 |
"temperature": temperature,
|
| 180 |
"top_p": top_p,
|
| 181 |
"repetition_penalty": repetition_penalty,
|
| 182 |
+
"top_k": top_k,
|
| 183 |
"max_length": 1000
|
| 184 |
},
|
| 185 |
huggingfacehub_api_token=huggingface_token
|
|
|
|
| 207 |
temperature,
|
| 208 |
top_p,
|
| 209 |
repeat_penalty,
|
| 210 |
+
top_k=50, # Added top_k parameter
|
| 211 |
):
|
| 212 |
+
model = get_model(temperature, top_p, repeat_penalty, top_k)
|
| 213 |
|
| 214 |
chat_template = MessagesFormatterType.MISTRAL
|
| 215 |
|
|
|
|
| 262 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
|
| 263 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
|
| 264 |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
|
| 265 |
+
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"), # Added top_k slider
|
| 266 |
],
|
| 267 |
theme=gr.themes.Soft(
|
| 268 |
primary_hue="orange",
|