Ashendilantha commited on
Commit
6eac30f
·
verified ·
1 Parent(s): d21afd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -72,16 +72,28 @@ def classify_csv(file):
72
  except Exception as e:
73
  return None, f"Error: {str(e)}"
74
 
75
- def chatbot_response(history, user_input, source):
76
  # Always use the user input as a question for the QA pipeline
77
  user_input = user_input.lower() # Optionally make it lowercase
78
 
79
- # Get the context from the source (single article or bulk content)
80
- context = context_storage["context"] if source == "Single Article" else context_storage["bulk_context"]
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if context:
83
  with st.spinner("Finding answer..."):
84
- # Pass the user's question and the content from the source (context) to the QA pipeline
85
  result = qa_pipeline(question=user_input, context=context)
86
  answer = result["answer"]
87
 
 
72
  except Exception as e:
73
  return None, f"Error: {str(e)}"
74
 
75
+ def chatbot_response(history, user_input, source, text_input=None, file_input=None):
76
  # Always use the user input as a question for the QA pipeline
77
  user_input = user_input.lower() # Optionally make it lowercase
78
 
79
+ # Determine context based on the source
80
+ context = ""
81
 
82
+ if source == "Single Article":
83
+ # If the source is single article, use the text_input (single article content)
84
+ context = text_input if text_input else ""
85
+
86
+ elif source == "Bulk Classification":
87
+ # If the source is bulk articles, use the file input (context for multiple articles)
88
+ if file_input:
89
+ # Assuming `classify_csv` method processes the file and sets the bulk context
90
+ df, _ = classify_csv(file_input) # This function should return the dataframe and context
91
+ context = context_storage["bulk_context"]
92
+
93
+ # Make sure there's context for QA pipeline to run
94
  if context:
95
  with st.spinner("Finding answer..."):
96
+ # Run QA pipeline with the user's question and context from either text_input or file_input
97
  result = qa_pipeline(question=user_input, context=context)
98
  answer = result["answer"]
99