NotRev commited on
Commit
45fc6e6
·
verified ·
1 Parent(s): da762f1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +13 -7
src/streamlit_app.py CHANGED
@@ -1,26 +1,32 @@
1
  import json, re, ast, streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
 
4
 
5
- # SWITCHED MODEL: From Mistral-7B to the much smaller Gemma-2B-Instruct
6
  model_id = "google/gemma-2b-it"
7
 
8
- tok = AutoTokenizer.from_pretrained(model_id)
 
9
 
10
- # Simplified Model Loading: Removed BitsAndBytesConfig
11
- # This smaller model might load cleanly without 4-bit quantization, resolving the dependency issues.
 
12
  try:
 
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_id,
15
  torch_dtype=torch.bfloat16,
16
- device_map="auto"
 
17
  )
18
  except Exception:
19
- # Fallback to float16 if bfloat16 causes issues
20
  model = AutoModelForCausalLM.from_pretrained(
21
  model_id,
22
  torch_dtype=torch.float16,
23
- device_map="auto"
 
24
  )
25
 
26
  gen = pipeline("text-generation", model=model, tokenizer=tok,
 
1
  import json, re, ast, streamlit as st
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
  import torch
4
+ import os # Necessary to read the HF_TOKEN from environment variables
5
 
6
+ # Model ID for the small, structured Gemma model
7
  model_id = "google/gemma-2b-it"
8
 
9
+ # Get the Hugging Face Token from the Space Secrets
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
12
+ tok = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
13
+
14
+ # Simplified Model Loading: No quantization needed due to smaller size
15
  try:
16
+ # Attempt to load using bfloat16 for efficiency
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_id,
19
  torch_dtype=torch.bfloat16,
20
+ device_map="auto",
21
+ token=HF_TOKEN
22
  )
23
  except Exception:
24
+ # Fallback to float16 if bfloat16 is not supported
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
27
  torch_dtype=torch.float16,
28
+ device_map="auto",
29
+ token=HF_TOKEN
30
  )
31
 
32
  gen = pipeline("text-generation", model=model, tokenizer=tok,