nexusbert commited on
Commit
d47b370
·
1 Parent(s): ff602f5

Refactor agent workflow and update documentation for Gemini-first implementation

Browse files

- Updated AGENT_WORKFLOW.md to reflect changes in API, architecture, and environment variables.
- Revised README.md to clarify the use of Gemini API for Task A and Task B, and removed references to local LLM inference.
- Refactored recommendation_pipeline.py to streamline ranking logic and remove local LLM dependencies.
- Simplified user_modeling.py by enforcing Gemini usage for generation and removing local LLM code.
- Cleaned up shared_models.py by removing unused local LLM functions and optimizing embedder initialization.

AGENT_WORKFLOW.md CHANGED
@@ -1,6 +1,6 @@
1
  # Agent workflow (Task A & Task B)
2
 
3
- This document matches the API `agent_steps`, the code layout, and environment variables. Use it as the backbone for the **solution paper** (architecture, experiments, limitations).
4
 
5
  ## Architecture overview
6
 
@@ -10,14 +10,14 @@ flowchart TB
10
  P[STARTUP_PREWARM=all]
11
  S[shared_models.warm_shared_weights]
12
  A[UserModelingService.warm - RAG index]
13
- B[RecommendationService.warm - catalog]
14
  P --> S --> A
15
  P --> B
16
  end
17
 
18
  subgraph taskA [POST /user-modeling]
19
  R1[yelp_rag_retrieve]
20
- G1[gemini_generate or local_hf]
21
  P1[parse_stars_review]
22
  R1 --> G1 --> P1
23
  end
@@ -25,8 +25,8 @@ flowchart TB
25
  subgraph taskB [POST /recommendation]
26
  E2[embedded_persona_context]
27
  V2[vector_retrieval_top_K]
28
- L2[gemini_rank or local_hf_rank]
29
- E2 --> V2 --> L2
30
  end
31
 
32
  S -.-> E2
@@ -35,11 +35,12 @@ flowchart TB
35
 
36
  | Layer | Role |
37
  |--------|------|
38
- | **`app/gemini_client.py`** | **Gemini API** text generation when `GEMINI_API_KEY` is set (`GENERATION_BACKEND=gemini` or `auto`). |
39
- | **`app/shared_models.py`** | Shared **SentenceTransformer**; optional local causal LM only when `GENERATION_BACKEND=local`. |
40
- | **`app/main.py`** | FastAPI routes; **`asyncio.to_thread`** for CPU/API work. |
41
- | **`scripts/docker_build_assets.py`** | Image build: embedder snapshot + JSONL indexes; skips Qwen download when `SKIP_LOCAL_LLM_HUB_DOWNLOAD=1`. |
42
- | **Runtime** | Embeddings from baked `/models/huggingface`; generation via **Gemini** by default. |
 
43
 
44
  ---
45
 
@@ -47,34 +48,33 @@ flowchart TB
47
 
48
  | Variable | Purpose |
49
  |----------|---------|
50
- | `GENERATION_BACKEND` | `gemini`, `local`, or `auto` (default: Gemini if `GEMINI_API_KEY` is set, else local Qwen). |
51
- | `GEMINI_API_KEY` / `GOOGLE_API_KEY` | Google AI Studio API key for Task A + Task B generation. |
52
  | `GEMINI_MODEL` | Model id (default `gemini-2.0-flash`). |
53
- | `SKIP_LOCAL_LLM_HUB_DOWNLOAD` | `1` at Docker build: do not bake Qwen weights (use with Gemini). |
54
- | `LOCAL_EMBEDDING_MODEL` | Shared embedder (default `all-MiniLM-L6-v2`) for Task A RAG queries and Task B retrieval. |
55
- | `LOCAL_LLM_MODEL` | Local causal LM when `GENERATION_BACKEND=local` (default `Qwen/Qwen2.5-1.5B-Instruct`). |
56
- | `TASK_A_REVIEWS_EMBEDDED` | Path to embedded review snippets JSONL for RAG. |
57
- | `TASK_A_RAG_TOP_K` | Snippets passed into the Task A prompt (default `5`). |
58
- | `TASK_A_MAX_TOKENS` / `TASK_A_TEMPERATURE` | Generation limits for Task A. |
59
  | `TASK_B_EMBEDDED_CATALOG` | Path to embedded business catalog JSONL. |
60
- | `TASK_B_LLM_CANDIDATE_CAP` | Max candidates sent to the LLM reranker (default `6`). |
61
- | `TASK_B_RANK_MODE` | `llm` (default) or `retrieval` (embedding order only, <1s). `TASK_B_FAST_RANK=1` forces retrieval. |
62
- | `STARTUP_PREWARM` | `all` (default): load shared weights + RAG + catalog before traffic. |
63
- | `SKIP_STARTUP_PREWARM` | Set to `1` to skip startup load (not recommended on Spaces). |
64
- | `HF_TOKEN` | Optional; pass into **Docker build** for Hub rate limits when downloading models. |
65
 
66
- Optional overrides (only if tasks use different models): `TASK_A_EMBEDDING_MODEL`, `TASK_A_LOCAL_LLM_MODEL`, `TASK_B_LOCAL_EMBEDDING_MODEL`, `TASK_B_LOCAL_LLM_MODEL`.
67
 
68
  ---
69
 
70
  ## Shared behaviour
71
 
72
- - **Generation:** **Gemini API** by default; **local Qwen** if `GENERATION_BACKEND=local`. Embeddings always **local** MiniLM.
73
- - **Nigerian English (competition bonus):** Task A reviews and Task B rationales are prompted for **natural Nigerian English** (see `user_modeling_prompt.py`, `recommendation_pipeline.py`).
74
- - **Task A point of view:** Reviews must be **first person** (`I` / `my`) as the user who visited — not third-person narration about the user.
75
- - **Failure handling:** Task A retries once with a **format + POV** nudge if parsing fails. Task B uses **popularity fallback** if reranker JSON is invalid.
76
- - **Deployment (2 vCPU / 16 GB):** Use **`uvicorn --workers 1`**. With Gemini, startup only loads the embedder (fast, low RAM). With local Qwen, one generation at a time via **`inference_lock`**.
77
- - **Swagger:** With **Gemini**, Task B usually finishes in a few seconds. Local CPU rerank may still exceed browser timeouts.
78
 
79
  ---
80
 
@@ -82,19 +82,19 @@ Optional overrides (only if tasks use different models): `TASK_A_EMBEDDING_MODEL
82
 
83
  **Endpoint:** `POST /user-modeling` (aliases `/task-1`, `/task_a`)
84
 
85
- **Input:** `persona` (multiline snapshot; optional `user_id: …` for RAG boost), `product` (business facts), `include_raw`.
86
 
87
- **Output:** `stars`, `review`, `parse_ok`, `rag_snippets_used`, `agent_steps`.
88
 
89
  | Step | `agent_steps` | Code | What happens |
90
- |------|---------------|------|----------------|
91
- | 1 | `yelp_rag_retrieve` *(if index exists)* | `UserModelingService._retrieve_examples` → `TaskAReviewRagIndex.retrieve` | Embed `persona + product` with `LOCAL_EMBEDDING_MODEL`; top`TASK_A_RAG_TOP_K` snippets; prefer same `user_id` when present in persona. |
92
- | 2 | `gemini_generate` or `local_hf_causal_lm` | `_generate` / `_generate_fix` | Gemini API or shared local Qwen; Nigerian voice + **first-person** rules; RAG + snapshot + business. |
93
- | 3 | `parse_stars_review` | `parse_model_output` | Parse `Stars:` and `Review:`; one retry if missing. |
94
 
95
- If `TASK_A_REVIEWS_EMBEDDED` is missing, step 1 is skipped (`rag_snippets_used: 0`).
96
 
97
- **Build index:** `scripts/build_task_a_review_rag.py` (Yelp `review.json` + `business.json`).
98
 
99
  ---
100
 
@@ -102,47 +102,52 @@ If `TASK_A_REVIEWS_EMBEDDED` is missing, step 1 is skipped (`rag_snippets_used:
102
 
