Monimoy commited on
Commit
6fceecb
·
verified ·
1 Parent(s): 64fd6a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # app.py
 
2
  import os
3
  import gradio as gr
4
  import torch
@@ -35,7 +36,8 @@ image_model_name = 'resnet50'
35
  image_embed_dim = 512
36
  siglip_pretrained_path = "image_encoder.pth" # Path to your pretrained SigLIP model
37
 
38
- device = torch.device("cpu") # Force CPU
 
39
  print(f"Using device: {device}")
40
 
41
  # Load Tokenizer (using a compatible tokenizer)
@@ -71,9 +73,10 @@ image_encoder.eval() # Set to evaluation mode
71
  #)
72
 
73
  base_model_name="microsoft/Phi-3-mini-4k-instruct"
74
- device = "cuda"
75
 
76
- base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32, device_map={"": device})
 
77
 
78
 
79
  # Load and merge
@@ -81,6 +84,8 @@ model = PeftModel.from_pretrained(base_model, peft_model_path, offload_dir='./of
81
  model = model.merge_and_unload()
82
  print("phi-3 model loaded sucessfully")
83
  # 3. Inference Function
 
 
84
  def predict(image, question):
85
  """
86
  Takes an image and a question as input and returns an answer.
 
1
  # app.py
2
+ import spaces
3
  import os
4
  import gradio as gr
5
  import torch
 
36
  image_embed_dim = 512
37
  siglip_pretrained_path = "image_encoder.pth" # Path to your pretrained SigLIP model
38
 
39
+ #device = torch.device("cpu") # Force CPU
40
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
  print(f"Using device: {device}")
42
 
43
  # Load Tokenizer (using a compatible tokenizer)
 
73
  #)
74
 
75
  base_model_name="microsoft/Phi-3-mini-4k-instruct"
76
+ #device = "cuda"
77
 
78
+ #base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32, device_map={"": device})
79
+ base_model = AutoModelForCausalLM.from_pretrained(base_model_name, torch_dtype=torch.float32, device_map="auto")
80
 
81
 
82
  # Load and merge
 
84
  model = model.merge_and_unload()
85
  print("phi-3 model loaded sucessfully")
86
  # 3. Inference Function
87
+
88
+ @spaces.GPU
89
  def predict(image, question):
90
  """
91
  Takes an image and a question as input and returns an answer.