Spaces:
Sleeping
Sleeping
Commit
·
3e32b0f
1
Parent(s):
eda06e0
Add application file
Browse files
app.py
CHANGED
|
@@ -29,23 +29,20 @@ import tempfile
|
|
| 29 |
|
| 30 |
|
| 31 |
# Prompt template
|
| 32 |
-
template = """
|
| 33 |
-
|
| 34 |
-
You are given the following extracted parts of a long document and a question. Provide a detailed answer.
|
| 35 |
-
If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
|
| 36 |
-
=======
|
| 37 |
-
{context}
|
| 38 |
-
=======
|
| 39 |
Question: {question}
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
# Multi-query generation prompt
|
| 43 |
-
multi_query_template = """
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
| 49 |
|
| 50 |
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
|
| 51 |
MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variables=["question"])
|
|
@@ -54,6 +51,9 @@ MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variabl
|
|
| 54 |
model_id = "microsoft/phi-2"
|
| 55 |
|
| 56 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
| 57 |
model = AutoModelForCausalLM.from_pretrained(
|
| 58 |
model_id, torch_dtype=torch.float32, trust_remote_code=True
|
| 59 |
)
|
|
@@ -66,6 +66,38 @@ embeddings = HuggingFaceEmbeddings(
|
|
| 66 |
)
|
| 67 |
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# Returns a faiss vector store retriever given a txt or pdf file
|
| 70 |
def prepare_vector_store_retriever(filename):
|
| 71 |
# Load data based on file extension
|
|
@@ -208,6 +240,10 @@ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store
|
|
| 208 |
max_new_tokens=max_new_tokens,
|
| 209 |
pad_token_id=tokenizer.eos_token_id,
|
| 210 |
eos_token_id=tokenizer.eos_token_id,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
streamer=streamer,
|
| 212 |
)
|
| 213 |
|
|
@@ -245,7 +281,23 @@ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store
|
|
| 245 |
try:
|
| 246 |
for token in streamer:
|
| 247 |
response += token
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
except Exception as e:
|
| 250 |
yield f"Error during streaming: {str(e)}"
|
| 251 |
return
|
|
@@ -259,8 +311,13 @@ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store
|
|
| 259 |
return
|
| 260 |
|
| 261 |
# Store Q&A pair if requested and response is valid
|
| 262 |
-
final_response = response.strip()
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
try:
|
| 265 |
store_qa_pair(question, final_response, vectorstore)
|
| 266 |
except Exception as e:
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
# Prompt template
|
| 32 |
+
template = """Context: {context}
|
| 33 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
Question: {question}
|
| 35 |
+
|
| 36 |
+
Answer: Based on the provided context, """
|
| 37 |
|
| 38 |
# Multi-query generation prompt
|
| 39 |
+
multi_query_template = """Generate 3 different ways to ask this question:
|
| 40 |
+
|
| 41 |
+
Original: {question}
|
| 42 |
+
|
| 43 |
+
Alternative 1:
|
| 44 |
+
Alternative 2:
|
| 45 |
+
Alternative 3:"""
|
| 46 |
|
| 47 |
QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
|
| 48 |
MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variables=["question"])
|
|
|
|
| 51 |
model_id = "microsoft/phi-2"
|
| 52 |
|
| 53 |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 54 |
+
if tokenizer.pad_token is None:
|
| 55 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 56 |
+
|
| 57 |
model = AutoModelForCausalLM.from_pretrained(
|
| 58 |
model_id, torch_dtype=torch.float32, trust_remote_code=True
|
| 59 |
)
|
|
|
|
| 66 |
)
|
| 67 |
|
| 68 |
|
| 69 |
+
def clean_response(text):
|
| 70 |
+
"""Clean up the generated response"""
|
| 71 |
+
# Remove excessive whitespace and newlines
|
| 72 |
+
text = ' '.join(text.split())
|
| 73 |
+
|
| 74 |
+
# Remove repetitive patterns
|
| 75 |
+
words = text.split()
|
| 76 |
+
cleaned_words = []
|
| 77 |
+
|
| 78 |
+
for word in words:
|
| 79 |
+
# Skip if the same word appears too many times consecutively
|
| 80 |
+
if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]):
|
| 81 |
+
continue
|
| 82 |
+
cleaned_words.append(word)
|
| 83 |
+
|
| 84 |
+
cleaned_text = ' '.join(cleaned_words)
|
| 85 |
+
|
| 86 |
+
# Truncate at natural stopping points
|
| 87 |
+
sentences = cleaned_text.split('.')
|
| 88 |
+
if len(sentences) > 1:
|
| 89 |
+
# Keep complete sentences
|
| 90 |
+
good_sentences = []
|
| 91 |
+
for sentence in sentences[:-1]: # Exclude last potentially incomplete sentence
|
| 92 |
+
if len(sentence.strip()) > 5: # Avoid very short fragments
|
| 93 |
+
good_sentences.append(sentence.strip())
|
| 94 |
+
|
| 95 |
+
if good_sentences:
|
| 96 |
+
return '. '.join(good_sentences) + '.'
|
| 97 |
+
|
| 98 |
+
return cleaned_text[:500] # Fallback: truncate to reasonable length
|
| 99 |
+
|
| 100 |
+
|
| 101 |
# Returns a faiss vector store retriever given a txt or pdf file
|
| 102 |
def prepare_vector_store_retriever(filename):
|
| 103 |
# Load data based on file extension
|
|
|
|
| 240 |
max_new_tokens=max_new_tokens,
|
| 241 |
pad_token_id=tokenizer.eos_token_id,
|
| 242 |
eos_token_id=tokenizer.eos_token_id,
|
| 243 |
+
do_sample=True,
|
| 244 |
+
temperature=0.7,
|
| 245 |
+
top_p=0.9,
|
| 246 |
+
repetition_penalty=1.1,
|
| 247 |
streamer=streamer,
|
| 248 |
)
|
| 249 |
|
|
|
|
| 281 |
try:
|
| 282 |
for token in streamer:
|
| 283 |
response += token
|
| 284 |
+
# Clean up the response - stop at natural points
|
| 285 |
+
cleaned_response = response.strip()
|
| 286 |
+
|
| 287 |
+
# Stop if we hit repetitive patterns
|
| 288 |
+
words = cleaned_response.split()
|
| 289 |
+
if len(words) > 10:
|
| 290 |
+
# Check for repetitive patterns
|
| 291 |
+
last_words = words[-5:]
|
| 292 |
+
if len(set(last_words)) <= 2: # Too much repetition
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
# Stop at sentence endings if we have enough content
|
| 296 |
+
if len(cleaned_response) > 50 and cleaned_response.endswith(('.', '!', '?')):
|
| 297 |
+
yield cleaned_response
|
| 298 |
+
break
|
| 299 |
+
|
| 300 |
+
yield cleaned_response
|
| 301 |
except Exception as e:
|
| 302 |
yield f"Error during streaming: {str(e)}"
|
| 303 |
return
|
|
|
|
| 311 |
return
|
| 312 |
|
| 313 |
# Store Q&A pair if requested and response is valid
|
| 314 |
+
final_response = clean_response(response.strip())
|
| 315 |
+
|
| 316 |
+
# Yield the final cleaned response
|
| 317 |
+
if final_response != response.strip():
|
| 318 |
+
yield final_response
|
| 319 |
+
|
| 320 |
+
if store_qa and final_response and "Error" not in final_response and len(final_response) > 10:
|
| 321 |
try:
|
| 322 |
store_qa_pair(question, final_response, vectorstore)
|
| 323 |
except Exception as e:
|