File size: 34,842 Bytes
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e29b232
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
9f87ec0
db764ae
 
 
 
9f87ec0
 
db764ae
 
 
 
 
 
 
 
 
9f87ec0
 
 
 
 
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e29b232
 
 
 
 
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f87ec0
 
 
 
 
 
 
 
 
 
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f87ec0
 
 
 
 
db764ae
 
 
9f87ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f87ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1710209
 
 
 
 
 
 
 
 
 
 
 
bf9e9a4
 
1710209
 
 
 
 
 
 
 
 
 
 
db764ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
"""
FastAPI server for the Contextual Similarity Engine.

Endpoints:
  /api/train/*        — Train/adapt models (3 strategies)
  /api/init            — Load a model into the engine
  /api/documents       — Add documents to the corpus
  /api/index/build     — Build FAISS index
  /api/query           — Semantic search
  /api/compare         — Compare two texts
  /api/analyze/*       — Keyword analysis
  /api/match           — Keyword meaning matching
  /api/eval/*          — Evaluation metrics
  /api/w2v/*           — Word2Vec baseline comparison
  /api/dataset/*       — HuggingFace dataset loading (Epstein Files)
"""

import asyncio
import logging
import os
import time
import threading
from collections import deque
from pathlib import Path
from typing import Literal, Optional

from fastapi import FastAPI, HTTPException, Query, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field

from contextual_similarity import ContextualSimilarityEngine
from evaluation import Evaluator, GroundTruthEntry
from training import CorpusTrainer
from word2vec_baseline import Word2VecEngine
from data_loader import load_raw_dataset, load_raw_to_engine, import_chromadb_to_engine, get_dataset_info

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# ------------------------------------------------------------------ #
#  Log streaming buffer
# ------------------------------------------------------------------ #

class LogBuffer(logging.Handler):
    """Thread-safe log handler that buffers recent messages for SSE streaming."""

    def __init__(self, max_lines: int = 500):
        super().__init__()
        self._buffer: deque[str] = deque(maxlen=max_lines)
        self._lock = threading.Lock()
        self._event = threading.Event()

    def emit(self, record: logging.LogRecord) -> None:
        msg = self.format(record)
        with self._lock:
            self._buffer.append(msg)
        self._event.set()

    def get_new_lines(self, after: int) -> tuple[list[str], int]:
        """Return lines added after index `after`, and the new cursor."""
        with self._lock:
            all_lines = list(self._buffer)
        new_lines = all_lines[after:] if after < len(all_lines) else []
        return new_lines, len(all_lines)



log_buffer = LogBuffer()
log_buffer.setFormatter(logging.Formatter("%(asctime)s %(name)s %(message)s", datefmt="%H:%M:%S"))
log_buffer.setLevel(logging.INFO)
# Attach to root logger so all modules' logs are captured
logging.getLogger().addHandler(log_buffer)


# ------------------------------------------------------------------ #
#  Security constants & validation helpers
# ------------------------------------------------------------------ #

ALLOWED_MODELS = frozenset({
    "all-MiniLM-L6-v2",
    "all-mpnet-base-v2",
    "BAAI/bge-large-en-v1.5",
})
ALLOWED_SOURCE_FILTERS = frozenset({"TEXT-", "IMAGES-"})
MAX_UPLOAD_BYTES = 10 * 1024 * 1024  # 10 MB
BASE_DIR = Path(__file__).parent.resolve()


def _validate_model_name(name: str) -> str:
    """Allow known HuggingFace models or local paths within project dir."""
    if name in ALLOWED_MODELS:
        return name
    # Treat as a local model path — must be within the project directory
    _validate_safe_path(name)
    return name


def _validate_safe_path(path_str: str) -> str:
    """Reject paths that escape the project directory."""
    resolved = Path(path_str).resolve()
    if not resolved.is_relative_to(BASE_DIR):
        raise HTTPException(400, "Path must be within the project directory.")
    return path_str


def _to_native(obj):
    """Recursively convert numpy types to native Python types for JSON serialization."""
    import numpy as np
    if isinstance(obj, dict):
        return {k: _to_native(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_native(v) for v in obj]
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    return obj


# ------------------------------------------------------------------ #
#  App setup
# ------------------------------------------------------------------ #

app = FastAPI(
    title="Contextual Similarity API",
    version="1.0.0",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:5173", "http://localhost:3000",
                    "https://huggingface.co", "https://*.hf.space"],
    allow_origin_regex=r"https://.*\.hf\.space",
    allow_credentials=True,
    allow_methods=["GET", "POST"],
    allow_headers=["Content-Type", "Authorization"],
)

