Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -34,28 +34,28 @@ nltk.download("punkt", quiet=True)
|
|
| 34 |
|
| 35 |
# Zero-Shot Classification Model (Topic Detection)
|
| 36 |
ZSC_MODEL_NAME = "facebook/bart-large-mnli"
|
| 37 |
-
zsc_tokenizer = AutoTokenizer.from_pretrained(ZSC_MODEL_NAME)
|
| 38 |
-
zsc_model = AutoModelForSequenceClassification.from_pretrained(ZSC_MODEL_NAME)
|
| 39 |
zero_shot_classifier = pipeline("zero-shot-classification", model=zsc_model, tokenizer=zsc_tokenizer)
|
| 40 |
|
| 41 |
# Summarization Model (Chunk-based Summaries)
|
| 42 |
SUM_MODEL_NAME = "facebook/bart-large-cnn"
|
| 43 |
-
sum_tokenizer = AutoTokenizer.from_pretrained(SUM_MODEL_NAME)
|
| 44 |
-
sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUM_MODEL_NAME)
|
| 45 |
summarizer = pipeline("summarization", model=sum_model, tokenizer=sum_tokenizer)
|
| 46 |
|
| 47 |
# QA Model (Chunk-based QA)
|
| 48 |
QA_MODEL_NAME = "deepset/roberta-base-squad2"
|
| 49 |
-
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME)
|
| 50 |
-
qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME)
|
| 51 |
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
|
| 52 |
|
| 53 |
# Speech-to-Text (STT) with tiny Whisper
|
| 54 |
WHISPER_MODEL_NAME = "openai/whisper-tiny"
|
| 55 |
-
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME)
|
| 56 |
-
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME)
|
| 57 |
|
| 58 |
-
# For real-time token usage,
|
| 59 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 60 |
|
| 61 |
###############################################################################
|
|
@@ -304,17 +304,19 @@ def answer_question(session_id, question):
|
|
| 304 |
def chat(user_input, chat_history, session_id):
|
| 305 |
if session_id not in SESSIONS:
|
| 306 |
SESSIONS[session_id] = {"files": {}, "chat_history": []}
|
|
|
|
| 307 |
if user_input.lower().startswith("ref:"):
|
| 308 |
query = user_input[4:].strip()
|
| 309 |
result = find_reference(session_id, query)
|
| 310 |
-
chat_history.append(
|
| 311 |
return "", chat_history
|
|
|
|
| 312 |
answer = answer_question(session_id, user_input)
|
| 313 |
question_tokens = approximate_tokens(user_input)
|
| 314 |
answer_tokens = approximate_tokens(answer)
|
| 315 |
usage_str = f"Tokens: Q={question_tokens}, A={answer_tokens}, Total={question_tokens + answer_tokens}"
|
| 316 |
full_answer = f"{answer}\n\n({usage_str})"
|
| 317 |
-
chat_history.append(
|
| 318 |
return "", chat_history
|
| 319 |
|
| 320 |
###############################################################################
|
|
@@ -368,27 +370,27 @@ with gr.Blocks() as demo:
|
|
| 368 |
new_session_btn.click(fn=reset_session, outputs=[session_id, new_session_out])
|
| 369 |
|
| 370 |
gr.Markdown("### 2. Voice Input (STT Only)")
|
| 371 |
-
# Removed the 'source' parameter because it is not supported in this version.
|
| 372 |
audio_in = gr.Audio(type="filepath", label="Speak your question")
|
| 373 |
stt_btn = gr.Button("Transcribe")
|
| 374 |
stt_output = gr.Textbox(label="Transcribed Text")
|
| 375 |
stt_btn.click(fn=transcribe_audio, inputs=[audio_in], outputs=[stt_output])
|
| 376 |
|
| 377 |
gr.Markdown("### 3. Chat / Q&A (Enter text below)")
|
|
|
|
| 378 |
chatbot = gr.Chatbot(label="Chat History", type="messages")
|
| 379 |
user_input = gr.Textbox(label="Your question (or 'ref: <term>' for reference search)", lines=2)
|
| 380 |
send_btn = gr.Button("Send")
|
| 381 |
|
| 382 |
def user_message(user_msg, history):
|
| 383 |
-
history = history + [
|
| 384 |
return "", history
|
| 385 |
send_btn.click(fn=user_message, inputs=[user_input, chatbot], outputs=[user_input, chatbot], queue=False)
|
| 386 |
|
| 387 |
def bot_message(history, sid):
|
| 388 |
-
# Check if history is empty
|
| 389 |
if not history:
|
| 390 |
return []
|
| 391 |
-
|
|
|
|
| 392 |
_, updated_history = chat(user_msg, history, sid)
|
| 393 |
return updated_history
|
| 394 |
send_btn.click(fn=bot_message, inputs=[chatbot, session_id], outputs=[chatbot])
|
|
|
|
| 34 |
|
| 35 |
# Zero-Shot Classification Model (Topic Detection)
|
| 36 |
ZSC_MODEL_NAME = "facebook/bart-large-mnli"
|
| 37 |
+
zsc_tokenizer = AutoTokenizer.from_pretrained(ZSC_MODEL_NAME, force_download=True)
|
| 38 |
+
zsc_model = AutoModelForSequenceClassification.from_pretrained(ZSC_MODEL_NAME, force_download=True)
|
| 39 |
zero_shot_classifier = pipeline("zero-shot-classification", model=zsc_model, tokenizer=zsc_tokenizer)
|
| 40 |
|
| 41 |
# Summarization Model (Chunk-based Summaries)
|
| 42 |
SUM_MODEL_NAME = "facebook/bart-large-cnn"
|
| 43 |
+
sum_tokenizer = AutoTokenizer.from_pretrained(SUM_MODEL_NAME, force_download=True)
|
| 44 |
+
sum_model = AutoModelForSeq2SeqLM.from_pretrained(SUM_MODEL_NAME, force_download=True)
|
| 45 |
summarizer = pipeline("summarization", model=sum_model, tokenizer=sum_tokenizer)
|
| 46 |
|
| 47 |
# QA Model (Chunk-based QA)
|
| 48 |
QA_MODEL_NAME = "deepset/roberta-base-squad2"
|
| 49 |
+
qa_tokenizer = AutoTokenizer.from_pretrained(QA_MODEL_NAME, force_download=True)
|
| 50 |
+
qa_model = AutoModelForQuestionAnswering.from_pretrained(QA_MODEL_NAME, force_download=True)
|
| 51 |
qa_pipeline = pipeline("question-answering", model=qa_model, tokenizer=qa_tokenizer)
|
| 52 |
|
| 53 |
# Speech-to-Text (STT) with tiny Whisper
|
| 54 |
WHISPER_MODEL_NAME = "openai/whisper-tiny"
|
| 55 |
+
whisper_processor = WhisperProcessor.from_pretrained(WHISPER_MODEL_NAME, force_download=True)
|
| 56 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained(WHISPER_MODEL_NAME, force_download=True)
|
| 57 |
|
| 58 |
+
# For real-time token usage, we'll use tiktoken (GPT-3.5 style tokenizer)
|
| 59 |
encoding = tiktoken.get_encoding("cl100k_base")
|
| 60 |
|
| 61 |
###############################################################################
|
|
|
|
| 304 |
def chat(user_input, chat_history, session_id):
|
| 305 |
if session_id not in SESSIONS:
|
| 306 |
SESSIONS[session_id] = {"files": {}, "chat_history": []}
|
| 307 |
+
# If the user wants to search for a reference:
|
| 308 |
if user_input.lower().startswith("ref:"):
|
| 309 |
query = user_input[4:].strip()
|
| 310 |
result = find_reference(session_id, query)
|
| 311 |
+
chat_history.append({"role": "assistant", "content": result})
|
| 312 |
return "", chat_history
|
| 313 |
+
# Process the question using QA:
|
| 314 |
answer = answer_question(session_id, user_input)
|
| 315 |
question_tokens = approximate_tokens(user_input)
|
| 316 |
answer_tokens = approximate_tokens(answer)
|
| 317 |
usage_str = f"Tokens: Q={question_tokens}, A={answer_tokens}, Total={question_tokens + answer_tokens}"
|
| 318 |
full_answer = f"{answer}\n\n({usage_str})"
|
| 319 |
+
chat_history.append({"role": "assistant", "content": full_answer})
|
| 320 |
return "", chat_history
|
| 321 |
|
| 322 |
###############################################################################
|
|
|
|
| 370 |
new_session_btn.click(fn=reset_session, outputs=[session_id, new_session_out])
|
| 371 |
|
| 372 |
gr.Markdown("### 2. Voice Input (STT Only)")
|
|
|
|
| 373 |
audio_in = gr.Audio(type="filepath", label="Speak your question")
|
| 374 |
stt_btn = gr.Button("Transcribe")
|
| 375 |
stt_output = gr.Textbox(label="Transcribed Text")
|
| 376 |
stt_btn.click(fn=transcribe_audio, inputs=[audio_in], outputs=[stt_output])
|
| 377 |
|
| 378 |
gr.Markdown("### 3. Chat / Q&A (Enter text below)")
|
| 379 |
+
# Set type="messages" for openai-style chat messages
|
| 380 |
chatbot = gr.Chatbot(label="Chat History", type="messages")
|
| 381 |
user_input = gr.Textbox(label="Your question (or 'ref: <term>' for reference search)", lines=2)
|
| 382 |
send_btn = gr.Button("Send")
|
| 383 |
|
| 384 |
def user_message(user_msg, history):
|
| 385 |
+
history = history + [{"role": "user", "content": user_msg}]
|
| 386 |
return "", history
|
| 387 |
send_btn.click(fn=user_message, inputs=[user_input, chatbot], outputs=[user_input, chatbot], queue=False)
|
| 388 |
|
| 389 |
def bot_message(history, sid):
|
|
|
|
| 390 |
if not history:
|
| 391 |
return []
|
| 392 |
+
# The most recent message should be from the user.
|
| 393 |
+
user_msg = history[-1]["content"]
|
| 394 |
_, updated_history = chat(user_msg, history, sid)
|
| 395 |
return updated_history
|
| 396 |
send_btn.click(fn=bot_message, inputs=[chatbot, session_id], outputs=[chatbot])
|