File size: 3,582 Bytes
473affa
b6f944c
55b0fc7
b6f944c
473affa
fe0e877
473affa
55b0fc7
fe0e877
473affa
55b0fc7
473affa
 
 
 
 
 
b6f944c
473affa
 
 
 
fe0e877
473affa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6f944c
 
473affa
 
 
 
 
 
 
fe0e877
473affa
 
 
 
 
 
 
b6f944c
473affa
 
 
 
b6f944c
473affa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe0e877
473affa
 
fe0e877
473affa
b6f944c
473affa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

import os
import fitz  # PyMuPDF
import faiss
import numpy as np
import pickle
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import gradio as gr

# Load or create FAISS index and associated data
INDEX_FILE = "faiss_index.bin"
CHUNKS_FILE = "chunks.pkl"

model = SentenceTransformer("all-MiniLM-L6-v2")
llm = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", tokenizer="mistralai/Mistral-7B-Instruct-v0.2", device=-1)

def load_pdf(file):
    doc = fitz.open(file)
    text = "\n".join(page.get_text() for page in doc)
    return text

def split_text(text, chunk_size=500):
    words = text.split()
    return [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]

def create_or_load_index(chunks):
    if os.path.exists(INDEX_FILE) and os.path.exists(CHUNKS_FILE):
        with open(CHUNKS_FILE, "rb") as f:
            chunks = pickle.load(f)
        index = faiss.read_index(INDEX_FILE)
    else:
        embeddings = model.encode(chunks)
        index = faiss.IndexFlatL2(embeddings.shape[1])
        index.add(np.array(embeddings))
        faiss.write_index(index, INDEX_FILE)
        with open(CHUNKS_FILE, "wb") as f:
            pickle.dump(chunks, f)
    return index, chunks

def retrieve_context(query, index, chunks, top_k=3):
    query_emb = model.encode([query])
    distances, indices = index.search(np.array(query_emb), top_k)
    return "\n\n".join([chunks[i] for i in indices[0]])

def answer_question(query, index, chunks):
    context = retrieve_context(query, index, chunks)
    prompt = f"Context:\n{context}\n\nQuestion: {query}\nAnswer:"
    response = llm(prompt, max_new_tokens=256, do_sample=False)
    return response[0]["generated_text"].split("Answer:")[-1].strip()

def generate_quiz(index, chunks):
    context = retrieve_context("generate quiz questions", index, chunks)
    prompt = f"Based on the following context, generate 3 quiz questions with multiple choice answers:\n\n{context}\n\nQuestions:"
    response = llm(prompt, max_new_tokens=512, do_sample=False)
    return response[0]["generated_text"].split("Questions:")[-1].strip()

# Gradio UI
with gr.Blocks() as demo:
    state = {"index": None, "chunks": []}

    gr.Markdown("# 📘 AI Revision Assistant")

    with gr.Row():
        file_input = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload your revision PDFs")
        status_output = gr.Textbox(label="Status", interactive=False)

    def process(files):
        all_chunks = []
        for file in files:
            text = load_pdf(file.name)
            chunks = split_text(text)
            all_chunks.extend(chunks)
        index, chunks = create_or_load_index(all_chunks)
        state["index"] = index
        state["chunks"] = chunks
        return f"Processed {len(files)} files. You can now ask questions or generate quizzes."

    file_input.change(fn=process, inputs=file_input, outputs=status_output)

    question_input = gr.Textbox(label="Ask a revision question")
    answer_output = gr.Textbox(label="Answer", lines=5)

    question_input.submit(fn=lambda q: answer_question(q, state["index"], state["chunks"]) if state["index"] else "Please upload files first.", inputs=question_input, outputs=answer_output)

    quiz_btn = gr.Button("Quiz Me")
    quiz_output = gr.Textbox(label="Generated Quiz Questions", lines=6)

    quiz_btn.click(fn=lambda: generate_quiz(state["index"], state["chunks"]) if state["index"] else "Please upload files first.", outputs=quiz_output)

demo.launch(debug=True)