103
  **Endpoint:** `POST /recommendation` (aliases `/task-2`, `/task_b`)
104
 
105
- **Input:** `persona`, optional `city` / `state`, `chat_history`, `top_k_retrieval` (default **20**), `top_n_final` (default **3**).
106
 
107
- **Output:** `recommendations[]` with `business_id`, `rank`, `rationale`, plus `candidates_considered`, `agent_steps`.
108
 
109
  | Step | `agent_steps` | Code | What happens |
110
- |------|---------------|------|----------------|
111
- | 1 | `embedded_persona_context` | `_embed_persona_local` `get_embedder()` | Persona + recent chat encoded with `LOCAL_EMBEDDING_MODEL`. |
112
- | 2 | `vector_retrieval_top_{K}` | `CatalogIndex.retrieve` | Cosine search on `TASK_B_EMBEDDED_CATALOG`; optional city/state filter. |
113
- | 3 | `gemini_reason_and_rank`, `local_hf_llm_reason_and_rank`, or `retrieval_score_rank` | `chat_rank_gemini`, `chat_rank_local_hf`, or `retrieval_rank` | LLM rerank ≤6 candidates; retrieval mode skips API. |
114
 
115
- **Build index:** `scripts/build_business_catalog.py` → `scripts/embed_catalog.py` (same embedding model as runtime).
116
 
117
  ---
118
 
119
  ## Docker build vs container start
120
 
121
  | Phase | What runs | What you get |
122
- |-------|-----------|----------------|
123
- | **`docker build`** | `snapshot_download` for embedder + LLM weights; stub or Yelp JSONL; `DOCKER_BUILD_SKIP_LLM_WARM=1` by default | Model **files** on disk in the image; no second full Qwen load during build (prevents exit 137 OOM). |
124
- | **Container start** | `STARTUP_PREWARM=all` → `warm_shared_weights()` + RAG + catalog load | Models loaded **once** into RAM (~1–2 min on CPU); then requests are much faster. |
125
- | **Each request** | `asyncio.to_thread` + shared models | One inference at a time under `inference_lock`; `/health` and `/docs` stay responsive. |
126
 
127
  ---
128
 
129
  ## Reproducibility checklist
130
 
131
- 1. `cp env.example .env` set `HF_TOKEN` if needed for builds.
132
- 2. Build indexes locally (optional if using Docker stubs):
133
- - `python scripts/build_task_a_review_rag.py`
134
- - `python scripts/build_business_catalog.py` then `python scripts/embed_catalog.py …`
135
- 3. `docker build` / Space rebuild with build-time `HF_TOKEN` if Hub rate-limits.
136
- 4. `docker compose up` or `uvicorn app.main:app --host 0.0.0.0 --port 8080`
137
- 5. Wait for logs: `Startup prewarm complete.`
138
- 6. Smoke test: `GET /health`, then `POST /user-modeling`, then `POST /recommendation` with `top_k_retrieval: 15`, `top_n_final: 5`.
 
 
 
 
 
139
 
140
  ---
141
 
142
  ## Paper pointers
143
 
144
- - **Why RAG for Task A:** Calibrate stars and voice from real Yelp snippets; `user_id` match surfaces same-user style when available.
145
- - **Why two-stage Task B:** Vector retrieval scales; LLM reranking adds persona-conditioned explanations.
146
- - **Why shared models:** One embedder + one LM for both tasks — fits 16 GB RAM with a single worker.
147
- - **Nigerian English + first-person Task A:** Prompt-level design; ablate with/without locale or POV blocks.
148
- - **Limitations:** CPU latency, single-worker queue, stub catalog when Yelp JSON is not baked into the image, small LM vs fine-tuned Azure baseline.
 
1
  # Agent workflow (Task A & Task B)
2
 
3
+ This document matches the current API `agent_steps`, application implementation, and environment variables. Use it as the backbone for architecture notes, evaluation, and deployment.
4
 
5
  ## Architecture overview
6
 
 
10
  P[STARTUP_PREWARM=all]
11
  S[shared_models.warm_shared_weights]
12
  A[UserModelingService.warm - RAG index]
13
+ B[RecommendationService.ensure_catalog]
14
  P --> S --> A
15
  P --> B
16
  end
17
 
18
  subgraph taskA [POST /user-modeling]
19
  R1[yelp_rag_retrieve]
20
+ G1[gemini_generate]
21
  P1[parse_stars_review]
22
  R1 --> G1 --> P1
23
  end
 
25
  subgraph taskB [POST /recommendation]
26
  E2[embedded_persona_context]
27
  V2[vector_retrieval_top_K]
28
+ G2[gemini_reason_and_rank]
29
+ E2 --> V2 --> G2
30
  end
31
 
32
  S -.-> E2
 
35
 
36
  | Layer | Role |
37
  |--------|------|
38
+ | **`app/gemini_client.py`** | Gemini API generation for Task A and Task B. |
39
+ | **`app/shared_models.py`** | Shared **SentenceTransformer** embedder for RAG and retrieval. |
40
+ | **`app/main.py`** | FastAPI routes and lifecycle with `asyncio.to_thread`. |
41
+ | **`app/user_modeling.py`** | Task A user simulation: RAG + Gemini generation + output parsing. |
42
+ | **`app/recommendation_pipeline.py`** | Task B recommendation: persona retrieval + Gemini reranking. |
43
+ | **`scripts/docker_build_assets.py`** | Builds embedding snapshots and JSONL indexes; honors `SKIP_LOCAL_LLM_HUB_DOWNLOAD` for Gemini-first deploys. |
44
 
45
  ---
46
 
 
48
 
49
  | Variable | Purpose |
50
  |----------|---------|
51
+ | `GENERATION_BACKEND` | `gemini` or `auto` (auto picks Gemini when `GEMINI_API_KEY` / `GOOGLE_API_KEY` is set). |
52
+ | `GEMINI_API_KEY` / `GOOGLE_API_KEY` | Required for Gemini API generation. |
53
  | `GEMINI_MODEL` | Model id (default `gemini-2.0-flash`). |
54
+ | `SKIP_LOCAL_LLM_HUB_DOWNLOAD` | `1` at Docker build: skip local Qwen download and keep the image Gemini-focused. |
55
+ | `LOCAL_EMBEDDING_MODEL` | Shared embedder (default `all-MiniLM-L6-v2`) for Task A RAG and Task B retrieval. |
56
+ | `TASK_A_REVIEWS_EMBEDDED` | Path to Task A RAG snippet JSONL. |
57
+ | `TASK_A_RAG_TOP_K` | Number of snippets retrieved for Task A (default `5`). |
58
+ | `TASK_A_MAX_TOKENS` | Token limit for Task A generation. |
59
+ | `TASK_A_TEMPERATURE` | Temperature for Task A Gemini generation. |
60
  | `TASK_B_EMBEDDED_CATALOG` | Path to embedded business catalog JSONL. |
61
+ | `TASK_B_MAX_OUTPUT_TOKENS` | Max tokens for Task B Gemini rank output. |
62
+ | `STARTUP_PREWARM` | `all` (default): warm shared embedder and indexes at startup. |
63
+ | `SKIP_STARTUP_PREWARM` | Set to `1` to skip startup prewarm. |
64
+ | `HF_TOKEN` | Optional token for Docker build-time Hub downloads. |
 
