chris1nexus commited on
Commit
4f37062
·
1 Parent(s): 1e6b325

Update default category

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +11 -1
src/streamlit_app.py CHANGED
@@ -51,6 +51,7 @@ class LLMadapter(BaseAdapter):
51
  #assert model_name in model_list, f'{model_name} not found for provider: {provider}\nAvailable models:\n{model_list}'
52
  self.adapter = LLMadapter.get_provider_class(provider)(model_name)
53
  self.system = system
 
54
  def generate(self, prompt, image):
55
  out = self.adapter.generate(prompt=prompt, image=image, system=self.system)
56
  return out
@@ -267,9 +268,18 @@ with st.sidebar:
267
 
268
  target_mode = st.selectbox("Target category mode", ["Pick specific", "Random each time"], index=0)
269
  if target_mode == "Pick specific":
 
 
 
 
 
 
 
 
270
  target_category = st.selectbox(
271
  "Target category",
272
- st.session_state.categories if st.session_state.categories else ["(load TSV first)"]
 
273
  )
274
  chosen_target = target_category if st.session_state.categories else None
275
  else:
 
51
  #assert model_name in model_list, f'{model_name} not found for provider: {provider}\nAvailable models:\n{model_list}'
52
  self.adapter = LLMadapter.get_provider_class(provider)(model_name)
53
  self.system = system
54
+
55
  def generate(self, prompt, image):
56
  out = self.adapter.generate(prompt=prompt, image=image, system=self.system)
57
  return out
 
268
 
269
  target_mode = st.selectbox("Target category mode", ["Pick specific", "Random each time"], index=0)
270
  if target_mode == "Pick specific":
271
+
272
+ DEFAULT_CAT = "bus" # normalized label
273
+
274
+ if cats and DEFAULT_CAT in cats:
275
+ default_idx = cats.index(DEFAULT_CAT)
276
+ else:
277
+ default_idx = 0 # fallback
278
+
279
  target_category = st.selectbox(
280
  "Target category",
281
+ cats,
282
+ index=default_idx,
283
  )
284
  chosen_target = target_category if st.session_state.categories else None
285
  else: