File size: 3,500 Bytes
b7d6ae5
 
96071c0
b7d6ae5
 
 
 
 
 
85d1fed
b7d6ae5
96071c0
b7d6ae5
b16ae29
 
 
 
 
b7d6ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
b16ae29
b7d6ae5
 
c979465
b7d6ae5
b16ae29
da20df8
b7d6ae5
 
 
 
 
 
 
 
 
 
 
 
b16ae29
 
da20df8
 
 
 
 
 
 
b7d6ae5
 
 
b16ae29
b7d6ae5
da20df8
 
 
 
 
b7d6ae5
 
 
 
 
 
 
b16ae29
 
 
b7d6ae5
 
 
 
5f20a58
3314548
5f20a58
3314548
 
b16ae29
b7d6ae5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader

# Model and Tokenizer
MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto")

# Ensure PAD token is set correctly
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load embedding model for RAG
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = None

# Function to process PDF and create vector database
def process_pdf(pdf_path):
    global vector_store
    loader = PyPDFLoader(pdf_path)
    documents = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
    texts = text_splitter.split_documents(documents)
    
    vector_store = FAISS.from_documents(texts, embedding_model)
    return "PDF successfully processed and indexed."

# RAG Query Function
def query_rag(message, temperature, max_new_tokens, top_k, repetition_penalty, top_p, system_prompt, history=None):
    if vector_store is None:
        return "Please upload and process a PDF first."

    # Retrieve relevant chunks
    docs = vector_store.similarity_search(message, k=3)
    context = "\n".join([doc.page_content for doc in docs])
    
    # Construct prompt
    instruction = f"<|im_start|>system\n{system_prompt}\n<|im_end|>\n"
    instruction += f"Relevant context:\n{context}\n"
    instruction += f"<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n"
    
    # Tokenization
    enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc.input_ids.to(device)
    attention_mask = enc.attention_mask.to(device)

    # Ensure correct parameter types
    temperature = float(temperature)
    max_new_tokens = int(max_new_tokens)
    top_k = int(top_k)
    repetition_penalty = float(repetition_penalty)
    top_p = float(top_p)

    # Generate response
    output_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,  # Fix for attention mask issue
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response

# Gradio Interface
def launch_interface():
    with gr.Blocks() as demo:
        gr.Markdown("## 🤖 RAG Chatbot with DeepSeek")
        pdf_uploader = gr.File(label="Upload PDF", type="filepath")
        process_btn = gr.Button("Process PDF")
        process_output = gr.Textbox(label="Processing Status", interactive=False)
        
        process_btn.click(process_pdf, inputs=[pdf_uploader], outputs=[process_output])
        
        # Create chat interface with correct argument handling
        chatbot = gr.ChatInterface(
            lambda message, history: query_rag(message, 0.7, 10000, 40, 1.1, 0.95, "You are a helpful assistant.", history)
        )

    demo.launch(share=True)  # Enable public link

if __name__ == "__main__":
    launch_interface()