Vaish6 commited on
Commit
7560863
·
verified ·
1 Parent(s): 7fa5ff0

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -2
  2. app.py +156 -102
Dockerfile CHANGED
@@ -22,5 +22,5 @@ ENV GRADIO_SERVER_NAME="0.0.0.0"
22
  ENV GRADIO_SERVER_PORT="7860"
23
  EXPOSE 7860
24
 
25
- # Run inference evaluation (Hackathon graders evaluate stdout)
26
- CMD ["python", "inference.py"]
 
22
  ENV GRADIO_SERVER_PORT="7860"
23
  EXPOSE 7860
24
 
25
+ # Run the persistent web dashboard (keeps Space alive)
26
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,116 +1,170 @@
1
- import gradio as gr
 
 
 
 
 
 
2
  import asyncio
3
- from env import MetaOCTEnv, Action
 
 
 
 
 
 
 
 
 
4
 
5
- def run_async(coro):
6
- """Helper to run async code synchronously for Gradio callbacks."""
7
- try:
8
- loop = asyncio.get_event_loop()
9
- except RuntimeError:
10
- loop = asyncio.new_event_loop()
11
- asyncio.set_event_loop(loop)
12
- return loop.run_until_complete(coro)
13
 
14
- async def init_env(difficulty="medium"):
15
- env = MetaOCTEnv(difficulty=difficulty)
16
- obs = await env.reset()
17
- budget_str = f"💲 Remaining Budget: ${obs.available_budget:.2f}"
18
- log_text = "\n\n".join(obs.tool_outputs)
19
- return env, budget_str, log_text, None, "", "Start interacting..."
20
 
21
- def ui_initialize(difficulty):
22
- return run_async(init_env(difficulty))
23
 
24
- async def take_step(env, action_name, diagnosis_input="NORMAL"):
25
- if env is None:
26
- return None, "Error: Start a new patient first!", "", None, "", ""
27
-
28
- params = {}
29
- if action_name == "submit_diagnosis":
30
- # Simplified parameters for UI
31
- params = {
32
- "diagnosis": diagnosis_input,
33
- "heatmap_coordinates": [[80,80], [150,150]],
34
- "reasoning": "Human judge overriding UI input."
35
- }
36
-
37
- result = await env.step(Action(tool_name=action_name, parameters=params))
38
-
39
- # Extract states
40
- obs = result.observation
41
- budget_str = f"💲 Remaining Budget: ${obs.available_budget:.2f}"
42
- log_text = "\n\n".join(obs.tool_outputs)
43
-
44
- img_path = None
45
- if len(obs.acquired_scans) > 0:
46
- img_path = obs.acquired_scans[-1]
47
-
48
- # Reward tracking
49
- if result.done:
50
- status = f"✅ DIAGNOSIS COMPLETE! FINAL GRADE: {result.reward:.3f} / 1.0"
51
- else:
52
- # Micro reward or penalty
53
- if result.reward < 0:
54
- status = f"⚠️ Penalty! Score: {result.reward}"
55
- else:
56
- status = f"🔄 Valid Move. Cost applied."
57
-
58
- return env, budget_str, log_text, img_path, status
59
 
60
- def ui_step_scan(env): return run_async(take_step(env, "request_oct_scan"))
61
- def ui_step_enhance(env): return run_async(take_step(env, "enhance_contrast"))
62
- def ui_step_measure(env): return run_async(take_step(env, "measure_fluid_thickness"))
63
- def ui_step_diagnose(env, diag): return run_async(take_step(env, "submit_diagnosis", diag))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- # Construct Gradio Theme
66
- custom_theme = gr.themes.Soft(
67
- primary_hue="blue",
68
- secondary_hue="indigo",
69
- neutral_hue="slate"
70
- )
 
 
 
 
 
 
 
 
 
71
 
