Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -105,7 +105,7 @@ from huggingface_hub import InferenceClient
|
|
| 105 |
|
| 106 |
# NLTK Resource Download
|
| 107 |
def download_nltk_resources():
|
| 108 |
-
resources = ['punkt', 'stopwords', 'snowball_data']
|
| 109 |
for resource in resources:
|
| 110 |
try:
|
| 111 |
nltk.download(resource, quiet=False)
|
|
@@ -337,7 +337,7 @@ def optimize_query(
|
|
| 337 |
vector_store_type: str, # Added to match your signature
|
| 338 |
search_type: str, # Added to match your signature
|
| 339 |
top_k: int = 3,
|
| 340 |
-
use_gpu: bool =
|
| 341 |
) -> str:
|
| 342 |
"""
|
| 343 |
CPU-optimized version of query expansion using a small language model.
|
|
@@ -354,7 +354,7 @@ def optimize_query(
|
|
| 354 |
|
| 355 |
Returns:
|
| 356 |
Expanded query string
|
| 357 |
-
"""
|
| 358 |
try:
|
| 359 |
# Set device
|
| 360 |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
|
@@ -372,74 +372,60 @@ def optimize_query(
|
|
| 372 |
expanded_terms.update([lemma.name() for lemma in syn.lemmas()[:2]])
|
| 373 |
|
| 374 |
# 3. Use provided model with reduced complexity
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
max_length=64,
|
| 398 |
-
truncation=True,
|
| 399 |
-
padding=True
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
# Generate with minimal parameters
|
| 403 |
-
with torch.no_grad():
|
| 404 |
-
outputs = model.generate(
|
| 405 |
-
inputs.input_ids.to(device),
|
| 406 |
-
max_length=32,
|
| 407 |
num_return_sequences=1,
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
)
|
| 412 |
|
| 413 |
-
|
|
|
|
|
|
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
|
| 419 |
-
except Exception as model_error:
|
| 420 |
-
print(f"Model-based expansion failed: {str(model_error)}")
|
| 421 |
-
enhanced_query = query
|
| 422 |
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
|
| 429 |
# 5. Remove stopwords and select top_k most relevant terms
|
| 430 |
stopwords = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'])
|
| 431 |
final_terms = [term for term in final_terms if term not in stopwords]
|
| 432 |
|
| 433 |
# Combine with original query
|
| 434 |
-
|
| 435 |
-
|
| 436 |
# Clean up
|
| 437 |
-
|
| 438 |
-
del tokenizer
|
| 439 |
-
if device == "cuda":
|
| 440 |
-
torch.cuda.empty_cache()
|
| 441 |
|
| 442 |
-
return
|
| 443 |
|
| 444 |
except Exception as e:
|
| 445 |
print(f"Query optimization failed: {str(e)}")
|
|
@@ -1073,6 +1059,7 @@ def analyze_results(stats_df):
|
|
| 1073 |
return recommendations
|
| 1074 |
|
| 1075 |
####
|
|
|
|
| 1076 |
|
| 1077 |
def get_llm_suggested_settings(file, num_chunks=1):
|
| 1078 |
if not file:
|
|
@@ -1092,7 +1079,7 @@ def get_llm_suggested_settings(file, num_chunks=1):
|
|
| 1092 |
sample_chunks = random.sample(chunks, min(num_chunks, len(chunks)))
|
| 1093 |
|
| 1094 |
|
| 1095 |
-
llm_pipeline = pipeline(model="meta-llama/Llama-3.2-1B-Instruct", device='
|
| 1096 |
|
| 1097 |
|
| 1098 |
prompt=f'''
|
|
@@ -1155,17 +1142,16 @@ def get_llm_suggested_settings(file, num_chunks=1):
|
|
| 1155 |
max_new_tokens=1900, # Control the length of the output,
|
| 1156 |
truncation=True, # Enable truncation
|
| 1157 |
)
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
#
|
| 1161 |
-
print("setting suggested")
|
| 1162 |
-
print(suggested_settings)
|
| 1163 |
-
# Parse the generated text to extract the dictionary
|
| 1164 |
try:
|
| 1165 |
-
|
|
|
|
|
|
|
| 1166 |
# Convert the settings to match the interface inputs
|
| 1167 |
return {
|
| 1168 |
-
"embedding_models":
|
| 1169 |
"split_strategy": settings_dict["split_strategy"],
|
| 1170 |
"chunk_size": settings_dict["chunk_size"],
|
| 1171 |
"overlap_size": settings_dict["overlap_size"],
|
|
@@ -1173,13 +1159,15 @@ def get_llm_suggested_settings(file, num_chunks=1):
|
|
| 1173 |
"search_type": settings_dict["search_type"],
|
| 1174 |
"top_k": settings_dict["top_k"],
|
| 1175 |
"apply_preprocessing": settings_dict["apply_preprocessing"],
|
| 1176 |
-
"optimize_vocab": settings_dict["
|
| 1177 |
-
"apply_phonetic": settings_dict["
|
| 1178 |
-
"phonetic_weight": 0.3 #
|
| 1179 |
}
|
| 1180 |
-
except:
|
|
|
|
| 1181 |
return {"error": "Failed to parse LLM suggestions"}
|
| 1182 |
|
|
|
|
| 1183 |
def update_inputs_with_llm_suggestions(suggestions):
|
| 1184 |
if suggestions is None or "error" in suggestions:
|
| 1185 |
return [gr.update() for _ in range(11)] # Return no updates if there's an error or None
|
|
|
|
| 105 |
|
| 106 |
# NLTK Resource Download
|
| 107 |
def download_nltk_resources():
|
| 108 |
+
resources = ['punkt', 'stopwords', 'snowball_data', 'wordnet']
|
| 109 |
for resource in resources:
|
| 110 |
try:
|
| 111 |
nltk.download(resource, quiet=False)
|
|
|
|
| 337 |
vector_store_type: str, # Added to match your signature
|
| 338 |
search_type: str, # Added to match your signature
|
| 339 |
top_k: int = 3,
|
| 340 |
+
use_gpu: bool = False
|
| 341 |
) -> str:
|
| 342 |
"""
|
| 343 |
CPU-optimized version of query expansion using a small language model.
|
|
|
|
| 354 |
|
| 355 |
Returns:
|
| 356 |
Expanded query string
|
| 357 |
+
"""
|
| 358 |
try:
|
| 359 |
# Set device
|
| 360 |
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
|
|
|
|
| 372 |
expanded_terms.update([lemma.name() for lemma in syn.lemmas()[:2]])
|
| 373 |
|
| 374 |
# 3. Use provided model with reduced complexity
|
| 375 |
+
try:
|
| 376 |
+
# Initialize the pipeline with the chosen model
|
| 377 |
+
llm_pipeline = pipeline(model="meta-llama/Llama-3.2-1B-Instruct", device='cpu')
|
| 378 |
+
|
| 379 |
+
# Define prompt for the assistant, making it context-specific
|
| 380 |
+
prompt = f'''
|
| 381 |
+
<|start_header_id|>system<|end_header_id|>
|
| 382 |
+
You are an expert in enhancing user input for vector store retrieval.
|
| 383 |
+
Enhance the followinf search query with relevant terms.
|
| 384 |
+
|
| 385 |
+
show me just the new term. You SHOULD NOT include any other text in the response.
|
| 386 |
+
|
| 387 |
+
<|eot_id|><|start_header_id|>user<|end_header_id|>
|
| 388 |
+
{query}
|
| 389 |
+
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
| 390 |
+
'''
|
| 391 |
+
|
| 392 |
+
# Get suggested settings from the LLM
|
| 393 |
+
suggested_settings = llm_pipeline(
|
| 394 |
+
prompt,
|
| 395 |
+
do_sample=True,
|
| 396 |
+
top_k=10,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
num_return_sequences=1,
|
| 398 |
+
return_full_text=False,
|
| 399 |
+
max_new_tokens=1900, # Control the length of the output
|
| 400 |
+
truncation=True # Enable truncation
|
| 401 |
)
|
| 402 |
|
| 403 |
+
# Extract the settings from the generated response
|
| 404 |
+
generated_text = suggested_settings[0].get('generated_text', '')
|
| 405 |
+
print(generated_text) # For debugging, ensure text output is as expected
|
| 406 |
|
| 407 |
+
except Exception as model_error:
|
| 408 |
+
print(f"LLM-based expansion failed: {str(model_error)}")
|
| 409 |
+
generated_text = "Default settings could not be generated." # Fallback message or settings
|
| 410 |
|
|
|
|
|
|
|
|
|
|
| 411 |
|
| 412 |
+
# 4. Combine original and expanded terms
|
| 413 |
+
final_terms = set(tokens)
|
| 414 |
+
final_terms.update(expanded_terms)
|
| 415 |
+
if generated_text != query:
|
| 416 |
+
final_terms.update(word_tokenize(generated_text.lower()))
|
| 417 |
|
| 418 |
# 5. Remove stopwords and select top_k most relevant terms
|
| 419 |
stopwords = set(['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'])
|
| 420 |
final_terms = [term for term in final_terms if term not in stopwords]
|
| 421 |
|
| 422 |
# Combine with original query
|
| 423 |
+
generated_text = f"{query} {' '.join(list(final_terms)[:top_k])}"
|
| 424 |
+
print(generated_text)
|
| 425 |
# Clean up
|
| 426 |
+
# llm_pipeline = None
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
+
return generated_text.strip() #[Document(page_content=generated_text.strip())]
|
| 429 |
|
| 430 |
except Exception as e:
|
| 431 |
print(f"Query optimization failed: {str(e)}")
|
|
|
|
| 1059 |
return recommendations
|
| 1060 |
|
| 1061 |
####
|
| 1062 |
+
import ast
|
| 1063 |
|
| 1064 |
def get_llm_suggested_settings(file, num_chunks=1):
|
| 1065 |
if not file:
|
|
|
|
| 1079 |
sample_chunks = random.sample(chunks, min(num_chunks, len(chunks)))
|
| 1080 |
|
| 1081 |
|
| 1082 |
+
llm_pipeline = pipeline(model="meta-llama/Llama-3.2-1B-Instruct", device='cpu')
|
| 1083 |
|
| 1084 |
|
| 1085 |
prompt=f'''
|
|
|
|
| 1142 |
max_new_tokens=1900, # Control the length of the output,
|
| 1143 |
truncation=True, # Enable truncation
|
| 1144 |
)
|
| 1145 |
+
|
| 1146 |
+
print(suggested_settings[0]['generated_text'])
|
| 1147 |
+
# Safely parse the generated text to extract the dictionary
|
|
|
|
|
|
|
|
|
|
| 1148 |
try:
|
| 1149 |
+
# Using ast.literal_eval for safe parsing
|
| 1150 |
+
settings_dict = ast.literal_eval(suggested_settings[0]['generated_text'])
|
| 1151 |
+
|
| 1152 |
# Convert the settings to match the interface inputs
|
| 1153 |
return {
|
| 1154 |
+
"embedding_models": settings_dict["embedding_models"],
|
| 1155 |
"split_strategy": settings_dict["split_strategy"],
|
| 1156 |
"chunk_size": settings_dict["chunk_size"],
|
| 1157 |
"overlap_size": settings_dict["overlap_size"],
|
|
|
|
| 1159 |
"search_type": settings_dict["search_type"],
|
| 1160 |
"top_k": settings_dict["top_k"],
|
| 1161 |
"apply_preprocessing": settings_dict["apply_preprocessing"],
|
| 1162 |
+
"optimize_vocab": settings_dict["optimize_vocab"],
|
| 1163 |
+
"apply_phonetic": settings_dict["apply_phonetic"],
|
| 1164 |
+
"phonetic_weight": settings_dict.get("phonetic_weight", 0.3) # Set default if not provided
|
| 1165 |
}
|
| 1166 |
+
except Exception as e:
|
| 1167 |
+
print(f"Error parsing LLM suggestions: {e}")
|
| 1168 |
return {"error": "Failed to parse LLM suggestions"}
|
| 1169 |
|
| 1170 |
+
|
| 1171 |
def update_inputs_with_llm_suggestions(suggestions):
|
| 1172 |
if suggestions is None or "error" in suggestions:
|
| 1173 |
return [gr.update() for _ in range(11)] # Return no updates if there's an error or None
|