abhishekjoel's picture
Update app.py
bc5cfc4 verified
import os
import openai
import streamlit as st
import io
from pydub import AudioSegment
from youtube_transcript_api import YouTubeTranscriptApi, TranscriptsDisabled, NoTranscriptFound
import fitz # PyMuPDF
import tiktoken # For token counting
import traceback # For detailed error logging
# --- Configuration ---
# Models chosen for speed and capability balance
TRANSCRIPTION_MODEL = "whisper-1"
LANGUAGE_MODEL = "gpt-3.5-turbo"
# Approximate context window limit for the language model (input tokens)
MAX_TOKENS_FOR_SUMMARY_INPUT = 3500
MAX_TOKENS_FOR_CHAT_INPUT = 3500 # Context + Question
AUDIO_SIZE_LIMIT_MB = 25 # OpenAI API limit
# --- Helper Functions ---
# Initialize tiktoken encoder globally
try:
encoding = tiktoken.encoding_for_model(LANGUAGE_MODEL)
except Exception as e:
st.warning(f"Could not initialize token encoder for {LANGUAGE_MODEL}: {e}. Using word count fallback.")
encoding = None
def count_tokens(text):
"""Counts tokens using tiktoken, with fallback."""
if not text:
return 0
if encoding:
try:
return len(encoding.encode(text))
except Exception as e:
st.warning(f"Token encoding failed: {e}. Falling back to word count.")
return len(text.split()) # Fallback if encoding fails
else:
# Fallback estimate if tiktoken failed to initialize
return len(text.split())
def truncate_text_by_tokens(text, max_tokens):
"""Truncates text to fit within a token limit."""
if not text:
return ""
if encoding:
try:
tokens = encoding.encode(text)
if len(tokens) > max_tokens:
truncated_tokens = tokens[:max_tokens]
return encoding.decode(truncated_tokens)
return text
except Exception as e:
st.warning(f"Token encoding/decoding failed during truncation: {e}. Using word count fallback.")
words = text.split()
estimated_words = int(max_tokens * 0.7)
return " ".join(words[:estimated_words])
else:
words = text.split()
estimated_words = int(max_tokens * 0.7)
return " ".join(words[:estimated_words])
# --- Core Functions ---
def initialize_openai():
"""Initializes OpenAI API key from Streamlit secrets."""
try:
api_key = st.secrets["OPENAI_API_KEY"]
if not api_key:
st.error("OpenAI API Key not found in Secrets. Please add 'OPENAI_API_KEY' to your Hugging Face Space secrets.")
return False
openai.api_key = api_key
return True
except KeyError:
st.error("OpenAI API Key not found in Secrets. Please add 'OPENAI_API_KEY' to your Hugging Face Space secrets.")
return False
except Exception as e:
st.error(f"Error initializing OpenAI: {e}")
return False
def transcribe_audio(audio_file):
"""Transcribes audio using OpenAI Whisper API."""
if audio_file.size > AUDIO_SIZE_LIMIT_MB * 1024 * 1024:
st.error(f"Audio file size exceeds {AUDIO_SIZE_LIMIT_MB}MB limit.")
return None
try:
audio = AudioSegment.from_file(audio_file)
buffer = io.BytesIO()
audio.export(buffer, format="wav")
buffer.seek(0)
buffer.name = "audio.wav" # Required by OpenAI API
response = openai.Audio.transcribe(
model=TRANSCRIPTION_MODEL,
file=buffer,
response_format="verbose_json"
)
transcription_text = "\n".join(
[f"[{seg['start']:.2f}-{seg['end']:.2f}] {seg['text']}" for seg in response['segments']]
)
return transcription_text
except openai.error.AuthenticationError:
st.error("Authentication Error: Invalid OpenAI API Key provided in Secrets.")
return None
except openai.error.RateLimitError:
st.error("OpenAI API Rate Limit Exceeded. Please check your usage or wait.")
return None
except Exception as e:
st.error(f"Error during audio transcription: {str(e)}")
print(f"Transcription Error Traceback:\n{traceback.format_exc()}")
return None
def extract_text_from_pdf(pdf_file):
"""Extracts text from a PDF using PyMuPDF."""
try:
pdf_bytes = pdf_file.getvalue()
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
text = ""
for page in doc:
text += page.get_text() + "\n"
doc.close()
if not text.strip():
st.warning("No text could be extracted. The PDF might be image-based (scanned) or empty.")
return ""
return text
except Exception as e:
st.error(f"Error reading PDF: {str(e)}")
print(f"PDF Extraction Error Traceback:\n{traceback.format_exc()}")
return None
def get_youtube_transcript(url):
"""Gets English transcript from a YouTube video."""
try:
video_id = None
if "watch?v=" in url:
video_id = url.split("watch?v=")[1].split("&")[0]
elif "youtu.be/" in url:
video_id = url.split("youtu.be/")[1].split("?")[0]
elif "youtu.be/" in url:
video_id = url.split("/")[-1].split("?")[0]
elif "youtu.be//" in url:
video_id = url.split("/")[-1].split("?")[0]
else:
# Basic check for other potential valid IDs (e.g., youtu.be links)
parts = url.split("/")
potential_id = parts[-1].split("?")[0]
if len(potential_id) == 11: # Common length for YouTube IDs
video_id = potential_id
else:
st.error("Could not automatically determine Video ID from URL. Please use standard 'watch?v=' URL.")
return None
if not video_id:
st.error("Failed to extract video ID.")
return None
transcript_list = YouTubeTranscriptApi.list_transcripts(video_id)
try:
# Prioritize manual transcripts, fallback to generated
transcript = transcript_list.find_manually_created_transcript(['en'])
except NoTranscriptFound:
try:
transcript = transcript_list.find_generated_transcript(['en'])
st.info("Using auto-generated English transcript.")
except NoTranscriptFound:
st.warning(f"No English transcript (manual or generated) found for video: {url}")
return None
transcript_data = transcript.fetch()
transcription_text = "\n".join(
[f"[{entry['start']:.2f}-{entry['start']+entry['duration']:.2f}] {entry['text']}" for entry in transcript_data]
)
return transcription_text
except TranscriptsDisabled:
st.error(f"Transcripts are disabled for video: {url}")
return None
except Exception as e:
st.error(f"Error fetching YouTube transcript: {str(e)}")
print(f"YouTube Transcript Error Traceback:\n{traceback.format_exc()}")
return None
def generate_summary(text_to_summarize, max_output_tokens=800):
"""Generates summary using OpenAI API, handling potential truncation."""
input_token_count = count_tokens(text_to_summarize)
if input_token_count > MAX_TOKENS_FOR_SUMMARY_INPUT:
st.warning(f"Input text ({input_token_count} tokens) exceeds the limit ({MAX_TOKENS_FOR_SUMMARY_INPUT} tokens) for the summarization model. Truncating input.")
text_to_summarize = truncate_text_by_tokens(text_to_summarize, MAX_TOKENS_FOR_SUMMARY_INPUT)
input_token_count = count_tokens(text_to_summarize) # Recount
if not text_to_summarize:
st.error("Input text for summarization is empty.")
return None
prompt = f"Summarize the following text comprehensively, focusing on key points, concepts, and conclusions. Aim for a detailed summary but keep it concise where possible:\n\n{text_to_summarize}"
try:
response = openai.ChatCompletion.create(
model=LANGUAGE_MODEL,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=max_output_tokens,
temperature=0.5
)
return response.choices[0].message.content.strip()
except openai.error.AuthenticationError:
st.error("Authentication Error: Invalid OpenAI API Key provided in Secrets.")
return None
except openai.error.RateLimitError:
st.error("OpenAI API Rate Limit Exceeded during summarization.")
return None
except openai.error.InvalidRequestError as e:
st.error(f"Invalid Request during summarization: {e}.")
return None
except Exception as e:
st.error(f"Error during summary generation: {str(e)}")
print(f"Summarization Error Traceback:\n{traceback.format_exc()}")
return None
def chat_with_ai(question, context, max_output_tokens=500):
"""Answers questions based on the provided context using OpenAI API."""
if not question:
st.warning("Please enter a question.")
return None
if not context:
st.error("Cannot answer question: No context available.")
return None
prompt = f"Based *only* on the following content:\n\n---\n{context}\n---\n\nAnswer the question: {question}"
prompt_token_count = count_tokens(prompt)
if prompt_token_count > MAX_TOKENS_FOR_CHAT_INPUT:
st.error(f"The question and context combined ({prompt_token_count} tokens) exceed the model's input limit ({MAX_TOKENS_FOR_CHAT_INPUT} tokens). Try using the summary as context or ask a shorter question.")
return None
try:
response = openai.ChatCompletion.create(
model=LANGUAGE_MODEL,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=max_output_tokens,
temperature=0.3
)
return response.choices[0].message.content.strip()
except openai.error.AuthenticationError:
st.error("Authentication Error: Invalid OpenAI API Key provided in Secrets.")
return None
except openai.error.RateLimitError:
st.error("OpenAI API Rate Limit Exceeded during chat.")
return None
except openai.error.InvalidRequestError as e:
st.error(f"Invalid Request during chat: {e}.")
return None
except Exception as e:
st.error(f"Error during AI chat: {str(e)}")
print(f"Chat Error Traceback:\n{traceback.format_exc()}")
return None
# --- Streamlit App Main Function ---
def main():
st.set_page_config(layout="wide", page_title="AI Summarization Bot")
# --- Styling (Restored Original CSS) ---
st.markdown("""
<style>
.stApp {
background: linear-gradient(180deg,
rgba(64,224,208,0.7) 0%,
rgba(32,112,104,0.4) 35%,
rgba(0,0,0,0) 100%
);
}
/* Attempt to make sidebar slightly transparent if needed */
div[data-testid="stSidebarContent"] {
background-color: rgba(255,255,255,0.1) !important; /* May need tweaking */
}
/* Style markdown text */
.stMarkdown p, .stMarkdown li, .stText, .stAlert p {
color: #ffffff !important; /* White text for markdown, etc. */
}
/* Text Area Styling */
.stTextArea textarea {
background-color: rgba(0, 0, 0, 0.6) !important; /* Darker transparent background */
color: #ffffff !important; /* White text */
border: 1px solid rgba(255, 255, 255, 0.3); /* Subtle border */
max-height: 400px; /* Ensure scroll height */
overflow-y: auto !important;
}
/* Input Text Styling */
.stTextInput input {
color: white !important;
background-color: rgba(0, 0, 0, 0.5) !important;
border: 1px solid rgba(255, 255, 255, 0.3);
}
/* Button Styling */
.stButton button {
background-color: #40E0D0; /* Turquoise */
color: black;
border: none;
padding: 0.5rem 1rem;
border-radius: 5px;
font-weight: bold;
}
.stButton button:hover {
background-color: #48D1CC; /* Slightly darker turquoise */
color: black;
}
/* Headings */
h1, h2, h3, h4, h5, h6 {
color: white !important;
}
/* Specific text elements like radio buttons, selectbox labels */
.stRadio label, .stSelectbox label, .stFileUploader label {
color: white !important;
}
/* Sidebar Header */
[data-testid="stSidebar"] [data-testid="stVerticalBlock"] {
color: white !important;
}
[data-testid="stSidebar"] h1, [data-testid="stSidebar"] h2, [data-testid="stSidebar"] h3 {
color: white !important;
}
[data-testid="stSidebar"] p, [data-testid="stSidebar"] li {
color: white !important;
}
/* Make text areas scrollable if content exceeds max-height */
div[data-baseweb="textarea"] > div > textarea {
overflow-y: auto !important;
}
</style>
""", unsafe_allow_html=True)
st.markdown("<h1 style='text-align: center;'>AI Summarization Bot 🤖</h1>", unsafe_allow_html=True)
# Removed redundant description paragraph as title is descriptive
# Initialize OpenAI API Key
if 'openai_initialized' not in st.session_state:
st.session_state['openai_initialized'] = initialize_openai()
if not st.session_state.get('openai_initialized'):
st.warning("OpenAI initialization failed. Please ensure your API key is correctly set in Hugging Face secrets and refresh.")
st.stop()
# --- Sidebar for Inputs ---
st.sidebar.header("Input Options")
input_type = st.sidebar.selectbox("Select Input Type", ["Audio File", "PDF Document", "YouTube URL"], key="input_type_select")
# Initialize session state variables
if 'full_text' not in st.session_state:
st.session_state['full_text'] = None
if 'summary' not in st.session_state:
st.session_state['summary'] = None
if 'last_input_type' not in st.session_state:
st.session_state['last_input_type'] = None
if 'last_input_data_key' not in st.session_state:
st.session_state['last_input_data_key'] = None
if 'current_input_key' not in st.session_state:
st.session_state['current_input_key'] = None
# Clear results if input type changes
if st.session_state['last_input_type'] != input_type:
st.session_state['full_text'] = None
st.session_state['summary'] = None
st.session_state['last_input_data_key'] = None
st.session_state['current_input_key'] = None # Reset current key too
st.session_state['last_input_type'] = input_type
# --- Input Elements ---
uploaded_file = None
youtube_url = None
process_button_pressed = False
if input_type == "Audio File":
uploaded_file = st.sidebar.file_uploader("Upload audio file (Max 25MB)", type=["mp3", "wav", "m4a", "ogg", "webm"], key="audio_uploader")
if uploaded_file:
# Use file name and size as the key instead of non-existent .id
st.session_state['current_input_key'] = f"{uploaded_file.name}-{uploaded_file.size}"
elif input_type == "PDF Document":
uploaded_file = st.sidebar.file_uploader("Upload PDF document", type=["pdf"], key="pdf_uploader")
if uploaded_file:
# Use file name and size as the key
st.session_state['current_input_key'] = f"{uploaded_file.name}-{uploaded_file.size}"
elif input_type == "YouTube URL":
youtube_url = st.sidebar.text_input("Enter YouTube URL", key="youtube_input", placeholder="e.g., https://www.youtube.com/watch?v=...")
if youtube_url:
st.session_state['current_input_key'] = youtube_url # Use URL as key
st.sidebar.markdown("---") # Separator
st.sidebar.markdown("### Steps:")
st.sidebar.markdown("1. Select input type & provide source.")
st.sidebar.markdown("2. Click 'Generate Summary & Notes'.")
st.sidebar.markdown("3. Review results and use chat if needed.")
# Single "Generate" button
if st.sidebar.button("Generate Summary & Notes", key="generate_button", use_container_width=True): # Make button wider
current_key = st.session_state.get('current_input_key')
# Check if input is provided for the selected type
valid_input_provided = False
if input_type == "Audio File" and uploaded_file:
valid_input_provided = True
elif input_type == "PDF Document" and uploaded_file:
valid_input_provided = True
elif input_type == "YouTube URL" and youtube_url:
valid_input_provided = True
if valid_input_provided:
# Check if it's a *new* input compared to the last processed one
if current_key != st.session_state.get('last_input_data_key'):
st.session_state['full_text'] = None
st.session_state['summary'] = None
st.session_state['last_input_data_key'] = current_key
process_button_pressed = True
else:
# Input hasn't changed, check if results already exist
if st.session_state.get('full_text') or st.session_state.get('summary'):
st.info("Results for the current input are already displayed. Upload a new file or URL to generate again.")
else: # Results don't exist for some reason, re-process
process_button_pressed = True
else:
st.warning("Please provide input (upload file or enter URL) before generating.")
# --- Processing Logic ---
if process_button_pressed:
extracted_text = None
input_valid = False # Re-check validity just before processing
if input_type == "Audio File" and uploaded_file:
input_valid = True
with st.spinner('Transcribing audio... (this may take a while)'):
extracted_text = transcribe_audio(uploaded_file)
elif input_type == "PDF Document" and uploaded_file:
input_valid = True
with st.spinner('Extracting text from PDF...'):
extracted_text = extract_text_from_pdf(uploaded_file)
elif input_type == "YouTube URL" and youtube_url:
input_valid = True
with st.spinner('Fetching YouTube transcript...'):
extracted_text = get_youtube_transcript(youtube_url)
if input_valid and extracted_text is not None:
st.session_state['full_text'] = extracted_text
if extracted_text: # Only summarize if text extraction was successful
with st.spinner('Generating summary...'):
summary_text = generate_summary(extracted_text)
st.session_state['summary'] = summary_text
if not summary_text:
st.error("Summary generation failed.") # Keep error message if summary is None
else:
st.warning("Text extraction resulted in empty content. Cannot generate summary.")
st.session_state['summary'] = None
elif input_valid and extracted_text is None:
# Error already shown in extraction func OR warning shown if text was empty
st.session_state['full_text'] = None
st.session_state['summary'] = None
# --- Display Results ---
# Use columns only if there's something to display to avoid empty columns
if st.session_state.get('full_text') or st.session_state.get('summary'):
st.markdown("---") # Separator before results
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("<h3>Full Text / Transcription</h3>", unsafe_allow_html=True)
full_text_content = st.session_state.get('full_text')
if full_text_content:
display_text = full_text_content
# Simple truncation for display performance, not affecting summary/chat context
if len(display_text) > 150000:
display_text = display_text[:150000] + "\n\n... (Text truncated for display performance)"
st.text_area("Full Content:", display_text, height=400, key="full_text_area", label_visibility="collapsed")
else:
# Show placeholder only if generation was attempted but failed/empty
if st.session_state.get('last_input_data_key') and process_button_pressed: # Check if process was triggered
st.info("No text extracted or transcribed.")
with col2:
st.markdown("<h3>Generated Summary</h3>", unsafe_allow_html=True)
summary_content = st.session_state.get('summary')
if summary_content:
st.text_area("Summary:", summary_content, height=400, key="summary_area", label_visibility="collapsed")
else:
# Show placeholder only if generation was attempted but failed/empty
if st.session_state.get('last_input_data_key') and process_button_pressed:
st.warning("Summary could not be generated.")
# --- Chat Section ---
st.markdown("---")
st.markdown("<h3>Chat with AI about the Content</h3>", unsafe_allow_html=True)
context_option = st.radio(
"Use as chat context:",
('Generated Summary', 'Full Text'),
key='chat_context_option',
horizontal=True,
label_visibility="collapsed" # Hide label for radio itself
)
chat_context = None
context_name = ""
if context_option == 'Generated Summary':
if st.session_state.get('summary'):
chat_context = st.session_state['summary']
context_name = "Summary"
else:
st.warning("Summary not available for chat context.")
else: # Full Text option
if st.session_state.get('full_text'):
full_text_for_chat = st.session_state['full_text']
# Truncate context *before* passing to chat if needed
# Estimate tokens needed for question + response buffer
max_context_tokens = MAX_TOKENS_FOR_CHAT_INPUT - 500
chat_context = truncate_text_by_tokens(full_text_for_chat, max_context_tokens)
if len(full_text_for_chat) > len(chat_context):
context_name = "Full Text (Truncated for Chat)"
else:
context_name = "Full Text"
else:
st.warning("Full text not available for chat context.")
if chat_context:
# Display which context is being used subtly
st.markdown(f"<small style='color: #cccccc;'>Chatting based on: **{context_name}**</small>", unsafe_allow_html=True)
question = st.text_input("Ask a question:", key="chat_question", placeholder="Ask anything about the selected context...")
if st.button("Ask AI", key="ask_ai_button", use_container_width=True):
if question:
with st.spinner("AI is thinking..."):
answer = chat_with_ai(question, chat_context)
if answer:
st.markdown("**AI Answer:**")
# Use markdown for potentially better formatting of AI response
st.markdown(answer)
else:
st.error("Failed to get an answer from the AI.")
else:
st.warning("Please enter a question first.")
else:
# Only show message if processing was attempted for current input
if st.session_state.get('last_input_data_key'):
st.markdown("_(Generate content or summary first to enable chat)_")
# Add footer or instructions if desired
st.sidebar.markdown("---")
st.sidebar.info("Powered by OpenAI Whisper & GPT models.")
if __name__ == "__main__":
main()