migrate to qt models
Browse files- app.py +4 -4
- requirements.txt +2 -0
- utils.py +36 -3
app.py
CHANGED
|
@@ -22,8 +22,8 @@ AVAILABLE_MODELS = {
|
|
| 22 |
# Initialize the social graph manager
|
| 23 |
social_graph = SocialGraphManager("social_graph.json")
|
| 24 |
|
| 25 |
-
# Initialize the suggestion generator with Gemma 3
|
| 26 |
-
suggestion_generator = SuggestionGenerator("google/gemma-3-
|
| 27 |
|
| 28 |
# Test the model to make sure it's working
|
| 29 |
test_result = suggestion_generator.test_model()
|
|
@@ -153,7 +153,7 @@ def generate_suggestions(
|
|
| 153 |
user_input,
|
| 154 |
suggestion_type,
|
| 155 |
selected_topic=None,
|
| 156 |
-
model_name="google/gemma-3-
|
| 157 |
temperature=0.7,
|
| 158 |
mood=3,
|
| 159 |
progress=gr.Progress(),
|
|
@@ -462,7 +462,7 @@ with gr.Blocks(title="Will's AAC Communication Aid", css="custom.css") as demo:
|
|
| 462 |
with gr.Row():
|
| 463 |
model_dropdown = gr.Dropdown(
|
| 464 |
choices=list(AVAILABLE_MODELS.keys()),
|
| 465 |
-
value="google/gemma-3-
|
| 466 |
label="Language Model",
|
| 467 |
info="Select which AI model to use for generating responses",
|
| 468 |
)
|
|
|
|
| 22 |
# Initialize the social graph manager
|
| 23 |
social_graph = SocialGraphManager("social_graph.json")
|
| 24 |
|
| 25 |
+
# Initialize the suggestion generator with Gemma 3 1B (default - smaller model to save memory)
|
| 26 |
+
suggestion_generator = SuggestionGenerator("google/gemma-3-1b-it")
|
| 27 |
|
| 28 |
# Test the model to make sure it's working
|
| 29 |
test_result = suggestion_generator.test_model()
|
|
|
|
| 153 |
user_input,
|
| 154 |
suggestion_type,
|
| 155 |
selected_topic=None,
|
| 156 |
+
model_name="google/gemma-3-1b-it",
|
| 157 |
temperature=0.7,
|
| 158 |
mood=3,
|
| 159 |
progress=gr.Progress(),
|
|
|
|
| 462 |
with gr.Row():
|
| 463 |
model_dropdown = gr.Dropdown(
|
| 464 |
choices=list(AVAILABLE_MODELS.keys()),
|
| 465 |
+
value="google/gemma-3-1b-it",
|
| 466 |
label="Language Model",
|
| 467 |
info="Select which AI model to use for generating responses",
|
| 468 |
)
|
requirements.txt
CHANGED
|
@@ -4,3 +4,5 @@ sentence-transformers>=2.2.2
|
|
| 4 |
torch>=2.0.0
|
| 5 |
numpy>=1.24.0
|
| 6 |
openai-whisper>=20231117
|
|
|
|
|
|
|
|
|
| 4 |
torch>=2.0.0
|
| 5 |
numpy>=1.24.0
|
| 6 |
openai-whisper>=20231117
|
| 7 |
+
bitsandbytes>=0.41.0
|
| 8 |
+
accelerate>=0.21.0
|
utils.py
CHANGED
|
@@ -216,6 +216,8 @@ class SuggestionGenerator:
|
|
| 216 |
if is_gated_model:
|
| 217 |
# Try to get token from environment
|
| 218 |
import os
|
|
|
|
|
|
|
| 219 |
|
| 220 |
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
|
| 221 |
"HF_TOKEN"
|
|
@@ -231,14 +233,31 @@ class SuggestionGenerator:
|
|
| 231 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 232 |
|
| 233 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 235 |
model_name, token=token
|
| 236 |
)
|
|
|
|
|
|
|
| 237 |
model = AutoModelForCausalLM.from_pretrained(
|
| 238 |
-
model_name,
|
|
|
|
|
|
|
|
|
|
| 239 |
)
|
|
|
|
| 240 |
self.generator = pipeline(
|
| 241 |
-
"text-generation",
|
|
|
|
|
|
|
|
|
|
| 242 |
)
|
| 243 |
except Exception as e:
|
| 244 |
print(f"Error loading gated model with token: {e}")
|
|
@@ -248,7 +267,21 @@ class SuggestionGenerator:
|
|
| 248 |
print(
|
| 249 |
"Please visit the model page on Hugging Face Hub and accept the license."
|
| 250 |
)
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
else:
|
| 253 |
print("No Hugging Face token found in environment variables.")
|
| 254 |
print(
|
|
|
|
| 216 |
if is_gated_model:
|
| 217 |
# Try to get token from environment
|
| 218 |
import os
|
| 219 |
+
import torch
|
| 220 |
+
from transformers import BitsAndBytesConfig
|
| 221 |
|
| 222 |
token = os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get(
|
| 223 |
"HF_TOKEN"
|
|
|
|
| 233 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 234 |
|
| 235 |
try:
|
| 236 |
+
# Configure 4-bit quantization to save memory
|
| 237 |
+
quantization_config = BitsAndBytesConfig(
|
| 238 |
+
load_in_4bit=True,
|
| 239 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 240 |
+
bnb_4bit_quant_type="nf4",
|
| 241 |
+
bnb_4bit_use_double_quant=True,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 245 |
model_name, token=token
|
| 246 |
)
|
| 247 |
+
|
| 248 |
+
# Load model with quantization
|
| 249 |
model = AutoModelForCausalLM.from_pretrained(
|
| 250 |
+
model_name,
|
| 251 |
+
token=token,
|
| 252 |
+
quantization_config=quantization_config,
|
| 253 |
+
device_map="auto",
|
| 254 |
)
|
| 255 |
+
|
| 256 |
self.generator = pipeline(
|
| 257 |
+
"text-generation",
|
| 258 |
+
model=model,
|
| 259 |
+
tokenizer=tokenizer,
|
| 260 |
+
torch_dtype=torch.float16,
|
| 261 |
)
|
| 262 |
except Exception as e:
|
| 263 |
print(f"Error loading gated model with token: {e}")
|
|
|
|
| 267 |
print(
|
| 268 |
"Please visit the model page on Hugging Face Hub and accept the license."
|
| 269 |
)
|
| 270 |
+
# Try loading without quantization as fallback
|
| 271 |
+
try:
|
| 272 |
+
print("Trying to load model without quantization...")
|
| 273 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 274 |
+
model_name, token=token
|
| 275 |
+
)
|
| 276 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 277 |
+
model_name, token=token
|
| 278 |
+
)
|
| 279 |
+
self.generator = pipeline(
|
| 280 |
+
"text-generation", model=model, tokenizer=tokenizer
|
| 281 |
+
)
|
| 282 |
+
except Exception as e2:
|
| 283 |
+
print(f"Fallback loading also failed: {e2}")
|
| 284 |
+
raise e
|
| 285 |
else:
|
| 286 |
print("No Hugging Face token found in environment variables.")
|
| 287 |
print(
|