Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- Dockerfile +2 -2
- app.py +156 -102
Dockerfile
CHANGED
|
@@ -22,5 +22,5 @@ ENV GRADIO_SERVER_NAME="0.0.0.0"
|
|
| 22 |
ENV GRADIO_SERVER_PORT="7860"
|
| 23 |
EXPOSE 7860
|
| 24 |
|
| 25 |
-
# Run
|
| 26 |
-
CMD ["python", "
|
|
|
|
| 22 |
ENV GRADIO_SERVER_PORT="7860"
|
| 23 |
EXPOSE 7860
|
| 24 |
|
| 25 |
+
# Run the persistent web dashboard (keeps Space alive)
|
| 26 |
+
CMD ["python", "app.py"]
|
app.py
CHANGED
|
@@ -1,116 +1,170 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import asyncio
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
"""Helper to run async code synchronously for Gradio callbacks."""
|
| 7 |
-
try:
|
| 8 |
-
loop = asyncio.get_event_loop()
|
| 9 |
-
except RuntimeError:
|
| 10 |
-
loop = asyncio.new_event_loop()
|
| 11 |
-
asyncio.set_event_loop(loop)
|
| 12 |
-
return loop.run_until_complete(coro)
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
"heatmap_coordinates": [[80,80], [150,150]],
|
| 34 |
-
"reasoning": "Human judge overriding UI input."
|
| 35 |
-
}
|
| 36 |
-
|
| 37 |
-
result = await env.step(Action(tool_name=action_name, parameters=params))
|
| 38 |
-
|
| 39 |
-
# Extract states
|
| 40 |
-
obs = result.observation
|
| 41 |
-
budget_str = f"💲 Remaining Budget: ${obs.available_budget:.2f}"
|
| 42 |
-
log_text = "\n\n".join(obs.tool_outputs)
|
| 43 |
-
|
| 44 |
-
img_path = None
|
| 45 |
-
if len(obs.acquired_scans) > 0:
|
| 46 |
-
img_path = obs.acquired_scans[-1]
|
| 47 |
-
|
| 48 |
-
# Reward tracking
|
| 49 |
-
if result.done:
|
| 50 |
-
status = f"✅ DIAGNOSIS COMPLETE! FINAL GRADE: {result.reward:.3f} / 1.0"
|
| 51 |
-
else:
|
| 52 |
-
# Micro reward or penalty
|
| 53 |
-
if result.reward < 0:
|
| 54 |
-
status = f"⚠️ Penalty! Score: {result.reward}"
|
| 55 |
-
else:
|
| 56 |
-
status = f"🔄 Valid Move. Cost applied."
|
| 57 |
-
|
| 58 |
-
return env, budget_str, log_text, img_path, status
|
| 59 |
|
| 60 |
-
def
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
btn_diagnose = gr.Button("📝 Submit Final Diagnosis ($0)", variant="stop")
|
| 99 |
-
|
| 100 |
-
clinical_log = gr.Textbox(label="Secure Clinical Record", lines=10, interactive=False)
|
| 101 |
-
|
| 102 |
-
# Wiring Buttons
|
| 103 |
-
btn_start.click(
|
| 104 |
-
fn=ui_initialize,
|
| 105 |
-
inputs=[difficulty_radio],
|
| 106 |
-
outputs=[env_state, budget_display, clinical_log, scan_image, status_box, status_box]
|
| 107 |
-
)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
if __name__ == "__main__":
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MetaOCT Hackathon Web App
|
| 3 |
+
Runs inference.py in background thread and serves results via a minimal web server.
|
| 4 |
+
This keeps the HuggingFace Space alive permanently.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import threading
|
| 8 |
import asyncio
|
| 9 |
+
import os
|
| 10 |
+
import json
|
| 11 |
+
from http.server import HTTPServer, BaseHTTPRequestHandler
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
from env import MetaOCTEnv, Action, Observation
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from typing import List, Optional
|
| 19 |
|
| 20 |
+
load_dotenv()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/hf-inference/v1/")
|
| 23 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Meta-Llama-3-8B-Instruct")
|
| 24 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 25 |
+
if HF_TOKEN is None and os.getenv("OPENAI_API_KEY") is None:
|
| 26 |
+
print("[WARNING] Required API keys missing.", flush=True)
|
| 27 |
+
API_KEY = os.getenv("OPENAI_API_KEY") or HF_TOKEN or os.getenv("API_KEY")
|
| 28 |
|
| 29 |
+
# Global results store
|
| 30 |
+
results = {"status": "running", "logs": [], "score": None, "success": None}
|
| 31 |
|
| 32 |
+
# Vision Model
|
| 33 |
+
print("[DEBUG] Loading Vision Model...", flush=True)
|
| 34 |
+
try:
|
| 35 |
+
processor = AutoImageProcessor.from_pretrained("octava/image_classification")
|
| 36 |
+
hf_model = AutoModelForImageClassification.from_pretrained("octava/image_classification", output_attentions=True)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
print(f"[DEBUG] Vision model warning: {e}", flush=True)
|
| 39 |
+
processor = None
|
| 40 |
+
hf_model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
def get_vision_prediction(image_path: str):
|
| 43 |
+
diagnosis = "NORMAL"
|
| 44 |
+
heatmap = [[0,0],[0,0]]
|
| 45 |
+
if hf_model is not None:
|
| 46 |
+
try:
|
| 47 |
+
image = Image.open(image_path).convert("RGB")
|
| 48 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
outputs = hf_model(**inputs)
|
| 51 |
+
idx = outputs.logits.argmax(-1).item()
|
| 52 |
+
label = hf_model.config.id2label[idx].upper()
|
| 53 |
+
if "CNV" in label: diagnosis = "CNV"
|
| 54 |
+
elif "DME" in label: diagnosis = "DME"
|
| 55 |
+
elif "DRUSEN" in label: diagnosis = "DRUSEN"
|
| 56 |
+
attentions = outputs.attentions
|
| 57 |
+
avg_attention = attentions[-1].mean(dim=1).squeeze(0)
|
| 58 |
+
cls_attention = avg_attention[0, 1:]
|
| 59 |
+
grid = cls_attention.reshape(14,14)
|
| 60 |
+
max_idx = torch.argmax(grid).item()
|
| 61 |
+
y, x = max_idx // 14, max_idx % 14
|
| 62 |
+
p = 16
|
| 63 |
+
heatmap = [[max(0,(x-1)*p), max(0,(y-1)*p)], [min(224,(x+2)*p), min(224,(y+2)*p)]]
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"[DEBUG] Vision error: {e}", flush=True)
|
| 66 |
+
return diagnosis, heatmap
|
| 67 |
|
| 68 |
+
def get_heuristic_action(step: int, obs, client) -> Action:
|
| 69 |
+
if step == 1: return Action(tool_name="request_oct_scan", parameters={})
|
| 70 |
+
elif step == 2: return Action(tool_name="enhance_contrast", parameters={})
|
| 71 |
+
elif step == 3: return Action(tool_name="measure_fluid_thickness", parameters={})
|
| 72 |
+
else:
|
| 73 |
+
image_path = obs.acquired_scans[-1] if obs.acquired_scans else "dummy.jpg"
|
| 74 |
+
diagnosis, heatmap = get_vision_prediction(image_path)
|
| 75 |
+
reasoning = "Clinical biomarkers align with diagnosis based on retinal morphology."
|
| 76 |
+
try:
|
| 77 |
+
prompt = f"You are an expert ophthalmologist. Diagnose: {diagnosis}. Give 1-sentence reasoning."
|
| 78 |
+
completion = client.chat.completions.create(model=MODEL_NAME, messages=[{"role":"user","content":prompt}], max_tokens=80)
|
| 79 |
+
reasoning = completion.choices[0].message.content.strip()
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"[DEBUG] LLM error: {e}", flush=True)
|
| 82 |
+
return Action(tool_name="submit_diagnosis", parameters={"diagnosis": diagnosis, "heatmap_coordinates": heatmap, "reasoning": reasoning})
|
| 83 |
|
| 84 |
+
async def run_inference():
|
| 85 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 86 |
+
global_rewards = []
|
| 87 |
+
global_steps = 0
|
| 88 |
+
difficulties = ["easy", "medium", "hard"]
|
| 89 |
|
| 90 |
+
log_line = f"[START] task=MetaOCT_POMDP env=meta_oct model={MODEL_NAME}"
|
| 91 |
+
print(log_line, flush=True)
|
| 92 |
+
results["logs"].append(log_line)
|
| 93 |
|
| 94 |
+
for diff in difficulties:
|
| 95 |
+
env = MetaOCTEnv(difficulty=diff)
|
| 96 |
+
for _ in range(min(env.max_patients, 3)):
|
| 97 |
+
obs = await env.reset()
|
| 98 |
+
episode_step = 0
|
| 99 |
+
while True:
|
| 100 |
+
episode_step += 1
|
| 101 |
+
global_steps += 1
|
| 102 |
+
action_obj = get_heuristic_action(episode_step, obs, client)
|
| 103 |
+
result = await env.step(action_obj)
|
| 104 |
+
reward = result.reward or 0.0
|
| 105 |
+
done = result.done
|
| 106 |
+
obs = result.observation
|
| 107 |
+
global_rewards.append(reward)
|
| 108 |
+
step_log = f"[STEP] step={global_steps} action=Tool({action_obj.tool_name}) reward={reward:.2f} done={str(done).lower()} error=null"
|
| 109 |
+
print(step_log, flush=True)
|
| 110 |
+
results["logs"].append(step_log)
|
| 111 |
+
if done: break
|
| 112 |
+
await env.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
+
max_total = float(len(global_rewards))
|
| 115 |
+
total_score = sum(global_rewards) / max_total if max_total > 0 else 0.0
|
| 116 |
+
success = total_score >= 0.7
|
| 117 |
+
end_log = f"[END] success={str(success).lower()} steps={global_steps} score={total_score:.3f} rewards={','.join(f'{r:.2f}' for r in global_rewards)}"
|
| 118 |
+
print(end_log, flush=True)
|
| 119 |
+
results["logs"].append(end_log)
|
| 120 |
+
results["status"] = "complete"
|
| 121 |
+
results["score"] = total_score
|
| 122 |
+
results["success"] = success
|
| 123 |
+
|
| 124 |
+
def run_inference_thread():
|
| 125 |
+
loop = asyncio.new_event_loop()
|
| 126 |
+
asyncio.set_event_loop(loop)
|
| 127 |
+
loop.run_until_complete(run_inference())
|
| 128 |
+
|
| 129 |
+
# Simple Web Server - Serves the results as HTML
|
| 130 |
+
class ResultsHandler(BaseHTTPRequestHandler):
|
| 131 |
+
def log_message(self, format, *args): pass # Suppress access logs
|
| 132 |
|
| 133 |
+
def do_GET(self):
|
| 134 |
+
self.send_response(200)
|
| 135 |
+
self.send_header("Content-type", "text/html")
|
| 136 |
+
self.end_headers()
|
| 137 |
+
|
| 138 |
+
status_color = "#00ff88" if results["status"] == "complete" else "#ffaa00"
|
| 139 |
+
score_display = f"{results['score']:.3f}" if results["score"] is not None else "Running..."
|
| 140 |
+
|
| 141 |
+
html = f"""<!DOCTYPE html>
|
| 142 |
+
<html><head>
|
| 143 |
+
<title>MetaOCT Virtual Clinic</title>
|
| 144 |
+
<meta http-equiv="refresh" content="5">
|
| 145 |
+
<style>
|
| 146 |
+
body{{background:#0d1117;color:#e6edf3;font-family:monospace;padding:40px;}}
|
| 147 |
+
h1{{color:#58a6ff;}} .status{{color:{status_color};font-size:1.4em;}}
|
| 148 |
+
pre{{background:#161b22;padding:20px;border-radius:8px;overflow-x:auto;font-size:13px;max-height:500px;overflow-y:auto;}}
|
| 149 |
+
.score{{font-size:2em;color:#f0883e;}} .badge{{background:#238636;padding:4px 12px;border-radius:20px;}}
|
| 150 |
+
</style></head>
|
| 151 |
+
<body>
|
| 152 |
+
<h1>👁️ MetaOCT: Virtual Medical Clinic (POMDP)</h1>
|
| 153 |
+
<p>Multi-Step Reinforcement Learning Environment | Meta OpenEnv Hackathon</p>
|
| 154 |
+
<p class="status">Status: {results["status"].upper()}</p>
|
| 155 |
+
<p class="score">Score: {score_display}</p>
|
| 156 |
+
<pre>{"<br>".join(results["logs"][-30:])}</pre>
|
| 157 |
+
<p><span class="badge">OpenEnv Compliant</span> Built with PyTorch + LLaMA-3 + OctaVA Vision</p>
|
| 158 |
+
</body></html>"""
|
| 159 |
+
self.wfile.write(html.encode())
|
| 160 |
|
| 161 |
if __name__ == "__main__":
|
| 162 |
+
# Start inference in background
|
| 163 |
+
thread = threading.Thread(target=run_inference_thread, daemon=True)
|
| 164 |
+
thread.start()
|
| 165 |
+
|
| 166 |
+
# Start web server on port 7860 (HuggingFace required)
|
| 167 |
+
port = int(os.getenv("GRADIO_SERVER_PORT", "7860"))
|
| 168 |
+
print(f"[INFO] Starting MetaOCT Dashboard on port {port}", flush=True)
|
| 169 |
+
server = HTTPServer(("0.0.0.0", port), ResultsHandler)
|
| 170 |
+
server.serve_forever()
|