mohamedachraf commited on
Commit
3e32b0f
·
1 Parent(s): eda06e0

Add application file

Browse files
Files changed (1) hide show
  1. app.py +74 -17
app.py CHANGED
@@ -29,23 +29,20 @@ import tempfile
29
 
30
 
31
  # Prompt template
32
- template = """Instruction:
33
- You are an AI assistant for answering questions about the provided context.
34
- You are given the following extracted parts of a long document and a question. Provide a detailed answer.
35
- If you don't know the answer, just say "Hmm, I'm not sure." Don't try to make up an answer.
36
- =======
37
- {context}
38
- =======
39
  Question: {question}
40
- Output:\n"""
 
41
 
42
  # Multi-query generation prompt
43
- multi_query_template = """You are an AI language model assistant. Your task is to generate 3
44
- different versions of the given user question to retrieve relevant documents from a vector
45
- database. By generating multiple perspectives on the user question, your goal is to help
46
- the user overcome some of the limitations of the distance-based similarity search.
47
- Provide these alternative questions separated by newlines.
48
- Original question: {question}"""
 
49
 
50
  QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
51
  MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variables=["question"])
@@ -54,6 +51,9 @@ MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variabl
54
  model_id = "microsoft/phi-2"
55
 
56
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 
 
 
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_id, torch_dtype=torch.float32, trust_remote_code=True
59
  )
@@ -66,6 +66,38 @@ embeddings = HuggingFaceEmbeddings(
66
  )
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  # Returns a faiss vector store retriever given a txt or pdf file
70
  def prepare_vector_store_retriever(filename):
71
  # Load data based on file extension
@@ -208,6 +240,10 @@ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store
208
  max_new_tokens=max_new_tokens,
209
  pad_token_id=tokenizer.eos_token_id,
210
  eos_token_id=tokenizer.eos_token_id,
 
 
 
 
211
  streamer=streamer,
212
  )
213
 
@@ -245,7 +281,23 @@ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store
245
  try:
246
  for token in streamer:
247
  response += token
248
- yield response.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  except Exception as e:
250
  yield f"Error during streaming: {str(e)}"
251
  return
@@ -259,8 +311,13 @@ def generate(question, answer, text_file, max_new_tokens, use_multi_query, store
259
  return
260
 
261
  # Store Q&A pair if requested and response is valid
262
- final_response = response.strip()
263
- if store_qa and final_response and "Error" not in final_response and len(final_response) > 0:
 
 
 
 
 
264
  try:
265
  store_qa_pair(question, final_response, vectorstore)
266
  except Exception as e:
 
29
 
30
 
31
  # Prompt template
32
+ template = """Context: {context}
33
+
 
 
 
 
 
34
  Question: {question}
35
+
36
+ Answer: Based on the provided context, """
37
 
38
  # Multi-query generation prompt
39
+ multi_query_template = """Generate 3 different ways to ask this question:
40
+
41
+ Original: {question}
42
+
43
+ Alternative 1:
44
+ Alternative 2:
45
+ Alternative 3:"""
46
 
47
  QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
48
  MULTI_QUERY_PROMPT = PromptTemplate(template=multi_query_template, input_variables=["question"])
 
51
  model_id = "microsoft/phi-2"
52
 
53
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
54
+ if tokenizer.pad_token is None:
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_id, torch_dtype=torch.float32, trust_remote_code=True
59
  )
 
66
  )
67
 
68
 
69
+ def clean_response(text):
70
+ """Clean up the generated response"""
71
+ # Remove excessive whitespace and newlines
72
+ text = ' '.join(text.split())
73
+
74
+ # Remove repetitive patterns
75
+ words = text.split()
76
+ cleaned_words = []
77
+
78
+ for word in words:
79
+ # Skip if the same word appears too many times consecutively
80
+ if len(cleaned_words) >= 3 and all(w == word for w in cleaned_words[-3:]):
81
+ continue
82
+ cleaned_words.append(word)
83
+
84
+ cleaned_text = ' '.join(cleaned_words)
85
+
86
+ # Truncate at natural stopping points
87
+ sentences = cleaned_text.split('.')
88
+ if len(sentences) > 1:
89
+ # Keep complete sentences
90
+ good_sentences = []
91
+ for sentence in sentences[:-1]: # Exclude last potentially incomplete sentence
92
+ if len(sentence.strip()) > 5: # Avoid very short fragments
93
+ good_sentences.append(sentence.strip())
94
+
95
+ if good_sentences:
96
+ return '. '.join(good_sentences) + '.'
97
+
98
+ return cleaned_text[:500] # Fallback: truncate to reasonable length
99
+
100
+
101
  # Returns a faiss vector store retriever given a txt or pdf file
102
  def prepare_vector_store_retriever(filename):
103
  # Load data based on file extension
 
240
  max_new_tokens=max_new_tokens,
241
  pad_token_id=tokenizer.eos_token_id,
242
  eos_token_id=tokenizer.eos_token_id,
243
+ do_sample=True,
244
+ temperature=0.7,
245
+ top_p=0.9,
246
+ repetition_penalty=1.1,
247
  streamer=streamer,
248
  )
249
 
 
281
  try:
282
  for token in streamer:
283
  response += token
284
+ # Clean up the response - stop at natural points
285
+ cleaned_response = response.strip()
286
+
287
+ # Stop if we hit repetitive patterns
288
+ words = cleaned_response.split()
289
+ if len(words) > 10:
290
+ # Check for repetitive patterns
291
+ last_words = words[-5:]
292
+ if len(set(last_words)) <= 2: # Too much repetition
293
+ break
294
+
295
+ # Stop at sentence endings if we have enough content
296
+ if len(cleaned_response) > 50 and cleaned_response.endswith(('.', '!', '?')):
297
+ yield cleaned_response
298
+ break
299
+
300
+ yield cleaned_response
301
  except Exception as e:
302
  yield f"Error during streaming: {str(e)}"
303
  return
 
311
  return
312
 
313
  # Store Q&A pair if requested and response is valid
314
+ final_response = clean_response(response.strip())
315
+
316
+ # Yield the final cleaned response
317
+ if final_response != response.strip():
318
+ yield final_response
319
+
320
+ if store_qa and final_response and "Error" not in final_response and len(final_response) > 10:
321
  try:
322
  store_qa_pair(question, final_response, vectorstore)
323
  except Exception as e: