sourize commited on
Commit
a6691ab
Β·
verified Β·
1 Parent(s): 0d312d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -1,19 +1,16 @@
1
  import os
2
  import streamlit as st
 
3
  from transformers import (
4
- pipeline, AutoTokenizer, AutoModelForCausalLM
 
 
 
5
  )
6
  from peft import PeftModel
7
  from supabase import create_client
8
  from sentence_transformers import SentenceTransformer
9
 
10
- # Try to import bitsandbytes for 4-bit; fall back if missing
11
- try:
12
- from transformers import BitsAndBytesConfig
13
- BNB_AVAILABLE = True
14
- except ImportError:
15
- BNB_AVAILABLE = False
16
-
17
  # ── Supabase setup ─────────────────────────────────────────────────────────
18
  SUPA_URL = os.getenv("SUPABASE_URL")
19
  SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
@@ -28,14 +25,14 @@ embedder = get_embedder()
28
 
29
  @st.cache_data(show_spinner=False)
30
  def fetch_mems(query, k=3):
31
- vec = embedder.encode(query).astype('float32').tolist()
32
  return supabase.rpc(
33
- "match_memories",
34
  {"query_embedding": vec, "match_count": k}
35
  ).execute().data
36
 
37
  def add_mem(speaker, text):
38
- vec = embedder.encode(text).astype('float32').tolist()
39
  supabase.table("memories").insert({
40
  "speaker": speaker, "text": text, "embedding": vec
41
  }).execute()
@@ -44,6 +41,7 @@ def add_mem(speaker, text):
44
  @st.cache_resource(show_spinner=False)
45
  def load_generator():
46
  REPO = "sourize/phi2-memory-lora"
 
47
  # 1) Tokenizer
48
  tokenizer = AutoTokenizer.from_pretrained(
49
  REPO, trust_remote_code=True, padding_side="left"
@@ -51,8 +49,9 @@ def load_generator():
51
  if tokenizer.pad_token_id is None:
52
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
53
 
54
- # 2) Base model load (with or without 4-bit)
55
- if BNB_AVAILABLE:
 
56
  bnb_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
  bnb_4bit_quant_type="nf4",
@@ -66,10 +65,12 @@ def load_generator():
66
  device_map="auto"
67
  )
68
  else:
 
 
69
  base = AutoModelForCausalLM.from_pretrained(
70
  "microsoft/phi-2",
71
  trust_remote_code=True,
72
- torch_dtype="auto",
73
  device_map="auto"
74
  )
75
 
 
1
  import os
2
  import streamlit as st
3
+ import torch
4
  from transformers import (
5
+ pipeline,
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ BitsAndBytesConfig,
9
  )
10
  from peft import PeftModel
11
  from supabase import create_client
12
  from sentence_transformers import SentenceTransformer
13
 
 
 
 
 
 
 
 
14
  # ── Supabase setup ─────────────────────────────────────────────────────────
15
  SUPA_URL = os.getenv("SUPABASE_URL")
16
  SUPA_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY")
 
25
 
26
  @st.cache_data(show_spinner=False)
27
  def fetch_mems(query, k=3):
28
+ vec = embedder.encode(query).astype("float32").tolist()
29
  return supabase.rpc(
30
+ "match_memories",
31
  {"query_embedding": vec, "match_count": k}
32
  ).execute().data
33
 
34
  def add_mem(speaker, text):
35
+ vec = embedder.encode(text).astype("float32").tolist()
36
  supabase.table("memories").insert({
37
  "speaker": speaker, "text": text, "embedding": vec
38
  }).execute()
 
41
  @st.cache_resource(show_spinner=False)
42
  def load_generator():
43
  REPO = "sourize/phi2-memory-lora"
44
+
45
  # 1) Tokenizer
46
  tokenizer = AutoTokenizer.from_pretrained(
47
  REPO, trust_remote_code=True, padding_side="left"
 
49
  if tokenizer.pad_token_id is None:
50
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
51
 
52
+ # 2) Decide quantization vs. fp16/fp32
53
+ use_4bit = torch.cuda.is_available()
54
+ if use_4bit:
55
  bnb_config = BitsAndBytesConfig(
56
  load_in_4bit=True,
57
  bnb_4bit_quant_type="nf4",
 
65
  device_map="auto"
66
  )
67
  else:
68
+ # CPU or no CUDA: use fp16 if available, else fp32
69
+ dtype = torch.float16 if torch.cuda.is_available() or torch.cuda.device_count()>0 else torch.float32
70
  base = AutoModelForCausalLM.from_pretrained(
71
  "microsoft/phi-2",
72
  trust_remote_code=True,
73
+ torch_dtype=dtype,
74
  device_map="auto"
75
  )
76