MaxSainz2000 commited on
Commit
2fd78a3
·
verified ·
1 Parent(s): 98fa1c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -28
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import datetime
3
  from fastapi import FastAPI
4
  from supabase import create_client, Client
@@ -6,44 +8,68 @@ from transformers import pipeline
6
  from huggingface_hub import login
7
  import torch
8
 
 
 
 
 
9
  app = FastAPI()
10
- url = os.environ.get("SUPABASE_URL")
11
- key = os.environ.get("SUPABASE_KEY")
12
- hf_token = os.environ.get("HF_TOKEN")
13
- supabase: Client = create_client(url, key)
14
 
15
- if hf_token:
16
- login(token=hf_token)
17
 
18
- # Load NVIDIA Nemotron-Mini-4B
19
  pipe = pipeline(
20
  "text-generation",
21
- model="nvidia/Llama-3.1-Minitron-4B-Width-Base", # Or nvidia/Llama-3.1-Minitron-4B-Width-Base
22
  model_kwargs={"torch_dtype": torch.bfloat16},
23
- device_map="auto",
24
- token=hf_token
25
  )
26
 
27
- AGENT_NAME = "Nemotron-Validator-Pod"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  @app.on_event("startup")
30
- def startup_event():
31
- agent_data = {
32
- "name": AGENT_NAME,
33
- "model_name": "nvidia/Minitron-4B",
34
- "status": "online",
35
- "last_seen": datetime.datetime.now().isoformat(),
36
- "capabilities": {"task": "validation", "specialty": "accuracy_check"}
37
- }
38
  supabase.table("agents").upsert(agent_data, on_conflict="name").execute()
39
 
40
  @app.get("/")
41
- def health():
42
- return {"status": "running", "agent": AGENT_NAME}
43
-
44
- @app.post("/validate")
45
- async def validate(original_prompt: str, ai_response: str):
46
- # Prompt engineering to make Nemotron act as a judge
47
- validation_prompt = f"Task: {original_prompt}\nResponse: {ai_response}\nIs this response accurate? Answer with YES or NO and a brief reason."
48
- outputs = pipe(validation_prompt, max_new_tokens=100, do_sample=False)
49
- return {"validation": outputs[0]["generated_text"]}
 
1
  import os
2
+ import threading
3
+ import time
4
  import datetime
5
  from fastapi import FastAPI
6
  from supabase import create_client, Client
 
8
  from huggingface_hub import login
9
  import torch
10
 
11
+ # --- CONFIGURATION ---
12
+ AGENT_NAME = "Nemotron-Validator-Pod"
13
+ MODEL_ID = "nvidia/Llama-3.1-Minitron-4B-Width-Base"
14
+
15
  app = FastAPI()
16
+ supabase: Client = create_client(os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_KEY"))
 
 
 
17
 
18
+ if os.environ.get("HF_TOKEN"):
19
+ login(token=os.environ.get("HF_TOKEN"))
20
 
21
+ print(f"📦 Loading {MODEL_ID}...")
22
  pipe = pipeline(
23
  "text-generation",
24
+ model=MODEL_ID,
25
  model_kwargs={"torch_dtype": torch.bfloat16},
26
+ device_map="auto"
 
27
  )
28
 
29
+ def worker_loop():
30
+ print(f"⚖️ {AGENT_NAME} Validator Loop Started.")
31
+ while True:
32
+ try:
33
+ # 1. Pull tasks assigned to ME for validation
34
+ res = supabase.table("tasks").select("*").eq("status", "processing_val").eq("assigned_to_name", AGENT_NAME).execute()
35
+
36
+ for task in res.data:
37
+ task_id = task['id']
38
+ original_prompt = task['input_text']
39
+ # Get the output gemma just posted
40
+ gemma_content = task['output_data'].get('agent_gemma', {}).get('content', '')
41
+
42
+ print(f"🔍 Validating task {task_id[:8]}...")
43
+
44
+ # 2. Validation Inference
45
+ val_prompt = f"Task: {original_prompt}\nResponse: {gemma_content}\nAnalyze if this is correct. Reply with VALID or INVALID and reason."
46
+
47
+ start_time = time.time()
48
+ outputs = pipe(val_prompt, max_new_tokens=150)
49
+ val_result = outputs[0]["generated_text"]
50
+ latency = round(time.time() - start_time, 2)
51
+
52
+ # 3. Update Supabase: Merge validation into output_data and flip status
53
+ new_output_data = task['output_data']
54
+ new_output_data['agent_nemotron'] = {"content": val_result, "latency": f"{latency}s"}
55
+
56
+ supabase.table("tasks").update({
57
+ "status": "val_completed",
58
+ "output_data": new_output_data
59
+ }).eq("id", task_id).execute()
60
+ print(f"✅ Validated task {task_id[:8]} in {latency}s")
61
+
62
+ except Exception as e:
63
+ print(f"⚠️ Validator Error: {e}")
64
+
65
+ time.sleep(2)
66
+
67
+ threading.Thread(target=worker_loop, daemon=True).start()
68
 
69
  @app.on_event("startup")
70
+ def register():
71
+ agent_data = {"name": AGENT_NAME, "status": "online", "last_seen": datetime.datetime.now().isoformat()}
 
 
 
 
 
 
72
  supabase.table("agents").upsert(agent_data, on_conflict="name").execute()
73
 
74
  @app.get("/")
75
+ def health(): return {"status": "alive", "worker": AGENT_NAME}