redhairedshanks1 commited on
Commit
d1fb649
·
verified ·
1 Parent(s): 331cdab

Upload 2 files

Browse files
Files changed (2) hide show
  1. services/extract_text.py +122 -0
  2. services/vector_store.py +76 -0
services/extract_text.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import fitz # PyMuPDF
4
+ import numpy as np
5
+ from PIL import Image
6
+ import tempfile
7
+ import cv2
8
+ import re
9
+
10
+ # OCR
11
+ from paddleocr import PaddleOCR
12
+
13
+ # Mistral OCR (optional)
14
+ try:
15
+ from doctr.models import ocr_predictor
16
+ from doctr.io import DocumentFile
17
+ mistral_ocr = ocr_predictor(pretrained=True)
18
+ use_mistral_ocr = True
19
+ except ImportError:
20
+ mistral_ocr = None
21
+ use_mistral_ocr = False
22
+
23
+ # Ensure OCR environment paths
24
+ os.environ.setdefault("HOME", "/app")
25
+ os.environ.setdefault("PADDLEOCR_HOME", "/app/.paddleocr")
26
+
27
+ # Logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Load PaddleOCR
32
+ ocr = PaddleOCR(use_angle_cls=True, lang='en')
33
+
34
+ def clean_text(text):
35
+ return re.sub(r'\s+', ' ', text).strip()
36
+
37
+ def auto_rotate_image(pil_img):
38
+ img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
39
+ coords = np.column_stack(np.where(img_cv > 0))
40
+ angle = cv2.minAreaRect(coords)[-1]
41
+ angle = -(90 + angle) if angle < -45 else -angle
42
+ (h, w) = img_cv.shape[:2]
43
+ M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
44
+ rotated = cv2.warpAffine(img_cv, M, (w, h), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_REPLICATE)
45
+ return Image.fromarray(cv2.cvtColor(rotated, cv2.COLOR_GRAY2RGB))
46
+
47
+ def extract_images_with_fitz(pdf_path):
48
+ doc = fitz.open(pdf_path)
49
+ images = []
50
+ for page in doc:
51
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
52
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
53
+ images.append(img)
54
+ doc.close()
55
+ return images
56
+
57
+ def extract_text_from_file(file, start_page=None, end_page=None, filename=None):
58
+ ext = os.path.splitext(filename or "")[-1].lower()
59
+ result = []
60
+
61
+ if ext == ".pdf":
62
+ doc = fitz.open(file.name)
63
+ images = extract_images_with_fitz(file.name)
64
+ total_pages = len(doc)
65
+ start = max(start_page or 1, 1)
66
+ end = min(end_page or total_pages, total_pages)
67
+
68
+ for i, page in enumerate(doc):
69
+ page_num = i + 1
70
+ if not (start <= page_num <= end):
71
+ continue
72
+
73
+ text = page.get_text()
74
+ if text.strip():
75
+ result.append(f"Page {page_num} (Extracted):\n{clean_text(text)}")
76
+ else:
77
+ if i < len(images):
78
+ img = auto_rotate_image(images[i])
79
+ img_np = np.array(img)
80
+ try:
81
+ ocr_result = ocr.ocr(img_np, cls=True)
82
+ ocr_text = "\n".join([line[1][0] for line in ocr_result[0]]) if ocr_result else ""
83
+ if not ocr_text and use_mistral_ocr:
84
+ doc_img = DocumentFile.from_images(img)
85
+ ocr_text = mistral_ocr(doc_img).render()
86
+ except Exception as e:
87
+ logger.warning(f"OCR error on page {page_num}: {e}")
88
+ ocr_text = "[OCR Error]"
89
+ result.append(f"Page {page_num} (OCR):\n{clean_text(ocr_text) or '[No OCR Text]'}")
90
+ else:
91
+ result.append(f"Page {page_num}: [No text or image]")
92
+
93
+ doc.close()
94
+ return "\n\n".join(result)
95
+
96
+ elif ext == ".docx":
97
+ from docx.api import Document
98
+ doc = Document(file.name)
99
+ paras = [p.text for p in doc.paragraphs if p.text.strip()]
100
+ page_texts = []
101
+ page_size = 500
102
+ for i in range(0, len(paras), page_size):
103
+ page_texts.append("\n".join(paras[i:i + page_size]))
104
+
105
+ selected_pages = page_texts
106
+ if start_page and end_page:
107
+ selected_pages = page_texts[start_page - 1:end_page]
108
+ return clean_text("\n\n".join(selected_pages))
109
+
110
+ elif ext == ".csv":
111
+ import pandas as pd
112
+ return pd.read_csv(file.name).to_string(index=False)
113
+
114
+ elif ext in [".xls", ".xlsx"]:
115
+ import pandas as pd
116
+ xl = pd.ExcelFile(file.name)
117
+ return "\n\n".join([
118
+ f"Sheet: {s}\n{xl.parse(s).to_string(index=False)}"
119
+ for s in xl.sheet_names
120
+ ])
121
+
122
+ return "Unsupported file type"
services/vector_store.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vector_store.py
2
+
3
+ import os
4
+ import hashlib
5
+ from qdrant_client import QdrantClient
6
+ from qdrant_client.http.models import PointStruct, VectorParams, Distance
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ # === CONFIGURATION ===
10
+ COLLECTION_NAME = os.getenv("QDRANT_COLLECTION", "document_store")
11
+ QDRANT_URL = os.getenv("QDRANT_URL")
12
+ QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
13
+
14
+ # === Initialize Qdrant client ===
15
+ client = QdrantClient(
16
+ url=QDRANT_URL,
17
+ api_key=QDRANT_API_KEY
18
+ )
19
+
20
+ # === Load embedding model ===
21
+ model = SentenceTransformer("all-MiniLM-L6-v2") # 384-dimensional vector
22
+
23
+ # === Ensure collection exists in Qdrant ===
24
+ def init_collection():
25
+ collections = client.get_collections().collections
26
+ if COLLECTION_NAME not in [col.name for col in collections]:
27
+ client.recreate_collection(
28
+ collection_name=COLLECTION_NAME,
29
+ vectors_config=VectorParams(size=384, distance=Distance.COSINE)
30
+ )
31
+
32
+ # Call once on import to verify/initialize collection
33
+ init_collection()
34
+
35
+ # === Generate a consistent ID based on filename (MD5 hash → int) ===
36
+ def compute_id(filename):
37
+ return int(hashlib.md5(filename.encode()).hexdigest()[:16], 16)
38
+
39
+ # === Retrieve entry by filename ===
40
+ def get_entry(filename):
41
+ point_id = compute_id(filename)
42
+ result = client.retrieve(collection_name=COLLECTION_NAME, ids=[point_id])
43
+ return result[0].payload if result else None
44
+
45
+ # === Add or update an entry (filename → vector + payload) ===
46
+ def upsert_entry(filename, **fields):
47
+ init_collection()
48
+
49
+ if "filename" in fields:
50
+ fields.pop("filename") # prevent multiple values for 'filename'
51
+
52
+ point_id = compute_id(filename)
53
+ existing = get_entry(filename) or {}
54
+
55
+ # Merge old and new fields; prefer non-null new values
56
+ payload = {**existing, **{k: v for k, v in fields.items() if v is not None}}
57
+
58
+ # Extract text field and encode into a vector
59
+ base_text = payload.get("text", "")
60
+ if not isinstance(base_text, str):
61
+ base_text = str(base_text)
62
+
63
+ try:
64
+ vector = model.encode(base_text, normalize_embeddings=True).tolist()
65
+ except Exception as e:
66
+ print(f"❌ Vector encoding failed for {filename}: {e}")
67
+ vector = [0.0] * 384
68
+
69
+ # Final payload with file reference
70
+ payload = {"filename": filename, **payload}
71
+
72
+ client.upsert(
73
+ collection_name=COLLECTION_NAME,
74
+ points=[PointStruct(id=point_id, vector=vector, payload=payload)]
75
+ )
76
+