Maira-ghaffar commited on
Commit
ba4fc63
·
verified ·
1 Parent(s): 91b51ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -1
app.py CHANGED
@@ -1,9 +1,28 @@
1
  import streamlit as st
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
  st.title("📚 AI Adaptive Learning (Local LLaMA)")
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Load model & tokenizer
8
  MODEL_ID = "TheBloke/vicuna-7B-1.1-HF" # smaller public LLaMA-like model
9
  @st.cache_resource # caches model to avoid reload on every run
@@ -26,4 +45,5 @@ if st.button("Submit") and user_input:
26
  outputs = model.generate(**inputs, max_new_tokens=256)
27
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
28
  st.subheader("AI Answer:")
29
- st.write(answer)
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
4
+ from huggingface_hub import loginimport streamlit as st
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import torch
7
 
8
  st.title("📚 AI Adaptive Learning (Local LLaMA)")
9
 
10
+ # Hugging Face authentication for private models
11
+ HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
12
+ login(token=HF_API_TOKEN)
13
+ MODEL_ID = "meta-llama/Llama-2-7b-chat-hf" # or your private model
14
+
15
+ @st.cache_resource
16
+ def load_model():
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_API_TOKEN)
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ MODEL_ID,
20
+ use_auth_token=HF_API_TOKEN,
21
+ device_map="auto", # uses GPU if available, CPU otherwise
22
+ )
23
+ return tokenizer, model
24
+
25
+ tokenizer, model = load_model()
26
  # Load model & tokenizer
27
  MODEL_ID = "TheBloke/vicuna-7B-1.1-HF" # smaller public LLaMA-like model
28
  @st.cache_resource # caches model to avoid reload on every run
 
45
  outputs = model.generate(**inputs, max_new_tokens=256)
46
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
47
  st.subheader("AI Answer:")
48
+ st.write(answer)
49
+ user_input = st.text_input("Ask a question:")