Springboardmen commited on
Commit
9ffd02f
·
verified ·
1 Parent(s): ad523fb

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +8 -6
src/streamlit_app.py CHANGED
@@ -7,8 +7,8 @@ st.set_page_config(page_title="FitPlan AI", layout="centered")
7
  # LOAD MODEL
8
  @st.cache_resource
9
  def load_model():
10
- tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
11
- model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
12
  return tokenizer, model
13
 
14
  tokenizer, model = load_model()
@@ -182,10 +182,12 @@ Equipment: {equipment_list}
182
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
183
 
184
  outputs = model.generate(
185
- **inputs,
186
- max_new_tokens=300,
187
- temperature=0.7,
188
- do_sample=True
 
 
189
  )
190
 
191
  result = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
 
7
  # LOAD MODEL
8
  @st.cache_resource
9
  def load_model():
10
+ tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
11
+ model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
12
  return tokenizer, model
13
 
14
  tokenizer, model = load_model()
 
182
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
183
 
184
  outputs = model.generate(
185
+ inputs,
186
+ max_new_tokens=600,
187
+ temperature=0.3,
188
+ do_sample=False,
189
+ top_p=0.9,
190
+ repetition_penalty=1.2
191
  )
192
 
193
  result = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()