heysho's picture
Update app.py
ab345c5 verified
import tiktoken
import traceback
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_text_splitters import RecursiveCharacterTextSplitter
# models
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_google_genai import ChatGoogleGenerativeAI
from urllib.parse import urlparse
from langchain_community.document_loaders import YoutubeLoader # Youtube用
###### dotenv を利用しない場合は消してください ######
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
import warnings
warnings.warn("dotenv not found. Please make sure to set your environment variables manually.", ImportWarning)
################################################
SUMMARIZE_PROMPT = """以下のコンテンツについて、内容を300文字程度でわかりやすく要約してください。
========
{content}
========
日本語で書いてね!
"""
def init_page():
st.set_page_config(
page_title="Youtube Summarizer",
page_icon="🤗"
)
st.header("Youtube Summarizer 🤗")
st.sidebar.title("Options")
def select_model(temperature=0):
models = ("GPT-3.5", "GPT-4", "Claude 3.5 Sonnet", "Gemini 1.5 Pro")
model = st.sidebar.radio("Choose a model:", models)
if model == "GPT-3.5":
return ChatOpenAI(
temperature=temperature,
model_name="gpt-3.5-turbo"
)
elif model == "GPT-4":
return ChatOpenAI(
temperature=temperature,
model_name="gpt-4o"
)
elif model == "Claude 3.5 Sonnet":
return ChatAnthropic(
temperature=temperature,
model_name="claude-3-5-sonnet-20240620"
)
elif model == "Gemini 1.5 Pro":
return ChatGoogleGenerativeAI(
temperature=temperature,
model="gemini-1.5-pro-latest"
)
def init_summarize_chain():
llm = select_model()
prompt = ChatPromptTemplate.from_messages([
("user", SUMMARIZE_PROMPT),
])
output_parser = StrOutputParser()
return prompt | llm | output_parser
def init_chain():
summarize_chain = init_summarize_chain()
text_splitter = \
RecursiveCharacterTextSplitter.from_tiktoken_encoder(
# モデルによってトークン数カウント方法が違うためmodel_nameを指定する
# Claude 3 の利用時に正確なトークン数を利用できないことには注意
model_name="gpt-3.5-turbo",
# チャンクサイズはtoken数でカウント
chunk_size=16000,
chunk_overlap=0,
)
text_split = RunnableLambda(
lambda x: [
{"content": doc} for doc
in text_splitter.split_text(x['content'])
]
)
text_concat = RunnableLambda(
lambda x: {"content": '\n'.join(x)})
map_reduce_chain = (
text_split
| summarize_chain.map()
| text_concat
| summarize_chain
)
def route(x):
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
token_count = len(encoding.encode(x["content"]))
if token_count > 16000:
return map_reduce_chain
else:
return summarize_chain
chain = RunnableLambda(route)
return chain
def validate_url(url):
""" URLが有効かどうかを判定する関数 """
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
def get_content(url):
"""
Document:
- page_content: str
- metadata: dict
- source: str
- title: str
- description: Optional[str],
- view_count: int
- thumbnail_url: Optional[str]
- publish_date: str
- length: int
- author: str
"""
with st.spinner("Fetching Youtube ..."):
loader = YoutubeLoader.from_youtube_url(
url,
add_video_info=True, # タイトルや再生数も取得できる
language=['en', 'ja'] # 英語→日本語の優先順位で字幕を取得
)
res = loader.load() # list of `Document` (page_content, metadata)
try:
if res:
content = res[0].page_content
title = res[0].metadata['title']
return f"Title: {title}\n\n{content}"
else:
return None
except:
st.write(traceback.format_exc()) # エラーが発生した場合はエラー内容を表示
return None
def main():
init_page()
chain = init_chain()
# ユーザーの入力を監視
if url := st.text_input("URL: ", key="input"):
is_valid_url = validate_url(url)
if not is_valid_url:
st.write('Please input valid url')
else:
if content := get_content(url):
st.markdown("## Summary")
st.write_stream(chain.stream({"content": content}))
st.markdown("---")
st.markdown("## Original Text")
st.write(content)
# コストを表示する場合は第3章と同じ実装を追加してください
# calc_and_display_costs()
if __name__ == '__main__':
main()