65
 
66
+ > Note: current app code is Gemini-first. Local causal LLM inference is not used by the active Task A/B code paths.
67
 
68
  ---
69
 
70
  ## Shared behaviour
71
 
72
+ - **Generation:** Gemini API only in active Task A/B code paths.
73
+ - **Embeddings:** Local SentenceTransformer embeddings are used for Task A RAG and Task B retrieval.
74
+ - **Nigerian English:** Prompt design enforces natural Nigerian English in Task A reviews and Task B rationales.
75
+ - **Task A POV:** Reviews must be first-person from the user’s perspective (`I`, `my`, `me`).
76
+ - **Failure handling:** Task A retries once with a stricter format/POV fix prompt if parsing fails. Task B falls back to retrieval order with safe rationale text if Gemini output cannot be parsed.
77
+ - **Startup preload:** `STARTUP_PREWARM=all` warms the shared embedder and loads indexes before traffic.
78
 
79
  ---
80
 
 
82
 
83
  **Endpoint:** `POST /user-modeling` (aliases `/task-1`, `/task_a`)
84
 
85
+ **Input:** `persona`, `product`, `include_raw`.
86
 
87
+ **Output:** `task`, `agent_steps`, `rag_snippets_used`, `stars`, `review`, `parse_ok`, `raw` (optional).
88
 
89
  | Step | `agent_steps` | Code | What happens |
90
+ |------|---------------|------|-------------|
91
+ | 1 | `yelp_rag_retrieve` *(if index exists)* | `UserModelingService._retrieve_examples` → `TaskAReviewRagIndex.retrieve` | Embed `persona + product` with `LOCAL_EMBEDDING_MODEL`; retrieve top `TASK_A_RAG_TOP_K` snippets for style calibration. |
92
+ | 2 | `gemini_generate` | `UserModelingService._generate` | Gemini generates a Nigerian English first-person review and star rating. |
93
+ | 3 | `parse_stars_review` | `parse_model_output` | Extract `Stars:` and `Review:`; retry once if parse is incomplete. |
94
 
95
+ If the RAG index file is missing, step 1 is skipped and `rag_snippets_used` is `0`.
96
 
97
+ **Build index:** `scripts/build_task_a_review_rag.py`.
98
 
99
  ---
100
 
 
102
 
103
  **Endpoint:** `POST /recommendation` (aliases `/task-2`, `/task_b`)
104
 
105
+ **Input:** `persona`, optional `city` / `state`, `chat_history`, `top_k_retrieval` (default `20`), `top_n_final` (default `5`).
106
 
107
+ **Output:** `task`, `agent_steps`, `candidates_considered`, `recommendations[]`.
108
 
109
  | Step | `agent_steps` | Code | What happens |
110
+ |------|---------------|------|-------------|
111
+ | 1 | `embedded_persona_context` | `build_query_text` | Build a combined persona + recent chat query string. |
112
+ | 2 | `vector_retrieval_top_{K}` | `CatalogIndex.retrieve` | Cosine similarity retrieval from `TASK_B_EMBEDDED_CATALOG` using local embeddings; optional city/state filter. |
113
+ | 3 | `gemini_reason_and_rank` | `chat_rank_gemini` | Gemini ranks selected candidates and returns conversational rationales. |
114
 
115
+ **Build index:** `scripts/build_business_catalog.py` → `scripts/embed_catalog.py`.
116
 
117
  ---
118
 
119
  ## Docker build vs container start
120
 
121
  | Phase | What runs | What you get |
122
+ |-------|-----------|-------------|
123
+ | **`docker build`** | Embedder snapshot and JSONL index creation; `SKIP_LOCAL_LLM_HUB_DOWNLOAD=1` keeps the build Gemini-first. | Model files and indexes ready in the image. |
124
+ | **Container start** | `STARTUP_PREWARM=all` → warm shared embedder, Task A RAG index, and Task B catalog index. | Startup load happens once, then requests are faster. |
125
+ | **Each request** | `asyncio.to_thread` + shared embeddings + Gemini API | Task A/B remain responsive; no local LLM load. |
126
 
127
  ---
128
 
129
  ## Reproducibility checklist
130
 
131
+ 1. `cp env.example .env` and set `GEMINI_API_KEY` or `GOOGLE_API_KEY`.
132
+ 2. Build indexes locally if needed:
133
+ - `python scripts/build_task_a_review_rag.py`
134
+ - `python scripts/build_business_catalog.py`
135
+ - `python scripts/embed_catalog.py`
136
+ 3. Run the app:
137
+ - `docker compose up`
138
+ - or `uvicorn app.main:app --host 0.0.0.0 --port 8080`
139
+ 4. Confirm startup logs: `Startup prewarm complete.`
140
+ 5. Smoke test:
141
+ - `GET /health`
142
+ - `POST /user-modeling`
143
+ - `POST /recommendation` with `top_k_retrieval: 15`, `top_n_final: 5`.
144
 
145
  ---
146
 
147
  ## Paper pointers
148
 
149
+ - **Gemini-first pipeline:** active code paths call Gemini for both Task A and Task B.
150
+ - **RAG for Task A:** retrieve real review snippets for style calibration and rating behavior.
151
+ - **Two-stage recommendation:** local retrieval scales, followed by persona-conditioned Gemini reranking.
152
+ - **Nigerian English:** prompt-level design enforces tone and first-person review voice.
153
+ - **Limitations:** API dependency, embedding latency, single worker, and potential query/catalog model mismatch.
README.md CHANGED
@@ -13,7 +13,7 @@ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-
13
 
14
  This Space is configured as **`sdk: docker`**. The image builds from `Dockerfile` (CPU-only PyTorch so CUDA wheels don’t OOM the builder). During **`docker build`**, models are **`snapshot_download`**’d into `/models/huggingface` **without loading the full LLM into RAM**; **`SentenceTransformer`** embeds a **stub** or Yelp-derived catalog plus **`data/task_a_reviews_embedded.jsonl`** (review RAG for Task A). See `scripts/docker_build_assets.py`.
15
 
16
- Task **A**: persona + product → rating/review via **Gemini API** (default) or optional **local** Qwen, plus **retrieved Yelp review snippets** from the baked JSONL. Task **B**: **local** sentence-transformer retrieval over businesses plus **Gemini** (or local) reranking.
17
 
18
  **Secrets (Hugging Face Space):** **`GEMINI_API_KEY`** (or `GOOGLE_API_KEY`) — required for generation when `GENERATION_BACKEND=gemini`. Optional **`HF_TOKEN`** for **Docker build** only (embedder download). Never commit keys in the repo.
19
 
@@ -67,9 +67,9 @@ python scripts/build_task_a_review_rag.py \
67
 
68
  Use the same `TASK_B_LOCAL_EMBEDDING_MODEL` (or `TASK_A_EMBEDDING_MODEL`) at build and runtime. Omit the file only for quick tests (generation runs without RAG).
69
 
70
- **Generation:** set `GEMINI_API_KEY` in `.env` (see `env.example`). With `GENERATION_BACKEND=gemini` (default), Task A and Task B use **`GEMINI_MODEL`** (default `gemini-2.0-flash`). Set `GENERATION_BACKEND=local` to use on-device Qwen instead.
71
 
72
- **Task B** reranking uses Gemini when configured; embeddings stay local (`LOCAL_EMBEDDING_MODEL`).
73
 
74
  **Recommendation index** (needs Yelp `business.json` on your machine, e.g. `../yelp_dataset/extracted/` from a parent workspace):
75
 
