Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import PyPDF2 | |
| from pathlib import Path | |
| import traceback # Import traceback for detailed error logging | |
| import sys | |
| # --- Model Architecture (Same as before) --- | |
| class EfficientMultiHeadAttention(nn.Module): | |
| def __init__(self, d_model, n_heads, dropout=0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.n_heads = n_heads | |
| self.head_dim = d_model // n_heads | |
| assert d_model % n_heads == 0, "d_model must be divisible by n_heads" | |
| self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False) | |
| self.out_proj = nn.Linear(d_model, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| self.scale = 1.0 / (self.head_dim ** 0.5) | |
| def forward(self, x, mask=None): | |
| B, T, C = x.shape | |
| qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_heads, self.head_dim) | |
| qkv = qkv.permute(2, 0, 3, 1, 4) | |
| q, k, v = qkv[0], qkv[1], qkv[2] | |
| att = (q @ k.transpose(-2, -1)) * self.scale | |
| if mask is not None: | |
| mask = mask.view(B, 1, 1, T) | |
| att = att.masked_fill(mask == 0, float('-inf')) | |
| att = F.softmax(att, dim=-1) | |
| att = self.dropout(att) | |
| y = att @ v | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| y = self.out_proj(y) | |
| return y | |
| class CompactFeedForward(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.w1 = nn.Linear(d_model, d_ff) | |
| self.w2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| return self.w2(self.dropout(F.gelu(self.w1(x)))) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, n_heads, d_ff, dropout=0.1): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.attn = EfficientMultiHeadAttention(d_model, n_heads, dropout) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.mlp = CompactFeedForward(d_model, d_ff, dropout) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, mask=None): | |
| x = x + self.dropout(self.attn(self.ln1(x), mask)) | |
| x = x + self.dropout(self.mlp(self.ln2(x))) | |
| return x | |
| class EdgeOptimizedSLM(nn.Module): | |
| def __init__(self, vocab_size, d_model=320, n_heads=8, n_layers=4, d_ff=1280, max_length=512, dropout=0.1): | |
| super().__init__() | |
| self.tok_emb = nn.Embedding(vocab_size, d_model) | |
| self.pos_emb = nn.Embedding(max_length, d_model) | |
| self.drop = nn.Dropout(dropout) | |
| self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]) | |
| self.ln_f = nn.LayerNorm(d_model) | |
| self.qa_proj = nn.Linear(d_model, d_model // 2) | |
| self.qa_start = nn.Linear(d_model // 2, 1) | |
| self.qa_end = nn.Linear(d_model // 2, 1) | |
| def forward(self, input_ids, attention_mask=None): | |
| device = input_ids.device | |
| B, T = input_ids.size() | |
| pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) | |
| tok_emb = self.tok_emb(input_ids) | |
| pos_emb = self.pos_emb(pos) | |
| x = self.drop(tok_emb + pos_emb) | |
| for block in self.blocks: | |
| x = block(x, attention_mask) | |
| x = self.ln_f(x) | |
| qa_hidden = F.gelu(self.qa_proj(x)) | |
| start_logits = self.qa_start(qa_hidden).squeeze(-1) | |
| end_logits = self.qa_end(qa_hidden).squeeze(-1) | |
| return {"start_logits": start_logits, "end_logits": end_logits} | |
| # --- Global Variables and Model Loading --- | |
| MODEL_PATH = "edge_deployment_package/models/model_dynamic_quantized_int8.pt" | |
| TOKENIZER_NAME = "bert-base-uncased" | |
| EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2' | |
| DEVICE = torch.device('cpu') | |
| # --- Robust Model Loading --- | |
| try: | |
| print("--- Starting Application ---") | |
| # 1. Load Custom Inference Model | |
| print(f"Attempting to load custom model from: {MODEL_PATH}") | |
| if not os.path.exists(MODEL_PATH): | |
| raise FileNotFoundError(f"CRITICAL: Model file not found at '{MODEL_PATH}'. Please ensure the file exists in your repository.") | |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE) | |
| config = checkpoint['config'] | |
| inference_model = EdgeOptimizedSLM(**config) | |
| inference_model.load_state_dict(checkpoint['model_state_dict']) | |
| inference_model.to(DEVICE) | |
| inference_model.eval() | |
| print("β Custom inference model loaded successfully.") | |
| # 2. Load Tokenizer | |
| print(f"Attempting to load tokenizer: {TOKENIZER_NAME}") | |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | |
| print("β Tokenizer loaded successfully.") | |
| # 3. Load Embedding Model | |
| print(f"Attempting to load embedding model: {EMBEDDING_MODEL_NAME}") | |
| embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE) | |
| print("β Embedding model loaded successfully.") | |
| except Exception as e: | |
| print("--- π΄ AN ERROR OCCURRED DURING STARTUP ---") | |
| print(f"Error Type: {type(e).__name__}") | |
| print(f"Error Details: {e}") | |
| print("------------------------------------------") | |
| traceback.print_exc() # Print the full traceback for detailed debugging | |
| # We exit here because the app cannot run without the models. | |
| sys.exit("Exiting application due to critical startup error.") | |
| # --- RAG and PDF Processing Logic (Same as before) --- | |
| class RAGPipeline: | |
| def __init__(self, embedding_model): | |
| self.text_chunks = [] | |
| self.vector_store = None | |
| self.embedding_model = embedding_model | |
| self.raw_embeddings_path = "document_embeddings.raw" | |
| def process_pdf(self, pdf_file_obj): | |
| if pdf_file_obj is None: return "Please upload a PDF file first.", None | |
| print(f"Processing PDF: {pdf_file_obj.name}") | |
| self.text_chunks = [] | |
| try: | |
| pdf_reader = PyPDF2.PdfReader(pdf_file_obj.name) | |
| text = "".join(page.extract_text() for page in pdf_reader.pages if page.extract_text()) | |
| if not text: return "Could not extract text from the PDF.", None | |
| words = text.split() | |
| chunk_size, overlap = 200, 30 | |
| for i in range(0, len(words), chunk_size - overlap): | |
| self.text_chunks.append(" ".join(words[i:i + chunk_size])) | |
| if not self.text_chunks: return "Text extracted but could not be split into chunks.", None | |
| embeddings = self.embedding_model.encode(self.text_chunks, convert_to_tensor=False, show_progress_bar=True) | |
| with open(self.raw_embeddings_path, 'wb') as f: | |
| f.write(embeddings.tobytes()) | |
| self.vector_store = faiss.IndexFlatL2(embeddings.shape[1]) | |
| self.vector_store.add(embeddings) | |
| status_message = f"Successfully processed '{Path(pdf_file_obj.name).name}'. Ready for questions." | |
| print("PDF processing complete.") | |
| return status_message, self.raw_embeddings_path | |
| except Exception as e: | |
| print(f"Error processing PDF: {e}") | |
| return f"Error processing PDF: {e}", None | |
| def retrieve_context(self, query, top_k=3): | |
| if self.vector_store is None: return "" | |
| query_embedding = self.embedding_model.encode([query]) | |
| _, indices = self.vector_store.search(query_embedding, top_k) | |
| return " ".join([self.text_chunks[i] for i in indices[0]]) | |
| rag_pipeline = RAGPipeline(embedding_model) | |
| # --- Chatbot Inference Logic (Same as before) --- | |
| def get_answer(question, context): | |
| if not context: | |
| return "I could not find relevant information in the document to answer that question." | |
| inputs = tokenizer(question, context, return_tensors='pt', max_length=model_config.get('max_length', 512), truncation=True, padding='max_length') | |
| input_ids, attention_mask = inputs['input_ids'].to(DEVICE), inputs['attention_mask'].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = inference_model(input_ids, attention_mask) | |
| start_index = torch.argmax(outputs['start_logits'], dim=1).item() | |
| end_index = torch.argmax(outputs['end_logits'], dim=1).item() | |
| if start_index <= end_index: | |
| answer_ids = input_ids[0][start_index:end_index+1] | |
| answer = tokenizer.decode(answer_ids, skip_special_tokens=True) | |
| return answer if answer.strip() else "I found a relevant section, but could not extract a precise answer." | |
| else: | |
| return "I found relevant information, but I'm having trouble formulating a clear answer." | |
| # --- Gradio Interface --- | |
| def add_text(history, text): | |
| history.append({"role": "user", "content": text}) | |
| return history, "" | |
| def bot(history): | |
| question = history[-1]["content"] | |
| context = rag_pipeline.retrieve_context(question) | |
| answer = get_answer(question, context) | |
| history.append({"role": "assistant", "content": answer}) | |
| return history | |
| print("--- Models loaded, building Gradio interface ---") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Chat with your PDF using a Custom Edge SLM") | |
| gr.Markdown("1. Upload a PDF. 2. Wait for it to be processed. 3. Ask questions about its content.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_upload = gr.File(label="Upload PDF") | |
| upload_status = gr.Textbox(label="PDF Status", interactive=False) | |
| download_embeddings = gr.File(label="Download Raw Embeddings", interactive=False) | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Chat History", height=500, type='messages') | |
| question_box = gr.Textbox(label="Your Question", placeholder="Ask something about the document...") | |
| question_box.submit(add_text, [chatbot, question_box], [chatbot, question_box]).then( | |
| bot, chatbot, chatbot | |
| ) | |
| pdf_upload.upload( | |
| fn=rag_pipeline.process_pdf, | |
| inputs=[pdf_upload], | |
| outputs=[upload_status, download_embeddings] | |
| ) | |
| print("β Gradio interface built successfully.") | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |