Update app.py
Browse files
app.py
CHANGED
|
@@ -100,13 +100,22 @@ examples = [
|
|
| 100 |
["filetype:pdf intitle:python"]
|
| 101 |
]
|
| 102 |
|
| 103 |
-
|
| 104 |
class CustomLLMSettings(BaseModel):
|
| 105 |
structured_output: LlmStructuredOutputSettings
|
| 106 |
temperature: float = Field(default=0.7)
|
| 107 |
top_p: float = Field(default=0.95)
|
| 108 |
repetition_penalty: float = Field(default=1.1)
|
| 109 |
-
top_k: int = Field(default=50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
class HuggingFaceHubWrapper:
|
| 112 |
def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
|
|
@@ -118,7 +127,9 @@ class HuggingFaceHubWrapper:
|
|
| 118 |
self.temperature = model_kwargs.get('temperature', 0.7)
|
| 119 |
self.top_p = model_kwargs.get('top_p', 0.95)
|
| 120 |
self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
|
| 121 |
-
self.top_k = model_kwargs.get('top_k', 50)
|
|
|
|
|
|
|
| 122 |
|
| 123 |
def get_provider_default_settings(self):
|
| 124 |
return CustomLLMSettings(
|
|
@@ -131,7 +142,9 @@ class HuggingFaceHubWrapper:
|
|
| 131 |
temperature=self.temperature,
|
| 132 |
top_p=self.top_p,
|
| 133 |
repetition_penalty=self.repetition_penalty,
|
| 134 |
-
top_k=self.top_k
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
def get_provider_identifier(self):
|
|
@@ -172,7 +185,7 @@ class CitingSources(BaseModel):
|
|
| 172 |
)
|
| 173 |
|
| 174 |
# Model function
|
| 175 |
-
def get_model(temperature, top_p, repetition_penalty, top_k=50):
|
| 176 |
return HuggingFaceHubWrapper(
|
| 177 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
| 178 |
model_kwargs={
|
|
@@ -180,7 +193,8 @@ def get_model(temperature, top_p, repetition_penalty, top_k=50):
|
|
| 180 |
"top_p": top_p,
|
| 181 |
"repetition_penalty": repetition_penalty,
|
| 182 |
"top_k": top_k,
|
| 183 |
-
"max_length":
|
|
|
|
| 184 |
},
|
| 185 |
huggingfacehub_api_token=huggingface_token
|
| 186 |
)
|
|
@@ -207,10 +221,10 @@ def respond(
|
|
| 207 |
temperature,
|
| 208 |
top_p,
|
| 209 |
repeat_penalty,
|
| 210 |
-
top_k=50,
|
|
|
|
| 211 |
):
|
| 212 |
-
model = get_model(temperature, top_p, repeat_penalty, top_k)
|
| 213 |
-
|
| 214 |
chat_template = MessagesFormatterType.MISTRAL
|
| 215 |
|
| 216 |
search_tool = WebSearchTool(
|
|
@@ -262,7 +276,8 @@ demo = gr.ChatInterface(
|
|
| 262 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
|
| 263 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
|
| 264 |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
|
| 265 |
-
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
|
|
|
| 266 |
],
|
| 267 |
theme=gr.themes.Soft(
|
| 268 |
primary_hue="orange",
|
|
|
|
| 100 |
["filetype:pdf intitle:python"]
|
| 101 |
]
|
| 102 |
|
|
|
|
| 103 |
class CustomLLMSettings(BaseModel):
|
| 104 |
structured_output: LlmStructuredOutputSettings
|
| 105 |
temperature: float = Field(default=0.7)
|
| 106 |
top_p: float = Field(default=0.95)
|
| 107 |
repetition_penalty: float = Field(default=1.1)
|
| 108 |
+
top_k: int = Field(default=50)
|
| 109 |
+
max_tokens: int = Field(default=1000)
|
| 110 |
+
stop: list[str] = Field(default_factory=list)
|
| 111 |
+
echo: bool = Field(default=False)
|
| 112 |
+
stream: bool = Field(default=False)
|
| 113 |
+
logprobs: int = Field(default=None)
|
| 114 |
+
presence_penalty: float = Field(default=0.0)
|
| 115 |
+
frequency_penalty: float = Field(default=0.0)
|
| 116 |
+
best_of: int = Field(default=1)
|
| 117 |
+
logit_bias: dict = Field(default_factory=dict)
|
| 118 |
+
max_tokens_per_summary: int = Field(default=2048)
|
| 119 |
|
| 120 |
class HuggingFaceHubWrapper:
|
| 121 |
def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
|
|
|
|
| 127 |
self.temperature = model_kwargs.get('temperature', 0.7)
|
| 128 |
self.top_p = model_kwargs.get('top_p', 0.95)
|
| 129 |
self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
|
| 130 |
+
self.top_k = model_kwargs.get('top_k', 50)
|
| 131 |
+
self.max_tokens = model_kwargs.get('max_length', 1000)
|
| 132 |
+
self.max_tokens_per_summary = model_kwargs.get('max_tokens_per_summary', 2048)
|
| 133 |
|
| 134 |
def get_provider_default_settings(self):
|
| 135 |
return CustomLLMSettings(
|
|
|
|
| 142 |
temperature=self.temperature,
|
| 143 |
top_p=self.top_p,
|
| 144 |
repetition_penalty=self.repetition_penalty,
|
| 145 |
+
top_k=self.top_k,
|
| 146 |
+
max_tokens=self.max_tokens,
|
| 147 |
+
max_tokens_per_summary=self.max_tokens_per_summary
|
| 148 |
)
|
| 149 |
|
| 150 |
def get_provider_identifier(self):
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
# Model function
|
| 188 |
+
def get_model(temperature, top_p, repetition_penalty, top_k=50, max_tokens=1000, max_tokens_per_summary=2048):
|
| 189 |
return HuggingFaceHubWrapper(
|
| 190 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
| 191 |
model_kwargs={
|
|
|
|
| 193 |
"top_p": top_p,
|
| 194 |
"repetition_penalty": repetition_penalty,
|
| 195 |
"top_k": top_k,
|
| 196 |
+
"max_length": max_tokens,
|
| 197 |
+
"max_tokens_per_summary": max_tokens_per_summary
|
| 198 |
},
|
| 199 |
huggingfacehub_api_token=huggingface_token
|
| 200 |
)
|
|
|
|
| 221 |
temperature,
|
| 222 |
top_p,
|
| 223 |
repeat_penalty,
|
| 224 |
+
top_k=50,
|
| 225 |
+
max_tokens_per_summary=2048
|
| 226 |
):
|
| 227 |
+
model = get_model(temperature, top_p, repeat_penalty, top_k, max_tokens, max_tokens_per_summary)
|
|
|
|
| 228 |
chat_template = MessagesFormatterType.MISTRAL
|
| 229 |
|
| 230 |
search_tool = WebSearchTool(
|
|
|
|
| 276 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
|
| 277 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
|
| 278 |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
|
| 279 |
+
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
| 280 |
+
gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens per summary"),
|
| 281 |
],
|
| 282 |
theme=gr.themes.Soft(
|
| 283 |
primary_hue="orange",
|