champ / app.py
amritn8's picture
Update app.py
e1982cd verified
import os
import torch
import whisper
import PyPDF2
from transformers import BertTokenizerFast, BertForQuestionAnswering, pipeline
from torch.nn.functional import softmax
from docx import Document
import streamlit as st
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models once
@st.cache_resource
def load_models():
qa_model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2").to(device)
tokenizer = BertTokenizerFast.from_pretrained("deepset/bert-base-cased-squad2")
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
whisper_model = whisper.load_model("base")
return qa_model, tokenizer, summarizer, whisper_model
qa_model, tokenizer, summarizer, whisper_model = load_models()
def extract_text(file_obj):
ext = os.path.splitext(file_obj.name)[1].lower()
if ext == ".pdf":
reader = PyPDF2.PdfReader(file_obj)
return "\n".join([p.extract_text() for p in reader.pages if p.extract_text()])
elif ext == ".docx":
doc = Document(file_obj)
return "\n".join([p.text for p in doc.paragraphs])
elif ext == ".txt":
return file_obj.getvalue().decode("utf-8")
return ""
def summarize_text(text):
if len(text) < 50:
return "Text too short to summarize."
if len(text) > 1000:
text = text[:1000]
summary = summarizer(text, max_length=120, min_length=30, do_sample=False)
return summary[0]['summary_text']
def ask_question(question, context):
inputs = tokenizer.encode_plus(question, context, return_tensors="pt", truncation=True, max_length=512).to(device)
with torch.no_grad():
outputs = qa_model(**inputs)
start_idx = torch.argmax(outputs.start_logits)
end_idx = torch.argmax(outputs.end_logits) + 1
score = softmax(outputs.start_logits, dim=1)[0][start_idx] * softmax(outputs.end_logits, dim=1)[0][end_idx - 1]
answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx])
return f"Answer: {answer.strip()}\nConfidence: {round(score.item()*100, 2)}%"
def transcribe(audio_path):
result = whisper_model.transcribe(audio_path)
return result["text"]
st.title("πŸŽ™οΈπŸ“„ LexPilot: Voice + Document Q&A Assistant")
st.write("Upload a document or paste content. Ask questions by typing or speaking.")
tab = st.tabs(["Question Answering", "Summarization"])
with tab[0]:
uploaded_file = st.file_uploader("Upload .pdf / .docx / .txt", type=["pdf", "docx", "txt"])
pasted_text = st.text_area("Or paste text manually", height=150)
typed_question = st.text_input("Type your question")
audio_input = st.file_uploader("Or upload audio file (wav, mp3, m4a)", type=["wav", "mp3", "m4a"])
if st.button("Get Answer"):
context = ""
if uploaded_file:
context = extract_text(uploaded_file)
elif pasted_text.strip():
context = pasted_text.strip()
else:
st.warning("❗ Please upload or paste content.")
st.stop()
if typed_question.strip():
question = typed_question.strip()
elif audio_input:
# Save audio temporarily
with open("temp_audio", "wb") as f:
f.write(audio_input.getbuffer())
question = transcribe("temp_audio")
st.write(f"Transcribed question: {question}")
else:
st.warning("❗ Please type or upload an audio question.")
st.stop()
answer = ask_question(question, context)
st.text_area("Answer and Confidence", value=answer, height=100)
with tab[1]:
sum_file = st.file_uploader("Upload .pdf / .docx / .txt to summarize", type=["pdf", "docx", "txt"])
sum_text = st.text_area("Or paste content to summarize", height=150)
if st.button("Summarize"):
context = ""
if sum_file:
context = extract_text(sum_file)
elif sum_text.strip():
context = sum_text.strip()
else:
st.warning("❗ Please upload or paste content to summarize.")
st.stop()
summary = summarize_text(context)
st.text_area("Summary", value=summary, height=150)