Spaces:
Runtime error
Runtime error
| """ | |
| LLM Agent Module for DReamMachine | |
| Handles all interactions with HuggingFace models via Inference API | |
| """ | |
| import os | |
| import json | |
| import time | |
| import logging | |
| from typing import Dict, List, Optional, Any | |
| from huggingface_hub import InferenceClient | |
| import yaml | |
| logger = logging.getLogger(__name__) | |
| class LLMAgent: | |
| """Manages LLM API calls to HuggingFace models""" | |
| def __init__(self, config_path: str = "config.yaml", hf_token: Optional[str] = None): | |
| """ | |
| Initialize LLM Agent | |
| Args: | |
| config_path: Path to configuration YAML file | |
| hf_token: HuggingFace API token (if not provided, uses HF_TOKEN env var) | |
| """ | |
| # Load configuration | |
| with open(config_path, 'r') as f: | |
| self.config = yaml.safe_load(f) | |
| # Get HuggingFace token | |
| self.hf_token = hf_token or os.getenv('HF_TOKEN') | |
| if not self.hf_token: | |
| raise ValueError("HuggingFace token required. Set HF_TOKEN environment variable.") | |
| # Initialize Inference Client | |
| self.client = InferenceClient(token=self.hf_token) | |
| # Load model configurations | |
| self.models = self.config.get('models', {}) | |
| self.use_zero_gpu = self.config.get('huggingface', {}).get('use_zero_gpu', False) | |
| logger.info("LLMAgent initialized successfully") | |
| def call_hf_model( | |
| self, | |
| model_id: str, | |
| system_prompt: str, | |
| user_prompt: str, | |
| temperature: float = 0.7, | |
| max_tokens: int = 1000, | |
| retries: int = 3 | |
| ) -> str: | |
| """ | |
| Standard function for HuggingFace model API calls | |
| Args: | |
| model_id: HuggingFace model identifier | |
| system_prompt: System-level instructions for the model | |
| user_prompt: User prompt/query | |
| temperature: Sampling temperature (higher = more creative) | |
| max_tokens: Maximum tokens to generate | |
| retries: Number of retry attempts on failure | |
| Returns: | |
| Generated text response | |
| """ | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| for attempt in range(retries): | |
| try: | |
| logger.info(f"Calling {model_id} (attempt {attempt + 1}/{retries})") | |
| response = self.client.chat_completion( | |
| model=model_id, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=False | |
| ) | |
| # Extract generated text | |
| result = response.choices[0].message.content | |
| logger.info(f"Successfully received response from {model_id}") | |
| return result | |
| except Exception as e: | |
| logger.warning(f"Error calling {model_id}: {str(e)}") | |
| if attempt < retries - 1: | |
| wait_time = (attempt + 1) * 2 # Exponential backoff | |
| logger.info(f"Retrying in {wait_time} seconds...") | |
| time.sleep(wait_time) | |
| else: | |
| logger.error(f"Failed to call {model_id} after {retries} attempts") | |
| raise | |
| return "" # Should not reach here | |
| def get_dreamer_output( | |
| self, | |
| prompt: str, | |
| model_config: Optional[Dict[str, Any]] = None, | |
| model_index: int = 0 | |
| ) -> str: | |
| """ | |
| Specialized wrapper for Dreamer LLM calls (high creativity) | |
| Args: | |
| prompt: The dream prompt | |
| model_config: Optional model configuration override | |
| model_index: Which dreamer model to use (0-2) | |
| Returns: | |
| Creative dream output | |
| """ | |
| if model_config is None: | |
| dreamers = self.models.get('dreamers', []) | |
| if model_index >= len(dreamers): | |
| model_index = 0 | |
| model_config = dreamers[model_index] | |
| system_prompt = """You are a creative genius and visionary inventor. Your purpose is to | |
| imagine breakthrough innovations that could change the world. Think freely, boldly, and without | |
| conventional limitations. This is a controlled creative hallucination - let your imagination soar | |
| while staying grounded in the realm of physical possibility.""" | |
| return self.call_hf_model( | |
| model_id=model_config['model_id'], | |
| system_prompt=system_prompt, | |
| user_prompt=prompt, | |
| temperature=model_config.get('temperature', 0.9), | |
| max_tokens=model_config.get('max_tokens', 1000) | |
| ) | |
| def get_writer_output(self, prompt: str) -> str: | |
| """Specialized wrapper for Writer LLM (narrative creation)""" | |
| model_config = self.models.get('writer', {}) | |
| system_prompt = """You are an expert technical storyteller. You transform complex | |
| innovations into compelling narratives that inspire and educate. Write with clarity, | |
| emotion, and vision.""" | |
| return self.call_hf_model( | |
| model_id=model_config.get('model_id', 'mistralai/Mistral-7B-Instruct-v0.2'), | |
| system_prompt=system_prompt, | |
| user_prompt=prompt, | |
| temperature=model_config.get('temperature', 0.6), | |
| max_tokens=model_config.get('max_tokens', 1200) | |
| ) | |
| def get_logger_output(self, prompt: str) -> str: | |
| """Specialized wrapper for Logger LLM (technical extraction)""" | |
| model_config = self.models.get('logger', {}) | |
| system_prompt = """You are a technical analyst. Extract and organize technical | |
| specifications with precision and clarity. Focus on concrete details and requirements.""" | |
| return self.call_hf_model( | |
| model_id=model_config.get('model_id', 'mistralai/Mistral-7B-Instruct-v0.2'), | |
| system_prompt=system_prompt, | |
| user_prompt=prompt, | |
| temperature=model_config.get('temperature', 0.4), | |
| max_tokens=model_config.get('max_tokens', 800) | |
| ) | |
| def get_narrator_output(self, prompt: str) -> str: | |
| """Specialized wrapper for Narrator LLM (presentation)""" | |
| model_config = self.models.get('narrator', {}) | |
| system_prompt = """You are a world-class presenter and communicator. Create | |
| engaging, inspiring presentations that connect with audiences emotionally while | |
| conveying complex ideas clearly.""" | |
| return self.call_hf_model( | |
| model_id=model_config.get('model_id', 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO'), | |
| system_prompt=system_prompt, | |
| user_prompt=prompt, | |
| temperature=model_config.get('temperature', 0.5), | |
| max_tokens=model_config.get('max_tokens', 1000) | |
| ) | |
| def get_deep_thinker_output(self, prompt: str) -> str: | |
| """Specialized wrapper for Deep Thinker LLM (feasibility analysis)""" | |
| model_config = self.models.get('deep_thinker', {}) | |
| system_prompt = """You are a senior research scientist with expertise across physics, | |
| engineering, chemistry, and materials science. Analyze proposals with rigorous scientific | |
| thinking. Be honest about challenges while remaining constructive.""" | |
| return self.call_hf_model( | |
| model_id=model_config.get('model_id', 'meta-llama/Meta-Llama-3-70B-Instruct'), | |
| system_prompt=system_prompt, | |
| user_prompt=prompt, | |
| temperature=model_config.get('temperature', 0.3), | |
| max_tokens=model_config.get('max_tokens', 1500) | |
| ) | |
| def get_curator_score(self, prompt: str) -> Dict[str, Any]: | |
| """ | |
| Specialized wrapper for Curator LLM (evaluation & scoring) | |
| Forces JSON output for scoring | |
| Args: | |
| prompt: Curator evaluation prompt | |
| Returns: | |
| Dictionary containing scorecard data | |
| """ | |
| model_config = self.models.get('curator', {}) | |
| system_prompt = """You are a rigorous innovation evaluator. You assess breakthrough | |
| ideas across multiple dimensions and provide structured scoring. You MUST respond with | |
| valid JSON only, following the exact schema provided in the prompt.""" | |
| response_text = self.call_hf_model( | |
| model_id=model_config.get('model_id', 'meta-llama/Meta-Llama-3-70B-Instruct'), | |
| system_prompt=system_prompt, | |
| user_prompt=prompt, | |
| temperature=model_config.get('temperature', 0.2), | |
| max_tokens=model_config.get('max_tokens', 800) | |
| ) | |
| # Parse JSON response | |
| try: | |
| # Try to extract JSON from response | |
| response_text = response_text.strip() | |
| # Handle potential markdown code blocks | |
| if response_text.startswith('```'): | |
| # Remove code block markers | |
| lines = response_text.split('\n') | |
| response_text = '\n'.join(lines[1:-1]) if len(lines) > 2 else response_text | |
| scorecard = json.loads(response_text) | |
| logger.info("Successfully parsed curator scorecard") | |
| return scorecard | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Failed to parse curator JSON response: {str(e)}") | |
| logger.error(f"Raw response: {response_text}") | |
| # Return a default scorecard | |
| return { | |
| "originality": 5, | |
| "originality_reasoning": "Failed to parse response", | |
| "feasibility": 5, | |
| "feasibility_reasoning": "Failed to parse response", | |
| "global_impact": 5, | |
| "global_impact_reasoning": "Failed to parse response", | |
| "narrative_coherence": 5, | |
| "narrative_coherence_reasoning": "Failed to parse response", | |
| "reforge_flag": False, | |
| "reforge_reasoning": "Failed to parse curator response", | |
| "overall_assessment": f"Error parsing response: {str(e)}", | |
| "next_steps": "Retry curation step" | |
| } | |
| def run_parallel_dreamers( | |
| self, | |
| prompt: str, | |
| num_dreamers: int = 3 | |
| ) -> List[str]: | |
| """ | |
| Run multiple dreamer models in parallel (simulated sequential for now) | |
| Args: | |
| prompt: Dream prompt to send to all dreamers | |
| num_dreamers: Number of dreamer outputs to generate | |
| Returns: | |
| List of dream outputs | |
| """ | |
| dreams = [] | |
| dreamers = self.models.get('dreamers', []) | |
| for i in range(min(num_dreamers, len(dreamers))): | |
| logger.info(f"Running Dreamer {i + 1}/{num_dreamers}") | |
| try: | |
| dream = self.get_dreamer_output(prompt, model_index=i) | |
| dreams.append(dream) | |
| except Exception as e: | |
| logger.error(f"Dreamer {i + 1} failed: {str(e)}") | |
| dreams.append(f"[Dreamer {i + 1} failed: {str(e)}]") | |
| return dreams | |
| # Convenience function | |
| def create_agent(config_path: str = "config.yaml", hf_token: Optional[str] = None) -> LLMAgent: | |
| """Create and return a configured LLM Agent""" | |
| return LLMAgent(config_path, hf_token) | |