admin08077 commited on
Commit
6bfc829
·
verified ·
1 Parent(s): 0ca4cd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -15
app.py CHANGED
@@ -34,28 +34,28 @@ nltk.download("punkt", quiet=True)
34
 
35
  # Zero-Shot Classification Model (Topic Detection)
36
  ZSC_MODEL_NAME = "facebook/bart-large-mnli"
37
- zsc_tokenizer = AutoTokenizer.from_pretrained(ZSC_MODEL_NAME)
38
- zsc_model = AutoModelForSequenceClassification.from_pretrained(ZSC_MODEL_NAME)
39
  zero_shot_classifier = pipeline("zero-shot-classification", model=zsc_model, tokenizer=zsc_tokenizer)
40
 
41
  # Summarization Model (Chunk-based Summaries)
42
  SUM_MODEL_NAME = "facebook/bart-large-cnn"
43
- sum_tokenizer = AutoTokenizer.from_pretrained(SUM_MODEL_NAME)
44
- sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUM_MODEL_NAME)
45
  summarizer = pipeline("summarization", model=sum_model, tokenizer=sum_tokenizer)
46
 
47
  # QA Model (Chunk-based QA)
48
  QA_MODEL_NAME = "deepset/roberta-base-squad2"
49
- qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
50
- qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME)
51
  qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
52
 
53
  # Speech-to-Text (STT) with tiny Whisper
54
  WHISPER_MODEL_NAME = "openai/whisper-tiny"
55
- whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
56
- whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME)
57
 
58
- # For real-time token usage, let's use tiktoken (GPT-3.5 style tokenizer)
59
  encoding = tiktoken.get_encoding("cl100k_base")
60
 
61
  ###############################################################################
@@ -304,17 +304,19 @@ def answer_question(session_id, question):
304
  def chat(user_input, chat_history, session_id):
305
  if session_id not in SESSIONS:
306
  SESSIONS[session_id] = {"files": {}, "chat_history": []}
 
307
  if user_input.lower().startswith("ref:"):
308
  query = user_input[4:].strip()
309
  result = find_reference(session_id, query)
310
- chat_history.append((user_input, result))
311
  return "", chat_history
 
312
  answer = answer_question(session_id, user_input)
313
  question_tokens = approximate_tokens(user_input)
314
  answer_tokens = approximate_tokens(answer)
315
  usage_str = f"Tokens: Q={question_tokens}, A={answer_tokens}, Total={question_tokens + answer_tokens}"
316
  full_answer = f"{answer}\n\n({usage_str})"
317
- chat_history.append((user_input, full_answer))
318
  return "", chat_history
319
 
320
  ###############################################################################
@@ -368,27 +370,27 @@ with gr.Blocks() as demo:
368
  new_session_btn.click(fn=reset_session, outputs=[session_id, new_session_out])
369
 
370
  gr.Markdown("### 2. Voice Input (STT Only)")
371
- # Removed the 'source' parameter because it is not supported in this version.
372
  audio_in = gr.Audio(type="filepath", label="Speak your question")
373
  stt_btn = gr.Button("Transcribe")
374
  stt_output = gr.Textbox(label="Transcribed Text")
375
  stt_btn.click(fn=transcribe_audio, inputs=[audio_in], outputs=[stt_output])
376
 
377
  gr.Markdown("### 3. Chat / Q&A (Enter text below)")
 
378
  chatbot = gr.Chatbot(label="Chat History", type="messages")
379
  user_input = gr.Textbox(label="Your question (or 'ref: <term>' for reference search)", lines=2)
380
  send_btn = gr.Button("Send")
381
 
382
  def user_message(user_msg, history):
383
- history = history + [[user_msg, None]]
384
  return "", history
385
  send_btn.click(fn=user_message, inputs=[user_input, chatbot], outputs=[user_input, chatbot], queue=False)
386
 
387
  def bot_message(history, sid):
388
- # Check if history is empty
389
  if not history:
390
  return []
391
- user_msg = history[-1][0]
 
392
  _, updated_history = chat(user_msg, history, sid)
393
  return updated_history
394
  send_btn.click(fn=bot_message, inputs=[chatbot, session_id], outputs=[chatbot])
 
