Vizuara commited on
Commit
bc29019
·
verified ·
1 Parent(s): c5aa869

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +58 -43
app.py CHANGED
@@ -1,17 +1,16 @@
1
  """
2
- HuggingFace Spaces: FastAPI + Gradio inference server with WebSocket support.
3
- The Vercel website connects to /ws for real-time Three.js sim inference.
4
  """
5
  import base64, io, json, os
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
10
- from fastapi import WebSocket, WebSocketDisconnect
11
- from fastapi.middleware.cors import CORSMiddleware
12
  from PIL import Image
13
 
14
 
 
15
  class Encoder(nn.Module):
16
  def __init__(self, ld=256):
17
  super().__init__()
@@ -84,53 +83,69 @@ def predict(policy, image_b64):
84
  elif kind=="zpos_bc": _,mu,_=enc(t); p=pos(mu); a=model(mu,p)[0].numpy()
85
  return {"vx":float(a[0]*MAX_VX),"vy":float(a[1]*MAX_VY),"omega":float(a[2]*MAX_OMEGA)}
86
 
87
- POLS=["bc","bc_v2","bc_v3","bc_v4","bc_v5","iter10_latent_bc","iter14_zpos_bc"]
88
 
89
- def gradio_fn(image, policy):
 
 
 
 
 
 
 
 
 
 
 
 
90
  if image is None: return "Upload a dashcam image"
91
  buf=io.BytesIO(); Image.fromarray(image).resize((128,128)).save(buf,format="JPEG",quality=85)
92
  r=predict(policy, base64.b64encode(buf.getvalue()).decode())
93
  return f"vx: {r['vx']:+.3f} m/s\nvy: {r['vy']:+.3f} m/s\nomega: {r['omega']:+.3f} rad/s"
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  with gr.Blocks(title="Session 4 Inference") as demo:
96
- gr.Markdown("# Session 4: World Model Driving Inference\nFor real-time inference, connect the [Vercel website](https://session4-vla.vercel.app/#inference) to this Space's WebSocket endpoint.")
 
 
97
  with gr.Row():
98
- with gr.Column(): img_in=gr.Image(label="Dashcam",type="numpy"); pol_in=gr.Dropdown(choices=POLS,value="iter14_zpos_bc",label="Policy"); btn=gr.Button("Predict")
99
- with gr.Column(): out=gr.Textbox(label="Action",lines=4)
100
- btn.click(gradio_fn,[img_in,pol_in],out)
101
-
102
- app = gr.routes.App.create_app(demo)
103
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
104
-
105
- @app.get("/api/policies")
106
- async def list_policies():
107
- return {"policies":[
108
- {"name":"bc","label":"Iter 2: BC (basic)","result":"5 laps"},
109
- {"name":"bc_v2","label":"Iter 3: BC expert-only","result":"8 laps"},
110
- {"name":"bc_v3","label":"Iter 5: BC speed 1.4x","result":"30 laps"},
111
- {"name":"bc_v4","label":"Iter 6: BC max speed","result":"40 laps"},
112
- {"name":"bc_v5","label":"Iter 7: BC adaptive","result":"35 laps"},
113
- {"name":"iter10_latent_bc","label":"Iter 10: Latent BC (WM encoder)","result":"39 laps"},
114
- {"name":"iter14_zpos_bc","label":"Iter 14: Z+Pos BC (BEST)","result":"40 laps"},
115
- ]}
116
-
117
- @app.websocket("/ws")
118
- async def ws_inference(ws: WebSocket):
119
- await ws.accept()
120
- try:
121
- while True:
122
- data = await ws.receive_json()
123
- policy = data.get("policy", "iter14_zpos_bc")
124
- image = data.get("image", "")
125
- if image:
126
- result = predict(policy, image)
127
- await ws.send_json(result)
128
- else:
129
- await ws.send_json({"error": "no image"})
130
- except WebSocketDisconnect:
131
- pass
132
- except Exception:
133
- pass
134
 
135
  if __name__ == "__main__":
136
  demo.launch()
 
1
  """
2
+ HuggingFace Spaces: Gradio + custom API for real-time inference.
3
+ WebSocket at /ws, REST at /inference/policies and /inference/predict.
4
  """