@@ -107,7 +107,7 @@ Default compose maps **`7860:7860`**. The image bakes **`/code/data/business_cat
107
 
108
  The Docker image sets **`HF_HUB_OFFLINE=1`** and **`TRANSFORMERS_OFFLINE=1`** so the running container does not call the Hugging Face Hub. During **`docker build`**, **`snapshot_download`** copies model **files** into `/models/huggingface` (and stub JSONL is embedded). Loading weights **into RAM** during build was disabled by default (**`DOCKER_BUILD_SKIP_LLM_WARM=1`**) because HF build VMs often **OOM (exit 137)** when loading Qwen; that RAM would not stay in the final image anyway.
109
 
110
- At **container start**, **`STARTUP_PREWARM=all`** (default) loads **one shared** embedding model and **one shared** causal LM (`app/shared_models.py`), then Task A RAG + Task B catalog — so **`/task-2`** does not pay a second full Qwen load. Expect **~1–2 minutes** on CPU after deploy while logs show `Loading shared …`; then both endpoints stay fast. Disable with **`SKIP_STARTUP_PREWARM=1`** (not recommended on Spaces).
111
 
112
  ### Smoke checks
113
 
@@ -119,7 +119,7 @@ OpenAPI: `http://localhost:7860/docs` when using Docker (port **7860**). Local `
119
  |------|------|
120
  | `app/main.py` | FastAPI routes |
121
  | [`AGENT_WORKFLOW.md`](AGENT_WORKFLOW.md) | Agent steps, reproducibility, paper hooks (Nigerian English, fallbacks) |
122
- | `app/user_modeling.py`, `app/user_modeling_prompt.py`, `app/task_a_rag.py` | Task 1 local LLM + Yelp review RAG |
123
  | `app/recommendation_pipeline.py` | Task 2 retrieval + rerank |
124
  | `scripts/build_business_catalog.py` | Yelp → catalog JSONL |
125
  | `scripts/embed_catalog.py` | Embed catalog (local sentence-transformers) |
 
13
 
14
  This Space is configured as **`sdk: docker`**. The image builds from `Dockerfile` (CPU-only PyTorch so CUDA wheels don’t OOM the builder). During **`docker build`**, models are **`snapshot_download`**’d into `/models/huggingface` **without loading the full LLM into RAM**; **`SentenceTransformer`** embeds a **stub** or Yelp-derived catalog plus **`data/task_a_reviews_embedded.jsonl`** (review RAG for Task A). See `scripts/docker_build_assets.py`.
15
 
16
+ Task **A**: persona + product → rating/review via **Gemini API** and retrieved Yelp review snippets from the baked JSONL. Task **B**: local sentence-transformer retrieval over businesses plus **Gemini** reranking.
17
 
18
  **Secrets (Hugging Face Space):** **`GEMINI_API_KEY`** (or `GOOGLE_API_KEY`) — required for generation when `GENERATION_BACKEND=gemini`. Optional **`HF_TOKEN`** for **Docker build** only (embedder download). Never commit keys in the repo.
19
 
 
67
 
68
  Use the same `TASK_B_LOCAL_EMBEDDING_MODEL` (or `TASK_A_EMBEDDING_MODEL`) at build and runtime. Omit the file only for quick tests (generation runs without RAG).
69
 
70
+ **Generation:** set `GEMINI_API_KEY` in `.env` (see `env.example`). With `GENERATION_BACKEND=gemini` or `auto` (default), Task A and Task B both use **Gemini**. Local causal LLM inference is not used by current runtime code.
71
 
72
+ **Task B** reranking uses Gemini; embeddings stay local (`LOCAL_EMBEDDING_MODEL`).
73
 
74
  **Recommendation index** (needs Yelp `business.json` on your machine, e.g. `../yelp_dataset/extracted/` from a parent workspace):
75
 
 
107
 
108
  The Docker image sets **`HF_HUB_OFFLINE=1`** and **`TRANSFORMERS_OFFLINE=1`** so the running container does not call the Hugging Face Hub. During **`docker build`**, **`snapshot_download`** copies model **files** into `/models/huggingface` (and stub JSONL is embedded). Loading weights **into RAM** during build was disabled by default (**`DOCKER_BUILD_SKIP_LLM_WARM=1`**) because HF build VMs often **OOM (exit 137)** when loading Qwen; that RAM would not stay in the final image anyway.
109
 
110
+ At **container start**, **`STARTUP_PREWARM=all`** (default) loads the shared embedding model and preloads Task A RAG + Task B catalog indexes. Expect **~1–2 minutes** on CPU after deploy while logs show `Loading shared …`; then both endpoints stay fast. Disable with **`SKIP_STARTUP_PREWARM=1`** (not recommended on Spaces).
111
 
112
  ### Smoke checks
113
 
 
119
  |------|------|
120
  | `app/main.py` | FastAPI routes |
121
  | [`AGENT_WORKFLOW.md`](AGENT_WORKFLOW.md) | Agent steps, reproducibility, paper hooks (Nigerian English, fallbacks) |
122
+ | `app/user_modeling.py`, `app/user_modeling_prompt.py`, `app/task_a_rag.py` | Task 1 Gemini generation + Yelp review RAG |
123
  | `app/recommendation_pipeline.py` | Task 2 retrieval + rerank |
124
  | `scripts/build_business_catalog.py` | Yelp → catalog JSONL |
125
  | `scripts/embed_catalog.py` | Embed catalog (local sentence-transformers) |
app/recommendation_pipeline.py CHANGED
@@ -2,30 +2,19 @@ from __future__ import annotations
2
 
3
  import json
4
  import logging
5
- import math
6
  import os
7
  import re
8
- import threading
9
  import time
10
  from pathlib import Path
11
  from typing import Any
12
 
13
  import numpy as np
14
- #modules
15
  from app._paths import submission_root
16
  from app.gemini_client import gemini_generate_text, use_gemini
17
- from app.shared_models import (
18
- causal_lm_model_id_task_b,
19
- embedding_model_name_task_b,
20
- get_causal_lm,
21
- get_embedder,
22
- inference_lock,
23
- )
24
 
25
  logger = logging.getLogger(__name__)
26
 
27
- _recommend_inflight = threading.Lock()
28
-
29
 
30
  def _resolve_catalog_path(raw: str) -> Path:
31
  p = Path(raw)
@@ -83,7 +72,7 @@ class CatalogIndex:
83
  )
84
  nq = np.linalg.norm(q)
85
  if nq == 0:
86
- q = np.ones_like(q) / math.sqrt(len(q))
87
  else:
88
  q = q / nq
89
 
@@ -121,239 +110,74 @@ class CatalogIndex:
121
  ]
122
 
123
 
