Update app.py
Browse files
app.py
CHANGED
|
@@ -10,7 +10,7 @@ import os
|
|
| 10 |
import tempfile
|
| 11 |
import gradio.themes as gr_themes
|
| 12 |
import csv
|
| 13 |
-
from transformers import pipeline
|
| 14 |
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
|
|
@@ -21,7 +21,7 @@ model.eval()
|
|
| 21 |
|
| 22 |
# Load the summarization model once at startup
|
| 23 |
#summarizer = pipeline("summarization", model="Falconsai/text_summarization", device="cpu")
|
| 24 |
-
|
| 25 |
|
| 26 |
def get_audio_segment(audio_path, start_second, end_second):
|
| 27 |
"""
|
|
@@ -236,35 +236,33 @@ def get_full_transcript(vis_data):
|
|
| 236 |
return ""
|
| 237 |
return " ".join([row[2] for row in vis_data if len(row) == 3])
|
| 238 |
|
| 239 |
-
# Simple summary function (replace with a real model if needed)
|
| 240 |
-
# Replace the old summarize_transcript function with this one
|
| 241 |
@spaces.GPU
|
| 242 |
-
|
| 243 |
# """
|
| 244 |
# Summarizes the transcript using the sshleifer/distilbart-cnn-12-6 model.
|
| 245 |
# """
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
|
| 269 |
# Apply the custom theme
|
| 270 |
|
|
|
|
| 10 |
import tempfile
|
| 11 |
import gradio.themes as gr_themes
|
| 12 |
import csv
|
| 13 |
+
from transformers.pipelines import pipeline
|
| 14 |
|
| 15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
MODEL_NAME="nvidia/parakeet-tdt-0.6b-v2"
|
|
|
|
| 21 |
|
| 22 |
# Load the summarization model once at startup
|
| 23 |
#summarizer = pipeline("summarization", model="Falconsai/text_summarization", device="cpu")
|
| 24 |
+
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
|
| 25 |
|
| 26 |
def get_audio_segment(audio_path, start_second, end_second):
|
| 27 |
"""
|
|
|
|
| 236 |
return ""
|
| 237 |
return " ".join([row[2] for row in vis_data if len(row) == 3])
|
| 238 |
|
|
|
|
|
|
|
| 239 |
@spaces.GPU
|
| 240 |
+
def summarize_transcript(transcript: str) -> str:
|
| 241 |
# """
|
| 242 |
# Summarizes the transcript using the sshleifer/distilbart-cnn-12-6 model.
|
| 243 |
# """
|
| 244 |
+
# Check for empty or whitespace-only input
|
| 245 |
+
if not transcript or not transcript.strip():
|
| 246 |
+
return "No transcript available to summarize."
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
gr.Info("Generating summary...", duration=2)
|
| 250 |
+
# Use the pre-loaded summarizer object to generate the summary
|
| 251 |
+
result = summarizer(
|
| 252 |
+
transcript,
|
| 253 |
+
max_length=250,
|
| 254 |
+
min_length=50,
|
| 255 |
+
num_beams=4,
|
| 256 |
+
early_stopping=True
|
| 257 |
+
)
|
| 258 |
+
# Extract the summary text from the result
|
| 259 |
+
summary = result[0]['summary_text']
|
| 260 |
+
return summary
|
| 261 |
+
except Exception as e:
|
| 262 |
+
error_message = f"An error occurred during summarization: {e}"
|
| 263 |
+
print(error_message) # Log the error to the console for debugging
|
| 264 |
+
gr.Warning("Sorry, the summary could not be generated at this time.")
|
| 265 |
+
return "" # Return an empty string on failure
|
| 266 |
|
| 267 |
# Apply the custom theme
|
| 268 |
|