# Global instances
engine: Optional[ContextualSimilarityEngine] = None
evaluator: Optional[Evaluator] = None
w2v_engine: Optional[Word2VecEngine] = None

ENGINE_SAVE_DIR = Path(os.environ.get("ENGINE_STATE_DIR", str(BASE_DIR / "engine_state")))
W2V_SAVE_DIR = Path(os.environ.get("W2V_STATE_DIR", str(BASE_DIR / "w2v_state")))


@app.on_event("startup")
def _auto_restore():
    """Restore engine and W2V state from disk if previous saves exist."""
    global engine, evaluator, w2v_engine
    if (ENGINE_SAVE_DIR / "meta.json").is_file():
        try:
            engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR))
            if engine.index is not None:
                evaluator = Evaluator(engine)
            logger.info("Auto-restored engine: %d chunks, %d docs",
                        len(engine.chunks), len(engine._doc_ids))
        except Exception:
            logger.exception("Failed to auto-restore engine state — starting fresh")
    if Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR)):
        try:
            w2v_engine = Word2VecEngine.load(str(W2V_SAVE_DIR))
            logger.info("Auto-restored Word2Vec: %d sentences, %d vocab",
                        len(w2v_engine.sentences), len(w2v_engine.model.wv))
        except Exception:
            logger.exception("Failed to auto-restore Word2Vec state — starting fresh")


@app.get("/api/logs/stream")
async def stream_logs():
    """SSE endpoint: streams server log lines in real-time."""
    async def event_generator():
        cursor = 0
        # Send initial snapshot
        lines, cursor = log_buffer.get_new_lines(cursor)
        for line in lines[-20:]:  # last 20 lines on connect
            yield f"data: {line}\n\n"
        cursor = max(cursor, 0)
        while True:
            await asyncio.sleep(0.5)
            lines, cursor = log_buffer.get_new_lines(cursor)
            if lines:
                for line in lines:
                    yield f"data: {line}\n\n"

    return StreamingResponse(
        event_generator(),
        media_type="text/event-stream",
        headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
    )


@app.get("/api/logs/poll")
def poll_logs(cursor: int = Query(default=0, ge=0)):
    """Polling fallback for log streaming (works through HF Spaces proxy)."""
    lines, new_cursor = log_buffer.get_new_lines(cursor)
    return {"lines": lines, "cursor": new_cursor}


# ------------------------------------------------------------------ #
#  Request models (with input validation)
# ------------------------------------------------------------------ #

class TrainRequest(BaseModel):
    corpus_texts: list[str] = Field(max_length=10_000)
    base_model: str = "all-MiniLM-L6-v2"
    output_path: str = "./trained_model"
    epochs: int = Field(default=5, ge=1, le=50)
    batch_size: int = Field(default=16, ge=1, le=256)


class TrainKeywordsRequest(TrainRequest):
    keyword_meanings: dict[str, str]


class TrainEvalRequest(BaseModel):
    test_pairs: list[dict] = Field(max_length=1_000)
    trained_model_path: str = "./trained_model"
    base_model: str = "all-MiniLM-L6-v2"
    corpus_texts: list[str] = Field(default=[], max_length=10_000)


class InitRequest(BaseModel):
    model_name: str = "all-MiniLM-L6-v2"
    chunk_size: int = Field(default=512, ge=50, le=8192)
    chunk_overlap: int = Field(default=128, ge=0, le=4096)
    batch_size: int = Field(default=64, ge=1, le=512)


class DocumentRequest(BaseModel):
    doc_id: str = Field(max_length=200)
    text: str = Field(max_length=10_000_000)


class QueryRequest(BaseModel):
    text: str = Field(max_length=10_000)
    top_k: int = Field(default=10, ge=1, le=100)


class CompareRequest(BaseModel):
    text_a: str = Field(max_length=50_000)
    text_b: str = Field(max_length=50_000)


class KeywordAnalysisRequest(BaseModel):
    keyword: str = Field(max_length=200)
    top_k: int = Field(default=10, ge=1, le=100)
    cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0)


class BatchAnalysisRequest(BaseModel):
    keywords: list[str] = Field(max_length=50)
    top_k: int = Field(default=10, ge=1, le=100)
    cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0)
    compare_across: bool = True


