Spaces:
Build error
Build error
| import gradio as gr | |
| import os | |
| import time | |
| import pdfplumber | |
| from dotenv import load_dotenv | |
| import torch | |
| from transformers import ( | |
| BertJapaneseTokenizer, | |
| BertModel, | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| pipeline, | |
| BitsAndBytesConfig | |
| ) | |
| from langchain_community.vectorstores import FAISS # 修正 | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.memory import ConversationBufferMemory | |
| from langchain_community.llms import HuggingFacePipeline # 修正 | |
| from langchain_community.embeddings import HuggingFaceEmbeddings # 修正 | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| # Pydanticの警告を無視 | |
| import warnings | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=r"Field \"model_name\" in HuggingFaceInferenceAPIEmbeddings has conflict with protected namespace" | |
| ) | |
| load_dotenv() | |
| list_llm = [ | |
| "meta-llama/Meta-Llama-3-8B-Instruct", | |
| "rinna/llama-3-youko-8b", | |
| ] | |
| list_llm_simple = [os.path.basename(llm) for llm in list_llm] | |
| # 日本語PDFのテキスト抽出 | |
| def extract_text_from_pdf(file_path): | |
| with pdfplumber.open(file_path) as pdf: | |
| pages = [page.extract_text() for page in pdf.pages] | |
| return " ".join(pages) | |
| # モデルとトークナイザの初期化 | |
| tokenizer_bert = BertJapaneseTokenizer.from_pretrained( | |
| 'cl-tohoku/bert-base-japanese', | |
| clean_up_tokenization_spaces=True | |
| ) | |
| model_bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese') | |
| def split_text_simple(text, chunk_size=1024): | |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] | |
| def create_db(splits): | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name='sonoisa/sentence-bert-base-ja-mean-tokens' | |
| ) | |
| vectordb = FAISS.from_texts(splits, embeddings) | |
| return vectordb | |
| def initialize_llmchain( | |
| llm_model, | |
| temperature, | |
| max_tokens, | |
| top_k, | |
| vector_db, | |
| retries=5, | |
| delay=5 | |
| ): | |
| attempt = 0 | |
| while attempt < retries: | |
| try: | |
| # ローカルモデルの場合 | |
| if "rinna" in llm_model.lower(): | |
| # デバイスの自動検出 | |
| if torch.cuda.is_available(): | |
| device_map = "auto" | |
| torch_dtype = torch.float16 | |
| # GPUがある場合は量子化を使用 | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| llm_model, | |
| device_map=device_map, | |
| quantization_config=quantization_config | |
| ) | |
| else: | |
| device_map = {"": "cpu"} | |
| torch_dtype = torch.float32 | |
| # CPUの場合は量子化を使用せずにモデルをロード | |
| model = AutoModelForCausalLM.from_pretrained( | |
| llm_model, | |
| device_map=device_map, | |
| torch_dtype=torch_dtype | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # エンドポイントモデルの場合 | |
| elif "meta-llama" in llm_model.lower() or "mistralai" in llm_model.lower(): | |
| # パラメータを直接指定 | |
| llm = HuggingFaceEndpoint( | |
| endpoint_url=f"https://api-inference.huggingface.co/models/{llm_model}", | |
| huggingfacehub_api_token=os.getenv("HF_TOKEN"), | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| top_k=top_k | |
| ) | |
| else: | |
| # その他のモデルの場合(必要に応じて追加) | |
| raise Exception(f"Unsupported model: {llm_model}") | |
| # 共通の処理 | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| output_key='answer', | |
| return_messages=True | |
| ) | |
| retriever = vector_db.as_retriever() | |
| qa_chain = ConversationalRetrievalChain.from_llm( | |
| llm, | |
| retriever=retriever, | |
| memory=memory, | |
| return_source_documents=True, | |
| verbose=False | |
| ) | |
| return qa_chain | |
| except Exception as e: | |
| if "Could not authenticate with huggingface_hub" in str(e): | |
| time.sleep(delay) | |
| attempt += 1 | |
| else: | |
| raise Exception(f"Error initializing QA chain: {str(e)}") | |
| raise Exception(f"Failed to initialize after {retries} attempts") | |
| def process_pdf(file): | |
| try: | |
| if file is None: | |
| return None, "Please upload a PDF file." | |
| text = extract_text_from_pdf(file.name) | |
| splits = split_text_simple(text) | |
| vdb = create_db(splits) | |
| return vdb, "PDF processed and vector database created." | |
| except Exception as e: | |
| return None, f"Error processing PDF: {str(e)}" | |
| def initialize_qa_chain( | |
| llm_index, | |
| temperature, | |
| max_tokens, | |
| top_k, | |
| vector_db | |
| ): | |
| try: | |
| if vector_db is None: | |
| return None, "Please process a PDF first." | |
| llm_name = list_llm[llm_index] | |
| chain = initialize_llmchain( | |
| llm_name, | |
| temperature, | |
| max_tokens, | |
| top_k, | |
| vector_db | |
| ) | |
| return chain, "QA Chatbot initialized with selected LLM." | |
| except Exception as e: | |
| return None, f"Error initializing QA chain: {str(e)}" | |
| def update_chat(msg, history, chain): | |
| try: | |
| if chain is None: | |
| return history + [("User", msg), ("Assistant", "Please initialize the QA Chatbot first.")] | |
| response = chain({"question": msg, "chat_history": history}) | |
| return history + [("User", msg), ("Assistant", response['answer'])] | |
| except Exception as e: | |
| return history + [("User", msg), ("Assistant", f"Error: {str(e)}")] | |
| def demo(): | |
| with gr.Blocks() as demo: | |
| vector_db = gr.State(value=None) | |
| qa_chain = gr.State(value=None) | |
| with gr.Tab("Step 1 - Upload and Process"): | |
| with gr.Row(): | |
| document = gr.File(label="Upload your Japanese PDF document", file_types=["pdf"]) | |
| with gr.Row(): | |
| process_btn = gr.Button("Process PDF") | |
| process_output = gr.Textbox(label="Processing Output") | |
| with gr.Tab("Step 2 - Initialize QA Chatbot"): | |
| with gr.Row(): | |
| llm_btn = gr.Radio(list_llm_simple, label="Select LLM Model", type="index") | |
| llm_temperature = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Temperature", value=0.7) | |
| max_tokens = gr.Slider(minimum=128, maximum=2048, step=128, label="Max Tokens", value=1024) | |
| top_k = gr.Slider(minimum=1, maximum=10, step=1, label="Top K", value=3) | |
| with gr.Row(): | |
| init_qa_btn = gr.Button("Initialize QA Chatbot") | |
| init_output = gr.Textbox(label="Initialization Output") | |
| with gr.Tab("Step 3 - Chat with your Document"): | |
| chatbot = gr.Chatbot() | |
| message = gr.Textbox(label="Ask a question") | |
| with gr.Row(): | |
| send_btn = gr.Button("Send") | |
| clear_chat_btn = gr.Button("Clear Chat") | |
| reset_all_btn = gr.Button("Reset All") | |
| process_btn.click( | |
| process_pdf, | |
| inputs=[document], | |
| outputs=[vector_db, process_output] | |
| ) | |
| init_qa_btn.click( | |
| initialize_qa_chain, | |
| inputs=[llm_btn, llm_temperature, max_tokens, top_k, vector_db], | |
| outputs=[qa_chain, init_output] | |
| ) | |
| send_btn.click( | |
| update_chat, | |
| inputs=[message, chatbot, qa_chain], | |
| outputs=[chatbot] | |
| ) | |
| # Clear Chatボタン:チャット履歴のみをクリア | |
| clear_chat_btn.click( | |
| lambda: None, | |
| outputs=[chatbot] | |
| ) | |
| # Reset Allボタン:チャット履歴、PDFデータ、チャットボットの状態をすべてクリア | |
| reset_all_btn.click( | |
| lambda: (None, None, None), | |
| outputs=[chatbot, vector_db, qa_chain] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo().launch() | |