aniketkumar1106 commited on
Commit
4db3151
·
verified ·
1 Parent(s): 84ae78a

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +41 -11
server.py CHANGED
@@ -7,6 +7,7 @@ import unicodedata
7
  import threading
8
  import anyio
9
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  from fastapi.staticfiles import StaticFiles
12
  from huggingface_hub import snapshot_download
@@ -17,7 +18,7 @@ logger = logging.getLogger(__name__)
17
 
18
  DATASET_REPO = "aniketkumar1106/orbit-data"
19
  IMAGE_DIR = "Productimages"
20
- DB_TARGET = "orbiitt.db"
21
  MIN_CONFIDENCE_THRESHOLD = 0.1
22
 
23
  # BOOTSTRAP: Mandatory folder creation
@@ -48,6 +49,7 @@ def background_sync():
48
  global engine, loading_status
49
  token = os.environ.get("HF_TOKEN")
50
 
 
51
  for f in os.listdir(IMAGE_DIR):
52
  p = os.path.join(IMAGE_DIR, f)
53
  try:
@@ -60,18 +62,33 @@ def background_sync():
60
  snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", token=token, local_dir=".")
61
 
62
  if os.path.exists("orbiitt_db.zip"):
 
63
  with zipfile.ZipFile("orbiitt_db.zip", 'r') as z:
64
  z.extractall("temp_extract")
65
 
66
- for root, _, files in os.walk("temp_extract"):
 
 
 
 
 
 
 
 
 
 
 
67
  for f in files:
68
- src = os.path.join(root, f)
69
  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
 
70
  clean_name = normalize_filename(f)
71
  shutil.copy(src, os.path.join(IMAGE_DIR, clean_name))
72
- elif "orbiitt" in f.lower() and f.endswith(".db"):
73
- shutil.copy(src, f"./{DB_TARGET}")
74
  shutil.rmtree("temp_extract")
 
 
 
 
75
 
76
  # LOGGING FILE COUNT FOR VALIDATION
77
  final_count = len(os.listdir(IMAGE_DIR))
@@ -80,21 +97,31 @@ def background_sync():
80
  loading_status = "Loading AI Engine..."
81
  try:
82
  from orbiitt_engine import OrbiittEngine
83
- engine = OrbiittEngine()
 
84
  loading_status = "Ready"
85
  logger.info(">>> ENGINE ONLINE <<<")
86
  except Exception as e:
87
  loading_status = f"Engine Error: {str(e)}"
 
 
88
  except Exception as e:
89
  loading_status = f"Sync Error: {str(e)}"
 
90
 
91
  @app.on_event("startup")
92
  async def startup_event():
93
  thread = threading.Thread(target=background_sync, daemon=True)
94
  thread.start()
95
 
 
96
  app.mount("/Productimages", StaticFiles(directory=IMAGE_DIR), name="Productimages")
97
 
 
 
 
 
 
98
  @app.get("/health")
99
  def health():
100
  return {"status": loading_status, "ready": engine is not None}
