rufatronics commited on
Commit
72c6b5a
·
verified ·
1 Parent(s): d74a673

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -3,34 +3,31 @@ from transformers import AutoModelForImageTextToText, AutoProcessor
3
  import torch
4
  import os
5
 
6
- # 1. Setup Model & Token
7
  model_id = "google/gemma-3n-E2B-it"
8
  hf_token = os.getenv("HF_TOKEN")
9
  device = "cpu"
10
 
11
- print("Loading Gemma 3n (10GB)... This takes a few minutes.")
12
 
13
- # We add low_cpu_mem_usage=True to prevent crashing on load
 
14
  processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
15
  model = AutoModelForImageTextToText.from_pretrained(
16
  model_id,
17
  token=hf_token,
18
- torch_dtype=torch.float32,
19
- low_cpu_mem_usage=True,
20
  device_map="auto"
21
  )
22
 
23
  def chat_function(message, history):
24
- # Prepare history for the model
25
  msgs = []
26
  for user_msg, assistant_msg in history:
27
  if user_msg: msgs.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
28
  if assistant_msg: msgs.append({"role": "model", "content": [{"type": "text", "text": assistant_msg}]})
29
 
30
- # Add new message
31
  msgs.append({"role": "user", "content": [{"type": "text", "text": message}]})
32
 
33
- # Apply template
34
  inputs = processor.apply_chat_template(
35
  msgs,
36
  add_generation_prompt=True,
@@ -38,8 +35,8 @@ def chat_function(message, history):
38
  return_tensors="pt"
39
  ).to(device)
40
 
41
- # Generate
42
- with torch.no_grad(): # Saves memory during generation
43
  outputs = model.generate(
44
  **inputs,
45
  max_new_tokens=400,
@@ -50,12 +47,7 @@ def chat_function(message, history):
50
  response = processor.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
51
  return response
52
 
53
- # 5. The Interface
54
- demo = gr.ChatInterface(
55
- fn=chat_function,
56
- title="Gemma 3n E2B (Fixed)",
57
- description="Now with 'timm' installed and optimized for 16GB RAM!",
58
- )
59
 
60
  if __name__ == "__main__":
61
  demo.launch()
 
3
  import torch
4
  import os
5
 
 
6
  model_id = "google/gemma-3n-E2B-it"
7
  hf_token = os.getenv("HF_TOKEN")
8
  device = "cpu"
9
 
10
+ print("Loading Gemma 3n with Memory Optimizations...")
11
 
12
+ # 1. We use bfloat16 to cut RAM usage by 50%
13
+ # 2. low_cpu_mem_usage prevents the 'double loading' crash
14
  processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
15
  model = AutoModelForImageTextToText.from_pretrained(
16
  model_id,
17
  token=hf_token,
18
+ torch_dtype=torch.bfloat16, # KEY FIX: Half-precision for CPU
19
+ low_cpu_mem_usage=True, # KEY FIX: Don't use double RAM on load
20
  device_map="auto"
21
  )
22
 
23
  def chat_function(message, history):
 
24
  msgs = []
25
  for user_msg, assistant_msg in history:
26
  if user_msg: msgs.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
27
  if assistant_msg: msgs.append({"role": "model", "content": [{"type": "text", "text": assistant_msg}]})
28
 
 
29
  msgs.append({"role": "user", "content": [{"type": "text", "text": message}]})
30
 
 
31
  inputs = processor.apply_chat_template(
32
  msgs,
33
  add_generation_prompt=True,
 
35
  return_tensors="pt"
36
  ).to(device)
37
 
38
+ # Note: Inference on CPU with bfloat16 is much safer for RAM
39
+ with torch.no_grad():
40
  outputs = model.generate(
41
  **inputs,
42
  max_new_tokens=400,
 
47
  response = processor.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
48
  return response
49
 
50
+ demo = gr.ChatInterface(fn=chat_function, title="Gemma 3n E2B (RAM Optimized)")
 
 
 
 
 
51
 
52
  if __name__ == "__main__":
53
  demo.launch()