class KeywordMatchRequest(BaseModel):
    keyword: str = Field(max_length=200)
    candidate_meanings: list[str] = Field(max_length=50)


class EvalDisambiguationRequest(BaseModel):
    ground_truth: list[dict] = Field(max_length=10_000)
    candidate_meanings: dict[str, list[str]]


class EvalRetrievalRequest(BaseModel):
    queries: list[dict] = Field(max_length=1_000)
    k_values: list[int] = Field(default=[1, 3, 5, 10], max_length=10)


class W2VInitRequest(BaseModel):
    corpus_texts: list[str] = Field(max_length=10_000)
    vector_size: int = Field(default=100, ge=50, le=500)
    window: int = Field(default=5, ge=1, le=20)
    epochs: int = Field(default=50, ge=1, le=200)


class W2VCompareRequest(BaseModel):
    text_a: str = Field(max_length=50_000)
    text_b: str = Field(max_length=50_000)


class W2VQueryRequest(BaseModel):
    text: str = Field(max_length=10_000)
    top_k: int = Field(default=10, ge=1, le=100)


class W2VWordRequest(BaseModel):
    word: str = Field(max_length=200)
    top_k: int = Field(default=10, ge=1, le=100)


class ContextAnalysisRequest(BaseModel):
    keyword: str = Field(max_length=200)
    cluster_threshold: float = Field(default=0.35, ge=0.01, le=1.0)
    top_words: int = Field(default=8, ge=1, le=30)


class DatasetLoadRequest(BaseModel):
    source: Literal["raw", "embeddings"] = "raw"
    max_docs: int = Field(default=500, ge=1, le=100_000)
    min_text_length: int = Field(default=100, ge=0, le=100_000)
    source_filter: Optional[str] = None
    build_index: bool = True


# ------------------------------------------------------------------ #
#  Training endpoints
# ------------------------------------------------------------------ #

def _run_training(req: TrainRequest, strategy: str, train_fn):
    """Common wrapper for all training endpoints: validate, log, time, train."""
    _validate_model_name(req.base_model)
    _validate_safe_path(req.output_path)
    logger.info("Training (%s): model=%s, corpus=%d texts, epochs=%d, batch=%d",
                strategy, req.base_model, len(req.corpus_texts), req.epochs, req.batch_size)
    t0 = time.time()
    trainer = CorpusTrainer(req.corpus_texts, req.base_model)
    result = train_fn(trainer)
    logger.info("Training (%s) complete in %.1fs → %s", strategy, time.time() - t0, req.output_path)
    return result

@app.post("/api/train/unsupervised")
def train_unsupervised(req: TrainRequest):
    """Soft-label domain adaptation. No labels needed."""
    return _run_training(req, "unsupervised",
                         lambda t: t.train_unsupervised(req.output_path, req.epochs, req.batch_size))


@app.post("/api/train/contrastive")
def train_contrastive(req: TrainRequest):
    """Contrastive: learns from corpus structure (adjacent sentences = similar)."""
    return _run_training(req, "contrastive",
                         lambda t: t.train_contrastive(req.output_path, req.epochs, req.batch_size))


@app.post("/api/train/keywords")
def train_keywords(req: TrainKeywordsRequest):
    """Keyword-supervised: provide keyword→meaning map, pairs auto-generated."""
    return _run_training(req, "keyword-supervised",
                         lambda t: t.train_with_keywords(req.keyword_meanings, req.output_path, req.epochs, req.batch_size))


@app.post("/api/train/evaluate")
def train_evaluate(req: TrainEvalRequest):
    """Compare base model vs trained model on test pairs."""
    _validate_model_name(req.base_model)
    _validate_model_name(req.trained_model_path)
    logger.info("Evaluating: base=%s vs trained=%s, %d test pairs",
                req.base_model, req.trained_model_path, len(req.test_pairs))
    corpus = req.corpus_texts or ["placeholder text for initialization."]
    trainer = CorpusTrainer(corpus, req.base_model)
    test_pairs = [
        (p["text_a"], p["text_b"], p.get("expected", p.get("score", 0.5)))
        for p in req.test_pairs
    ]
    result = trainer.evaluate_model(test_pairs, req.trained_model_path)
    logger.info("Evaluation complete: %d pairs evaluated", len(test_pairs))
    return result


# ------------------------------------------------------------------ #
#  Engine endpoints
# ------------------------------------------------------------------ #

