mohamedachraf commited on
Commit
b88f075
·
1 Parent(s): aa52fc9

modify the pipeline

Browse files
Files changed (1) hide show
  1. app.py +16 -54
app.py CHANGED
@@ -31,10 +31,10 @@ import tempfile
31
  # Prompt template
32
  template = """Context: {context}
33
 
34
- Question: {query}
35
 
36
  Answer:"""
37
- QA_PROMPT = PromptTemplate(template=template, input_variables=["query", "context"])
38
 
39
 
40
  # Load Phi-2 model from hugging face hub
@@ -143,9 +143,7 @@ def generate(question, answer, text_file, max_new_tokens):
143
  return
144
 
145
  try:
146
- streamer = TextIteratorStreamer(
147
- tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=300.0
148
- )
149
  phi2_pipeline = pipeline(
150
  "text-generation",
151
  tokenizer=tokenizer,
@@ -157,7 +155,6 @@ def generate(question, answer, text_file, max_new_tokens):
157
  temperature=0.7,
158
  top_p=0.9,
159
  repetition_penalty=1.1,
160
- streamer=streamer,
161
  )
162
 
163
  hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
@@ -169,58 +166,23 @@ def generate(question, answer, text_file, max_new_tokens):
169
  yield "Your question is too long! Please shorten it."
170
  return
171
 
172
- # Run the chain in a separate thread
173
- result_container = {"result": None, "error": None}
174
-
175
- def run_chain():
176
- try:
177
- result_container["result"] = qa_chain.invoke({"query": query})
178
- except Exception as e:
179
- result_container["error"] = str(e)
180
-
181
- thread = Thread(target=run_chain)
182
- thread.start()
183
-
184
- # Stream the response
185
- response = ""
186
  try:
187
- for token in streamer:
188
- response += token
189
- # Clean up the response - stop at natural points
190
- cleaned_response = response.strip()
191
-
192
- # Stop if we hit repetitive patterns
193
- words = cleaned_response.split()
194
- if len(words) > 10:
195
- # Check for repetitive patterns
196
- last_words = words[-5:]
197
- if len(set(last_words)) <= 2: # Too much repetition
198
- break
199
 
200
- # Stop at sentence endings if we have enough content
201
- if len(cleaned_response) > 50 and cleaned_response.endswith(('.', '!', '?')):
202
- yield cleaned_response
203
- break
204
-
205
- yield cleaned_response
206
  except Exception as e:
207
- yield f"Error during streaming: {str(e)}"
208
  return
209
-
210
- # Wait for thread to complete
211
- thread.join()
212
-
213
- # Check for errors
214
- if result_container["error"]:
215
- yield f"Error: {result_container['error']}"
216
- return
217
-
218
- # Final cleanup of response
219
- final_response = clean_response(response.strip())
220
-
221
- # Yield the final cleaned response
222
- if final_response != response.strip():
223
- yield final_response
224
 
225
  except Exception as e:
226
  yield f"Error: {str(e)}"
 
31
  # Prompt template
32
  template = """Context: {context}
33
 
34
+ Question: {question}
35
 
36
  Answer:"""
37
+ QA_PROMPT = PromptTemplate(template=template, input_variables=["question", "context"])
38
 
39
 
40
  # Load Phi-2 model from hugging face hub
 
143
  return
144
 
145
  try:
146
+ # Create pipeline without streamer first to test
 
 
147
  phi2_pipeline = pipeline(
148
  "text-generation",
149
  tokenizer=tokenizer,
 
155
  temperature=0.7,
156
  top_p=0.9,
157
  repetition_penalty=1.1,
 
158
  )
159
 
160
  hf_model = HuggingFacePipeline(pipeline=phi2_pipeline)
 
166
  yield "Your question is too long! Please shorten it."
167
  return
168
 
169
+ # Get the response directly without streaming first
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  try:
171
+ result = qa_chain.invoke({"query": query})
172
+
173
+ # Extract the answer from the result
174
+ if isinstance(result, dict):
175
+ response = result.get('result', str(result))
176
+ else:
177
+ response = str(result)
 
 
 
 
 
178
 
179
+ # Clean the response
180
+ cleaned_response = clean_response(response)
181
+ yield cleaned_response
182
+
 
 
183
  except Exception as e:
184
+ yield f"Error during generation: {str(e)}"
185
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  except Exception as e:
188
  yield f"Error: {str(e)}"