thliang01 commited on
Commit
627fa81
·
unverified ·
1 Parent(s): 801b1cf

chore: remove example

Browse files
Files changed (1) hide show
  1. app.py +28 -37
app.py CHANGED
@@ -2,9 +2,7 @@ import spaces
2
  import torch
3
  import gradio as gr
4
  import requests
5
- from PIL import Image
6
  from torchvision import transforms, models
7
- import os
8
 
9
  # Download human-readable labels for ImageNet.
10
  response = requests.get("https://git.io/JJkYN")
@@ -23,51 +21,44 @@ preprocess = transforms.Compose([
23
  ])
24
 
25
  @spaces.GPU(duration=60)
26
- def gpu_inference(input_tensor):
27
- """Isolated GPU function"""
28
- device = torch.device("cuda")
29
- model.to(device)
30
- input_tensor = input_tensor.to(device)
31
-
32
  with torch.no_grad():
 
 
 
 
 
 
33
  output = model(input_tensor)
34
- prediction = torch.nn.functional.softmax(output[0], dim=0)
35
-
36
- return prediction.cpu()
37
 
38
  def predict(inp):
39
- """Main prediction function"""
40
  if inp is None:
41
  return {}
42
 
43
- try:
44
- # Preprocess on CPU
45
- input_tensor = preprocess(inp).unsqueeze(0)
46
-
47
- # Run GPU inference
48
- prediction = gpu_inference(input_tensor)
49
-
50
- # Convert to confidences
51
- confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
52
- return confidences
53
- except Exception as e:
54
- print(f"Error in prediction: {e}")
55
- return {"error": str(e)}
56
-
57
- # Create example list only if example files exist
58
- example_files = ["lion.jpg", "cheetah.jpg", "cat.avif", "hot-dog.avif", "llama.jpg", "medieval_knight.jpg"]
59
- examples = [f for f in example_files if os.path.exists(f)]
60
 
61
- # Create Gradio interface
62
- iface = gr.Interface(
63
  fn=predict,
64
  inputs=gr.Image(type="pil"),
65
  outputs=gr.Label(num_top_classes=3),
66
- examples=examples if examples else None,
67
- cache_examples=False,
68
  title="Image Classifier with ZeroGPU",
 
69
  css=".footer{display:none !important}"
70
- )
71
-
72
- if __name__ == "__main__":
73
- iface.launch()
 
2
  import torch
3
  import gradio as gr
4
  import requests
 
5
  from torchvision import transforms, models
 
6
 
7
  # Download human-readable labels for ImageNet.
8
  response = requests.get("https://git.io/JJkYN")
 
21
  ])
22
 
23
  @spaces.GPU(duration=60)
24
+ def run_model_on_gpu(input_tensor):
25
+ """Pure GPU computation function - no Gradio context"""
 
 
 
 
26
  with torch.no_grad():
27
+ # Move everything to GPU
28
+ device = torch.device("cuda")
29
+ model.to(device)
30
+ input_tensor = input_tensor.to(device)
31
+
32
+ # Run inference
33
  output = model(input_tensor)
34
+
35
+ # Return CPU tensor
36
+ return output.cpu()
37
 
38
  def predict(inp):
39
+ """Main prediction function that handles all logic"""
40
  if inp is None:
41
  return {}
42
 
43
+ # Preprocess image on CPU
44
+ input_tensor = preprocess(inp).unsqueeze(0)
45
+
46
+ # Get model output via GPU function
47
+ output = run_model_on_gpu(input_tensor)
48
+
49
+ # Process predictions on CPU
50
+ prediction = torch.nn.functional.softmax(output[0], dim=0)
51
+ confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
52
+
53
+ # Return top predictions
54
+ return confidences
 
 
 
 
 
55
 
56
+ # Create Gradio interface without examples first to test
57
+ gr.Interface(
58
  fn=predict,
59
  inputs=gr.Image(type="pil"),
60
  outputs=gr.Label(num_top_classes=3),
 
 
61
  title="Image Classifier with ZeroGPU",
62
+ description="Upload an image to classify it using ResNet-34",
63
  css=".footer{display:none !important}"
64
+ ).launch()