Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from langchain_community.llms import HuggingFacePipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains import RetrievalQA | |
| from langchain.prompts import PromptTemplate | |
| import warnings | |
| import os | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| # Model Configuration | |
| MODEL_NAME = "gpt2" | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| def initialize_models(): | |
| """Initialize language model and embedding model.""" | |
| try: | |
| # Determine device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load model and tokenizer | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| # Create pipeline | |
| text_generation_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| repetition_penalty=1.1 | |
| ) | |
| # Langchain LLM | |
| llm = HuggingFacePipeline(pipeline=text_generation_pipeline) | |
| # Embedding model | |
| embedding_model = HuggingFaceEmbeddings( | |
| model_name=EMBEDDING_MODEL, | |
| model_kwargs={'device': str(device)} | |
| ) | |
| return llm, embedding_model, model, tokenizer | |
| except Exception as e: | |
| print(f"Model initialization error: {e}") | |
| return None, None, None, None | |
| # Initialize models | |
| llm, embedding_model, model, tokenizer = initialize_models() | |
| # Global variables for RAG state | |
| rag_retriever = None | |
| document_loaded = False | |
| loaded_doc_name = "No document loaded" | |
| def setup_rag_pipeline(doc_text, chunk_size=1000, chunk_overlap=150): | |
| """Loads text, chunks, embeds, creates FAISS index, and sets up retriever.""" | |
| global rag_retriever, document_loaded, loaded_doc_name | |
| if not doc_text or not isinstance(doc_text, str) or len(doc_text.strip()) == 0: | |
| return "Error: No text provided or invalid input." | |
| try: | |
| # Text splitting | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=chunk_size, | |
| chunk_overlap=chunk_overlap, | |
| length_function=len, | |
| ) | |
| docs = text_splitter.split_text(doc_text) | |
| if not docs: | |
| return "Error: Text splitting resulted in no documents." | |
| # Create embeddings and FAISS index | |
| vector_store = FAISS.from_texts(docs, embedding_model) | |
| rag_retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
| document_loaded = True | |
| loaded_doc_name = f"Document processed ({len(doc_text)} chars, {len(docs)} chunks)." | |
| return loaded_doc_name | |
| except Exception as e: | |
| document_loaded = False | |
| rag_retriever = None | |
| return f"Error processing document: {e}" | |
| def answer_question(question): | |
| """Answers a question using the loaded RAG pipeline.""" | |
| if llm is None or embedding_model is None: | |
| return "Error: Models not initialized properly." | |
| if not document_loaded or rag_retriever is None: | |
| return "Error: Please load a document before asking questions." | |
| if not question or not isinstance(question, str) or len(question.strip()) == 0: | |
| return "Error: Please enter a question." | |
| try: | |
| # Define a prompt template | |
| template = """You are a helpful assistant answering questions based on the provided context. | |
| Use only the information given in the context below to answer the question. | |
| If the context doesn't contain the answer, say "The provided context does not contain the answer to this question." | |
| Be concise. | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| QA_CHAIN_PROMPT = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=template, | |
| ) | |
| # Create RetrievalQA chain | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=llm, | |
| chain_type="stuff", | |
| retriever=rag_retriever, | |
| chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}, | |
| return_source_documents=False | |
| ) | |
| result = qa_chain.invoke({"query": question}) | |
| answer = result.get("result", str(result)) if isinstance(result, dict) else str(result) | |
| return answer.strip() | |
| except Exception as e: | |
| return f"Error answering question: {e}" | |
| def summarize_text(text_to_summarize, max_length=150, min_length=30): | |
| """Summarizes the provided text using the LLM.""" | |
| if llm is None: | |
| return "Error: Models not initialized properly." | |
| if not text_to_summarize or not isinstance(text_to_summarize, str) or len(text_to_summarize.strip()) == 0: | |
| return "Error: Please enter text to summarize." | |
| try: | |
| # Create a prompt for summarization | |
| prompt = f"Summarize the following text concisely, aiming for {min_length} to {max_length} words:\n\n{text_to_summarize}" | |
| # Use the pipeline directly for summarization | |
| summary_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=max_length, | |
| temperature=0.5 | |
| ) | |
| # Generate summary | |
| summary_result = summary_pipeline(prompt, max_length=max_length)[0]['generated_text'] | |
| # Extract the actual summary part | |
| summary = summary_result.replace(prompt, '').strip() | |
| return summary | |
| except Exception as e: | |
| return f"Error summarizing text: {e}" | |
| def draft_text(instructions): | |
| """Drafts text based on user instructions using the LLM.""" | |
| if llm is None: | |
| return "Error: Models not initialized properly." | |
| if not instructions or not isinstance(instructions, str) or len(instructions.strip()) == 0: | |
| return "Error: Please enter drafting instructions." | |
| try: | |
| # Drafting prompt | |
| prompt = f"Write the following based on these instructions:\n\n{instructions}" | |
| # Use the pipeline for text generation | |
| draft_pipeline = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=500, | |
| temperature=0.7 | |
| ) | |
| # Generate draft | |
| draft_result = draft_pipeline(prompt, max_length=500)[0]['generated_text'] | |
| # Extract the actual draft part | |
| draft = draft_result.replace(prompt, '').strip() | |
| return draft | |
| except Exception as e: | |
| return f"Error drafting text: {e}" | |
| # Gradio Interface | |
| def create_gradio_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as iface: | |
| gr.Markdown("# Workplace Assistant (GPT-2 Demo)") | |
| gr.Markdown("Powered by GPT-2 and Langchain") | |
| with gr.Tabs(): | |
| # Document Q&A Tab | |
| with gr.TabItem("Document Q&A (RAG)"): | |
| gr.Markdown("Load text content from a document, then ask questions about it.") | |
| doc_input = gr.Textbox(label="Paste Document Text Here", lines=10, placeholder="Paste the full text content you want to query...") | |
| load_button = gr.Button("Process Document") | |
| status_output = gr.Textbox(label="Document Status", value=loaded_doc_name, interactive=False) | |
| question_input = gr.Textbox(label="Your Question", placeholder="Ask a question about the document...") | |
| ask_button = gr.Button("Ask Question") | |
| answer_output = gr.Textbox(label="Answer", lines=5, interactive=False) | |
| load_button.click( | |
| fn=setup_rag_pipeline, | |
| inputs=[doc_input], | |
| outputs=[status_output] | |
| ) | |
| ask_button.click( | |
| fn=answer_question, | |
| inputs=[question_input], | |
| outputs=[answer_output] | |
| ) | |
| # Summarization Tab | |
| with gr.TabItem("Summarization"): | |
| gr.Markdown("Paste text to get a concise summary.") | |
| summarize_input = gr.Textbox(label="Text to Summarize", lines=10, placeholder="Paste text here...") | |
| summarize_button = gr.Button("Summarize") | |
| summary_output = gr.Textbox(label="Summary", lines=5, interactive=False) | |
| with gr.Accordion("Advanced Options", open=False): | |
| max_len_slider = gr.Slider(minimum=20, maximum=300, value=150, step=10, label="Max Summary Length (approx words)") | |
| min_len_slider = gr.Slider(minimum=10, maximum=100, value=30, step=5, label="Min Summary Length (approx words)") | |
| summarize_button.click( | |
| fn=summarize_text, | |
| inputs=[summarize_input, max_len_slider, min_len_slider], | |
| outputs=[summary_output] | |
| ) | |
| # Drafting Tab | |
| with gr.TabItem("Drafting Assistant"): | |
| gr.Markdown("Provide instructions for the AI to draft text.") | |
| draft_instructions = gr.Textbox(label="Drafting Instructions", lines=5, placeholder="e.g., Draft a short, friendly email to the team.") | |
| draft_button = gr.Button("Generate Draft") | |
| draft_output = gr.Textbox(label="Generated Draft", lines=10, interactive=False) | |
| draft_button.click( | |
| fn=draft_text, | |
| inputs=[draft_instructions], | |
| outputs=[draft_output] | |
| ) | |
| return iface | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| try: | |
| iface = create_gradio_interface() | |
| iface.launch(share=True, debug=True) | |
| except Exception as e: | |
| print(f"Error launching Gradio interface: {e}") |