bditto commited on
Commit
e3cea63
·
verified ·
1 Parent(s): c31d961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -40
app.py CHANGED
@@ -1,43 +1,21 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import (
4
- AutoModelForCausalLM,
5
- AutoTokenizer,
6
- TextIteratorStreamer,
7
- pipeline,
8
- BitsAndBytesConfig
9
- )
10
  from threading import Thread
11
  import random
12
 
13
- # Configuration 🛠
14
  model_name = "HuggingFaceH4/zephyr-7b-beta"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Quantization setup
18
- quantization_config = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_quant_type="nf4",
21
- bnb_4bit_compute_dtype=torch.float16,
22
- bnb_4bit_use_double_quant=True,
23
  )
24
 
25
- # Model loading with fallback
26
- try:
27
- model = AutoModelForCausalLM.from_pretrained(
28
- model_name,
29
- quantization_config=quantization_config if device == "cuda" else None,
30
- device_map="auto",
31
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
32
- )
33
- except Exception as e:
34
- print(f"Error loading model with GPU: {e}")
35
- model = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
- device_map="cpu",
38
- torch_dtype=torch.float32
39
- )
40
-
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
42
 
43
  # Safety tools 🛡️
@@ -47,11 +25,7 @@ SAFE_IDEAS = [
47
  "Code a game about recycling ♻️",
48
  "Plan an AI tool for school safety 🚸"
49
  ]
50
- safety_checker = pipeline(
51
- "text-classification",
52
- model="unitary/toxic-bert",
53
- device=0 if device == "cuda" else -1
54
- )
55
 
56
  def is_safe(text):
57
  text = text.lower()
@@ -66,7 +40,7 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
66
 
67
  messages = [{"role": "system", "content": system_message}]
68
 
69
- for user_msg, bot_msg in history[-5:]:
70
  if user_msg:
71
  messages.append({"role": "user", "content": user_msg})
72
  if bot_msg:
@@ -82,7 +56,7 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
82
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
83
  generation_kwargs = {
84
  "inputs": inputs,
85
- "max_new_tokens": max_tokens,
86
  "temperature": temperature,
87
  "top_p": top_p,
88
  "streamer": streamer
@@ -102,9 +76,9 @@ with gr.Blocks() as demo:
102
  respond,
103
  additional_inputs=[
104
  gr.Textbox("You help students create ethical AI projects.", label="Guidelines"),
105
- gr.Slider(128, 1024, value=512, label="Max Response Length"),
106
- gr.Slider(0.1, 1.0, value=0.3, label="Creativity Level"),
107
- gr.Slider(0.7, 1.0, value=0.85, label="Focus Level")
108
  ],
109
  examples=[
110
  ["How to build a robot that plants trees?"],
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
 
 
 
 
 
 
4
  from threading import Thread
5
  import random
6
 
7
+ # Use CPU-friendly configuration 🖥
8
  model_name = "HuggingFaceH4/zephyr-7b-beta"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Load model with CPU optimization
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ device_map="auto",
15
+ torch_dtype=torch.float32,
16
+ low_cpu_mem_usage=True
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
  # Safety tools 🛡️
 
25
  "Code a game about recycling ♻️",
26
  "Plan an AI tool for school safety 🚸"
27
  ]
28
+ safety_checker = pipeline("text-classification", model="unitary/toxic-bert")
 
 
 
 
29
 
30
  def is_safe(text):
31
  text = text.lower()
 
40
 
41
  messages = [{"role": "system", "content": system_message}]
42
 
43
+ for user_msg, bot_msg in history[-3:]: # Reduce history length for CPU
44
  if user_msg:
45
  messages.append({"role": "user", "content": user_msg})
46
  if bot_msg:
 
56
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
57
  generation_kwargs = {
58
  "inputs": inputs,
59
+ "max_new_tokens": min(max_tokens, 256), # Limit tokens for CPU
60
  "temperature": temperature,
61
  "top_p": top_p,
62
  "streamer": streamer
 
76
  respond,
77
  additional_inputs=[
78
  gr.Textbox("You help students create ethical AI projects.", label="Guidelines"),
79
+ gr.Slider(64, 512, value=256, label="Max Response Length"),
80
+ gr.Slider(0.1, 1.0, value=0.5, label="Creativity Level"),
81
+ gr.Slider(0.7, 1.0, value=0.9, label="Focus Level")
82
  ],
83
  examples=[
84
  ["How to build a robot that plants trees?"],