import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain_community.llms import HuggingFacePipeline from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate import warnings import os # Suppress warnings warnings.filterwarnings("ignore") os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Model Configuration MODEL_NAME = "gpt2" EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" def initialize_models(): """Initialize language model and embedding model.""" try: # Determine device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load model and tokenizer model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Create pipeline text_generation_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512, temperature=0.7, repetition_penalty=1.1 ) # Langchain LLM llm = HuggingFacePipeline(pipeline=text_generation_pipeline) # Embedding model embedding_model = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL, model_kwargs={'device': str(device)} ) return llm, embedding_model, model, tokenizer except Exception as e: print(f"Model initialization error: {e}") return None, None, None, None # Initialize models llm, embedding_model, model, tokenizer = initialize_models() # Global variables for RAG state rag_retriever = None document_loaded = False loaded_doc_name = "No document loaded" def setup_rag_pipeline(doc_text, chunk_size=1000, chunk_overlap=150): """Loads text, chunks, embeds, creates FAISS index, and sets up retriever.""" global rag_retriever, document_loaded, loaded_doc_name if not doc_text or not isinstance(doc_text, str) or len(doc_text.strip()) == 0: return "Error: No text provided or invalid input." try: # Text splitting text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, ) docs = text_splitter.split_text(doc_text) if not docs: return "Error: Text splitting resulted in no documents." # Create embeddings and FAISS index vector_store = FAISS.from_texts(docs, embedding_model) rag_retriever = vector_store.as_retriever(search_kwargs={"k": 3}) document_loaded = True loaded_doc_name = f"Document processed ({len(doc_text)} chars, {len(docs)} chunks)." return loaded_doc_name except Exception as e: document_loaded = False rag_retriever = None return f"Error processing document: {e}" def answer_question(question): """Answers a question using the loaded RAG pipeline.""" if llm is None or embedding_model is None: return "Error: Models not initialized properly." if not document_loaded or rag_retriever is None: return "Error: Please load a document before asking questions." if not question or not isinstance(question, str) or len(question.strip()) == 0: return "Error: Please enter a question." try: # Define a prompt template template = """You are a helpful assistant answering questions based on the provided context. Use only the information given in the context below to answer the question. If the context doesn't contain the answer, say "The provided context does not contain the answer to this question." Be concise. Context: {context} Question: {question} Answer:""" QA_CHAIN_PROMPT = PromptTemplate( input_variables=["context", "question"], template=template, ) # Create RetrievalQA chain qa_chain = RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=rag_retriever, chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, return_source_documents=False ) result = qa_chain.invoke({"query": question}) answer = result.get("result", str(result)) if isinstance(result, dict) else str(result) return answer.strip() except Exception as e: return f"Error answering question: {e}" def summarize_text(text_to_summarize, max_length=150, min_length=30): """Summarizes the provided text using the LLM.""" if llm is None: return "Error: Models not initialized properly." if not text_to_summarize or not isinstance(text_to_summarize, str) or len(text_to_summarize.strip()) == 0: return "Error: Please enter text to summarize." try: # Create a prompt for summarization prompt = f"Summarize the following text concisely, aiming for {min_length} to {max_length} words:\n\n{text_to_summarize}" # Use the pipeline directly for summarization summary_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_length, temperature=0.5 ) # Generate summary summary_result = summary_pipeline(prompt, max_length=max_length)[0]['generated_text'] # Extract the actual summary part summary = summary_result.replace(prompt, '').strip() return summary except Exception as e: return f"Error summarizing text: {e}" def draft_text(instructions): """Drafts text based on user instructions using the LLM.""" if llm is None: return "Error: Models not initialized properly." if not instructions or not isinstance(instructions, str) or len(instructions.strip()) == 0: return "Error: Please enter drafting instructions." try: # Drafting prompt prompt = f"Write the following based on these instructions:\n\n{instructions}" # Use the pipeline for text generation draft_pipeline = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=500, temperature=0.7 ) # Generate draft draft_result = draft_pipeline(prompt, max_length=500)[0]['generated_text'] # Extract the actual draft part draft = draft_result.replace(prompt, '').strip() return draft except Exception as e: return f"Error drafting text: {e}" # Gradio Interface def create_gradio_interface(): with gr.Blocks(theme=gr.themes.Soft()) as iface: gr.Markdown("# Workplace Assistant (GPT-2 Demo)") gr.Markdown("Powered by GPT-2 and Langchain") with gr.Tabs(): # Document Q&A Tab with gr.TabItem("Document Q&A (RAG)"): gr.Markdown("Load text content from a document, then ask questions about it.") doc_input = gr.Textbox(label="Paste Document Text Here", lines=10, placeholder="Paste the full text content you want to query...") load_button = gr.Button("Process Document") status_output = gr.Textbox(label="Document Status", value=loaded_doc_name, interactive=False) question_input = gr.Textbox(label="Your Question", placeholder="Ask a question about the document...") ask_button = gr.Button("Ask Question") answer_output = gr.Textbox(label="Answer", lines=5, interactive=False) load_button.click( fn=setup_rag_pipeline, inputs=[doc_input], outputs=[status_output] ) ask_button.click( fn=answer_question, inputs=[question_input], outputs=[answer_output] ) # Summarization Tab with gr.TabItem("Summarization"): gr.Markdown("Paste text to get a concise summary.") summarize_input = gr.Textbox(label="Text to Summarize", lines=10, placeholder="Paste text here...") summarize_button = gr.Button("Summarize") summary_output = gr.Textbox(label="Summary", lines=5, interactive=False) with gr.Accordion("Advanced Options", open=False): max_len_slider = gr.Slider(minimum=20, maximum=300, value=150, step=10, label="Max Summary Length (approx words)") min_len_slider = gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Min Summary Length (approx words)") summarize_button.click( fn=summarize_text, inputs=[summarize_input, max_len_slider, min_len_slider], outputs=[summary_output] ) # Drafting Tab with gr.TabItem("Drafting Assistant"): gr.Markdown("Provide instructions for the AI to draft text.") draft_instructions = gr.Textbox(label="Drafting Instructions", lines=5, placeholder="e.g., Draft a short, friendly email to the team.") draft_button = gr.Button("Generate Draft") draft_output = gr.Textbox(label="Generated Draft", lines=10, interactive=False) draft_button.click( fn=draft_text, inputs=[draft_instructions], outputs=[draft_output] ) return iface # Launch the interface if __name__ == "__main__": try: iface = create_gradio_interface() iface.launch(share=True, debug=True) except Exception as e: print(f"Error launching Gradio interface: {e}")