File size: 3,323 Bytes
bd9a582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
from typing import List
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import tempfile
from aimakerspace.vectordatabase import VectorDatabase
from aimakerspace.openai_utils.chatmodel import ChatOpenAI

from app import (
    RetrievalAugmentedQAPipeline,
    process_file,
    system_role_prompt,
    user_role_prompt,
)

app = FastAPI()

# Update CORS middleware configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3001",    # Development React server
        "http://localhost:7860",    # Production nginx server
        "http://localhost",         # Just in case
        "*",                        # Allow all origins in development
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"],
)

class ChatResponse(BaseModel):
    response: str
    context: List[tuple]

class ChatRequest(BaseModel):
    query: str

@app.post("/api/upload", response_model=dict)
async def upload_file(file: UploadFile = File(...)):
    try:
        # Create a temporary file to store the upload
        with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.split('.')[-1]}") as temp_file:
            content = await file.read()
            temp_file.write(content)
            temp_file.flush()

            # Process the file using existing function
            texts = process_file(temp_file.name, file.filename)

            # Create vector database
            vector_db = VectorDatabase()
            vector_db = await vector_db.abuild_from_list(texts)

            # Create chat model
            chat_openai = ChatOpenAI()

            # Create pipeline
            pipeline = RetrievalAugmentedQAPipeline(
                vector_db_retriever=vector_db,
                llm=chat_openai
            )

            # Store the pipeline in memory (Note: this is not production-ready)
            if not hasattr(app, 'pipelines'):
                app.pipelines = {}
            pipeline_id = str(len(app.pipelines))
            app.pipelines[pipeline_id] = pipeline

            # Clean up temporary file
            os.unlink(temp_file.name)

            return {"pipeline_id": pipeline_id, "message": "File processed successfully"}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/chat/{pipeline_id}", response_model=ChatResponse)
async def chat(pipeline_id: str, request: ChatRequest):
    try:
        if not hasattr(app, 'pipelines') or pipeline_id not in app.pipelines:
            raise HTTPException(status_code=404, detail="Pipeline not found. Please upload a file first.")

        pipeline = app.pipelines[pipeline_id]
        result = await pipeline.arun_pipeline(request.query)

        # Collect the streaming response
        response_text = ""
        async for chunk in result["response"]:
            response_text += chunk

        return ChatResponse(
            response=response_text,
            context=result["context"]
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)