@@ -117,13 +144,13 @@ async def search(text: str = Form(None), weight: float = Form(0.5), file: Upload
117
  async with await anyio.open_file(t_path, "wb") as f:
118
  await f.write(content)
119
 
120
- # THE FIX: Use lambda to ensure keyword arguments are passed correctly
121
  results = await anyio.to_thread.run_sync(
122
  lambda: engine.search(
123
  text_query=text,
124
  image_file=t_path,
125
- text_weight=actual_weight,
126
- top_k=50
127
  )
128
  )
129
 
@@ -138,13 +165,15 @@ async def search(text: str = Form(None), weight: float = Form(0.5), file: Upload
138
  if score < MIN_CONFIDENCE_THRESHOLD or pid in seen_ids:
139
  continue
140
 
141
- raw_path = r.get('url') or r.get('path') or ""
142
- fname = normalize_filename(os.path.basename(raw_path))
 
143
 
144
  match = None
145
  if fname in all_files:
146
  match = fname
147
  else:
 
148
  for disk_f in all_files:
149
  if fname[:15].lower() in disk_f.lower():
150
  match = disk_f
@@ -153,6 +182,7 @@ async def search(text: str = Form(None), weight: float = Form(0.5), file: Upload
153
  if match:
154
  final_list.append({
155
  "id": pid,
 
156
  "url": f"Productimages/{urllib.parse.quote(match)}",
157
  "score": round(float(score), 4)
158
  })
 
7
  import threading
8
  import anyio
9
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException
10
+ from fastapi.responses import FileResponse
11
  from fastapi.middleware.cors import CORSMiddleware
12
  from fastapi.staticfiles import StaticFiles
13
  from huggingface_hub import snapshot_download
 
18
 
19
  DATASET_REPO = "aniketkumar1106/orbit-data"
20
  IMAGE_DIR = "Productimages"
21
+ DB_TARGET_FOLDER = "orbiitt_db" # The folder ChromaDB expects
22
  MIN_CONFIDENCE_THRESHOLD = 0.1
23
 
24
  # BOOTSTRAP: Mandatory folder creation
 
49
  global engine, loading_status
50
  token = os.environ.get("HF_TOKEN")
51
 
52
+ # Cleanup old images
53
  for f in os.listdir(IMAGE_DIR):
54
  p = os.path.join(IMAGE_DIR, f)
55
  try:
 
62
  snapshot_download(repo_id=DATASET_REPO, repo_type="dataset", token=token, local_dir=".")
63
 
64
  if os.path.exists("orbiitt_db.zip"):
65
+ loading_status = "Extracting Database..."
66
  with zipfile.ZipFile("orbiitt_db.zip", 'r') as z:
67
  z.extractall("temp_extract")
68
 
69
+ # Smart Extraction: Find the ChromaDB folder logic
70
+ db_found = False
71
+ for root, dirs, files in os.walk("temp_extract"):
72
+ # Identifying ChromaDB by its signature file
73
+ if "chroma.sqlite3" in files:
74
+ if os.path.exists(DB_TARGET_FOLDER):
75
+ shutil.rmtree(DB_TARGET_FOLDER)
76
+ # Move the directory containing the sqlite3 file to our target location
77
+ shutil.move(root, DB_TARGET_FOLDER)
78
+ db_found = True
79
+
80
+ # Move images
81
  for f in files:
 
82
  if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
83
+ src = os.path.join(root, f)
84
  clean_name = normalize_filename(f)
85
  shutil.copy(src, os.path.join(IMAGE_DIR, clean_name))
86
+
 
87
  shutil.rmtree("temp_extract")
88
+
89
+ # If no DB folder was moved (maybe zip structure was flat?), ensure folder exists
90
+ if not db_found and not os.path.exists(DB_TARGET_FOLDER):
91
+ os.makedirs(DB_TARGET_FOLDER, exist_ok=True)
92
 
93
  # LOGGING FILE COUNT FOR VALIDATION
94
  final_count = len(os.listdir(IMAGE_DIR))
 
97
  loading_status = "Loading AI Engine..."
98
  try:
99
  from orbiitt_engine import OrbiittEngine
100
+ # Initialize with the CORRECT folder path
101
+ engine = OrbiittEngine(db_path=f"./{DB_TARGET_FOLDER}")
102
  loading_status = "Ready"
103
  logger.info(">>> ENGINE ONLINE <<<")
104
  except Exception as e:
105
  loading_status = f"Engine Error: {str(e)}"
106
+ logger.error(f"Engine Failed: {e}")
107
+
108
  except Exception as e:
109
  loading_status = f"Sync Error: {str(e)}"
110
+ logger.error(f"Sync Failed: {e}")
111
 
112
  @app.on_event("startup")
113
  async def startup_event():
114
  thread = threading.Thread(target=background_sync, daemon=True)
115
  thread.start()
116
 
117
+ # Mount Static Images
118
  app.mount("/Productimages", StaticFiles(directory=IMAGE_DIR), name="Productimages")
119
 
120
+ # Serve UI (Root Endpoint)
121
+ @app.get("/")
122
+ async def read_index():
123
+ return FileResponse('index.html')
124
+
125
  @app.get("/health")
126
  def health():
127
  return {"status": loading_status, "ready": engine is not None}
 
144
  async with await anyio.open_file(t_path, "wb") as f:
145
  await f.write(content)
146
 
147
+ # CORRECTED: Calling engine with top_k
148
  results = await anyio.to_thread.run_sync(
149
  lambda: engine.search(
150
  text_query=text,
151
  image_file=t_path,
152
+ text_weight=actual_weight,
153
+ top_k=50
154
  )
155
  )
156
 
 
165
  if score < MIN_CONFIDENCE_THRESHOLD or pid in seen_ids:
166
  continue
167
 
168
+ # The engine returns a clean ID/path, let's match it to disk
169
+ fname_from_db = os.path.basename(r.get('id', ''))
170
+ fname = normalize_filename(fname_from_db)
171
 
172
  match = None
173
  if fname in all_files:
174
  match = fname
175
  else:
176
+ # Fuzzy fallback if exact match fails
177
  for disk_f in all_files:
178
  if fname[:15].lower() in disk_f.lower():
179
  match = disk_f
 
182
  if match:
183
  final_list.append({
184
  "id": pid,
185
+ # Ensure URL is properly encoded for web
186
  "url": f"Productimages/{urllib.parse.quote(match)}",
187
  "score": round(float(score), 4)
188
  })