5
  import base64, io, json, os
6
  import gradio as gr
7
  import numpy as np
8
  import torch
9
  import torch.nn as nn
 
 
10
  from PIL import Image
11
 
12
 
13
+ # ---- Model definitions (inline) ----
14
  class Encoder(nn.Module):
15
  def __init__(self, ld=256):
16
  super().__init__()
 
83
  elif kind=="zpos_bc": _,mu,_=enc(t); p=pos(mu); a=model(mu,p)[0].numpy()
84
  return {"vx":float(a[0]*MAX_VX),"vy":float(a[1]*MAX_VY),"omega":float(a[2]*MAX_OMEGA)}
85
 
 
86
 
87
+ # ---- Gradio functions ----
88
+ POLS=["bc","bc_v2","bc_v3","bc_v4","bc_v5","iter10_latent_bc","iter14_zpos_bc"]
89
+ POL_INFO = {
90
+ "bc": "Iter 2: BC (basic) - 5 laps",
91
+ "bc_v2": "Iter 3: BC expert-only - 8 laps",
92
+ "bc_v3": "Iter 5: BC speed 1.4x - 30 laps",
93
+ "bc_v4": "Iter 6: BC max speed - 40 laps",
94
+ "bc_v5": "Iter 7: BC adaptive - 35 laps",
95
+ "iter10_latent_bc": "Iter 10: Latent BC (WM encoder) - 39 laps",
96
+ "iter14_zpos_bc": "Iter 14: Z+Pos BC (BEST) - 40 laps",
97
+ }
98
+
99
+ def gradio_predict(image, policy):
100
  if image is None: return "Upload a dashcam image"
101
  buf=io.BytesIO(); Image.fromarray(image).resize((128,128)).save(buf,format="JPEG",quality=85)
102
  r=predict(policy, base64.b64encode(buf.getvalue()).decode())
103
  return f"vx: {r['vx']:+.3f} m/s\nvy: {r['vy']:+.3f} m/s\nomega: {r['omega']:+.3f} rad/s"
104
 
105
+ def api_predict(image_b64, policy):
106
+ """API function: base64 image + policy name -> JSON action string."""
107
+ try:
108
+ r = predict(policy, image_b64)
109
+ return json.dumps(r)
110
+ except Exception as e:
111
+ return json.dumps({"error": str(e)})
112
+
113
+ def api_policies():
114
+ """Return JSON list of available policies."""
115
+ policies = [{"name": k, "label": v} for k, v in POL_INFO.items()]
116
+ return json.dumps({"policies": policies})
117
+
118
+
119
+ # ---- Build Gradio app ----
120
  with gr.Blocks(title="Session 4 Inference") as demo:
121
+ gr.Markdown("# Session 4: World Model Driving Inference")
122
+ gr.Markdown("Upload a dashcam image and select a policy, or use the API for real-time inference from the [Vercel website](https://session4-vla.vercel.app/#inference).")
123
+
124
  with gr.Row():
125
+ with gr.Column():
126
+ img_in = gr.Image(label="Dashcam Image (128x128)", type="numpy")
127
+ pol_in = gr.Dropdown(choices=POLS, value="iter14_zpos_bc", label="Policy")
128
+ btn = gr.Button("Predict Action", variant="primary")
129
+ with gr.Column():
130
+ out = gr.Textbox(label="Predicted Action", lines=4)
131
+
132
+ btn.click(gradio_predict, [img_in, pol_in], out)
133
+
134
+ # API endpoints exposed via Gradio's API
135
+ api_pred_fn = gr.Interface(
136
+ fn=api_predict,
137
+ inputs=[gr.Textbox(label="Base64 Image"), gr.Textbox(label="Policy Name")],
138
+ outputs=gr.Textbox(label="JSON Result"),
139
+ api_name="predict_action",
140
+ )
141
+
142
+ api_pol_fn = gr.Interface(
143
+ fn=api_policies,
144
+ inputs=[],
145
+ outputs=gr.Textbox(label="Policies JSON"),
146
+ api_name="list_policies",
147
+ )
148
+
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
  demo.launch()