SLMChatbot / app.py
Nihal2000's picture
First commit
dd191a9
raw
history blame
10.3 kB
import os
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import PyPDF2
from pathlib import Path
# Ensure all required packages are installed
# This is generally handled by requirements.txt on Hugging Face Spaces,
# but this is a fallback for local execution.
try:
import faiss
except ImportError:
print("Installing faiss-cpu...")
os.system("pip install --quiet faiss-cpu")
import faiss
try:
import PyPDF2
except ImportError:
print("Installing PyPDF2...")
os.system("pip install --quiet PyPDF2")
import PyPDF2
# --- Model Architecture (Copied from your provided code) ---
class EfficientMultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False)
self.out_proj = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.scale = 1.0 / (self.head_dim ** 0.5)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
att = (q @ k.transpose(-2, -1)) * self.scale
if mask is not None:
mask = mask.view(B, 1, 1, T)
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.out_proj(y)
return y
class CompactFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w2(self.dropout(F.gelu(self.w1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = EfficientMultiHeadAttention(d_model, n_heads, dropout)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = CompactFeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
x = x + self.dropout(self.attn(self.ln1(x), mask))
x = x + self.dropout(self.mlp(self.ln2(x)))
return x
class EdgeOptimizedSLM(nn.Module):
def __init__(self, vocab_size, d_model=320, n_heads=8, n_layers=4, d_ff=1280, max_length=512, dropout=0.1):
super().__init__()
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_length, d_model)
self.drop = nn.Dropout(dropout)
self.blocks = nn.ModuleList([TransformerBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)])
self.ln_f = nn.LayerNorm(d_model)
self.qa_proj = nn.Linear(d_model, d_model // 2)
self.qa_start = nn.Linear(d_model // 2, 1)
self.qa_end = nn.Linear(d_model // 2, 1)
def forward(self, input_ids, attention_mask=None):
device = input_ids.device
B, T = input_ids.size()
pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0)
tok_emb = self.tok_emb(input_ids)
pos_emb = self.pos_emb(pos)
x = self.drop(tok_emb + pos_emb)
for block in self.blocks:
x = block(x, attention_mask)
x = self.ln_f(x)
qa_hidden = F.gelu(self.qa_proj(x))
start_logits = self.qa_start(qa_hidden).squeeze(-1)
end_logits = self.qa_end(qa_hidden).squeeze(-1)
return {"start_logits": start_logits, "end_logits": end_logits}
# --- Global Variables and Model Loading ---
MODEL_PATH = "edge_deployment_package/models/model_dynamic_quantized_int8.pt"
TOKENIZER_NAME = "bert-base-uncased"
EMBEDDING_MODEL_NAME = 'all-MiniLM-L6-v2'
DEVICE = torch.device('cpu')
# Create a dummy model if the actual model is not found
if not os.path.exists(MODEL_PATH):
print(f"Warning: Model not found at {MODEL_PATH}. Creating a dummy model for demonstration.")
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
dummy_tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
dummy_config = {
'vocab_size': dummy_tokenizer.vocab_size, 'd_model': 320, 'n_heads': 8, 'n_layers': 4, 'd_ff': 1280, 'max_length': 512
}
dummy_model = EdgeOptimizedSLM(**dummy_config)
torch.save({
'config': dummy_config, 'model_state_dict': dummy_model.state_dict(), 'quantization': 'dynamic_int8'
}, MODEL_PATH)
def load_custom_model(model_path):
checkpoint = torch.load(model_path, map_location=DEVICE)
config = checkpoint['config']
model = EdgeOptimizedSLM(**config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
return model, config
print("Loading models...")
inference_model, model_config = load_custom_model(MODEL_PATH)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device=DEVICE)
print("Models loaded successfully.")
# --- RAG and PDF Processing Logic ---
class RAGPipeline:
def __init__(self, embedding_model):
self.text_chunks = []
self.vector_store = None
self.embedding_model = embedding_model
self.raw_embeddings_path = "document_embeddings.raw"
def process_pdf(self, pdf_file_obj):
if pdf_file_obj is None:
return "Please upload a PDF file first.", None
print(f"Processing PDF: {pdf_file_obj.name}")
self.text_chunks = []
try:
pdf_reader = PyPDF2.PdfReader(pdf_file_obj.name)
text = "".join(page.extract_text() for page in pdf_reader.pages if page.extract_text())
if not text:
return "Could not extract text from the PDF.", None
words = text.split()
chunk_size, overlap = 200, 30
for i in range(0, len(words), chunk_size - overlap):
self.text_chunks.append(" ".join(words[i:i + chunk_size]))
if not self.text_chunks:
return "Text extracted but could not be split into chunks.", None
print(f"Generating embeddings for {len(self.text_chunks)} chunks...")
embeddings = self.embedding_model.encode(self.text_chunks, convert_to_tensor=False, show_progress_bar=True)
with open(self.raw_embeddings_path, 'wb') as f:
f.write(embeddings.tobytes())
self.vector_store = faiss.IndexFlatL2(embeddings.shape[1])
self.vector_store.add(embeddings)
status_message = f"Successfully processed '{Path(pdf_file_obj.name).name}'. Ready for questions."
print("PDF processing complete.")
return status_message, self.raw_embeddings_path
except Exception as e:
print(f"Error processing PDF: {e}")
return f"Error processing PDF: {e}", None
def retrieve_context(self, query, top_k=3):
if self.vector_store is None: return ""
query_embedding = self.embedding_model.encode([query])
_, indices = self.vector_store.search(query_embedding, top_k)
return " ".join([self.text_chunks[i] for i in indices[0]])
rag_pipeline = RAGPipeline(embedding_model)
# --- Chatbot Inference Logic ---
def get_answer(question, context):
if not context:
return "I could not find relevant information in the document to answer that question."
inputs = tokenizer(question, context, return_tensors='pt', max_length=model_config.get('max_length', 512), truncation=True, padding='max_length')
input_ids, attention_mask = inputs['input_ids'].to(DEVICE), inputs['attention_mask'].to(DEVICE)
with torch.no_grad():
outputs = inference_model(input_ids, attention_mask)
start_index = torch.argmax(outputs['start_logits'], dim=1).item()
end_index = torch.argmax(outputs['end_logits'], dim=1).item()
if start_index <= end_index:
answer_ids = input_ids[0][start_index:end_index+1]
answer = tokenizer.decode(answer_ids, skip_special_tokens=True)
return answer if answer.strip() else "I found a relevant section, but could not extract a precise answer."
else:
return "I found relevant information, but I'm having trouble formulating a clear answer."
# --- Gradio Interface ---
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def bot(history):
question = history[-1][0]
context = rag_pipeline.retrieve_context(question)
answer = get_answer(question, context)
history[-1][1] = answer
return history
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# Chat with your PDF using a Custom Edge SLM")
gr.Markdown("1. Upload a PDF. 2. Wait for it to be processed. 3. Ask questions about its content.")
with gr.Row():
with gr.Column(scale=1):
pdf_upload = gr.File(label="Upload PDF")
upload_status = gr.Textbox(label="PDF Status", interactive=False)
download_embeddings = gr.File(label="Download Raw Embeddings", interactive=False)
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat History", height=500)
question_box = gr.Textbox(label="Your Question", placeholder="Ask something about the document...")
# Event Handlers
question_box.submit(add_text, [chatbot, question_box], [chatbot, question_box]).then(
bot, chatbot, chatbot
)
pdf_upload.upload(
fn=rag_pipeline.process_pdf,
inputs=[pdf_upload],
outputs=[upload_status, download_embeddings]
)
# To this:
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7830) # Or another port if 7860 is taken