Kalpokoch commited on
Commit
d1af85f
·
1 Parent(s): 6b9a057

sentence transformer added

Browse files
Files changed (2) hide show
  1. app.py +65 -64
  2. requirements.txt +4 -1
app.py CHANGED
@@ -3,98 +3,99 @@ from fastapi.middleware.cors import CORSMiddleware
3
  import uuid
4
  import logging
5
  import io
6
-
7
- # NEW: Import parsing libraries
8
- import fitz # PyMuPDF
9
  from PIL import Image
10
  import pytesseract
 
 
 
 
 
11
 
12
- # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
-
16
  app = FastAPI()
17
-
18
- # CORS Middleware
19
  app.add_middleware(
20
- CORSMiddleware,
21
- allow_origins=["*"],
22
- allow_credentials=True,
23
- allow_methods=["*"],
24
- allow_headers=["*"],
25
  )
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  # In-memory session store
28
  SESSION_DATA = {}
29
- logger.info("Session store initialized.")
30
-
31
- # --- NEW: Parsing Functions ---
32
 
33
- def parse_pdf(content: bytes) -> str:
34
- """Extracts text from a PDF file's bytes."""
35
- try:
36
- doc = fitz.open(stream=content, filetype="pdf")
37
- text = ""
38
- for page in doc:
39
- text += page.get_text()
40
- logger.info(f"Successfully parsed PDF, extracted {len(text)} characters.")
41
- return text
42
- except Exception as e:
43
- logger.error(f"PDF parsing failed: {e}")
44
- return ""
45
 
46
- def parse_image(content: bytes) -> str:
47
- """Extracts text from an image file's bytes using OCR."""
48
- try:
49
- image = Image.open(io.BytesIO(content))
50
- text = pytesseract.image_to_string(image)
51
- logger.info(f"Successfully parsed image, extracted {len(text)} characters.")
52
- return text
53
- except Exception as e:
54
- logger.error(f"Image parsing failed: {e}")
55
- return ""
 
56
 
57
  # --- MODIFIED: The /upload Endpoint ---
58
-
59
  @app.post("/upload")
60
  async def upload_file(file: UploadFile = File(...)):
61
- """
62
- Accepts a file, detects its type, parses it, and stores the extracted text.
63
- """
64
  session_id = str(uuid.uuid4())
65
  logger.info(f"New upload '{file.filename}'. Creating session_id: {session_id}")
66
-
67
  content = await file.read()
 
 
68
  extracted_text = ""
 
 
 
 
69
 
70
- # Simple dispatcher based on file's content type
71
- if file.content_type == "application/pdf":
72
- extracted_text = parse_pdf(content)
73
- elif file.content_type and file.content_type.startswith("image/"):
74
- extracted_text = parse_image(content)
75
- elif file.content_type == "text/plain":
76
- extracted_text = content.decode("utf-8")
77
- else:
78
- raise HTTPException(status_code=400, detail=f"Unsupported file type: {file.content_type}")
79
-
80
- if not extracted_text:
81
  raise HTTPException(status_code=400, detail="Could not extract any text from the file.")
82
 
83
- # Store the EXTRACTED TEXT, not the raw file content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  SESSION_DATA[session_id] = {
85
  "filename": file.filename,
86
- "text": extracted_text
 
87
  }
88
 
89
  return {
90
  "session_id": session_id,
91
  "filename": file.filename,
92
- "chars_extracted": len(extracted_text)
93
- }
94
-
95
- # This endpoint is useful for debugging
96
- @app.get("/session/{session_id}/text")
97
- def get_session_text(session_id: str):
98
- if session_id not in SESSION_DATA:
99
- raise HTTPException(status_code=404, detail="Session not found.")
100
- return {"text": SESSION_DATA[session_id].get("text", "")}
 
3
  import uuid
4
  import logging
5
  import io
6
+ import fitz
 
 
7
  from PIL import Image
8
  import pytesseract
9
+ import numpy as np
10
+
11
+ # NEW: Import AI and search libraries
12
+ from sentence_transformers import SentenceTransformer
13
+ import faiss
14
 
15
+ # --- Basic Setup (Logging, FastAPI, CORS) ---
16
  logging.basicConfig(level=logging.INFO)
17
  logger = logging.getLogger(__name__)
 