72
- with gr.Blocks(theme=custom_theme, title="MetaOCT Virtual Clinic") as demo:
73
- gr.Markdown("# 👁️ MetaOCT: Virtual Medical Clinic (POMDP)")
74
- gr.Markdown("Prove your diagnostic efficiency. You have a limited budget. Perform necessary scans before extracting the final diagnosis! Made for the Meta OpenEnv Challenge.")
 
 
75
 
76
- # Stores the environment instance
77
- env_state = gr.State(None)
 
78
 
79
- with gr.Row():
80
- # LEFT COLUMN (Visuals & Economics)
81
- with gr.Column(scale=1):
82
- difficulty_radio = gr.Radio(["easy", "medium", "hard"], value="medium", label="Task Difficulty")
83
- btn_start = gr.Button("🏥 Accept New Patient", variant="primary")
84
-
85
- budget_display = gr.Markdown("### 💲 Remaining Budget: --")
86
- scan_image = gr.Image(label="Optical Coherence Tomography (OCT)", type="filepath", interactive=False)
87
- status_box = gr.Textbox(label="Evaluation Status", interactive=False)
88
-
89
- # RIGHT COLUMN (Interactions & Output)
90
- with gr.Column(scale=1):
91
- gr.Markdown("### 🛠️ Clinical Tools")
92
- btn_tool_1 = gr.Button("🔍 Tool: Request Scan (-$150)")
93
- btn_tool_2 = gr.Button(" Tool: Enhance Contrast (-$50)")
94
- btn_tool_3 = gr.Button("📏 Tool: Measure Fluid Thickness (-$200)")
95
-
96
- gr.Markdown("### 📋 Final Diagnosis (Terminal State)")
97
- diagnosis_dropdown = gr.Dropdown(["NORMAL", "CNV", "DME", "DRUSEN"], label="Select Pathogen", value="NORMAL")
98
- btn_diagnose = gr.Button("📝 Submit Final Diagnosis ($0)", variant="stop")
99
-
100
- clinical_log = gr.Textbox(label="Secure Clinical Record", lines=10, interactive=False)
101
-
102
- # Wiring Buttons
103
- btn_start.click(
104
- fn=ui_initialize,
105
- inputs=[difficulty_radio],
106
- outputs=[env_state, budget_display, clinical_log, scan_image, status_box, status_box]
107
- )
108
 