@app.post("/api/init")
def init_engine(req: InitRequest):
    """Initialize the similarity engine with a model (pretrained or trained)."""
    _validate_model_name(req.model_name)
    if req.chunk_overlap >= req.chunk_size:
        raise HTTPException(400, "chunk_overlap must be less than chunk_size.")
    global engine, evaluator
    logger.info("Initializing engine: model=%s, chunk_size=%d, overlap=%d, batch=%d",
                req.model_name, req.chunk_size, req.chunk_overlap, req.batch_size)
    t0 = time.time()
    engine = ContextualSimilarityEngine(
        model_name=req.model_name,
        chunk_size=req.chunk_size,
        chunk_overlap=req.chunk_overlap,
        batch_size=req.batch_size,
    )
    evaluator = None
    elapsed = round(time.time() - t0, 2)
    logger.info("Engine initialized in %.2fs (model=%s)", elapsed, req.model_name)
    return {"status": "ok", "model": req.model_name, "load_time_seconds": elapsed}


@app.post("/api/documents")
def add_document(req: DocumentRequest):
    _ensure_engine()
    logger.info("Adding document: id=%s, text_length=%d", req.doc_id, len(req.text))
    try:
        chunks = engine.add_document(req.doc_id, req.text)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    logger.info("Document '%s' added: %d chunks", req.doc_id, len(chunks))
    return {
        "status": "ok", "doc_id": req.doc_id, "num_chunks": len(chunks),
        "chunks_preview": [{"index": c.chunk_index, "text": c.text[:150]} for c in chunks[:5]],
    }


@app.post("/api/documents/upload")
async def upload_document(file: UploadFile = File(...), doc_id: Optional[str] = Form(None)):
    _ensure_engine()
    contents = await file.read()
    if len(contents) > MAX_UPLOAD_BYTES:
        raise HTTPException(413, f"File too large. Maximum size is {MAX_UPLOAD_BYTES // (1024 * 1024)}MB.")
    try:
        text = contents.decode("utf-8")
    except UnicodeDecodeError:
        raise HTTPException(400, "File must be valid UTF-8 text.")
    d_id = doc_id or Path(file.filename or "upload").stem
    try:
        chunks = engine.add_document(d_id, text)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    return {"status": "ok", "doc_id": d_id, "num_chunks": len(chunks)}


@app.post("/api/index/build")
def build_index():
    _ensure_engine()
    logger.info("Building FAISS index (%d documents in corpus)...", len(engine.corpus))
    t0 = time.time()
    try:
        engine.build_index(show_progress=True)
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))
    global evaluator
    evaluator = Evaluator(engine)
    elapsed = round(time.time() - t0, 2)
    logger.info("FAISS index built: %d vectors (dim=%d) in %.2fs",
                engine.index.ntotal, engine.embedding_dim, elapsed)
    # Auto-save so data persists across restarts
    try:
        engine.save(str(ENGINE_SAVE_DIR))
    except Exception:
        logger.warning("Auto-save after index build failed", exc_info=True)
    return {
        "status": "ok", "total_chunks": engine.index.ntotal,
        "embedding_dim": engine.embedding_dim, "build_time_seconds": elapsed,
    }


@app.post("/api/query")
def query_similar(req: QueryRequest):
    _ensure_engine(); _ensure_index()
    logger.info("Query: text='%s...' top_k=%d", req.text[:80], req.top_k)
    results = engine.query(req.text, top_k=req.top_k)
    return {"query": req.text, "results": [
        {"rank": r.rank, "score": round(r.score, 4), "doc_id": r.chunk.doc_id,
         "chunk_index": r.chunk.chunk_index, "text": r.chunk.text}
        for r in results
    ]}


@app.post("/api/compare")
def compare_texts(req: CompareRequest):
    _ensure_engine()
    logger.info("Compare: text_a='%s...' vs text_b='%s...'", req.text_a[:60], req.text_b[:60])
    return {"text_a": req.text_a, "text_b": req.text_b, "similarity": round(engine.compare_texts(req.text_a, req.text_b), 4)}


@app.post("/api/analyze/keyword")
def analyze_keyword(req: KeywordAnalysisRequest):
    _ensure_engine(); _ensure_index()
    logger.info("Keyword analysis: keyword='%s', top_k=%d, threshold=%.2f",
                req.keyword, req.top_k, req.cluster_threshold)
    return _serialize_analysis(engine.analyze_keyword(req.keyword, req.top_k, req.cluster_threshold))


