SallySims commited on
Commit
e80f308
·
verified ·
1 Parent(s): 38f3073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -16,20 +16,27 @@ st.set_page_config(page_title="AnthroBot", page_icon="🤖", layout="centered")
16
 
17
  # Load model & tokenizer
18
  @st.cache_resource
19
- def load_model():
20
- peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora")
21
- base_model = AutoModelForCausalLM.from_pretrained(
22
- peft_config.base_model_name_or_path,
23
- torch_dtype=torch.float16,
24
- device_map="auto"
25
- )
26
- model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora")
27
- model.eval()
28
-
29
- tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
30
- tokenizer.pad_token = tokenizer.eos_token
31
- return model, tokenizer
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  model, tokenizer = load_model()
34
 
35
  # Prediction function
@@ -101,5 +108,3 @@ with tab2:
101
 
102
  csv_output = df.to_csv(index=False).encode("utf-8")
103
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")
104
-
105
-
 
16
 
17
  # Load model & tokenizer
18
  @st.cache_resource
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ def load_model():
21
+ try:
22
+ peft_config = PeftConfig.from_pretrained("SallySims/AnthroBot_Model_Lora")
23
+ base_model = AutoModelForCausalLM.from_pretrained(
24
+ peft_config.base_model_name_or_path,
25
+ torch_dtype=torch.float16,
26
+ device_map="auto"
27
+ )
28
+ model = PeftModel.from_pretrained(base_model, "SallySims/AnthroBot_Model_Lora")
29
+ model.eval()
30
+
31
+ tokenizer = AutoTokenizer.from_pretrained(peft_config.base_model_name_or_path)
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ st.write("✅ Model and tokenizer loaded successfully.")
35
+ return model, tokenizer
36
+
37
+ except Exception as e:
38
+ st.error(f"Error loading model: {str(e)}")
39
+ raise e
40
  model, tokenizer = load_model()
41
 
42
  # Prediction function
 
108
 
109
  csv_output = df.to_csv(index=False).encode("utf-8")
110
  st.download_button("📤 Download Predictions", data=csv_output, file_name="predictions.csv")