109
- btn_tool_1.click(fn=ui_step_scan, inputs=[env_state], outputs=[env_state, budget_display, clinical_log, scan_image, status_box])
110
- btn_tool_2.click(fn=ui_step_enhance, inputs=[env_state], outputs=[env_state, budget_display, clinical_log, scan_image, status_box])
111
- btn_tool_3.click(fn=ui_step_measure, inputs=[env_state], outputs=[env_state, budget_display, clinical_log, scan_image, status_box])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
- btn_diagnose.click(fn=ui_step_diagnose, inputs=[env_state, diagnosis_dropdown], outputs=[env_state, budget_display, clinical_log, scan_image, status_box])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  if __name__ == "__main__":
116
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MetaOCT Hackathon Web App
3
+ Runs inference.py in background thread and serves results via a minimal web server.
4
+ This keeps the HuggingFace Space alive permanently.
5
+ """
6
+
7
+ import threading
8
  import asyncio
9
+ import os
10
+ import json
11
+ from http.server import HTTPServer, BaseHTTPRequestHandler
12
+ from dotenv import load_dotenv
13
+ from env import MetaOCTEnv, Action, Observation
14
+ from openai import OpenAI
15
+ import torch
16
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
17
+ from PIL import Image
18
+ from typing import List, Optional
19
 
20
+ load_dotenv()
 
 
 
 
 
 
 
21
 
22
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/hf-inference/v1/")
23
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
24
+ HF_TOKEN = os.getenv("HF_TOKEN")
25
+ if HF_TOKEN is None and os.getenv("OPENAI_API_KEY") is None:
26
+ print("[WARNING] Required API keys missing.", flush=True)
27
+ API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN or os.getenv("API_KEY")
28
 
29
+ # Global results store
30
+ results = {"status": "running", "logs": [], "score": None, "success": None}
31
 
32
+ # Vision Model
33
+ print("[DEBUG] Loading Vision Model...", flush=True)
34
+ try:
35
+ processor = AutoImageProcessor.from_pretrained("octava/image_classification")
36
+ hf_model = AutoModelForImageClassification.from_pretrained("octava/image_classification", output_attentions=True)
37
+ except Exception as e:
38
+ print(f"[DEBUG] Vision model warning: {e}", flush=True)
39
+ processor = None
40
+ hf_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def get_vision_prediction(image_path: str):
43
+ diagnosis = "NORMAL"
44
+ heatmap = [[0,0],[0,0]]
45
+ if hf_model is not None:
46
+ try:
47
+ image = Image.open(image_path).convert("RGB")
48
+ inputs = processor(images=image, return_tensors="pt")
49
+ with torch.no_grad():
50
+ outputs = hf_model(**inputs)
51
+ idx = outputs.logits.argmax(-1).item()
52
+ label = hf_model.config.id2label[idx].upper()
53
+ if "CNV" in label: diagnosis = "CNV"
54
+ elif "DME" in label: diagnosis = "DME"
55
+ elif "DRUSEN" in label: diagnosis = "DRUSEN"
56
+ attentions = outputs.attentions
57
+ avg_attention = attentions[-1].mean(dim=1).squeeze(0)
58
+ cls_attention = avg_attention[0, 1:]
59
+ grid = cls_attention.reshape(14,14)
60
+ max_idx = torch.argmax(grid).item()
61
+ y, x = max_idx // 14, max_idx % 14
62
+ p = 16
63
+ heatmap = [[max(0,(x-1)*p), max(0,(y-1)*p)], [min(224,(x+2)*p), min(224,(y+2)*p)]]
64
+ except Exception as e:
65
+ print(f"[DEBUG] Vision error: {e}", flush=True)
66
+ return diagnosis, heatmap
67
 
68
+ def get_heuristic_action(step: int, obs, client) -> Action:
69
+ if step == 1: return Action(tool_name="request_oct_scan", parameters={})
70
+ elif step == 2: return Action(tool_name="enhance_contrast", parameters={})
71
+ elif step == 3: return Action(tool_name="measure_fluid_thickness", parameters={})
72
+ else:
73
+ image_path = obs.acquired_scans[-1] if obs.acquired_scans else "dummy.jpg"
74
+ diagnosis, heatmap = get_vision_prediction(image_path)
75
+ reasoning = "Clinical biomarkers align with diagnosis based on retinal morphology."
76
+ try:
77
+ prompt = f"You are an expert ophthalmologist. Diagnose: {diagnosis}. Give 1-sentence reasoning."
78
+ completion = client.chat.completions.create(model=MODEL_NAME, messages=[{"role":"user","content":prompt}], max_tokens=80)
79
+ reasoning = completion.choices[0].message.content.strip()
80
+ except Exception as e:
81
+ print(f"[DEBUG] LLM error: {e}", flush=True)
82
+ return Action(tool_name="submit_diagnosis", parameters={"diagnosis": diagnosis, "heatmap_coordinates": heatmap, "reasoning": reasoning})
83
 
84
+ async def run_inference():
85
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
86
+ global_rewards = []
87
+ global_steps = 0
88
+ difficulties = ["easy", "medium", "hard"]
89
 
90
+ log_line = f"[START] task=MetaOCT_POMDP env=meta_oct model={MODEL_NAME}"
91
+ print(log_line, flush=True)
92
+ results["logs"].append(log_line)
93
 
94
+ for diff in difficulties:
95
+ env = MetaOCTEnv(difficulty=diff)
96
+ for _ in range(min(env.max_patients, 3)):
97
+ obs = await env.reset()
98
+ episode_step = 0
99
+ while True:
100
+ episode_step += 1
101
+ global_steps += 1
102
+ action_obj = get_heuristic_action(episode_step, obs, client)
103
+ result = await env.step(action_obj)
104
+ reward = result.reward or 0.0
105
+ done = result.done
106
+ obs = result.observation
107
+ global_rewards.append(reward)
108
+ step_log = f"[STEP] step={global_steps} action=Tool({action_obj.tool_name}) reward={reward:.2f} done={str(done).lower()} error=null"
109
+ print(step_log, flush=True)
110
+ results["logs"].append(step_log)
111
+ if done: break
112
+ await env.close()
 
 
 
 
 
 
 
 
 
 
113
 
114
+ max_total = float(len(global_rewards))
115
+ total_score = sum(global_rewards) / max_total if max_total > 0 else 0.0
116
+ success = total_score >= 0.7
117
+ end_log = f"[END] success={str(success).lower()} steps={global_steps} score={total_score:.3f} rewards={','.join(f'{r:.2f}' for r in global_rewards)}"
118
+ print(end_log, flush=True)
119
+ results["logs"].append(end_log)
120
+ results["status"] = "complete"
121
+ results["score"] = total_score
122
+ results["success"] = success
123
+
124
+ def run_inference_thread():
125
+ loop = asyncio.new_event_loop()
126
+ asyncio.set_event_loop(loop)
127
+ loop.run_until_complete(run_inference())
128
+
129
+ # Simple Web Server - Serves the results as HTML
130
+ class ResultsHandler(BaseHTTPRequestHandler):
131
+ def log_message(self, format, *args): pass # Suppress access logs
132
 
133
+ def do_GET(self):
134
+ self.send_response(200)
135
+ self.send_header("Content-type", "text/html")
136
+ self.end_headers()
137
+
138
+ status_color = "#00ff88" if results["status"] == "complete" else "#ffaa00"
139
+ score_display = f"{results['score']:.3f}" if results["score"] is not None else "Running..."
140
+
141
+ html = f"""<!DOCTYPE html>
142
+ <html><head>
143
+ <title>MetaOCT Virtual Clinic</title>
144
+ <meta http-equiv="refresh" content="5">
145
+ <style>
146
+ body{{background:#0d1117;color:#e6edf3;font-family:monospace;padding:40px;}}
147
+ h1{{color:#58a6ff;}} .status{{color:{status_color};font-size:1.4em;}}
148
+ pre{{background:#161b22;padding:20px;border-radius:8px;overflow-x:auto;font-size:13px;max-height:500px;overflow-y:auto;}}
149
+ .score{{font-size:2em;color:#f0883e;}} .badge{{background:#238636;padding:4px 12px;border-radius:20px;}}
150
+ </style></head>
151
+ <body>
152
+ <h1>👁️ MetaOCT: Virtual Medical Clinic (POMDP)</h1>
153
+ <p>Multi-Step Reinforcement Learning Environment | Meta OpenEnv Hackathon</p>
154
+ <p class="status">Status: {results["status"].upper()}</p>
155
+ <p class="score">Score: {score_display}</p>
156
+ <pre>{"<br>".join(results["logs"][-30:])}</pre>
157
+ <p><span class="badge">OpenEnv Compliant</span> &nbsp; Built with PyTorch + LLaMA-3 + OctaVA Vision</p>
158
+ </body></html>"""
159
+ self.wfile.write(html.encode())
160
 
161
  if __name__ == "__main__":
162
+ # Start inference in background
163
+ thread = threading.Thread(target=run_inference_thread, daemon=True)
164
+ thread.start()
165
+
166
+ # Start web server on port 7860 (HuggingFace required)
167
+ port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
168
+ print(f"[INFO] Starting MetaOCT Dashboard on port {port}", flush=True)
169
+ server = HTTPServer(("0.0.0.0", port), ResultsHandler)
170
+ server.serve_forever()