@app.post("/api/analyze/similar-words")
def analyze_similar_words(req: W2VWordRequest):
    """Find words that appear in similar contexts using transformer embeddings."""
    _ensure_engine(); _ensure_index()
    logger.info("Similar words (transformer): word='%s', top_k=%d", req.word, req.top_k)
    results = engine.similar_words(req.word, req.top_k)
    return {"word": req.word, "similar": results}


@app.post("/api/analyze/context")
def analyze_context(req: ContextAnalysisRequest):
    """Infer what a keyword likely means from its surrounding context words."""
    _ensure_engine(); _ensure_index()
    logger.info("Context analysis: keyword='%s', threshold=%.2f, top_words=%d",
                req.keyword, req.cluster_threshold, req.top_words)
    return engine.infer_keyword_meanings(
        req.keyword,
        top_words=req.top_words,
        cluster_threshold=req.cluster_threshold,
    )


@app.post("/api/analyze/batch")
def batch_analyze(req: BatchAnalysisRequest):
    _ensure_engine(); _ensure_index()
    logger.info("Batch analysis: %d keywords=%s, top_k=%d",
                len(req.keywords), req.keywords[:5], req.top_k)
    results = engine.batch_analyze_keywords(req.keywords, req.top_k, req.cluster_threshold, req.compare_across)
    return {kw: _serialize_analysis(a) for kw, a in results.items()}


@app.post("/api/match")
def match_keyword(req: KeywordMatchRequest):
    _ensure_engine(); _ensure_index()
    results = engine.match_keyword_to_meaning(req.keyword, req.candidate_meanings)
    return {"keyword": req.keyword, "candidate_meanings": req.candidate_meanings, "matches": [
        {"doc_id": r["chunk"].doc_id, "chunk_index": r["chunk"].chunk_index, "text": r["chunk"].text[:300],
         "best_match": r["best_match"], "best_score": round(r["best_score"], 4),
         "all_scores": {k: round(v, 4) for k, v in r["all_scores"].items()}}
        for r in results
    ]}


# ------------------------------------------------------------------ #
#  Evaluation endpoints
# ------------------------------------------------------------------ #

@app.post("/api/eval/disambiguation")
def evaluate_disambiguation(req: EvalDisambiguationRequest):
    _ensure_evaluator()
    gt = [GroundTruthEntry(keyword=e["keyword"], text=e["text"], true_meaning=e["true_meaning"]) for e in req.ground_truth]
    metrics = evaluator.evaluate_disambiguation(gt, req.candidate_meanings)
    return _to_native({"metrics": [
        {"keyword": m.keyword, "accuracy": m.accuracy, "weighted_f1": m.weighted_f1,
         "per_meaning_precision": m.per_meaning_precision, "per_meaning_recall": m.per_meaning_recall,
         "per_meaning_f1": m.per_meaning_f1, "confusion_matrix": m.confusion, "total_samples": m.total_samples}
        for m in metrics
    ]})


@app.post("/api/eval/retrieval")
def evaluate_retrieval(req: EvalRetrievalRequest):
    _ensure_evaluator()
    metrics = evaluator.evaluate_retrieval(req.queries, req.k_values)
    return _to_native({"metrics": [
        {"query": m.query, "mrr": round(float(m.mrr), 4),
         "precision_at_k": {str(k): round(float(v), 4) for k, v in m.precision_at_k.items()},
         "recall_at_k": {str(k): round(float(v), 4) for k, v in m.recall_at_k.items()},
         "ndcg_at_k": {str(k): round(float(v), 4) for k, v in m.ndcg_at_k.items()},
         "avg_similarity": round(float(m.avg_similarity), 4), "top_score": round(float(m.top_score), 4)}
        for m in metrics
    ]})


@app.get("/api/eval/similarity-distribution")
def similarity_distribution():
    _ensure_evaluator()
    return _to_native(evaluator.analyze_similarity_distribution())


@app.get("/api/eval/report")
def get_eval_report():
    _ensure_evaluator()
    return _to_native(evaluator.get_report().summary())


@app.get("/api/stats")
def get_stats():
    _ensure_engine()
    return engine.get_stats()