34
 
35
  # Zero-Shot Classification Model (Topic Detection)
36
  ZSC_MODEL_NAME = "facebook/bart-large-mnli"
37
+ zsc_tokenizer = AutoTokenizer.from_pretrained(ZSC_MODEL_NAME, force_download=True)
38
+ zsc_model = AutoModelForSequenceClassification.from_pretrained(ZSC_MODEL_NAME, force_download=True)
39
  zero_shot_classifier = pipeline("zero-shot-classification", model=zsc_model, tokenizer=zsc_tokenizer)
40
 
41
  # Summarization Model (Chunk-based Summaries)
42
  SUM_MODEL_NAME = "facebook/bart-large-cnn"
43
+ sum_tokenizer = AutoTokenizer.from_pretrained(SUM_MODEL_NAME, force_download=True)
44
+ sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUM_MODEL_NAME, force_download=True)
45
  summarizer = pipeline("summarization", model=sum_model, tokenizer=sum_tokenizer)
46
 
47
  # QA Model (Chunk-based QA)
48
  QA_MODEL_NAME = "deepset/roberta-base-squad2"
49
+ qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME, force_download=True)
50
+ qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME, force_download=True)
51
  qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
52
 
53
  # Speech-to-Text (STT) with tiny Whisper
54
  WHISPER_MODEL_NAME = "openai/whisper-tiny"
55
+ whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME, force_download=True)
56
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME, force_download=True)
57
 
58
+ # For real-time token usage, we'll use tiktoken (GPT-3.5 style tokenizer)
59
  encoding = tiktoken.get_encoding("cl100k_base")
60
 
61
  ###############################################################################
 
304
  def chat(user_input, chat_history, session_id):
305
  if session_id not in SESSIONS:
306
  SESSIONS[session_id] = {"files": {}, "chat_history": []}
307
+ # If the user wants to search for a reference:
308
  if user_input.lower().startswith("ref:"):
309
  query = user_input[4:].strip()
310
  result = find_reference(session_id, query)
311
+ chat_history.append({"role": "assistant", "content": result})
312
  return "", chat_history
313
+ # Process the question using QA:
314
  answer = answer_question(session_id, user_input)
315
  question_tokens = approximate_tokens(user_input)
316
  answer_tokens = approximate_tokens(answer)
317
  usage_str = f"Tokens: Q={question_tokens}, A={answer_tokens}, Total={question_tokens + answer_tokens}"
318
  full_answer = f"{answer}\n\n({usage_str})"
319
+ chat_history.append({"role": "assistant", "content": full_answer})
320
  return "", chat_history
321
 
322
  ###############################################################################
 
370
  new_session_btn.click(fn=reset_session, outputs=[session_id, new_session_out])
371
 
372
  gr.Markdown("### 2. Voice Input (STT Only)")
 
373
  audio_in = gr.Audio(type="filepath", label="Speak your question")
374
  stt_btn = gr.Button("Transcribe")
375
  stt_output = gr.Textbox(label="Transcribed Text")
376
  stt_btn.click(fn=transcribe_audio, inputs=[audio_in], outputs=[stt_output])
377
 
378
  gr.Markdown("### 3. Chat / Q&A (Enter text below)")
379
+ # Set type="messages" for openai-style chat messages
380
  chatbot = gr.Chatbot(label="Chat History", type="messages")
381
  user_input = gr.Textbox(label="Your question (or 'ref: <term>' for reference search)", lines=2)
382
  send_btn = gr.Button("Send")
383
 
384
  def user_message(user_msg, history):
385
+ history = history + [{"role": "user", "content": user_msg}]
386
  return "", history
387
  send_btn.click(fn=user_message, inputs=[user_input, chatbot], outputs=[user_input, chatbot], queue=False)
388
 
389
  def bot_message(history, sid):
 
390
  if not history:
391
  return []
392
+ # The most recent message should be from the user.
393
+ user_msg = history[-1]["content"]
394
  _, updated_history = chat(user_msg, history, sid)
395
  return updated_history
396
  send_btn.click(fn=bot_message, inputs=[chatbot, session_id], outputs=[chatbot])