agrd's picture
Update agent.py
6e11ff2 verified
import json
import re
import os
from huggingface_hub import InferenceClient
from schemas import TakeoffRequest
from calculator import calculate_takeoff_roll_data
# --- AUTHENTICATION SETUP ---
# 1. Try to load environment variables from a .env file (for local dev)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
# 2. Get Token (Robust Check)
# Check for "TokenApertus" (your new request)
hf_token = os.environ.get("tokenApertus1")
# Debugging: Print status (without revealing key)
if hf_token:
print(f"✅ Environment variable found. (Length: {len(hf_token)})")
else:
print("❌ Environment variable 'tokenApertus1' (or 'tokenApertus1') NOT found.")
# DEBUG AID: Check if there's a typo in the variable name
# This will print any available environment variables that contain "TOKEN" or "APERTUS"
# so you can see if you named it slightly differently (e.g. all caps).
candidates = [k for k in os.environ.keys() if "TOKEN" in k.upper() or "APERTUS" in k.upper()]
if candidates:
print(f" ⚠️ Found these similar variables: {candidates}")
print(" Please update the code to match the exact name used in Settings.")
else:
print(" ⚠️ No variables containing 'TOKEN' or 'APERTUS' were found.")
print(" Ensure you created a 'Secret' (not a Variable) in HF Space Settings and performed a Factory Reboot.")
# Initialize Client safely
client = None
if hf_token:
try:
# Only initialize if we have a token
client = InferenceClient(token=hf_token)
except Exception as e:
print(f"⚠️ Error initializing client: {e}")
client = None
else:
print("⚠️ WARNING: API Token is NOT set.")
print(" Please create a Secret named 'TokenApertus' in your Hugging Face Space settings.")
# client remains None
def _extract_json_string(text):
"""
Robust helper to find JSON in mixed text.
It looks for the first '{' and the last '}' to handle cases where
the model is chatty (e.g. "Sure, here is the JSON: {...}")
"""
# 1. Try finding markdown block first (most reliable)
pattern = r"```json\s*(.*?)\s*```"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1)
# 2. Fallback: Find the outermost curly braces
start = text.find('{')
end = text.rfind('}')
if start != -1 and end != -1 and end > start:
return text[start:end+1]
# 3. Last resort: return original text
return text
def call_apertus_llm(system_prompt, user_prompt):
"""
Calls the Apertus Instruct model via Hugging Face API.
"""
# CRITICAL: Check if client exists before calling
if client is None:
return json.dumps({})
model_id = "swiss-ai/Apertus-8B-Instruct-2509"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
try:
response = client.chat_completion(
model=model_id,
messages=messages,
max_tokens=500,
temperature=0.1,
seed=42
)
raw_content = response.choices[0].message.content
# DEBUG: Print what the model actually said to the console
print(f"--- LLM Raw Response ---\n{raw_content}\n------------------------")
cleaned_json = _extract_json_string(raw_content)
return cleaned_json
except Exception as e:
print(f"API Error: {e}")
return "{}"
class TakeoffAgent:
def __init__(self):
self.current_state = TakeoffRequest()
def process_message(self, user_message: str):
# 1. System Prompt construction
system_instruction = """
You are an extraction assistant for a PA-28 Flight Calculator.
Extract parameters from the user text into JSON format matching these keys:
- altitude_ft (float)
- qnh_hpa (float)
- temperature_c (float)
- weight_kg (float)
- wind_type ("Headwind" or "Tailwind")
- wind_speed_kt (float)
- safety_factor (float)
If a value is not provided, do not include the key in the JSON.
Return ONLY valid JSON.
"""
# 2. Call LLM
llm_response_json = call_apertus_llm(system_instruction, user_message)
# 3. Update Pydantic Model (State Management)
try:
new_data = json.loads(llm_response_json)
# DEBUG: Print what data we extracted
print(f"Extracted Data: {new_data}")
# Update only fields that are present in the new extraction
updated_fields = self.current_state.model_dump(exclude_defaults=True)
updated_fields.update(new_data)
self.current_state = TakeoffRequest(**updated_fields)
except json.JSONDecodeError:
print("⚠️ Failed to parse JSON from LLM response")
except Exception as e:
print(f"⚠️ Error updating state: {e}")
# 4. Check Completeness (Guardrails)
if not self.current_state.is_complete():
missing = self.current_state.get_missing_fields()
response_text = f"I updated your flight parameters. I still need: **{', '.join(missing)}**.\n\n"
response_text += f"**Current State:**\n{self._format_state_summary()}"
if client is None:
response_text += "\n\n⚠️ **System Alert:** API Key missing. Please set 'TokenApertus' in your Hugging Face Space settings."
return response_text
# 5. Check Correctness (Validation Logic)
warnings = []
if self.current_state.temperature_c > 45:
warnings.append("⚠️ Warning: Temperature is extremely high.")
if self.current_state.weight_kg > 1160:
warnings.append("⚠️ Warning: Weight exceeds typical MTOW.")
# 6. Run Calculation
try:
result = calculate_takeoff_roll_data(
indicated_altitude_ft=self.current_state.altitude_ft,
qnh_hpa=self.current_state.qnh_hpa,
temperature_c=self.current_state.temperature_c,
weight_kg=self.current_state.weight_kg,
wind_type=self.current_state.wind_type,
wind_speed=self.current_state.wind_speed_kt,
safety_factor=self.current_state.safety_factor
)
# 7. Formulate Response
response = "### ✅ Takeoff Performance Calculated\n\n"
if warnings:
response += "**Alerts:**\n" + "\n".join(warnings) + "\n\n"
response += f"**Environmental:**\n"
response += f"- Pressure Alt: {result['pressure_altitude']:.0f} ft\n"
response += f"- Density Alt: {result['density_altitude']:.0f} ft\n\n"
response += f"**Ground Roll:**\n"
response += f"- Base: {result['ground_roll']['base']:.1f} ft\n"
response += f"- Corrections: Weight {result['ground_roll']['weight_adj']:.1f}, Wind {result['ground_roll']['wind_adj']:.1f}\n"
response += f"- **Final: {result['ground_roll']['final_m']:.0f} meters** ({result['ground_roll']['final_ft']:.0f} ft)\n\n"
response += f"**50ft Obstacle:**\n"
response += f"- **Final: {result['obstacle_50ft']['final_m']:.0f} meters** ({result['obstacle_50ft']['final_ft']:.0f} ft)\n"
return response
except Exception as e:
return f"Error in calculation: {str(e)}"
def _format_state_summary(self):
s = self.current_state
return (f"- Alt: {s.altitude_ft} ft\n- Temp: {s.temperature_c} C\n"
f"- Weight: {s.weight_kg} kg\n- Wind: {s.wind_speed_kt} kt ({s.wind_type})")