@app.get("/api/corpus/texts")
def get_corpus_texts(max_docs: int = Query(default=500, ge=1, le=10_000)):
    """Return loaded document texts grouped by doc_id (for use as training corpus)."""
    _ensure_engine()
    # Group chunks by doc_id
    docs: dict[str, list[str]] = {}
    for chunk in engine.chunks:
        if chunk.doc_id not in docs:
            docs[chunk.doc_id] = []
        docs[chunk.doc_id].append(chunk.text)
    # Combine chunks per document
    result = []
    for doc_id in sorted(docs.keys()):
        if len(result) >= max_docs:
            break
        result.append({"doc_id": doc_id, "text": "\n".join(docs[doc_id])})
    return {"documents": result, "count": len(result)}


@app.get("/api/documents/{doc_id}")
def get_document(doc_id: str):
    """Return the full text of a document by reconstructing its chunks."""
    _ensure_engine()
    chunks = [c for c in engine.chunks if c.doc_id == doc_id]
    if not chunks:
        raise HTTPException(404, f"Document '{doc_id}' not found.")
    chunks.sort(key=lambda c: c.chunk_index)
    full_text = "\n".join(c.text for c in chunks)
    return {"doc_id": doc_id, "text": full_text, "num_chunks": len(chunks)}


@app.post("/api/engine/save")
def save_engine():
    """Save current engine state to disk for later restore."""
    _ensure_engine()
    result = engine.save(str(ENGINE_SAVE_DIR))
    return {"status": "ok", **result}


@app.post("/api/engine/load")
def load_engine_state():
    """Load a previously saved engine state from disk."""
    global engine, evaluator
    if not (ENGINE_SAVE_DIR / "meta.json").is_file():
        raise HTTPException(400, "No saved engine state found.")
    engine = ContextualSimilarityEngine.load(str(ENGINE_SAVE_DIR))
    evaluator = Evaluator(engine) if engine.index is not None else None
    return {"status": "ok", **engine.get_stats()}


@app.get("/api/engine/has-saved-state")
def has_saved_state():
    """Check if a saved engine state exists on disk."""
    exists = (ENGINE_SAVE_DIR / "meta.json").is_file()
    return {"exists": exists}


# ------------------------------------------------------------------ #
#  Word2Vec baseline endpoints
# ------------------------------------------------------------------ #

@app.post("/api/w2v/init")
def w2v_init(req: W2VInitRequest):
    """Train Word2Vec on corpus for comparison."""
    global w2v_engine
    logger.info("Word2Vec init: %d texts, vector_size=%d, window=%d, epochs=%d",
                len(req.corpus_texts), req.vector_size, req.window, req.epochs)
    t0 = time.time()
    w2v_engine = Word2VecEngine(vector_size=req.vector_size, window=req.window, epochs=req.epochs)
    for i, text in enumerate(req.corpus_texts):
        w2v_engine.add_document(f"doc_{i}", text)
    stats = w2v_engine.build_index()
    elapsed = round(time.time() - t0, 2)
    logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed)
    # Auto-save so data persists across restarts
    try:
        w2v_engine.save(str(W2V_SAVE_DIR))
    except Exception:
        logger.warning("Auto-save W2V after init failed", exc_info=True)
    return {**stats, "seconds": elapsed}


@app.post("/api/w2v/init-from-engine")
def w2v_init_from_engine(
    vector_size: int = Query(default=100, ge=50, le=500),
    window: int = Query(default=5, ge=1, le=20),
    epochs: int = Query(default=50, ge=1, le=200),
):
    """Train Word2Vec directly from all documents already loaded in the engine.

    This avoids the round-trip through the frontend and uses ALL engine docs.
    """
    global w2v_engine
    _ensure_engine()
    if not engine.chunks:
        raise HTTPException(400, "No documents in the engine. Load a dataset first.")

    # Group chunks by doc_id to reconstruct full documents
    docs: dict[str, list[str]] = {}
    for chunk in engine.chunks:
        if chunk.doc_id not in docs:
            docs[chunk.doc_id] = []
        docs[chunk.doc_id].append(chunk.text)

    logger.info("Word2Vec init from engine: %d documents, vector_size=%d, window=%d, epochs=%d",
                len(docs), vector_size, window, epochs)
    t0 = time.time()
    w2v_engine = Word2VecEngine(vector_size=vector_size, window=window, epochs=epochs)
    for doc_id, chunks_list in docs.items():
        w2v_engine.add_document(doc_id, "\n".join(chunks_list))
    stats = w2v_engine.build_index()
    elapsed = round(time.time() - t0, 2)
    logger.info("Word2Vec ready: %s in %.2fs", stats, elapsed)
    # Auto-save
    try:
        w2v_engine.save(str(W2V_SAVE_DIR))
    except Exception:
        logger.warning("Auto-save W2V after init failed", exc_info=True)
    return {**stats, "seconds": elapsed, "documents_used": len(docs)}