18
  app = FastAPI()
 
 
19
  app.add_middleware(
20
+ CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
 
 
 
 
21
  )
22
 
23
+ # --- AI MODEL LOADING ---
24
+ # This happens only once when the app starts.
25
+ # 'all-MiniLM-L6-v2' is a great, lightweight model for CPU.
26
+ try:
27
+ logger.info("Loading sentence-transformer model...")
28
+ model = SentenceTransformer('all-MiniLM-L6-v2')
29
+ logger.info("Model loaded successfully.")
30
+ except Exception as e:
31
+ logger.error(f"Failed to load sentence-transformer model: {e}")
32
+ model = None
33
+
34
  # In-memory session store
35
  SESSION_DATA = {}
 
 
 
36
 
37
+ # --- Parsing Functions (parse_pdf, parse_image - keep these as they are) ---
38
+ def parse_pdf(content: bytes) -> str: # ... your existing function ...
39
+ def parse_image(content: bytes) -> str: # ... your existing function ...
 
 
 
 
 
 
 
 
 
40
 
41
+ # --- NEW: Helper function for chunking text ---
42
+ def chunk_text(text: str, chunk_size: int = 256, overlap: int = 32) -> list[str]:
43
+ """Splits text into overlapping chunks of words."""
44
+ words = text.split()
45
+ if not words:
46
+ return []
47
+ chunks = []
48
+ for i in range(0, len(words), chunk_size - overlap):
49
+ chunk = " ".join(words[i:i + chunk_size])
50
+ chunks.append(chunk)
51
+ return chunks
52
 
53
  # --- MODIFIED: The /upload Endpoint ---
 
54
  @app.post("/upload")
55
  async def upload_file(file: UploadFile = File(...)):
56
+ if not model:
57
+ raise HTTPException(status_code=503, detail="AI model is not available.")
58
+
59
  session_id = str(uuid.uuid4())
60
  logger.info(f"New upload '{file.filename}'. Creating session_id: {session_id}")
 
61
  content = await file.read()
62
+
63
+ # 1. PARSE (This part is the same as before)
64
  extracted_text = ""
65
+ if file.content_type == "application/pdf": extracted_text = parse_pdf(content)
66
+ elif file.content_type and file.content_type.startswith("image/"): extracted_text = parse_image(content)
67
+ elif file.content_type == "text/plain": extracted_text = content.decode("utf-8")
68
+ else: raise HTTPException(status_code=400, detail=f"Unsupported file type: {file.content_type}")
69
 
70
+ if not extracted_text.strip():
 
 
 
 
 
 
 
 
 
 
71
  raise HTTPException(status_code=400, detail="Could not extract any text from the file.")
72
 
73
+ # 2. CHUNK
74
+ text_chunks = chunk_text(extracted_text)
75
+ logger.info(f"Text chunked into {len(text_chunks)} pieces.")
76
+ if not text_chunks:
77
+ raise HTTPException(status_code=400, detail="Document is empty or too short to be chunked.")
78
+
79
+ # 3. EMBED
80
+ logger.info("Generating embeddings for text chunks...")
81
+ embeddings = model.encode(text_chunks, convert_to_numpy=True)
82
+ logger.info(f"Embeddings generated with shape: {embeddings.shape}")
83
+
84
+ # 4. INDEX
85
+ d = embeddings.shape[1] # Dimension of embeddings
86
+ index = faiss.IndexFlatL2(d)
87
+ index.add(embeddings.astype('float32')) # FAISS requires float32
88
+ logger.info(f"FAISS index created with {index.ntotal} vectors.")
89
+
90
+ # Store the index AND the original text chunks in the session
91
  SESSION_DATA[session_id] = {
92
  "filename": file.filename,
93
+ "chunks": text_chunks,
94
+ "index": index.serialize() # Serialize the index for storage
95
  }
96
 
97
  return {
98
  "session_id": session_id,
99
  "filename": file.filename,
100
+ "chunks_created": len(text_chunks)
101
+ }
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -4,4 +4,7 @@ python-multipart
4
 
5
  PyMuPDF
6
  Pillow
7
- pytesseract
 
 
 
 
4
 
5
  PyMuPDF
6
  Pillow
7
+ pytesseract
8
+
9
+ sentence-transformers
10
+ faiss-cpu