sambanova-scribe / src /scribe.py
Daryl Lim
Update scribe.py
9a231cf
import argparse
import base64
import concurrent.futures
import copy
import os
import sys
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union
import yaml
import yt_dlp
from dotenv import load_dotenv
from langchain.output_parsers import OutputFixingParser
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.llms import LLM
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.output_parsers import PydanticOutputParser, StrOutputParser
from langchain_core.prompts import load_prompt
from pydantic import BaseModel, Field
current_dir = os.path.dirname(os.path.abspath(__file__))
kit_dir = os.path.dirname(current_dir)
repo_dir = os.path.dirname(kit_dir)
sys.path.append(kit_dir)
sys.path.append(repo_dir)
from utils.model_wrappers.api_gateway import APIGateway
load_dotenv(os.path.join(repo_dir, '.env'))
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB in bytes
class FileSizeExceededError(Exception):
pass
class Transcript(BaseModel):
transcript: str = Field(description='audio transcription')
class Scribe:
"""Downloading, transcription, question answering and summarization class"""
def __init__(
self,
sambanova_api_key: Optional[str] = None,
) -> None:
"""
Create a new Scribe class
Args:
sambanova_api_key (str): sambanova Cloud env api key
"""
config = self.get_config_info()
self.llm_info = config[0]
self.audio_model_info = config[1]
self.prod_mode = config[2]
self.sambanova_api_key: Optional[str] = sambanova_api_key
self.audio_model = self.set_audio_model()
self.llm = self.set_llm()
self.reset_query_audio_conversation()
def get_config_info(self) -> Tuple[Dict[str, Any], Dict[str, Any], bool]:
"""
Loads json config file
"""
with open(CONFIG_PATH, 'r') as yaml_file:
config = yaml.safe_load(yaml_file)
llm_info = config['llm']
audio_model_info = config['audio_model']
prod_mode = config['prod_mode']
return llm_info, audio_model_info, prod_mode
def set_audio_model(self) -> BaseChatModel:
"""
Sets the audio model.
Returns:
audio_model: The audio model.
"""
audio_model = APIGateway.load_chat(
type=self.audio_model_info['type'],
streaming=False,
max_tokens=self.audio_model_info['max_tokens'],
temperature=self.audio_model_info['temperature'],
model=self.audio_model_info['model'],
sambanova_api_key=self.sambanova_api_key,
)
return audio_model
def set_llm(self) -> Union[LLM, BaseChatModel]:
"""
Sets the sncloud, or sambastudio LLM based on the llm type attribute.
Returns:
LLM: The SambaStudio Cloud or Sambastudio Langchain ChatModel.
"""
llm = APIGateway.load_chat(
type=self.llm_info['type'],
streaming=False,
do_sample=self.llm_info['do_sample'],
max_tokens=self.llm_info['max_tokens'],
temperature=self.llm_info['temperature'],
model=self.llm_info['model'],
process_prompt=False,
sambanova_api_key=self.sambanova_api_key,
)
return llm
def summarize(self, text: str, num: int = 5) -> str:
"""
/Crete a bullet points summarY of the text input.
Args:
text (str): The text to summarize.
num (int, optional): The number of bullet points to generate. Defaults to 5.
Returns:
str: The bullet points summary of the text.
"""
prompt_template = load_prompt(os.path.join(kit_dir, 'prompts', 'summary.yaml'))
chain = prompt_template | self.llm | StrOutputParser()
summary = chain.invoke({'text': text, 'num': num})
return summary
def reset_query_audio_conversation(self) -> None:
self.query_audio_conversation: List[Union[HumanMessage, AIMessage, SystemMessage]] = [
AIMessage(
'You are helpful assistant called Scribe developed by SambaNova Systems, '
'you are helping users in general purpose tasks'
)
]
def query_audio(self, audio: Optional[Union[BytesIO, str]] = None, query: Optional[str] = None) -> str:
if audio is not None:
b64_audio = self.load_encode_audio(audio)
self.query_audio_conversation.append(
HumanMessage(
content=[
{'type': 'audio_content', 'audio_content': {'content': f'data:audio/mp3;base64,{b64_audio}'}}
]
)
)
if query is not None:
self.query_audio_conversation.append(HumanMessage(f'{query}, explain your response'))
chain = self.audio_model | StrOutputParser()
response = chain.invoke(self.query_audio_conversation)
self.query_audio_conversation.append(AIMessage(response))
return response
def query_audio_pipeline(self, audio: Union[BytesIO, str], query: str) -> str:
with concurrent.futures.ThreadPoolExecutor() as executor:
transcription_future = executor.submit(self.transcribe_audio, audio)
audio_result_future = executor.submit(self.query_audio, audio, query)
transcription = transcription_future.result()
audio_result = audio_result_future.result()
conversation = [
SystemMessage("""
You are helpful assistant called Scribe developed by SambaNova Systems.
the user will ask information about an audio.
You will get the user query, the audio transcription and an intermediate response generated by a model capable of listening the audio
Whit those give a final response to the user query
"""), # noqa: E501
HumanMessage(
f"""
Transcript: {transcription}
Intermediate Audio Response: {audio_result}
Query: {query}
"""
),
]
chain = self.llm | StrOutputParser()
response = chain.invoke(conversation)
return response
def encode_to_base64(self, content: bytes) -> str:
"""Encode audio file to base64"""
return base64.b64encode(content).decode('utf-8')
def load_encode_audio(self, audio: Union[BytesIO, str]) -> str:
if isinstance(audio, str):
with open(audio, 'rb') as file:
audio_bytes = file.read()
else:
# make a copy given BytesIO object is not thread safe
audio_copy = copy.deepcopy(audio)
audio_bytes = audio_copy.read()
b64_audio = self.encode_to_base64(content=audio_bytes)
return b64_audio
def transcribe_audio(self, audio_file: Union[BytesIO, str]) -> str:
b64_audio = self.load_encode_audio(audio_file)
conversation = [
AIMessage('You are Automatic Speech Recognition tool'),
HumanMessage(
content=[{'type': 'audio_content', 'audio_content': {'content': f'data:audio/mp3;base64,{b64_audio}'}}]
),
HumanMessage(
f"""Please transcribe the previous audio in the following format
```
{{
"transcript":"<audio transcription>"
}}
```
Always return your response enclosed by ``` and using double quotes
"""
),
]
parser = PydanticOutputParser(pydantic_object=Transcript)
autofix_parser = OutputFixingParser.from_llm(parser=parser, llm=self.llm)
chain = self.audio_model | autofix_parser
return chain.invoke(conversation).transcript # type: ignore
def download_youtube_audio(
self, url: str, output_path: Optional[str] = None, max_filesize: int = MAX_FILE_SIZE
) -> Optional[str]:
"""
Downloads the audio from a YouTube URL and saves it to the specified output path.
Args:
url (str): The YouTube URL to download.
output_path (str, optional): The path to save the downloaded audio. Defaults to ./data.
max_filesize (int, optional): The maximum file size in bytes. Defaults to MAX_FILE_SIZE.
Returns:
str: The path of the downloaded audio if successful, otherwise None.
"""
if output_path is None:
output_path = os.path.join(kit_dir, 'data')
downloaded_filename = None
def progress_hook(d: Dict[str, Any]) -> None:
nonlocal downloaded_filename
if d['status'] == 'finished':
downloaded_filename = d['filename']
elif d['status'] == 'downloading':
if 'total_bytes' in d and d['total_bytes'] > max_filesize:
if 'tmpfilename' in d:
try:
os.remove(d['tmpfilename'])
print(f"Deleted temporary file: {d['tmpfilename']}")
except OSError as e:
print(f'Error deleting temporary file: {e}')
raise FileSizeExceededError(f'File size exceeds {max_filesize/1024/1024:.2f} MB limit')
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [
{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}
],
'outtmpl': output_path + '/%(title)s.%(ext)s',
'progress_hooks': [progress_hook],
'cookiefile': os.path.join(kit_dir, 'data', 'sample_yt_cookies.txt'),
'username': os.environ.get('YOUTUBE_USERNAME'),
'password': os.environ.get('YOUTUBE_PASSWORD'),
'verbose': True,
}
try:
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
ydl.download([url])
print(f'Successfully downloaded audio from: {url}')
# Ensure the filename has .mp3 extension
if downloaded_filename and not downloaded_filename.endswith('.mp3'):
new_filename = os.path.splitext(downloaded_filename)[0] + '.mp3'
if os.path.exists(new_filename):
downloaded_filename = new_filename
return downloaded_filename
except FileSizeExceededError as e:
print(f'Skipped downloading {url}: {str(e)}')
except yt_dlp.utils.DownloadError as e:
print(f'An error occurred while downloading {url}: {str(e)}')
except Exception as e:
print(f'An unexpected error occurred while downloading {url}: {str(e)}')
return None
def delete_downloaded_file(self, file_path: str) -> None:
"""
Deletes the specified file.
Args:
file_path (str): The path of the file to delete.
"""
try:
# Check if the file exists
if os.path.exists(file_path):
os.remove(file_path)
print(f"File '{file_path}' successfully deleted.")
else:
print(f"Error: File '{file_path}' does not exist.")
except PermissionError:
print(f"PermissionError: You do not have permission to delete the file '{file_path}'.")
except OSError as e:
print(f"OSError: Failed to delete the file '{file_path}' due to: {e.strerror}.")
except Exception as e:
print(f'An unexpected error occurred: {str(e)}')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Transcribe, summarize, or query audio')
parser.add_argument(
'--task', type=str, choices=['transcribe', 'summarize', 'query'], help='Task to perform', default='transcribe'
)
parser.add_argument('--source', type=str, choices=['youtube', 'local'], help='Source of the audio', default='local')
parser.add_argument('--audio', type=str, help='Path or YouTube link to the audio', required=True)
parser.add_argument('--query', type=str, help='query when using query audio task', default=None)
args = parser.parse_args()
scribe = Scribe()
if args.source == 'youtube':
audio_path = scribe.download_youtube_audio(args.audio)
else:
audio_path = args.audio
assert audio_path is not None
if args.task == 'transcribe':
transcript = scribe.transcribe_audio(audio_path)
print(f'Transcript: {transcript}')
elif args.task == 'summarize':
transcript = scribe.transcribe_audio(audio_path)
summary = scribe.summarize(transcript)
print(f'Summary: {summary}')
elif args.task == 'query':
if args.query is None:
response = scribe.query_audio(audio=audio_path)
print(f'Response: {response}')
else:
response = scribe.query_audio_pipeline(audio=audio_path, query=args.query)
print(f'Response: {response}')
if args.source == 'youtube':
scribe.delete_downloaded_file(audio_path)