Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import os | |
| from huggingface_hub import InferenceClient | |
| from schemas import TakeoffRequest | |
| from calculator import calculate_takeoff_roll_data | |
| # --- AUTHENTICATION SETUP --- | |
| # 1. Try to load environment variables from a .env file (for local dev) | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| # 2. Get Token (Robust Check) | |
| # Check for "TokenApertus" (your new request) | |
| hf_token = os.environ.get("tokenApertus1") | |
| # Debugging: Print status (without revealing key) | |
| if hf_token: | |
| print(f"✅ Environment variable found. (Length: {len(hf_token)})") | |
| else: | |
| print("❌ Environment variable 'tokenApertus1' (or 'tokenApertus1') NOT found.") | |
| # DEBUG AID: Check if there's a typo in the variable name | |
| # This will print any available environment variables that contain "TOKEN" or "APERTUS" | |
| # so you can see if you named it slightly differently (e.g. all caps). | |
| candidates = [k for k in os.environ.keys() if "TOKEN" in k.upper() or "APERTUS" in k.upper()] | |
| if candidates: | |
| print(f" ⚠️ Found these similar variables: {candidates}") | |
| print(" Please update the code to match the exact name used in Settings.") | |
| else: | |
| print(" ⚠️ No variables containing 'TOKEN' or 'APERTUS' were found.") | |
| print(" Ensure you created a 'Secret' (not a Variable) in HF Space Settings and performed a Factory Reboot.") | |
| # Initialize Client safely | |
| client = None | |
| if hf_token: | |
| try: | |
| # Only initialize if we have a token | |
| client = InferenceClient(token=hf_token) | |
| except Exception as e: | |
| print(f"⚠️ Error initializing client: {e}") | |
| client = None | |
| else: | |
| print("⚠️ WARNING: API Token is NOT set.") | |
| print(" Please create a Secret named 'TokenApertus' in your Hugging Face Space settings.") | |
| # client remains None | |
| def _extract_json_string(text): | |
| """ | |
| Robust helper to find JSON in mixed text. | |
| It looks for the first '{' and the last '}' to handle cases where | |
| the model is chatty (e.g. "Sure, here is the JSON: {...}") | |
| """ | |
| # 1. Try finding markdown block first (most reliable) | |
| pattern = r"```json\s*(.*?)\s*```" | |
| match = re.search(pattern, text, re.DOTALL) | |
| if match: | |
| return match.group(1) | |
| # 2. Fallback: Find the outermost curly braces | |
| start = text.find('{') | |
| end = text.rfind('}') | |
| if start != -1 and end != -1 and end > start: | |
| return text[start:end+1] | |
| # 3. Last resort: return original text | |
| return text | |
| def call_apertus_llm(system_prompt, user_prompt): | |
| """ | |
| Calls the Apertus Instruct model via Hugging Face API. | |
| """ | |
| # CRITICAL: Check if client exists before calling | |
| if client is None: | |
| return json.dumps({}) | |
| model_id = "swiss-ai/Apertus-8B-Instruct-2509" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| try: | |
| response = client.chat_completion( | |
| model=model_id, | |
| messages=messages, | |
| max_tokens=500, | |
| temperature=0.1, | |
| seed=42 | |
| ) | |
| raw_content = response.choices[0].message.content | |
| # DEBUG: Print what the model actually said to the console | |
| print(f"--- LLM Raw Response ---\n{raw_content}\n------------------------") | |
| cleaned_json = _extract_json_string(raw_content) | |
| return cleaned_json | |
| except Exception as e: | |
| print(f"API Error: {e}") | |
| return "{}" | |
| class TakeoffAgent: | |
| def __init__(self): | |
| self.current_state = TakeoffRequest() | |
| def process_message(self, user_message: str): | |
| # 1. System Prompt construction | |
| system_instruction = """ | |
| You are an extraction assistant for a PA-28 Flight Calculator. | |
| Extract parameters from the user text into JSON format matching these keys: | |
| - altitude_ft (float) | |
| - qnh_hpa (float) | |
| - temperature_c (float) | |
| - weight_kg (float) | |
| - wind_type ("Headwind" or "Tailwind") | |
| - wind_speed_kt (float) | |
| - safety_factor (float) | |
| If a value is not provided, do not include the key in the JSON. | |
| Return ONLY valid JSON. | |
| """ | |
| # 2. Call LLM | |
| llm_response_json = call_apertus_llm(system_instruction, user_message) | |
| # 3. Update Pydantic Model (State Management) | |
| try: | |
| new_data = json.loads(llm_response_json) | |
| # DEBUG: Print what data we extracted | |
| print(f"Extracted Data: {new_data}") | |
| # Update only fields that are present in the new extraction | |
| updated_fields = self.current_state.model_dump(exclude_defaults=True) | |
| updated_fields.update(new_data) | |
| self.current_state = TakeoffRequest(**updated_fields) | |
| except json.JSONDecodeError: | |
| print("⚠️ Failed to parse JSON from LLM response") | |
| except Exception as e: | |
| print(f"⚠️ Error updating state: {e}") | |
| # 4. Check Completeness (Guardrails) | |
| if not self.current_state.is_complete(): | |
| missing = self.current_state.get_missing_fields() | |
| response_text = f"I updated your flight parameters. I still need: **{', '.join(missing)}**.\n\n" | |
| response_text += f"**Current State:**\n{self._format_state_summary()}" | |
| if client is None: | |
| response_text += "\n\n⚠️ **System Alert:** API Key missing. Please set 'TokenApertus' in your Hugging Face Space settings." | |
| return response_text | |
| # 5. Check Correctness (Validation Logic) | |
| warnings = [] | |
| if self.current_state.temperature_c > 45: | |
| warnings.append("⚠️ Warning: Temperature is extremely high.") | |
| if self.current_state.weight_kg > 1160: | |
| warnings.append("⚠️ Warning: Weight exceeds typical MTOW.") | |
| # 6. Run Calculation | |
| try: | |
| result = calculate_takeoff_roll_data( | |
| indicated_altitude_ft=self.current_state.altitude_ft, | |
| qnh_hpa=self.current_state.qnh_hpa, | |
| temperature_c=self.current_state.temperature_c, | |
| weight_kg=self.current_state.weight_kg, | |
| wind_type=self.current_state.wind_type, | |
| wind_speed=self.current_state.wind_speed_kt, | |
| safety_factor=self.current_state.safety_factor | |
| ) | |
| # 7. Formulate Response | |
| response = "### ✅ Takeoff Performance Calculated\n\n" | |
| if warnings: | |
| response += "**Alerts:**\n" + "\n".join(warnings) + "\n\n" | |
| response += f"**Environmental:**\n" | |
| response += f"- Pressure Alt: {result['pressure_altitude']:.0f} ft\n" | |
| response += f"- Density Alt: {result['density_altitude']:.0f} ft\n\n" | |
| response += f"**Ground Roll:**\n" | |
| response += f"- Base: {result['ground_roll']['base']:.1f} ft\n" | |
| response += f"- Corrections: Weight {result['ground_roll']['weight_adj']:.1f}, Wind {result['ground_roll']['wind_adj']:.1f}\n" | |
| response += f"- **Final: {result['ground_roll']['final_m']:.0f} meters** ({result['ground_roll']['final_ft']:.0f} ft)\n\n" | |
| response += f"**50ft Obstacle:**\n" | |
| response += f"- **Final: {result['obstacle_50ft']['final_m']:.0f} meters** ({result['obstacle_50ft']['final_ft']:.0f} ft)\n" | |
| return response | |
| except Exception as e: | |
| return f"Error in calculation: {str(e)}" | |
| def _format_state_summary(self): | |
| s = self.current_state | |
| return (f"- Alt: {s.altitude_ft} ft\n- Temp: {s.temperature_c} C\n" | |
| f"- Weight: {s.weight_kg} kg\n- Wind: {s.wind_speed_kt} kt ({s.wind_type})") |