|
|
import os |
|
|
import torch |
|
|
import whisper |
|
|
import PyPDF2 |
|
|
from transformers import BertTokenizerFast, BertForQuestionAnswering, pipeline |
|
|
from torch.nn.functional import softmax |
|
|
from docx import Document |
|
|
import streamlit as st |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_models(): |
|
|
qa_model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2").to(device) |
|
|
tokenizer = BertTokenizerFast.from_pretrained("deepset/bert-base-cased-squad2") |
|
|
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") |
|
|
whisper_model = whisper.load_model("base") |
|
|
return qa_model, tokenizer, summarizer, whisper_model |
|
|
|
|
|
qa_model, tokenizer, summarizer, whisper_model = load_models() |
|
|
|
|
|
def extract_text(file_obj): |
|
|
ext = os.path.splitext(file_obj.name)[1].lower() |
|
|
if ext == ".pdf": |
|
|
reader = PyPDF2.PdfReader(file_obj) |
|
|
return "\n".join([p.extract_text() for p in reader.pages if p.extract_text()]) |
|
|
elif ext == ".docx": |
|
|
doc = Document(file_obj) |
|
|
return "\n".join([p.text for p in doc.paragraphs]) |
|
|
elif ext == ".txt": |
|
|
return file_obj.getvalue().decode("utf-8") |
|
|
return "" |
|
|
|
|
|
def summarize_text(text): |
|
|
if len(text) < 50: |
|
|
return "Text too short to summarize." |
|
|
if len(text) > 1000: |
|
|
text = text[:1000] |
|
|
summary = summarizer(text, max_length=120, min_length=30, do_sample=False) |
|
|
return summary[0]['summary_text'] |
|
|
|
|
|
def ask_question(question, context): |
|
|
inputs = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True, max_length=512).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = qa_model(**inputs) |
|
|
start_idx = torch.argmax(outputs.start_logits) |
|
|
end_idx = torch.argmax(outputs.end_logits) + 1 |
|
|
score = softmax(outputs.start_logits, dim=1)[0][start_idx] * softmax(outputs.end_logits, dim=1)[0][end_idx - 1] |
|
|
answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx]) |
|
|
return f"Answer: {answer.strip()}\nConfidence: {round(score.item()*100, 2)}%" |
|
|
|
|
|
def transcribe(audio_path): |
|
|
result = whisper_model.transcribe(audio_path) |
|
|
return result["text"] |
|
|
|
|
|
st.title("ποΈπ LexPilot: Voice + Document Q&A Assistant") |
|
|
st.write("Upload a document or paste content. Ask questions by typing or speaking.") |
|
|
|
|
|
tab = st.tabs(["Question Answering", "Summarization"]) |
|
|
|
|
|
with tab[0]: |
|
|
uploaded_file = st.file_uploader("Upload .pdf / .docx / .txt", type=["pdf", "docx", "txt"]) |
|
|
pasted_text = st.text_area("Or paste text manually", height=150) |
|
|
|
|
|
typed_question = st.text_input("Type your question") |
|
|
audio_input = st.file_uploader("Or upload audio file (wav, mp3, m4a)", type=["wav", "mp3", "m4a"]) |
|
|
|
|
|
if st.button("Get Answer"): |
|
|
context = "" |
|
|
if uploaded_file: |
|
|
context = extract_text(uploaded_file) |
|
|
elif pasted_text.strip(): |
|
|
context = pasted_text.strip() |
|
|
else: |
|
|
st.warning("β Please upload or paste content.") |
|
|
st.stop() |
|
|
|
|
|
if typed_question.strip(): |
|
|
question = typed_question.strip() |
|
|
elif audio_input: |
|
|
|
|
|
with open("temp_audio", "wb") as f: |
|
|
f.write(audio_input.getbuffer()) |
|
|
question = transcribe("temp_audio") |
|
|
st.write(f"Transcribed question: {question}") |
|
|
else: |
|
|
st.warning("β Please type or upload an audio question.") |
|
|
st.stop() |
|
|
|
|
|
answer = ask_question(question, context) |
|
|
st.text_area("Answer and Confidence", value=answer, height=100) |
|
|
|
|
|
with tab[1]: |
|
|
sum_file = st.file_uploader("Upload .pdf / .docx / .txt to summarize", type=["pdf", "docx", "txt"]) |
|
|
sum_text = st.text_area("Or paste content to summarize", height=150) |
|
|
|
|
|
if st.button("Summarize"): |
|
|
context = "" |
|
|
if sum_file: |
|
|
context = extract_text(sum_file) |
|
|
elif sum_text.strip(): |
|
|
context = sum_text.strip() |
|
|
else: |
|
|
st.warning("β Please upload or paste content to summarize.") |
|
|
st.stop() |
|
|
|
|
|
summary = summarize_text(context) |
|
|
st.text_area("Summary", value=summary, height=150) |
|
|
|