124
- def _task_b_rank_mode() -> str:
125
- """llm (default) or retrieval (no causal LM; sub-second, Swagger-safe)."""
126
- if os.environ.get("TASK_B_FAST_RANK", "").strip().lower() in ("1", "true", "yes"):
127
- return "retrieval"
128
- mode = os.environ.get("TASK_B_RANK_MODE", "llm").strip().lower()
129
- if mode in ("retrieval", "retrieval_only", "fast", "embed"):
130
- return "retrieval"
131
- return "llm"
132
-
133
-
134
- def retrieval_rank(candidates: list[dict[str, Any]], top_n: int) -> list[dict[str, Any]]:
135
- ordered = sorted(
136
- candidates,
137
- key=lambda c: float(c.get("retrieval_score", 0)),
138
- reverse=True,
139
- )[:top_n]
140
- out: list[dict[str, Any]] = []
141
- for i, c in enumerate(ordered, start=1):
142
- cats = (c.get("categories") or "similar venues").strip()
143
- if len(cats) > 55:
144
- cats = cats[:52] + "…"
145
- out.append(
146
- {
147
- "business_id": c["business_id"],
148
- "rank": i,
149
- "rationale": f"Good semantic fit for {cats}; aligns with persona signals.",
150
- }
151
- )
152
- return out
153
-
154
-
155
- def popularity_fallback(candidates: list[dict[str, Any]], top_n: int) -> list[dict[str, Any]]:
156
- ranked = sorted(
157
- candidates,
158
- key=lambda c: (float(c.get("review_count", 0)), float(c.get("stars", 0))),
159
- reverse=True,
160
- )
161
- out = []
162
- for i, c in enumerate(ranked[:top_n], start=1):
163
- out.append(
164
- {
165
- "business_id": c["business_id"],
166
- "rank": i,
167
- "rationale": "High activity and average rating on Yelp (popularity prior).",
168
- }
169
- )
170
- return out
171
-
172
-
173
- _RANK_SYSTEM = "Return valid JSON only. No markdown."
174
 
175
 
176
- def _build_rank_user_prompt(
 
177
  persona: str,
178
  chat_history: list[dict[str, str]],
179
  candidates: list[dict[str, Any]],
180
  top_n: int,
181
- ) -> str:
182
- hist_txt = ""
183
- if chat_history:
184
- lines = []
185
- for turn in chat_history[-6:]:
186
- role = turn.get("role", "user")
187
- content = turn.get("content", "")
188
- lines.append(f"{role}: {content}")
189
- hist_txt = "\n".join(lines)
190
-
191
- cand_payload = [
192
- {
193
- "id": c["business_id"],
194
- "name": (c.get("name", "") or "")[:48],
195
- "cat": (c.get("categories", "") or "")[:56],
196
- }
197
- for c in candidates
198
- ]
199
-
200
- return f"""Rank the best {top_n} businesses for this user (Nigerian English rationales, third person, under 14 words each).
201
 
202
  Persona:
203
  {persona.strip()[:1200]}
204
 
205
- Chat:
206
- {hist_txt or "(none)"}
207
 
208
  Candidates:
209
- {json.dumps(cand_payload, ensure_ascii=False)}
210
 
211
  Output ONLY a JSON array of {top_n} objects: {{"business_id":"<id>","rank":1,"rationale":"..."}} — distinct ids from candidates, rank 1 best."""
212
-
213
-
214
- def _normalize_ranked_output(
215
- raw: str,
216
- candidates: list[dict[str, Any]],
217
- top_n: int,
218
- ) -> list[dict[str, Any]]:
219
- data = _parse_json_array(raw)
220
- if not data:
221
- logger.warning("Rank parse failed; using popularity fallback on candidates.")
222
- return popularity_fallback(candidates, top_n)
223
- seen: set[str] = set()
224
- cleaned = []
225
- for item in data:
226
- if not isinstance(item, dict):
227
- continue
228
- bid = item.get("business_id")
229
- if not bid or bid in seen:
230
- continue
231
- seen.add(str(bid))
232
- cleaned.append(
233
- {
234
- "business_id": str(bid),
235
- "rank": int(item.get("rank", len(cleaned) + 1)),
236
- "rationale": str(item.get("rationale", "")).strip()
237
- or "Matched persona and retrieval signals.",
238
- }
239
- )
240
- if len(cleaned) >= top_n:
241
- break
242
- if len(cleaned) < min(top_n, len(candidates)):
243
- for c in candidates:
244
- if len(cleaned) >= top_n:
245
- break
246
- bid = c["business_id"]
247
- if bid in seen:
248
- continue
249
- seen.add(bid)
250
- cleaned.append(
251
- {
252
- "business_id": bid,
253
- "rank": len(cleaned) + 1,
254
- "rationale": "Added by retrieval order after partial model output.",
255
- }
256
- )
257
- cleaned.sort(key=lambda x: x["rank"])
258
- return cleaned[:top_n]
259
-
260
-
261
- def chat_rank_gemini(
262
- *,
263
- persona: str,
264
- chat_history: list[dict[str, str]],
265
- candidates: list[dict[str, Any]],
266
- top_n: int,
267
- ) -> list[dict[str, Any]]:
268
- user_prompt = _build_rank_user_prompt(persona, chat_history, candidates, top_n)
269
- temp = float(os.environ.get("TASK_B_TEMPERATURE", "0.2"))
270
- max_out = min(512, int(os.environ.get("TASK_B_MAX_OUTPUT_TOKENS", "256")))
271
  raw = gemini_generate_text(
272
- system_instruction=_RANK_SYSTEM,
273
  user_text=user_prompt,
274
- temperature=temp,
275
- max_output_tokens=max_out,
276
  )
277
  return _normalize_ranked_output(raw, candidates, top_n)
278
 
279
 
280
- def chat_rank_local_hf(
281
- *,
282
- persona: str,
283
- chat_history: list[dict[str, str]],
284
  candidates: list[dict[str, Any]],
285
  top_n: int,
286
- tokenizer: Any,
287
- model: Any,
288
- device: str,
289
  ) -> list[dict[str, Any]]:
290
  try:
291
- import torch # type: ignore[import-untyped]
292
- except ImportError as e:
293
- raise RuntimeError("Local reranking needs PyTorch (install sentence-transformers or torch).") from e
294
-
295
- user_prompt = _build_rank_user_prompt(persona, chat_history, candidates, top_n)
296
- messages = [
297
- {"role": "system", "content": _RANK_SYSTEM},
298
- {"role": "user", "content": user_prompt},
299
- ]
300
- prompt_txt = tokenizer.apply_chat_template(
301
- messages,
302
- tokenize=False,
303
- add_generation_prompt=True,
304
- )
305
- inputs = tokenizer(
306
- prompt_txt,
307
- return_tensors="pt",
308
- truncation=True,
309
- max_length=1536,
310
- ).to(device)
311
- if tokenizer.pad_token_id is None:
312
- tokenizer.pad_token_id = tokenizer.eos_token_id
313
-
314
- max_new_tokens = min(200, 24 + top_n * 32)
315
- with inference_lock(), torch.no_grad():
316
- out = model.generate(
317
- **inputs,
318
- max_new_tokens=max_new_tokens,
319
- do_sample=False,
320
- pad_token_id=tokenizer.pad_token_id,
321
- )
322
- gen_ids = out[0][inputs["input_ids"].shape[1] :]
323
- raw = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
324
- return _normalize_ranked_output(raw, candidates, top_n)
325
-
326
-
327
- def _parse_json_array(raw: str) -> list[Any]:
328
- raw = raw.strip()
329
- try:
330
- val = json.loads(raw)
331
- return val if isinstance(val, list) else []
332
- except json.JSONDecodeError:
333
- pass
334
- m = re.search(r"\[[\s\S]*\]", raw)
335
- if m:
336
- try:
337
- val = json.loads(m.group(0))
338
- return val if isinstance(val, list) else []
339
- except json.JSONDecodeError:
340
- pass
341
- return []
342
-
343
-
344
- def build_query_text(persona: str, chat_history: list[dict[str, str]]) -> str:
345
- parts = [persona.strip()]
346
- for turn in chat_history[-4:]:
347
- c = turn.get("content", "").strip()
348
- if c:
349
- parts.append(c)
350
- return "\n".join(parts)
351
 
352
 
353
  class RecommendationService:
354
  def __init__(self) -> None:
355
- self._local_llm_model_id = causal_lm_model_id_task_b()
356
- self._local_model_name = embedding_model_name_task_b()
357
  catalog_raw = os.environ.get(
358
  "TASK_B_EMBEDDED_CATALOG", "data/business_catalog_embedded.jsonl"
359
  )
