Spaces:
Runtime error
Runtime error
| # ── DIAGNOSTICS & SHIM (must come before any BERTopic import) ───────────── | |
| import pkgutil, sentence_transformers, bertopic, sys, json, os, uuid | |
| # 1) Print versions & model-list | |
| print("ST version:", sentence_transformers.__version__) | |
| print("BERTopic version:", bertopic.__version__) | |
| models = [m.name for m in pkgutil.iter_modules(sentence_transformers.models.__path__)] | |
| print("ST models:", models) | |
| sys.stdout.flush() | |
| # 2) If StaticEmbedding is missing, alias Transformer → StaticEmbedding | |
| if "StaticEmbedding" not in models: | |
| from sentence_transformers.models import Transformer | |
| import sentence_transformers.models as _st_mod | |
| setattr(_st_mod, "StaticEmbedding", Transformer) | |
| print("🔧 Shim applied: StaticEmbedding → Transformer") | |
| sys.stdout.flush() | |
| # ────────────────────────────────────────────────────────────────────────────── | |
| # ── REST OF APP.PY ─────────────────────────────────────────────────────────── | |
| from typing import List | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from bertopic import BERTopic | |
| from sentence_transformers import SentenceTransformer | |
| from umap import UMAP | |
| from hdbscan import HDBSCAN | |
| from sklearn.feature_extraction.text import CountVectorizer | |
| from stop_words import get_stop_words | |
| # 0) Quick env dump | |
| print("ENV-snapshot:", json.dumps({k: os.environ[k] for k in list(os.environ)[:10]})) | |
| sys.stdout.flush() | |
| # 1) Tidy numba cache | |
| os.environ.setdefault("NUMBA_CACHE_DIR", "/tmp/numba_cache") | |
| os.makedirs(os.environ["NUMBA_CACHE_DIR"], exist_ok=True) | |
| os.environ["NUMBA_DISABLE_CACHE"] = "1" | |
| # 2) Config from ENV | |
| # Read model name from env and normalize to lowercase to match HF repo ID | |
| env_model = os.getenv("EMBED_MODEL", "Seznam/simcse-small-e-czech") | |
| MODEL_NAME = env_model | |
| MIN_TOPIC = int(os.getenv("MIN_TOPIC_SIZE", "10")) | |
| MAX_DOCS = int(os.getenv("MAX_DOCS", "5000")) | |
| # 3) Set HF cache envs to a writeable folder (once at startup) envs to a writeable folder (once at startup) | |
| cache_dir = "/tmp/hfcache" | |
| os.makedirs(cache_dir, exist_ok=True) | |
| import stat | |
| os.chmod(cache_dir, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) | |
| os.environ["HF_HOME"] = cache_dir | |
| os.environ["TRANSFORMERS_CACHE"] = cache_dir | |
| os.environ["SENTENCE_TRANSFORMERS_HOME"] = cache_dir | |
| # 4) Initialise embeddings once | |
| # Use huggingface_hub to snapshot-download the model locally | |
| from huggingface_hub import snapshot_download | |
| print(f"Downloading model {MODEL_NAME} to {cache_dir}...") | |
| sys.stdout.flush() | |
| local_model_path = snapshot_download(repo_id=MODEL_NAME, cache_dir=cache_dir) | |
| # Load SentenceTransformer from local path | |
| embeddings = SentenceTransformer(local_model_path, cache_folder=cache_dir) | |
| # Pre-initialize fallback global models for small-batch debugging | |
| # Global UMAP: 2-neighbors, cosine space, random init | |
| global_umap = UMAP( | |
| n_neighbors=2, | |
| metric="cosine", | |
| init="random", | |
| random_state=42 | |
| ) | |
| # Global HDBSCAN: min cluster size 2, min_samples 1, cosine metric | |
| global_hdbscan = HDBSCAN( | |
| min_cluster_size=2, | |
| min_samples=1, | |
| metric="cosine", | |
| cluster_selection_method="eom", | |
| prediction_data=True | |
| ) | |
| # Global Czech vectorizer: stopwords + bigrams | |
| global_vectorizer = CountVectorizer( | |
| stop_words=get_stop_words("czech"), | |
| ngram_range=(1, 2) | |
| ) | |
| # 5) FastAPI schemas and app | |
| class Sentence(BaseModel): | |
| text: str | |
| start: float | |
| end: float | |
| speaker: str | None = None | |
| chunk_index: int | None = None | |
| class Segment(BaseModel): | |
| topic_id: int | |
| label: str | None | |
| keywords: List[str] | |
| start: float | |
| end: float | |
| probability: float | None | |
| sentences: List[int] | |
| class SegmentationResponse(BaseModel): | |
| run_id: str | |
| segments: List[Segment] | |
| app = FastAPI(title="CZ Topic Segmenter", version="1.0") | |
| async def root(): | |
| return {"message": "CZ Topic Segmenter is running."} | |
| def segment(sentences: List[Sentence]): | |
| print(f"[segment] Received {len(sentences)} sentences, chunk_indices={[s.chunk_index for s in sentences]}") | |
| sys.stdout.flush() | |
| if len(sentences) > MAX_DOCS: | |
| raise HTTPException(413, f"Too many sentences ({len(sentences)} > {MAX_DOCS})") | |
| # sort by chunk_index | |
| sorted_sent = sorted( | |
| sentences, | |
| key=lambda s: s.chunk_index if s.chunk_index is not None else 0 | |
| ) | |
| docs = [s.text for s in sorted_sent] | |
| # Use global UMAP/HDBSCAN/vectorizer instances for debugging | |
| umap_model = global_umap | |
| hdbscan_model = global_hdbscan | |
| vectorizer_model = global_vectorizer | |
| # instantiate BERTopic per request with global components | |
| topic_model = BERTopic( | |
| embedding_model=embeddings, | |
| umap_model=umap_model, | |
| hdbscan_model=hdbscan_model, | |
| vectorizer_model=vectorizer_model, | |
| min_topic_size=2, | |
| calculate_probabilities=True | |
| ) | |
| topics, probs = topic_model.fit_transform(docs) | |
| segments, cur = [], None | |
| for idx, (t_id, prob) in enumerate(zip(topics, probs)): | |
| orig_idx = sorted_sent[idx].chunk_index if sorted_sent[idx].chunk_index is not None else idx | |
| if cur is None or t_id != cur["topic_id"]: | |
| words = [w for w, _ in topic_model.get_topic(t_id)[:5]] | |
| cur = dict( | |
| topic_id=t_id, | |
| label=" ".join(words) if t_id != -1 else None, | |
| keywords=words, | |
| start=sorted_sent[idx].start, | |
| end=sorted_sent[idx].end, | |
| probability=float(prob or 0), | |
| sentences=[orig_idx], | |
| ) | |
| else: | |
| cur["end"] = sorted_sent[idx].end | |
| cur["sentences"].append(orig_idx) | |
| if cur: | |
| segments.append(cur) | |
| print(f"[segment] Returning {len(segments)} segments: {segments}") | |
| sys.stdout.flush() | |
| return {"run_id": str(uuid.uuid4()), "segments": segments} | |