File size: 4,283 Bytes
790a5eb
767c4eb
790a5eb
 
 
 
 
fb75b56
 
767c4eb
fb75b56
 
 
 
 
 
 
 
 
 
767c4eb
fb75b56
 
 
767c4eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790a5eb
 
fb75b56
 
 
 
 
 
 
 
 
 
 
 
 
 
790a5eb
 
fb75b56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790a5eb
2a4c506
767c4eb
2a4c506
 
fb75b56
767c4eb
2a4c506
790a5eb
2a4c506
fb75b56
767c4eb
2a4c506
790a5eb
2a4c506
fb75b56
767c4eb
 
 
fb75b56
 
 
 
790a5eb
fb75b56
767c4eb
 
 
 
790a5eb
 
2a4c506
 
 
790a5eb
 
 
767c4eb
 
790a5eb
767c4eb
fb75b56
790a5eb
767c4eb
 
2a4c506
767c4eb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import fitz  # PyMuPDF
import easyocr
from PIL import Image
from sentence_transformers import SentenceTransformer
from chromadb import Client, Settings

# Load Zephyr 7B (fine-tuned for chat)
zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
zephyr_model = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceH4/zephyr-7b-alpha",
    torch_dtype=torch.float16,  # Use half-precision for faster inference
    device_map="auto"  # Automatically loads the model on GPU if available
)

# Load a sentence transformer model for embeddings
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

# Initialize Chroma client for RAG
chroma_client = Client(Settings())
collection = chroma_client.create_collection(name="knowledge_base")

# Function to extract text from PDF
def extract_text_from_pdf(pdf_path):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    return text

# Function to extract text from image
def extract_text_from_image(image_path):
    reader = easyocr.Reader(['en'])
    results = reader.readtext(image_path)
    extracted_text = " ".join([res[1] for res in results])
    return extracted_text

# Function to generate a response
def generate_response(prompt):
    # Structure the input prompt for chat
    formatted_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
    
    # Tokenize the input prompt
    inputs = zephyr_tokenizer(formatted_prompt, return_tensors="pt").to(zephyr_model.device)
    
    # Generate the response
    outputs = zephyr_model.generate(**inputs, max_length=200)
    
    # Decode the response
    response = zephyr_tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the assistant's response
    response = response.split("<|assistant|>")[-1].strip()
    return response

# Function to add documents to the knowledge base
def add_to_knowledge_base(text_chunks):
    embeddings = embedding_model.encode(text_chunks)
    for idx, (chunk, embedding) in enumerate(zip(text_chunks, embeddings)):
        collection.add(
            documents=[chunk],
            embeddings=[embedding.tolist()],
            ids=[str(idx)]
        )

# Function to retrieve relevant chunks
def retrieve_relevant_chunks(query, top_k=3):
    query_embedding = embedding_model.encode(query)
    results = collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=top_k
    )
    return results["documents"][0]

# Chatbot function to handle text, PDF, and image inputs
def chatbot(input_type, text_input, pdf_input, image_input):
    if input_type == "Text":
        if not text_input:
            return "Please enter some text."
        query = text_input
    elif input_type == "PDF":
        if pdf_input is None:
            return "Please upload a PDF file."
        pdf_text = extract_text_from_pdf(pdf_input)
        query = f"Extracted text from PDF:\n{pdf_text}\n\nQuestion: {text_input}"
    elif input_type == "Image":
        if image_input is None:
            return "Please upload an image file."
        image_text = extract_text_from_image(image_input)
        query = f"Extracted text from image:\n{image_text}\n\nQuestion: {text_input}"
    else:
        return "Invalid input type."

    # Retrieve relevant chunks from the knowledge base
    relevant_chunks = retrieve_relevant_chunks(query)
    context = "\n\n".join(relevant_chunks)

    # Generate response using the model
    prompt = f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
    response = generate_response(prompt)
    return response

# Gradio interface
input_components = [
    gr.Dropdown(choices=["Text", "PDF", "Image"], label="Input Type"),
    gr.Textbox(lines=2, placeholder="Enter text...", label="Text Input"),
    gr.File(label="Upload PDF", file_types=[".pdf"]),
    gr.Image(label="Upload Image", type="filepath")
]

# Create the Gradio interface
interface = gr.Interface(
    fn=chatbot,
    inputs=input_components,
    outputs="text",
    title="RAG Chatbot with PDF and Image Support",
    description="Select the input type (Text, PDF, or Image) and provide your input."
)

# Launch the app
interface.launch()