Tin113 commited on
Commit
8f08c02
verified
1 Parent(s): 6c39b64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -19
app.py CHANGED
@@ -186,9 +186,8 @@ transform = transforms.Compose([
186
  transforms.ToTensor(),
187
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
188
  ])
189
-
190
  def create_interface():
191
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
192
 
193
  try:
194
  model, word_to_idx, idx_to_word = load_model(
@@ -198,29 +197,38 @@ def create_interface():
198
  device
199
  )
200
 
201
- def vqa_interface(image, question):
202
- return predict(image, question, model, word_to_idx, idx_to_word, device)
203
-
204
- examples = [
205
- ["example1.jpg", "What color is the animal?"],
206
- ["example2.jpg", "Is this a cat or a dog?"]
207
- ]
 
 
 
 
 
 
208
 
209
- return gr.Interface(
210
- fn=vqa_interface,
211
  inputs=[
212
- gr.Image(type="pil", label="Upload an image"),
213
- gr.Textbox(label="Ask a question about the image")
214
  ],
215
  outputs=gr.Textbox(label="Answer"),
216
- examples=examples,
217
- title="Visual Question Answering System",
218
- description="Upload an image and ask a question about it. The model will try to answer."
219
  )
 
220
  except Exception as e:
221
- print(f"Interface creation failed: {e}")
222
- return gr.Interface(lambda x: "Model loading failed", "text", "text")
223
 
224
  if __name__ == "__main__":
225
  iface = create_interface()
226
- iface.launch()
 
 
 
 
 
186
  transforms.ToTensor(),
187
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
188
  ])
 
189
  def create_interface():
190
+ device = 'cpu' # Lu么n d霉ng CPU tr锚n Spaces
191
 
192
  try:
193
  model, word_to_idx, idx_to_word = load_model(
 
197
  device
198
  )
199
 
200
+ def predict(image, question):
201
+ try:
202
+ transform = transforms.Compose([
203
+ transforms.Resize((224, 224)),
204
+ transforms.ToTensor(),
205
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
206
+ std=[0.229, 0.224, 0.225])
207
+ ])
208
+ image = transform(image).unsqueeze(0).to(device)
209
+ answer = model.predict(image, question, word_to_idx, idx_to_word, device)
210
+ return answer
211
+ except Exception as e:
212
+ return f"Error: {str(e)}"
213
 
214
+ iface = gr.Interface(
215
+ fn=predict,
216
  inputs=[
217
+ gr.Image(type="pil", label="Upload Image"),
218
+ gr.Textbox(label="Question")
219
  ],
220
  outputs=gr.Textbox(label="Answer"),
221
+ title="VQA System",
222
+ description="Upload an image and ask questions about it"
 
223
  )
224
+ return iface
225
  except Exception as e:
226
+ return gr.Interface(lambda: "Model failed to load", None, "text")
 
227
 
228
  if __name__ == "__main__":
229
  iface = create_interface()
230
+ iface.launch(
231
+ server_name="0.0.0.0",
232
+ server_port=7860
233
+ )
234
+