File size: 5,434 Bytes
9a000fe
bc8a612
e769917
bc8a612
9a000fe
bc8a612
 
 
 
9a000fe
 
 
 
 
 
6b9a057
 
bc8a612
6fca0b0
61e6651
 
 
 
 
bc8a612
 
e769917
 
 
bc8a612
61e6651
bc8a612
e769917
bc8a612
9a000fe
e769917
 
61e6651
d1af85f
61e6651
6fca0b0
61e6651
 
9a000fe
6fca0b0
 
 
 
 
 
9a000fe
6fca0b0
bc8a612
d1af85f
bc8a612
 
9a000fe
d1af85f
e769917
6b9a057
bc8a612
6fca0b0
 
 
 
 
 
 
 
 
 
 
bc8a612
61e6651
bc8a612
6fca0b0
 
61e6651
bc8a612
6fca0b0
 
bc8a612
 
 
 
6fca0b0
 
bc8a612
 
e769917
6fca0b0
 
61e6651
e769917
6b9a057
bc8a612
61e6651
6fca0b0
 
 
 
 
 
 
 
 
 
 
61e6651
9a000fe
6fca0b0
 
61e6651
9a000fe
 
6fca0b0
 
61e6651
9a000fe
bc8a612
 
 
 
 
9a000fe
bc8a612
9a000fe
bc8a612
 
 
 
9a000fe
 
61e6651
9a000fe
6fca0b0
 
61e6651
9a000fe
bc8a612
9a000fe
bc8a612
61e6651
 
 
9a000fe
 
 
61e6651
 
 
9a000fe
61e6651
9a000fe
 
 
61e6651
 
 
9a000fe
 
6fca0b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# app.py

import logging
import uuid
import io
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

# Import from our core modules
from core.chunking import semantic_chunker
from core.vector_store import create_faiss_index, deserialize_faiss_index

# Parsing and AI libraries
import fitz
from PIL import Image
import pytesseract
from sentence_transformers import SentenceTransformer
from ctransformers import AutoModelForCausalLM  # ✅ FIXED import

# --- THIS IS THE FIX FOR TESSERACT ---
# Explicitly tell pytesseract where to find the Tesseract OCR engine.
pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
# ------------------------------------

# --- 1. INITIAL SETUP & MODEL LOADING ---

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Optimized Universal Data AI", version="3.1.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
)

# --- Load Optimized Models ---
try:
    logger.info("Loading optimized AI models...")

    # Using a smaller, but still powerful, BGE model
    embedding_model = SentenceTransformer('BAAI/bge-base-en-v1.5')
    
    # Load TinyLlama in GGUF format using ctransformers
    llm = AutoModelForCausalLM.from_pretrained(
        "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
        model_file="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf",
        model_type="llama",   # Tell ctransformers the model family
        gpu_layers=0          # For CPU-only environment
    )

    logger.info("AI models loaded successfully.")
except Exception as e:
    logger.critical(f"Fatal error: Could not load AI models. {e}")
    embedding_model = None
    llm = None

SESSION_DATA = {}

# --- 2. DATA MODELS ---
class QueryRequest(BaseModel):
    question: str

class UploadResponse(BaseModel):
    session_id: str
    filename: str
    chunks_created: int

class QueryResponse(BaseModel):
    answer: str
    context: str

# --- 3. HELPER FUNCTIONS ---
def parse_pdf(content: bytes) -> str:
    doc = fitz.open(stream=content, filetype="pdf")
    return "".join(page.get_text() for page in doc)

def parse_image(content: bytes) -> str:
    image = Image.open(io.BytesIO(content))
    return pytesseract.image_to_string(image)

# --- 4. API ENDPOINTS ---

@app.get("/")
def read_root():
    return {"status": "ok", "message": "Welcome to the Optimized Universal Data AI"}

@app.post("/upload", response_model=UploadResponse)
async def upload_file(file: UploadFile = File(...)):
    if not embedding_model:
        raise HTTPException(status_code=503, detail="Embedding model not available.")
    
    session_id = str(uuid.uuid4())
    content = await file.read()
    content_type = file.content_type
    
    if content_type == "application/pdf":
        text = parse_pdf(content)
    elif content_type and content_type.startswith("image/"):
        text = parse_image(content)
    elif file.filename.endswith(('.txt', '.md')):
        text = content.decode("utf-8")
    else:
        raise HTTPException(status_code=400, detail=f"Unsupported file type: {content_type}")

    if not text.strip():
        raise HTTPException(status_code=400, detail="No text could be extracted.")
    
    text_chunks = semantic_chunker(text, embedding_model)
    if not text_chunks:
        raise HTTPException(status_code=400, detail="Document too short to be processed.")
    
    embeddings = embedding_model.encode(text_chunks, convert_to_numpy=True)
    serialized_index = create_faiss_index(embeddings)
    if not serialized_index:
        raise HTTPException(status_code=500, detail="Failed to create document index.")
    
    SESSION_DATA[session_id] = {"chunks": text_chunks, "index": serialized_index}
    logger.info(f"Session {session_id} created with {len(text_chunks)} chunks.")
    return {"session_id": session_id, "filename": file.filename, "chunks_created": len(text_chunks)}

@app.post("/query/{session_id}", response_model=QueryResponse)
async def query_session(session_id: str, request: QueryRequest):
    if not llm or not embedding_model:
        raise HTTPException(status_code=503, detail="AI models are not available.")
        
    session = SESSION_DATA.get(session_id)
    if not session:
        raise HTTPException(status_code=404, detail="Session not found.")
    
    query_with_prefix = f"Represent this sentence for searching relevant passages: {request.question}"
    question_embedding = embedding_model.encode([query_with_prefix], convert_to_numpy=True).astype('float32')
    
    index = deserialize_faiss_index(session["index"])
    if not index:
        raise HTTPException(status_code=500, detail="Could not load session index.")
    
    k = min(5, index.ntotal)
    distances, indices = index.search(question_embedding, k)
    context = "\n".join([session["chunks"][i] for i in indices[0]])

    # Correct prompt format for TinyLlama Chat
    prompt = f"""<|im_start|>user
Use the following context to answer the question.
Context:
{context}

Question: {request.question}<|im_end|>
<|im_start|>assistant
"""

    logger.info("Generating answer with TinyLlama...")
    
    answer = llm(
        prompt,
        max_new_tokens=256,
        temperature=0.3,
        stop=["<|im_end|>"]
    )
    
    return {"answer": answer.strip(), "context": context}