File size: 2,330 Bytes
1505bbf
 
 
 
 
bf4791f
1505bbf
 
 
 
 
 
 
 
 
 
 
bf4791f
 
 
1505bbf
 
 
 
 
 
 
 
 
bf4791f
 
 
 
1505bbf
 
 
 
 
 
 
bf4791f
 
 
 
 
 
1505bbf
 
bf4791f
 
 
1505bbf
 
 
 
 
 
 
 
bf4791f
 
 
 
1505bbf
bf4791f
1505bbf
 
bf4791f
 
 
1505bbf
bf4791f
 
 
 
 
 
1505bbf
 
bf4791f
 
 
1505bbf
bf4791f
1505bbf
bf4791f
1505bbf
bf4791f
 
 
 
1505bbf
bf4791f
 
 
1505bbf
bf4791f
 
1505bbf
bf4791f
 
 
1505bbf
bf4791f
1505bbf
bf4791f
 
1505bbf
 
bf4791f
 
1505bbf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import FastAPI, UploadFile, Form, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates

from ingestion import extract_pdf
from chunking import chunk_text
from retrieval_colbert import ColBERTRetriever
from reranker import rerank
from llm import generate_answer
from scraper import scrape_url

app = FastAPI()

templates = Jinja2Templates(directory="templates")

retriever = ColBERTRetriever()

# conversation memory
chat_memory = []


@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
    return templates.TemplateResponse(
        "index.html",
        {"request": request}
    )


# -----------------------
# PDF Upload
# -----------------------

@app.post("/upload")
async def upload(file: UploadFile):

    text = extract_pdf(file.file)

    chunks = chunk_text(text, file.filename)

    if retriever.chunks:
        retriever.build_index(retriever.chunks + chunks)
    else:
        retriever.build_index(chunks)

    return {"status": "PDF indexed"}


# -----------------------
# Website Scraper
# -----------------------

@app.post("/scrape")
async def scrape(url: str = Form(...)):

    text = scrape_url(url)

    chunks = chunk_text(text, url)

    if retriever.chunks:
        retriever.build_index(retriever.chunks + chunks)
    else:
        retriever.build_index(chunks)

    return {"status": "Website indexed"}


# -----------------------
# Chat endpoint
# -----------------------

@app.post("/chat")
async def chat(message: str = Form(...)):

    global chat_memory

    retrieved = retriever.query(message, k=25)

    if not retrieved:
        return {"answer": "Please index a document first."}

    reranked = rerank(message, retrieved)

    top_chunks = reranked[:3]

    context = "\n\n".join(c["text"] for c in top_chunks)

    # build conversation history
    history = ""
    for m in chat_memory[-4:]:
        history += f"{m['role']}: {m['content']}\n"

    prompt_context = f"""
Conversation history:
{history}

Context:
{context}

User question:
{message}
"""

    answer = generate_answer(prompt_context, message)

    chat_memory.append({"role": "user", "content": message})
    chat_memory.append({"role": "assistant", "content": answer})

    return {
        "answer": answer,
        "chunks": top_chunks
    }