|
|
import gradio as gr |
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import pickle |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
from langchain_core.prompts import PromptTemplate |
|
|
from langchain_community.document_loaders import PDFMinerLoader, CSVLoader, JSONLoader |
|
|
from langchain.text_splitter import SentenceTransformersTokenTextSplitter |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline |
|
|
|
|
|
MODEL_NAME = "TheBloke/Llama-2-7B-GPTQ" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="cpu") |
|
|
|
|
|
text_pipeline = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
|
|
|
|
|
|
template = """ |
|
|
<s>[INST] <<SYS>> |
|
|
Use the following information to answer the question at the end. |
|
|
<</SYS>> |
|
|
|
|
|
{context} |
|
|
|
|
|
{question} [/INST] |
|
|
""" |
|
|
prompt = PromptTemplate(template=template, input_variables=["context", "question"]) |
|
|
|
|
|
|
|
|
CACHE_DIR = 'cache' |
|
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
|
|
|
def save_cache(filename, data): |
|
|
with open(os.path.join(CACHE_DIR, filename), 'wb') as f: |
|
|
pickle.dump(data, f) |
|
|
|
|
|
def load_cache(filename): |
|
|
try: |
|
|
with open(os.path.join(CACHE_DIR, filename), 'rb') as f: |
|
|
return pickle.load(f) |
|
|
except FileNotFoundError: |
|
|
return None |
|
|
|
|
|
async def process_files(file_paths): |
|
|
try: |
|
|
print("Processing files...") |
|
|
docs = [] |
|
|
for file_path in file_paths: |
|
|
if file_path.endswith('.pdf'): |
|
|
loader = PDFMinerLoader(file_path) |
|
|
docs.extend(await asyncio.to_thread(loader.load)) |
|
|
elif file_path.endswith('.csv'): |
|
|
loader = CSVLoader(file_path) |
|
|
docs.extend(loader.load()) |
|
|
elif file_path.endswith('.json'): |
|
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
docs.append(data) |
|
|
|
|
|
print("Files loaded.") |
|
|
text_splitter = SentenceTransformersTokenTextSplitter(chunk_size=1024, chunk_overlap=64) |
|
|
texts = text_splitter.split_documents(docs) |
|
|
print("Text split into chunks.") |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name="thenlper/gte-large", |
|
|
model_kwargs={"device": "cpu"}, |
|
|
encode_kwargs={"normalize_embeddings": True}, |
|
|
) |
|
|
|
|
|
db = FAISS.from_documents(texts, embeddings) |
|
|
print("FAISS index created.") |
|
|
save_cache('cache_key', (texts, db, embeddings)) |
|
|
return texts, db, embeddings |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
return None, None, str(e) |
|
|
|
|
|
async def query_files(files, question): |
|
|
if not files or not question.strip(): |
|
|
return "Please upload valid files and enter a question." |
|
|
|
|
|
print("Starting query processing...") |
|
|
file_paths = [file.name for file in files] |
|
|
|
|
|
texts, db, embeddings = await process_files(file_paths) |
|
|
|
|
|
if db is None: |
|
|
print("Error during processing.") |
|
|
return f"Error processing files: {embeddings}" |
|
|
|
|
|
print("Processing complete.") |
|
|
results = db.similarity_search(question, k=5) |
|
|
context = " ".join([result.page_content for result in results]) |
|
|
|
|
|
prompt_text = prompt.format(context=context, question=question) |
|
|
|
|
|
generated_text = text_pipeline(prompt_text)[0]['generated_text'] |
|
|
|
|
|
return generated_text |
|
|
|
|
|
def process_and_query(files, question): |
|
|
return asyncio.run(query_files(files, question)) |
|
|
|
|
|
with gr.Blocks() as interface: |
|
|
gr.Markdown("### Retrieval Augmented Generation (RAG) for LLM Local Trial") |
|
|
gr.Markdown( |
|
|
"Upload multiple files (PDF, CSV, JSON) and ask a question. The app will generate the answer based on the content of the input files.") |
|
|
|
|
|
with gr.Row(): |
|
|
question_input = gr.Textbox(label="Enter your question", lines=3) |
|
|
files_input = gr.File(label="Upload Files", type="filepath", file_count="multiple") |
|
|
|
|
|
submit_button = gr.Button("Submit") |
|
|
output_text = gr.Textbox(label="LLM Response", lines=8) |
|
|
submit_button.click(process_and_query, inputs=[files_input, question_input], outputs=output_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
interface.launch() |
|
|
|