@@ -361,29 +185,13 @@ class RecommendationService:
361
  self.index = CatalogIndex(self.catalog_path)
362
  self._loaded = False
363
 
364
- def _ensure_local_embedder(self) -> Any:
365
- return get_embedder(self._local_model_name)
366
-
367
- def _embed_persona_local(self, text: str) -> list[float]:
368
- model = self._ensure_local_embedder()
369
- t = text.replace("\n", " ")[:8000]
370
- vec = model.encode([t], convert_to_numpy=True, normalize_embeddings=False)[0]
371
- return vec.astype(float).tolist()
372
-
373
- def _ensure_local_rank_llm(self) -> tuple[Any, Any, str]:
374
- tok, mdl, dev = get_causal_lm(self._local_llm_model_id)
375
- return tok, mdl, str(dev)
376
-
377
  def ensure_catalog(self) -> None:
378
  if self._loaded:
379
  return
380
- logger.info("Loading Task B catalog from %s", self.catalog_path)
381
  self.index.load()
382
  self._loaded = True
383
- logger.info("Task B catalog ready (%d businesses)", len(self.index._rows))
384
-
385
- def warm(self) -> None:
386
- self.ensure_catalog()
387
 
388
  def recommend(
389
  self,
@@ -395,98 +203,22 @@ class RecommendationService:
395
  top_k_retrieval: int = 20,
396
  top_n_final: int = 5,
397
  ) -> dict[str, Any]:
398
- # Local HF rerank is slow on CPU; serialize. Gemini can run concurrently.
399
- if not use_gemini():
400
- if not _recommend_inflight.acquire(blocking=False):
401
- raise RuntimeError(
402
- "Another recommendation is already running; wait for it to finish before retrying."
403
- )
404
- try:
405
- return self._recommend_impl(
406
- persona,
407
- city=city,
408
- state=state,
409
- chat_history=chat_history,
410
- top_k_retrieval=top_k_retrieval,
411
- top_n_final=top_n_final,
412
- )
413
- finally:
414
- _recommend_inflight.release()
415
- return self._recommend_impl(
416
- persona,
417
- city=city,
418
- state=state,
419
- chat_history=chat_history,
420
- top_k_retrieval=top_k_retrieval,
421
- top_n_final=top_n_final,
422
- )
423
-
424
- def _recommend_impl(
425
- self,
426
- persona: str,
427
- *,
428
- city: str | None = None,
429
- state: str | None = None,
430
- chat_history: list[dict[str, str]] | None = None,
431
- top_k_retrieval: int = 20,
432
- top_n_final: int = 5,
433
- ) -> dict[str, Any]:
434
- t0 = time.perf_counter()
435
- chat_history = chat_history or []
436
  self.ensure_catalog()
437
- rank_mode = _task_b_rank_mode()
438
-
439
- llm_cap = int(os.environ.get("TASK_B_LLM_CANDIDATE_CAP", "6"))
440
- top_k_retrieval = max(top_k_retrieval, top_n_final)
441
-
442
  qtext = build_query_text(persona, chat_history)
443
- qemb = self._embed_persona_local(qtext)
444
- candidates = self.index.retrieve(qemb, top_k_retrieval, city, state)
445
- logger.info("Task B retrieved %d candidates in %.2fs", len(candidates), time.perf_counter() - t0)
446
-
447
- if not candidates:
448
- raise RuntimeError("No retrieval candidates — check catalog filters and embeddings.")
449
-
450
- rank_pool = candidates[: min(len(candidates), llm_cap)]
451
- logger.info(
452
- "Task B %s rank on %d of %d candidates (top_n=%d) …",
453
- rank_mode,
454
- len(rank_pool),
455
- len(candidates),
456
- top_n_final,
457
  )
458
- t1 = time.perf_counter()
459
- if rank_mode == "retrieval":
460
- ranked = retrieval_rank(rank_pool, top_n_final)
461
- rank_step = "retrieval_score_rank"
462
- elif use_gemini():
463
- ranked = chat_rank_gemini(
464
- persona=persona,
465
- chat_history=chat_history,
466
- candidates=rank_pool,
467
- top_n=top_n_final,
468
- )
469
- rank_step = "gemini_reason_and_rank"
470
- else:
471
- tok, mdl, dev = self._ensure_local_rank_llm()
472
- ranked = chat_rank_local_hf(
473
- persona=persona,
474
- chat_history=chat_history,
475
- candidates=rank_pool,
476
- top_n=top_n_final,
477
- tokenizer=tok,
478
- model=mdl,
479
- device=dev,
480
- )
481
- rank_step = "local_hf_llm_reason_and_rank"
482
- logger.info("Task B rank done in %.2fs (total %.2fs)", time.perf_counter() - t1, time.perf_counter() - t0)
483
-
484
  return {
485
  "task": "2_recommendation",
486
  "agent_steps": [
487
  "embedded_persona_context",
488
  f"vector_retrieval_top_{top_k_retrieval}",
489
- rank_step,
490
  ],
491
  "candidates_considered": len(candidates),
492
  "recommendations": ranked,
 
2
 
3
  import json
4
  import logging
 
5
  import os
6
  import re
 
7
  import time
8
  from pathlib import Path
9
  from typing import Any
10
 
11
  import numpy as np
12
+
13
  from app._paths import submission_root
14
  from app.gemini_client import gemini_generate_text, use_gemini
 
 
 
 
 
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
 
18
 
19
  def _resolve_catalog_path(raw: str) -> Path:
20
  p = Path(raw)
 
72
  )
73
  nq = np.linalg.norm(q)
74
  if nq == 0:
75
+ q = np.ones_like(q) / np.sqrt(len(q))
76
  else:
77
  q = q / nq
78
 
 
110
  ]
111
 
112
 