@app.post("/api/w2v/compare")
def w2v_compare(req: W2VCompareRequest):
    _ensure_w2v()
    return {"text_a": req.text_a, "text_b": req.text_b,
            "similarity": round(w2v_engine.compare_texts(req.text_a, req.text_b), 4)}


@app.post("/api/w2v/query")
def w2v_query(req: W2VQueryRequest):
    _ensure_w2v()
    results = w2v_engine.query(req.text, top_k=req.top_k)
    return {"query": req.text, "results": [
        {"rank": r.rank, "score": round(r.score, 4), "doc_id": r.doc_id, "text": r.text}
        for r in results
    ]}


@app.post("/api/w2v/similar-words")
def w2v_similar_words(req: W2VWordRequest):
    _ensure_w2v()
    similar = w2v_engine.most_similar_words(req.word, req.top_k)
    return {"word": req.word, "similar": [{"word": w, "score": round(s, 4)} for w, s in similar]}


@app.get("/api/w2v/status")
def w2v_status():
    """Check if Word2Vec is loaded (from training or restored from disk)."""
    if w2v_engine is not None and w2v_engine.model is not None:
        return {
            "ready": True,
            "vocab_size": len(w2v_engine.model.wv),
            "sentences": len(w2v_engine.sentences),
            "vector_size": w2v_engine.vector_size,
        }
    has_saved = Word2VecEngine.has_saved_state(str(W2V_SAVE_DIR))
    return {"ready": False, "has_saved_state": has_saved}


@app.post("/api/w2v/reset")
def w2v_reset():
    """Delete saved Word2Vec state and clear the in-memory model."""
    global w2v_engine
    w2v_engine = None
    import shutil
    if W2V_SAVE_DIR.is_dir():
        shutil.rmtree(str(W2V_SAVE_DIR))
        logger.info("Word2Vec state deleted from %s", W2V_SAVE_DIR)
    return {"status": "ok", "message": "Word2Vec state cleared. You can retrain now."}


# ------------------------------------------------------------------ #
#  Dataset endpoints (HuggingFace Epstein Files)
# ------------------------------------------------------------------ #

@app.get("/api/dataset/info")
def dataset_info():
    """Get metadata about available HuggingFace datasets."""
    return get_dataset_info()


@app.post("/api/dataset/load")
def dataset_load(req: DatasetLoadRequest):
    """Load Epstein Files dataset from HuggingFace into the engine."""
    global engine, evaluator
    if engine is None:
        logger.info("Engine not initialized — auto-initializing with default model...")
        engine = ContextualSimilarityEngine(
            model_name="all-MiniLM-L6-v2",
            chunk_size=512,
            chunk_overlap=128,
            batch_size=64,
        )
        evaluator = None
        logger.info("Engine auto-initialized with all-MiniLM-L6-v2")
    if req.source_filter and req.source_filter not in ALLOWED_SOURCE_FILTERS:
        raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}")
    logger.info("Dataset load: source=%s, max_docs=%d, min_text=%d, filter=%s, build_index=%s",
                req.source, req.max_docs, req.min_text_length, req.source_filter, req.build_index)
    t0 = time.time()
    try:
        if req.source == "embeddings":
            result = import_chromadb_to_engine(engine, max_chunks=req.max_docs * 10)
        else:
            result = load_raw_to_engine(
                engine,
                max_docs=req.max_docs,
                min_text_length=req.min_text_length,
                source_filter=req.source_filter,
                build_index=req.build_index,
            )
        logger.info("Dataset loaded in %.1fs", time.time() - t0)
        # Auto-save so data persists across restarts
        try:
            engine.save(str(ENGINE_SAVE_DIR))
        except Exception:
            logger.warning("Auto-save after dataset load failed", exc_info=True)
        return result
    except Exception:
        logger.exception("Dataset load failed")
        raise HTTPException(500, "Dataset load failed. Check server logs for details.")


