bh4vay commited on
Commit
cbd648d
·
verified ·
1 Parent(s): 2a03292

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -7,30 +7,42 @@ from diffusers import StableDiffusionPipeline
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  st.write(f"Using device: {device}") # Debug message
9
 
10
- # Load text model (TinyLlama) with error handling
11
  @st.cache_resource
12
  def load_text_model():
13
  try:
14
  st.write("⏳ Loading text model...")
15
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
 
 
 
 
 
 
 
18
  st.write("✅ Text model loaded successfully!")
19
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
 
20
  except Exception as e:
21
  st.error(f"❌ Error loading text model: {e}")
22
  return None
23
 
24
  story_generator = load_text_model()
25
 
26
- # Load image model (Stable Diffusion) with error handling
27
  @st.cache_resource
28
  def load_image_model():
29
  try:
30
  st.write("⏳ Loading image model...")
31
  model_id = "runwayml/stable-diffusion-v1-5"
32
- model = StableDiffusionPipeline.from_pretrained(model_id).to(device)
33
- model.enable_attention_slicing() # Optimize memory usage
 
 
 
34
  st.write("✅ Image model loaded successfully!")
35
  return model
36
  except Exception as e:
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
  st.write(f"Using device: {device}") # Debug message
9
 
10
+ # Load text model (TinyLlama) with optimizations
11
  @st.cache_resource
12
  def load_text_model():
13
  try:
14
  st.write("⏳ Loading text model...")
15
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16
+
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ # Load model with FP16 or 8-bit quantization
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_name,
22
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32, # Reduce VRAM usage
23
+ low_cpu_mem_usage=True # Optimize memory
24
+ ).to(device)
25
+
26
  st.write("✅ Text model loaded successfully!")
27
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
28
+
29
  except Exception as e:
30
  st.error(f"❌ Error loading text model: {e}")
31
  return None
32
 
33
  story_generator = load_text_model()
34
 
35
+ # Load image model (Stable Diffusion) with optimizations
36
  @st.cache_resource
37
  def load_image_model():
38
  try:
39
  st.write("⏳ Loading image model...")
40
  model_id = "runwayml/stable-diffusion-v1-5"
41
+ model = StableDiffusionPipeline.from_pretrained(
42
+ model_id,
43
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32 # Reduce VRAM usage
44
+ ).to(device)
45
+ model.enable_attention_slicing() # Optimize GPU memory
46
  st.write("✅ Image model loaded successfully!")
47
  return model
48
  except Exception as e: