halfacupoftea commited on
Commit
2c5464d
Β·
1 Parent(s): 3314397

Update spaces to use ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +52 -21
  2. requirements.txt +9 -7
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
- from dotenv import load_dotenv
3
  import gradio as gr
4
- import torch
5
- from PIL import Image
6
  from transformers import pipeline
 
 
 
7
 
8
  load_dotenv()
9
  hf_token = os.getenv("HF_GEMMA_TOKEN")
@@ -11,16 +11,18 @@ hf_token = os.getenv("HF_GEMMA_TOKEN")
11
  pipe = pipeline(
12
  "image-text-to-text",
13
  model="google/gemma-3-4b-it",
14
- device_map="cpu",
15
- torch_dtype=torch.float32,
16
- token=hf_token
17
  )
18
 
19
- def analyze_image(image):
20
- image = image.convert('RGB')
 
 
 
21
 
22
  # Define the prompt
23
- system_prompt = "You are a helpful assistant."
24
  user_prompt = '''Analyze this food image and provide detailed nutritional information in JSON format.
25
  Identify the specific vegetarian food items, estimate portion sizes, and provide nutritional breakdown.
26
  This app focuses on vegetarian foods only, so analyze from that perspective.
@@ -45,11 +47,11 @@ def analyze_image(image):
45
  }
46
 
47
  Return ONLY the JSON without any explanations or markdown formatting.'''
48
-
49
  messages = [
50
  {
51
  "role": "system",
52
- "content": [{"type": "text", "text": system_prompt}]
53
  },
54
  {
55
  "role": "user",
@@ -60,19 +62,48 @@ def analyze_image(image):
60
  }
61
  ]
62
 
63
- output = pipe(text=messages, max_new_tokens=500, return_full_text=False)
 
 
 
 
 
 
 
64
 
65
- return output[0]["generated_text"]
66
 
 
 
 
67
 
68
- # Gradio interface
69
- demo = gr.Interface(
70
- fn=analyze_image,
71
- inputs=gr.Image(type="pil"),
72
- outputs=gr.Textbox(),
73
- title="Gemma Powered Calorie Tracker",
74
- description="Upload an image of your food to get detailed nutritional information."
75
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  if __name__ == "__main__":
78
  demo.launch()
 
1
  import os
 
2
  import gradio as gr
 
 
3
  from transformers import pipeline
4
+ import torch
5
+ from dotenv import load_dotenv
6
+ import spaces
7
 
8
  load_dotenv()
9
  hf_token = os.getenv("HF_GEMMA_TOKEN")
 
11
  pipe = pipeline(
12
  "image-text-to-text",
13
  model="google/gemma-3-4b-it",
14
+ token=hf_token,
15
+ device="cuda",
16
+ torch_dtype=torch.bfloat16,
17
  )
18
 
19
+ @spaces.GPU()
20
+ def get_response(chat_history, image):
21
+ if image is None:
22
+ chat_history.append(("Please upload an image (required)", ""))
23
+ return chat_history
24
 
25
  # Define the prompt
 
26
  user_prompt = '''Analyze this food image and provide detailed nutritional information in JSON format.
27
  Identify the specific vegetarian food items, estimate portion sizes, and provide nutritional breakdown.
28
  This app focuses on vegetarian foods only, so analyze from that perspective.
 
47
  }
48
 
49
  Return ONLY the JSON without any explanations or markdown formatting.'''
50
+
51
  messages = [
52
  {
53
  "role": "system",
54
+ "content": [{"type": "text", "text": "You are a helpful assistant."}]
55
  },
56
  {
57
  "role": "user",
 
62
  }
63
  ]
64
 
65
+ output = pipe(text=messages, max_new_tokens=200)
66
+
67
+ try:
68
+ response = output[0]["generated_text"][-1]["content"]
69
+ chat_history.append((user_prompt, response))
70
+ except (KeyError, IndexError, TypeError) as e:
71
+ error_message = f"Error processing the response: {str(e)}"
72
+ chat_history.append((user_prompt, error_message))
73
 
74
+ return chat_history
75
 
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# Gemma Powered Calorie Tracker")
78
+ gr.Markdown("Upload an image to get detailed nutritional information.")
79
 
80
+ chatbot = gr.Chatbot()
81
+
82
+ with gr.Row():
83
+ img = gr.Image(
84
+ type="pil",
85
+ label="Upload image (required)",
86
+ scale=1
87
+ )
88
+
89
+ submit_btn = gr.Button("Send")
90
+
91
+ clear_btn = gr.Button("Clear")
92
+
93
+ def clear_interface():
94
+ return [], None
95
+
96
+ submit_btn.click(
97
+ get_response,
98
+ inputs=[chatbot, img],
99
+ outputs=chatbot
100
+ )
101
+
102
+ clear_btn.click(
103
+ clear_interface,
104
+ inputs=None,
105
+ outputs=[chatbot, img]
106
+ )
107
 
108
  if __name__ == "__main__":
109
  demo.launch()
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
- transformers
2
- torch
3
- Pillow
4
- gradio
5
- accelerate
6
- transformers @ git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
7
- python-dotenv
 
 
 
1
+ git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
2
+ gradio>=4.0.0
3
+ torch>=2.0.0
4
+ torchvision>=0.15.0
5
+ pillow>=9.0.0
6
+ requests>=2.28.0
7
+ numpy>=1.22.0
8
+ python-dotenv
9
+ spaces