Spaces:
Sleeping
Sleeping
| import tiktoken | |
| import traceback | |
| import streamlit as st | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| # models | |
| from langchain_openai import ChatOpenAI | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from urllib.parse import urlparse | |
| from langchain_community.document_loaders import YoutubeLoader # Youtube用 | |
| ###### dotenv を利用しない場合は消してください ###### | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| import warnings | |
| warnings.warn("dotenv not found. Please make sure to set your environment variables manually.", ImportWarning) | |
| ################################################ | |
| SUMMARIZE_PROMPT = """以下のコンテンツについて、内容を300文字程度でわかりやすく要約してください。 | |
| ======== | |
| {content} | |
| ======== | |
| 日本語で書いてね! | |
| """ | |
| def init_page(): | |
| st.set_page_config( | |
| page_title="Youtube Summarizer", | |
| page_icon="🤗" | |
| ) | |
| st.header("Youtube Summarizer 🤗") | |
| st.sidebar.title("Options") | |
| def select_model(temperature=0): | |
| models = ("GPT-3.5", "GPT-4", "Claude 3.5 Sonnet", "Gemini 1.5 Pro") | |
| model = st.sidebar.radio("Choose a model:", models) | |
| if model == "GPT-3.5": | |
| return ChatOpenAI( | |
| temperature=temperature, | |
| model_name="gpt-3.5-turbo" | |
| ) | |
| elif model == "GPT-4": | |
| return ChatOpenAI( | |
| temperature=temperature, | |
| model_name="gpt-4o" | |
| ) | |
| elif model == "Claude 3.5 Sonnet": | |
| return ChatAnthropic( | |
| temperature=temperature, | |
| model_name="claude-3-5-sonnet-20240620" | |
| ) | |
| elif model == "Gemini 1.5 Pro": | |
| return ChatGoogleGenerativeAI( | |
| temperature=temperature, | |
| model="gemini-1.5-pro-latest" | |
| ) | |
| def init_summarize_chain(): | |
| llm = select_model() | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("user", SUMMARIZE_PROMPT), | |
| ]) | |
| output_parser = StrOutputParser() | |
| return prompt | llm | output_parser | |
| def init_chain(): | |
| summarize_chain = init_summarize_chain() | |
| text_splitter = \ | |
| RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| # モデルによってトークン数カウント方法が違うためmodel_nameを指定する | |
| # Claude 3 の利用時に正確なトークン数を利用できないことには注意 | |
| model_name="gpt-3.5-turbo", | |
| # チャンクサイズはtoken数でカウント | |
| chunk_size=16000, | |
| chunk_overlap=0, | |
| ) | |
| text_split = RunnableLambda( | |
| lambda x: [ | |
| {"content": doc} for doc | |
| in text_splitter.split_text(x['content']) | |
| ] | |
| ) | |
| text_concat = RunnableLambda( | |
| lambda x: {"content": '\n'.join(x)}) | |
| map_reduce_chain = ( | |
| text_split | |
| | summarize_chain.map() | |
| | text_concat | |
| | summarize_chain | |
| ) | |
| def route(x): | |
| encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
| token_count = len(encoding.encode(x["content"])) | |
| if token_count > 16000: | |
| return map_reduce_chain | |
| else: | |
| return summarize_chain | |
| chain = RunnableLambda(route) | |
| return chain | |
| def validate_url(url): | |
| """ URLが有効かどうかを判定する関数 """ | |
| try: | |
| result = urlparse(url) | |
| return all([result.scheme, result.netloc]) | |
| except ValueError: | |
| return False | |
| def get_content(url): | |
| """ | |
| Document: | |
| - page_content: str | |
| - metadata: dict | |
| - source: str | |
| - title: str | |
| - description: Optional[str], | |
| - view_count: int | |
| - thumbnail_url: Optional[str] | |
| - publish_date: str | |
| - length: int | |
| - author: str | |
| """ | |
| with st.spinner("Fetching Youtube ..."): | |
| loader = YoutubeLoader.from_youtube_url( | |
| url, | |
| add_video_info=True, # タイトルや再生数も取得できる | |
| language=['en', 'ja'] # 英語→日本語の優先順位で字幕を取得 | |
| ) | |
| res = loader.load() # list of `Document` (page_content, metadata) | |
| try: | |
| if res: | |
| content = res[0].page_content | |
| title = res[0].metadata['title'] | |
| return f"Title: {title}\n\n{content}" | |
| else: | |
| return None | |
| except: | |
| st.write(traceback.format_exc()) # エラーが発生した場合はエラー内容を表示 | |
| return None | |
| def main(): | |
| init_page() | |
| chain = init_chain() | |
| # ユーザーの入力を監視 | |
| if url := st.text_input("URL: ", key="input"): | |
| is_valid_url = validate_url(url) | |
| if not is_valid_url: | |
| st.write('Please input valid url') | |
| else: | |
| if content := get_content(url): | |
| st.markdown("## Summary") | |
| st.write_stream(chain.stream({"content": content})) | |
| st.markdown("---") | |
| st.markdown("## Original Text") | |
| st.write(content) | |
| # コストを表示する場合は第3章と同じ実装を追加してください | |
| # calc_and_display_costs() | |
| if __name__ == '__main__': | |
| main() |