113
+ def build_query_text(persona: str, chat_history: list[dict[str, str]]) -> str:
114
+ parts = [persona.strip()]
115
+ for turn in chat_history[-4:]:
116
+ c = turn.get("content", "").strip()
117
+ if c:
118
+ parts.append(c)
119
+ return "\n".join(parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
+ def chat_rank_gemini(
123
+ *,
124
  persona: str,
125
  chat_history: list[dict[str, str]],
126
  candidates: list[dict[str, Any]],
127
  top_n: int,
128
+ ) -> list[dict[str, Any]]:
129
+ user_prompt = f"""Rank the best {top_n} businesses for this user with conversational rationales.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  Persona:
132
  {persona.strip()[:1200]}
133
 
134
+ Chat History:
135
+ {json.dumps(chat_history[-6:], ensure_ascii=False) if chat_history else "(none)"}
136
 
137
  Candidates:
138
+ {json.dumps(candidates, ensure_ascii=False)}
139
 
140
  Output ONLY a JSON array of {top_n} objects: {{"business_id":"<id>","rank":1,"rationale":"..."}} — distinct ids from candidates, rank 1 best."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  raw = gemini_generate_text(
142
+ system_instruction="Return valid JSON only.",
143
  user_text=user_prompt,
144
+ temperature=0.2,
145
+ max_output_tokens=512,
146
  )
147
  return _normalize_ranked_output(raw, candidates, top_n)
148
 
149
 
150
+ def _normalize_ranked_output(
151
+ raw: str,
 
 
152
  candidates: list[dict[str, Any]],
153
  top_n: int,
 
 
 
154
  ) -> list[dict[str, Any]]:
155
  try:
156
+ data = json.loads(raw)
157
+ if not isinstance(data, list):
158
+ raise ValueError("Invalid JSON format for ranked output.")
159
+ return [
160
+ {
161
+ "business_id": item["business_id"],
162
+ "rank": item["rank"],
163
+ "rationale": item["rationale"],
164
+ }
165
+ for item in data[:top_n]
166
+ ]
167
+ except (json.JSONDecodeError, KeyError, ValueError):
168
+ logger.warning("Failed to parse Gemini output; falling back to retrieval order.")
169
+ return [
170
+ {
171
+ "business_id": c["business_id"],
172
+ "rank": i + 1,
173
+ "rationale": "Fallback rationale due to parsing error.",
174
+ }
175
+ for i, c in enumerate(candidates[:top_n])
176
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  class RecommendationService:
180
  def __init__(self) -> None:
 
 
181
  catalog_raw = os.environ.get(
182
  "TASK_B_EMBEDDED_CATALOG", "data/business_catalog_embedded.jsonl"
183
  )
 
185
  self.index = CatalogIndex(self.catalog_path)
186
  self._loaded = False
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  def ensure_catalog(self) -> None:
189
  if self._loaded:
190
  return
191
+ logger.info("Loading catalog from %s", self.catalog_path)
192
  self.index.load()
193
  self._loaded = True
194
+ logger.info("Catalog loaded with %d businesses", len(self.index._rows))
 
 
 
195
 
196
  def recommend(
197
  self,
 
203
  top_k_retrieval: int = 20,
204
  top_n_final: int = 5,
205
  ) -> dict[str, Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  self.ensure_catalog()
207
+ chat_history = chat_history or []
 
 
 
 
208
  qtext = build_query_text(persona, chat_history)
209
+ candidates = self.index.retrieve(qtext, top_k_retrieval, city, state)
210
+ ranked = chat_rank_gemini(
211
+ persona=persona,
212
+ chat_history=chat_history,
213
+ candidates=candidates,
214
+ top_n=top_n_final,
 
 
 
 
 
 
 
 
215
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  return {
217
  "task": "2_recommendation",
218
  "agent_steps": [
219
  "embedded_persona_context",
220
  f"vector_retrieval_top_{top_k_retrieval}",
221
+ "gemini_reason_and_rank",
222
  ],
223
  "candidates_considered": len(candidates),
224
  "recommendations": ranked,
app/shared_models.py CHANGED
@@ -5,12 +5,9 @@ import os
5
  import threading
6
  from typing import Any
7
 
8
- from app.gemini_client import use_gemini
9
-
10
  logger = logging.getLogger(__name__)
11
 
12
  _embedders: dict[str, Any] = {}
13
- _llm_cache: dict[str, tuple[Any, Any, Any]] = {}
14
  _inference_lock = threading.Lock()
15
 
16
 
@@ -23,15 +20,6 @@ def local_embedding_model() -> str:
23
  )
24
 
25
 
26
- def local_llm_model() -> str:
27
- return (
28
- os.environ.get("LOCAL_LLM_MODEL", "").strip()
29
- or os.environ.get("TASK_B_LOCAL_LLM_MODEL", "").strip()
30
- or os.environ.get("TASK_A_LOCAL_LLM_MODEL", "").strip()
31
- or "Qwen/Qwen2.5-1.5B-Instruct"
32
- )
33
-
34
-
35
  def embedding_model_name_task_a() -> str:
36
  override = os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip()
37
  return override or local_embedding_model()
@@ -42,26 +30,11 @@ def embedding_model_name_task_b() -> str:
42
  return override or local_embedding_model()
43
 
44
 
45
- def causal_lm_model_id_task_a() -> str:
46
- override = os.environ.get("TASK_A_LOCAL_LLM_MODEL", "").strip()
47
- return override or local_llm_model()
48
-
49
-
50
- def causal_lm_model_id_task_b() -> str:
51
- override = os.environ.get("TASK_B_LOCAL_LLM_MODEL", "").strip()
52
- return override or local_llm_model()
53
-
54
-
55
  def unique_embedding_model_names() -> list[str]:
56
  names = {embedding_model_name_task_a(), embedding_model_name_task_b()}
57
  return sorted(names)
58
 
59
 
60
- def unique_llm_model_ids() -> list[str]:
61
- ids = {causal_lm_model_id_task_a(), causal_lm_model_id_task_b()}
62
- return sorted(ids)
63
-
64
-
65
  def get_embedder(model_name: str) -> Any:
66
  key = model_name.strip()
67
  if key not in _embedders:
@@ -78,44 +51,10 @@ def inference_lock() -> threading.Lock:
78
  return _inference_lock
79
 
80
 
81
- def get_causal_lm(model_id: str) -> tuple[Any, Any, Any]:
82
- mid = model_id.strip()
83
- if mid not in _llm_cache:
84
- try:
85
- import torch # type: ignore[import-untyped]
86
- from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore[import-untyped]
87
- except ImportError as e:
88
- raise RuntimeError("transformers and torch required") from e
89
- logger.info("Loading shared causal LM %s", mid)
90
- tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
91
- use_cuda = torch.cuda.is_available()
92
- device_obj = torch.device("cuda" if use_cuda else "cpu")
93
- dtype = torch.float16 if use_cuda else torch.float32
94
- mdl = AutoModelForCausalLM.from_pretrained(
95
- mid,
96
- torch_dtype=dtype,
97
- trust_remote_code=True,
98
- low_cpu_mem_usage=True,
99
- )
100
- mdl = mdl.to(device_obj)
101
- mdl.eval()
102
- _llm_cache[mid] = (tok, mdl, device_obj)
103
- return _llm_cache[mid]
104
-
105
-
106
  def warm_shared_weights() -> None:
107
  for name in unique_embedding_model_names():
108
  get_embedder(name)
109
- if use_gemini():
110
- logger.info(
111
- "Shared weights ready (%d embedder(s); generation via Gemini API)",
112
- len(_embedders),
113
- )
114
- return
115
- for mid in unique_llm_model_ids():
116
- get_causal_lm(mid)
117
  logger.info(
118
- "Shared weights ready (%d embedder(s), %d causal LM(s))",
119
  len(_embedders),
120
- len(_llm_cache),
121
  )
 
5
  import threading
6
  from typing import Any
7
 
 
 
8
  logger = logging.getLogger(__name__)
9
 
10
  _embedders: dict[str, Any] = {}
 
11
  _inference_lock = threading.Lock()
12
 
13
 
 
20
  )
21
 
22
 
 
 
 
 
 
 
 
 
 
23
  def embedding_model_name_task_a() -> str:
24
  override = os.environ.get("TASK_A_EMBEDDING_MODEL", "").strip()
25
  return override or local_embedding_model()
 
30
  return override or local_embedding_model()
31
 
32
 
 
 
 
 
 
 
 
 
 
 
33
  def unique_embedding_model_names() -> list[str]:
34
  names = {embedding_model_name_task_a(), embedding_model_name_task_b()}
35
  return sorted(names)
36
 
37
 
 
 
 
 
 
38
  def get_embedder(model_name: str) -> Any:
39
  key = model_name.strip()
40
  if key not in _embedders:
 
51
  return _inference_lock
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def warm_shared_weights() -> None:
55
  for name in unique_embedding_model_names():
56
  get_embedder(name)
 
 
 
 
 
 
 
 
57
  logger.info(
58
+ "Shared weights ready (%d embedder(s); generation via Gemini API)",
59
  len(_embedders),
 
60
  )
app/user_modeling.py CHANGED
@@ -10,11 +10,8 @@ from typing import Any
10
  from app._paths import submission_root
11
  from app.gemini_client import gemini_generate_chat, gemini_generate_text, use_gemini
12
  from app.shared_models import (
13
- causal_lm_model_id_task_a,
14
  embedding_model_name_task_a,
15
- get_causal_lm,
16
  get_embedder,
17
- inference_lock,
18
  )
19
  from app.task_a_rag import TaskAReviewRagIndex
20
  from app.user_modeling_prompt import build_prompt_parts_with_rag
@@ -45,14 +42,13 @@ def _resolve_path(raw: str) -> Path:
45
 
46
 
47
  def _task_a_gen_step() -> str:
48
- return "gemini_generate" if use_gemini() else "local_hf_causal_lm"
49
 
50
 
51
  class UserModelingService:
52
  def __init__(self) -> None:
53
  self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024"))
54
  self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35"))
55
- self._local_llm_model_id = causal_lm_model_id_task_a()
56
  self._embedding_model_name = embedding_model_name_task_a()
57
  rag_raw = os.environ.get(
58
  "TASK_A_REVIEWS_EMBEDDED",
@@ -74,9 +70,6 @@ class UserModelingService:
74
  if self._rag_path.is_file():
75
  self._rag().load()
76
 
77
- def _ensure_local_llm(self) -> tuple[Any, Any, Any]:
78
- return get_causal_lm(self._local_llm_model_id)
79
-
80
  def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]:
81
  if not self._rag_path.is_file():
82
  return []
@@ -86,14 +79,14 @@ class UserModelingService:
86
 
87
  def _generate(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str:
88
  inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
89
- if use_gemini():
90
- return gemini_generate_text(
91
- system_instruction=inst,
92
- user_text=user_body,
93
- temperature=self._temperature,
94
- max_output_tokens=min(int(self._max_tokens), 1024),
95
- )
96
- return self._generate_local(persona, product, examples)
97
 
98
  def _generate_fix(
99
  self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]
@@ -103,88 +96,19 @@ class UserModelingService:
103
  "Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n"
104
  "The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly."
105
  )
106
- if use_gemini():
107
- return gemini_generate_chat(
108
- [
109
- {"role": "system", "content": inst},
110
- {"role": "user", "content": user_body},
111
- {"role": "assistant", "content": prior_raw},
112
- {"role": "user", "content": fix_user},
113
- ],
114
- temperature=0.2,
115
- max_output_tokens=min(int(self._max_tokens), 1024),
116
- )
117
- return self._generate_local_fix(persona, product, prior_raw, examples)
118
-
119
- def _generate_local(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str:
120
- tok, mdl, device = self._ensure_local_llm()
121
- inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
122
- messages = [
123
- {"role": "system", "content": inst},
124
- {"role": "user", "content": user_body},
125
- ]
126
- prompt_txt = tok.apply_chat_template(
127
- messages,
128
- tokenize=False,
129
- add_generation_prompt=True,
130
- )
131
- try:
132
- import torch # type: ignore[import-untyped]
133
- except ImportError as e:
134
- raise RuntimeError("Task 1 needs torch.") from e
135
-
136
- inputs = tok(prompt_txt, return_tensors="pt").to(device)
137
- if tok.pad_token_id is None:
138
- tok.pad_token_id = tok.eos_token_id
139
-
140
- max_new = min(int(self._max_tokens), 768)
141
- with inference_lock(), torch.no_grad():
142
- out = mdl.generate(
143
- **inputs,
144
- max_new_tokens=max_new,
145
- do_sample=True,
146
- temperature=self._temperature,
147
- top_p=0.9,
148
- pad_token_id=tok.pad_token_id,
149
- )
150
- gen_ids = out[0][inputs["input_ids"].shape[1] :]
151
- return tok.decode(gen_ids, skip_special_tokens=True).strip()
152
-
153
- def _generate_local_fix(
154
- self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]
155
- ) -> str:
156
- tok, mdl, device = self._ensure_local_llm()
157
- inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
158
- fix_user = (
159
- "Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n"
160
- "The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly."
161
- )
162
- messages = [
163
- {"role": "system", "content": inst},
164
- {"role": "user", "content": user_body},
165
- {"role": "assistant", "content": prior_raw},
166
- {"role": "user", "content": fix_user},
167
- ]
168
- prompt_txt = tok.apply_chat_template(
169
- messages,
170
- tokenize=False,
171
- add_generation_prompt=True,
172
  )
173
- import torch # type: ignore[import-untyped]
174
-
175
- inputs = tok(prompt_txt, return_tensors="pt").to(device)
176
- if tok.pad_token_id is None:
177
- tok.pad_token_id = tok.eos_token_id
178
- max_new = min(int(self._max_tokens), 768)
179
- with inference_lock(), torch.no_grad():
180
- out = mdl.generate(
181
- **inputs,
182
- max_new_tokens=max_new,
183
- do_sample=False,
184
- pad_token_id=tok.pad_token_id,
185
- )
186
- gen_ids = out[0][inputs["input_ids"].shape[1] :]
187
- return tok.decode(gen_ids, skip_special_tokens=True).strip()
188
 
189
  def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]:
190
  t0 = time.perf_counter()
 
10
  from app._paths import submission_root
11
  from app.gemini_client import gemini_generate_chat, gemini_generate_text, use_gemini
12
  from app.shared_models import (
 
13
  embedding_model_name_task_a,
 
14
  get_embedder,
 
15
  )
16
  from app.task_a_rag import TaskAReviewRagIndex
17
  from app.user_modeling_prompt import build_prompt_parts_with_rag
 
42
 
43
 
44
  def _task_a_gen_step() -> str:
45
+ return "gemini_generate"
46
 
47
 
48
  class UserModelingService:
49
  def __init__(self) -> None:
50
  self._max_tokens = int(os.environ.get("TASK_A_MAX_TOKENS", "1024"))
51
  self._temperature = float(os.environ.get("TASK_A_TEMPERATURE", "0.35"))
 
52
  self._embedding_model_name = embedding_model_name_task_a()
53
  rag_raw = os.environ.get(
54
  "TASK_A_REVIEWS_EMBEDDED",
 
70
  if self._rag_path.is_file():
71
  self._rag().load()
72
 
 
 
 
73
  def _retrieve_examples(self, persona: str, product: str) -> list[dict[str, Any]]:
74
  if not self._rag_path.is_file():
75
  return []
 
79
 
80
  def _generate(self, persona: str, product: str, examples: list[dict[str, Any]]) -> str:
81
  inst, user_body = build_prompt_parts_with_rag(persona, product, examples)
82
+ if not use_gemini():
83
+ raise RuntimeError("Task 1 requires Gemini for generation.")
84
+ return gemini_generate_text(
85
+ system_instruction=inst,
86
+ user_text=user_body,
87
+ temperature=self._temperature,
88
+ max_output_tokens=min(int(self._max_tokens), 1024),
89
+ )
90
 
91
  def _generate_fix(
92
  self, persona: str, product: str, prior_raw: str, examples: list[dict[str, Any]]
 
96
  "Your answer must follow exactly:\nStars: <1-5>\nReview:\n<text>\n\n"
97
  "The Review must be first person (I/my/me), as the user who visited — not third person. Fix strictly."
98
  )
99
+ if not use_gemini():
100
+ raise RuntimeError("Task 1 requires Gemini for generation.")
101
+ return gemini_generate_chat(
102
+ [
103
+ {"role": "system", "content": inst},
104
+ {"role": "user", "content": user_body},
105
+ {"role": "assistant", "content": prior_raw},
106
+ {"role": "user", "content": fix_user},
107
+ ],
108
+ temperature=0.2,
109
+ max_output_tokens=min(int(self._max_tokens), 1024),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  )
111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  def generate(self, persona: str, product: str, *, include_raw: bool = False) -> dict[str, Any]:
114
  t0 = time.perf_counter()