Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from sentence_transformers import SentenceTransformer | |
| from PyPDF2 import PdfReader | |
| import numpy as np | |
| import torch | |
| class RAGChatbot: | |
| def __init__(self, | |
| model_name="facebook/opt-350m", | |
| embedding_model="all-MiniLM-L6-v2"): | |
| # Initialize tokenizer and model | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| # self.bnb_config = BitsAndBytesConfig( | |
| # load_in_8bit=True, # Enable 8-bit loading | |
| # llm_int8_threshold=6.0, # Threshold for mixed-precision computation | |
| # ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| # Initialize embedding model | |
| self.embedding_model = SentenceTransformer(embedding_model) | |
| # Initialize document storage | |
| self.documents = [] | |
| self.embeddings = [] | |
| def extract_text_from_pdf(self, pdf_path): | |
| reader = PdfReader(pdf_path) | |
| text = "" | |
| for page in reader.pages: | |
| text += page.extract_text() + "\n" | |
| return text | |
| def load_documents(self, file_paths): | |
| self.documents = [] | |
| self.embeddings = [] | |
| for file_path in file_paths: | |
| if file_path.endswith('.pdf'): | |
| text = self.extract_text_from_pdf(file_path) | |
| else: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| text = f.read() | |
| # Split text into chunks | |
| chunks = [text[i:i+500] for i in range(0, len(text), 500)] | |
| self.documents.extend(chunks) | |
| # Generate embeddings | |
| self.embeddings = self.embedding_model.encode(self.documents) | |
| return f"Loaded {len(self.documents)} text chunks from {len(file_paths)} files" | |
| def retrieve_relevant_context(self, query, top_k=3): | |
| if not self.documents: | |
| return "No documents loaded" | |
| # Generate query embedding | |
| query_embedding = self.embedding_model.encode([query])[0] | |
| # Calculate cosine similarities | |
| similarities = np.dot(self.embeddings, query_embedding) / ( | |
| np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding) | |
| ) | |
| # Get top k most similar documents | |
| top_indices = similarities.argsort()[-top_k:][::-1] | |
| return " ".join([self.documents[i] for i in top_indices]) | |
| def generate_response(self, query, context): | |
| # Construct prompt with | |
| truncated_context = " ".join(context.split()[:100]) | |
| full_prompt = f"Context: {truncated_context}\n\nQuestion: {query}\n\nAnswer:" | |
| # Generate response | |
| tokens = self.tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True).to(self.model.device) | |
| inputs = tokens.input_ids.to(self.model.device) | |
| attention_mask = tokens.attention_mask | |
| outputs = self.model.generate(inputs, max_new_tokens=128,attention_mask=attention_mask,pad_token_id=self.tokenizer.eos_token_id,repetition_penalty=1.0) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response.split("Answer:")[-1].strip() | |
| def chat(self, query, history): | |
| if not query: | |
| return history, "" | |
| try: | |
| # Retrieve relevant context | |
| context = self.retrieve_relevant_context(query) | |
| # Generate response | |
| response = self.generate_response(query, context) | |
| # Append to history using messages format | |
| updated_history = history + [ | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": response} | |
| ] | |
| return updated_history, "" | |
| except Exception as e: | |
| error_response = f"An error occurred: {str(e)}" | |
| updated_history = history + [ | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": error_response} | |
| ] | |
| return updated_history, "" | |
| # Create Gradio interface | |
| def create_interface(): | |
| rag_chatbot = RAGChatbot() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Ask your PDf!") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload Documents", file_count="multiple", type="filepath") | |
| load_btn = gr.Button("Load Documents") | |
| status_output = gr.Textbox(label="Load Status") | |
| chatbot = gr.Chatbot(type="messages") # Specify message type | |
| msg = gr.Textbox(label="Enter your query") | |
| submit_btn = gr.Button("Send") | |
| clear_btn = gr.Button("Clear Chat") | |
| # Event handlers | |
| load_btn.click( | |
| rag_chatbot.load_documents, | |
| inputs=[file_input], | |
| outputs=[status_output] | |
| ) | |
| submit_btn.click( | |
| rag_chatbot.chat, | |
| inputs=[msg, chatbot], | |
| outputs=[chatbot, msg] | |
| ) | |
| msg.submit( | |
| rag_chatbot.chat, | |
| inputs=[msg, chatbot], | |
| outputs=[chatbot, msg] | |
| ) | |
| clear_btn.click(lambda: (None, ""), None, [chatbot, msg]) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch() |