ehejin commited on
Commit
a40c8e4
·
1 Parent(s): 9ad9cc9

model_name error fix

Browse files
Files changed (1) hide show
  1. src/ui/screens_shared.py +12 -12
src/ui/screens_shared.py CHANGED
@@ -307,8 +307,9 @@ def screen_chat(s: dict, cfg: dict) -> None:
307
  messages.append({"role": t["role"], "content": t["content"]})
308
  messages.append({"role": "user", "content": user_msg})
309
 
 
 
310
  with st.spinner("AI is responding…"):
311
- item_cfg = {**cfg, "model_name": item["model_name"]}
312
  ai_reply = call_model(messages, item_cfg)
313
 
314
  now = time.time()
@@ -324,9 +325,9 @@ def screen_chat(s: dict, cfg: dict) -> None:
324
  "role": "assistant",
325
  "content": ai_reply,
326
  "timestamp": now,
327
- "model": cfg["model_name"],
328
  })
329
- conv["num_turns"] = num_turns + 1
330
  s["items"][idx]["conversation"] = conv
331
  st.rerun()
332
 
@@ -364,7 +365,7 @@ def screen_post_rating(s: dict, cfg: dict) -> None:
364
  render_single_card(item["product"])
365
  question = "How **likely** are you to purchase this product now?"
366
 
367
- choices = rating_choices(study_type)
368
  post_val = st.radio(question, choices, index=None, key=f"post_rating_{idx}")
369
 
370
  if st.button("Next →", type="primary", use_container_width=True):
@@ -436,7 +437,6 @@ def screen_reflection(s: dict, cfg: dict) -> None:
436
  s["meta"] = {
437
  "submission_time": end_time,
438
  "duration_seconds": round(end_time - s.get("start_time", end_time), 1),
439
- "model": cfg["model_name"],
440
  "study_type": cfg["study_type"],
441
  }
442
  with st.spinner("Saving your responses…"):
@@ -473,19 +473,19 @@ def screen_done(s: dict, cfg: dict) -> None:
473
 
474
  if study_type == "preference":
475
  rows.append({
476
- "#": i + 1,
477
- "Category": cat,
478
- "Product A": (item.get("product_a", {}).get("title", "") or "")[:45] + "…",
479
- "Product B": (item.get("product_b", {}).get("title", "") or "")[:45] + "…",
480
  "Pre-rating": labels.get(pre, str(pre) if pre is not None else "—"),
481
  "Post-rating": labels.get(post, str(post) if post is not None else "—"),
482
  "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "—",
483
  })
484
  else:
485
  rows.append({
486
- "#": i + 1,
487
- "Category": cat,
488
- "Product": (item.get("product", {}).get("title", "") or "")[:65] + "…",
489
  "Pre-rating": labels.get(pre, str(pre) if pre is not None else "—"),
490
  "Post-rating": labels.get(post, str(post) if post is not None else "—"),
491
  "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "—",
 
307
  messages.append({"role": t["role"], "content": t["content"]})
308
  messages.append({"role": "user", "content": user_msg})
309
 
310
+ # Use per-item model name
311
+ item_cfg = {**cfg, "model_name": item.get("model_name", "")}
312
  with st.spinner("AI is responding…"):
 
313
  ai_reply = call_model(messages, item_cfg)
314
 
315
  now = time.time()
 
325
  "role": "assistant",
326
  "content": ai_reply,
327
  "timestamp": now,
328
+ "model": item.get("model_name", ""), # ← per-item, not cfg
329
  })
330
+ conv["num_turns"] = num_turns + 1
331
  s["items"][idx]["conversation"] = conv
332
  st.rerun()
333
 
 
365
  render_single_card(item["product"])
366
  question = "How **likely** are you to purchase this product now?"
367
 
368
+ choices = rating_choices(study_type)
369
  post_val = st.radio(question, choices, index=None, key=f"post_rating_{idx}")
370
 
371
  if st.button("Next →", type="primary", use_container_width=True):
 
437
  s["meta"] = {
438
  "submission_time": end_time,
439
  "duration_seconds": round(end_time - s.get("start_time", end_time), 1),
 
440
  "study_type": cfg["study_type"],
441
  }
442
  with st.spinner("Saving your responses…"):
 
473
 
474
  if study_type == "preference":
475
  rows.append({
476
+ "#": i + 1,
477
+ "Category": cat,
478
+ "Product A": (item.get("product_a", {}).get("title", "") or "")[:45] + "…",
479
+ "Product B": (item.get("product_b", {}).get("title", "") or "")[:45] + "…",
480
  "Pre-rating": labels.get(pre, str(pre) if pre is not None else "—"),
481
  "Post-rating": labels.get(post, str(post) if post is not None else "—"),
482
  "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "—",
483
  })
484
  else:
485
  rows.append({
486
+ "#": i + 1,
487
+ "Category": cat,
488
+ "Product": (item.get("product", {}).get("title", "") or "")[:65] + "…",
489
  "Pre-rating": labels.get(pre, str(pre) if pre is not None else "—"),
490
  "Post-rating": labels.get(post, str(post) if post is not None else "—"),
491
  "Shift": f"{arrow} {delta:+d}" if isinstance(delta, int) else "—",