@app.post("/api/dataset/preview")
def dataset_preview(
    max_docs: int = Query(default=10, ge=1, le=100),
    min_text_length: int = Query(default=100, ge=0, le=100_000),
    source_filter: Optional[str] = Query(default=None),
):
    """Preview a few documents from the raw dataset without loading into engine."""
    if source_filter and source_filter not in ALLOWED_SOURCE_FILTERS:
        raise HTTPException(400, f"source_filter must be one of: {sorted(ALLOWED_SOURCE_FILTERS)}")
    try:
        docs = load_raw_dataset(
            max_docs=max_docs,
            min_text_length=min_text_length,
            source_filter=source_filter,
        )
        return {
            "count": len(docs),
            "documents": [
                {"doc_id": d["doc_id"], "filename": d["filename"],
                 "text_preview": d["text"][:500], "text_length": len(d["text"])}
                for d in docs
            ],
        }
    except Exception:
        logger.exception("Dataset preview failed")
        raise HTTPException(500, "Dataset preview failed. Check server logs for details.")


# ------------------------------------------------------------------ #
#  Helpers
# ------------------------------------------------------------------ #

def _ensure_engine():
    if engine is None:
        raise HTTPException(400, "Engine not initialized. POST /api/init first.")

def _ensure_index():
    if engine.index is None:
        raise HTTPException(400, "Index not built. POST /api/index/build first.")

def _ensure_evaluator():
    global evaluator
    if evaluator is None:
        _ensure_engine(); _ensure_index()
        evaluator = Evaluator(engine)

def _ensure_w2v():
    if w2v_engine is None:
        raise HTTPException(400, "Word2Vec not initialized. POST /api/w2v/init first.")

def _serialize_analysis(analysis):
    return {
        "keyword": analysis.keyword,
        "total_occurrences": analysis.total_occurrences,
        "meaning_clusters": [{
            "cluster_id": c["cluster_id"], "size": c["size"],
            "representative_text": c["representative_text"],
            "contexts": [{"doc_id": ctx.chunk.doc_id, "chunk_index": ctx.chunk.chunk_index,
                          "text": ctx.chunk.text[:300], "highlight_positions": ctx.highlight_positions}
                         for ctx in c["contexts"]],
            "similar_passages": [{"rank": s.rank, "score": round(s.score, 4),
                                  "doc_id": s.chunk.doc_id, "text": s.chunk.text[:200]}
                                 for s in c["similar_passages"]],
        } for c in analysis.meaning_clusters],
        "cross_keyword_similarities": {k: round(v, 4) for k, v in analysis.cross_keyword_similarities.items()},
    }

# ------------------------------------------------------------------ #
#  Static frontend (production build served from /frontend/dist)
# ------------------------------------------------------------------ #

_FRONTEND_DIR = BASE_DIR / "frontend" / "dist"

@app.get("/api/debug/frontend")
async def debug_frontend_files():
    """List all files in the frontend dist directory (for debugging deploys)."""
    if not _FRONTEND_DIR.is_dir():
        return {"exists": False, "path": str(_FRONTEND_DIR)}
    files = []
    for root, _dirs, filenames in os.walk(str(_FRONTEND_DIR)):
        for f in filenames:
            full = os.path.join(root, f)
            rel = os.path.relpath(full, str(_FRONTEND_DIR))
            files.append({"path": rel, "size": os.path.getsize(full)})
    return {"exists": True, "path": str(_FRONTEND_DIR), "files": files}


if _FRONTEND_DIR.is_dir():
    @app.get("/{full_path:path}")
    async def serve_frontend(full_path: str):
        """Serve the React SPA — static files or index.html fallback."""
        if full_path:
            file_path = (_FRONTEND_DIR / full_path).resolve()
            if file_path.is_file() and file_path.is_relative_to(_FRONTEND_DIR):
                return FileResponse(file_path)
        return FileResponse(_FRONTEND_DIR / "index.html")

    logger.info("Frontend serving enabled from %s", _FRONTEND_DIR)


if __name__ == "__main__":
    import uvicorn
    host = os.environ.get("HOST", "127.0.0.1")
    port = int(os.environ.get("PORT", "8000"))
    has_frontend = _FRONTEND_DIR.is_dir()
    logger.info("=" * 60)
    logger.info("Contextual Similarity API starting")
    logger.info("  Server:   http://%s:%d", host, port)
    if has_frontend:
        logger.info("  Frontend: http://%s:%d (built-in)", host, port)
    else:
        logger.info("  Frontend: http://localhost:5173 (dev server)")
    logger.info("  API Docs: http://%s:%d/docs", host, port)
    logger.info("=" * 60)
    uvicorn.run(app, host=host, port=port)