| |
| from config.config import GenerationConfig, ModelConfig |
|
|
| from typing import List, Dict, Any, Optional, Tuple |
| from datetime import datetime |
| import logging |
| import torch |
|
|
| from config.config import settings |
|
|
| from services.prompt_builder import LlamaPromptTemplate |
| from services.model_manager import ModelManager |
|
|
| from services.base_generator import BaseGenerator |
|
|
| from services.strategy import DefaultStrategy, MajorityVotingStrategy, BestOfN, BeamSearch, DVT, COT, ReAct |
|
|
| import asyncio |
| from io import StringIO |
| import pandas as pd |
|
|
| from langfuse.decorators import observe, langfuse_context |
| import os |
|
|
| |
| os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" |
| os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" |
| os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
|
|
| try: |
| langfuse = Langfuse() |
| except Exception as e: |
| print("Langfuse Offline") |
| |
|
|
|
|
| @observe() |
| class LlamaGenerator(BaseGenerator): |
| def __init__( |
| self, |
| llama_model_name: str, |
| prm_model_path: str, |
| device: Optional[str] = None, |
| default_generation_config: Optional[GenerationConfig] = None, |
| model_config: Optional[ModelConfig] = None, |
| cache_size: int = 1000, |
| max_batch_size: int = 32, |
| |
| |
| |
| ): |
| print(llama_model_name) |
| print(prm_model_path) |
|
|
| self.model_manager = ModelManager() |
| |
| self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.tokenizer = self.model_manager.load_tokenizer(llama_model_name) |
| |
| super().__init__( |
| llama_model_name, |
| device, |
| default_generation_config, |
| model_config, |
| cache_size, |
| max_batch_size |
| ) |
| |
| |
| self.model_manager.load_model( |
| "llama", |
| llama_model_name, |
| "llama", |
| self.model_config |
| ) |
| self.model_manager.load_model( |
| "prm", |
| prm_model_path, |
| "gguf", |
| self.model_config |
| ) |
| |
| |
| self.model = self.model_manager.models.get("llama") |
| if not self.model: |
| raise ValueError(f"Failed to load model: {llama_model_name}") |
|
|
| self.prm_model = self.model_manager.models.get("prm") |
| |
|
|
| |
| self.prompt_builder = LlamaPromptTemplate() |
| self._init_strategies() |
| |
| def _init_strategies(self): |
| self.strategies = { |
| "default": DefaultStrategy(), |
| "majority_voting": MajorityVotingStrategy(), |
| "best_of_n": BestOfN(), |
| "beam_search": BeamSearch(), |
| "dvts": DVT(), |
| } |
| |
| def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
| """Get generation kwargs based on config.""" |
| return { |
| key: getattr(config, key) |
| for key in [ |
| "max_new_tokens", |
| "temperature", |
| "top_p", |
| "top_k", |
| "repetition_penalty", |
| "length_penalty", |
| "do_sample" |
| ] |
| if hasattr(config, key) |
| } |
|
|
| @observe() |
| def generate_stream (self): |
| return " NOt implememnted yet " |
|
|
| @observe() |
| def generate( |
| self, |
| prompt: str, |
| model_kwargs: Dict[str, Any], |
| strategy: str = "default", |
| **kwargs |
| ) -> str: |
| """ |
| Generate text based on a given strategy. |
| |
| Args: |
| prompt (str): Input prompt for text generation. |
| model_kwargs (Dict[str, Any]): Additional arguments for model generation. |
| strategy (str): The generation strategy to use (default: "default"). |
| **kwargs: Additional arguments passed to the strategy. |
| |
| Returns: |
| str: Generated text response. |
| |
| Raises: |
| ValueError: If the specified strategy is not available. |
| """ |
| |
| if strategy not in self.strategies: |
| raise ValueError(f"Unknown strategy: {strategy}. Available strategies are: {list(self.strategies.keys())}") |
| |
| |
| kwargs.pop("generator", None) |
| |
| |
| return self.strategies[strategy].generate( |
| generator=self, |
| prompt=prompt, |
| model_kwargs=model_kwargs, |
| **kwargs |
| ) |
|
|
| @observe() |
| def generate_with_context( |
| self, |
| context: str, |
| user_input: str, |
| chat_history: List[Tuple[str, str]], |
| model_kwargs: Dict[str, Any], |
| max_history_turns: int = 3, |
| strategy: str = "default", |
| num_samples: int = 5, |
| depth: int = 3, |
| breadth: int = 2, |
| |
| ) -> str: |
| """Generate a response using context and chat history. |
| |
| Args: |
| context (str): Context for the conversation |
| user_input (str): Current user input |
| chat_history (List[Tuple[str, str]]): List of (user, assistant) message pairs |
| model_kwargs (dict): Additional arguments for model.generate() |
| max_history_turns (int): Maximum number of history turns to include |
| strategy (str): Generation strategy |
| num_samples (int): Number of samples for applicable strategies |
| depth (int): Depth for DVTS strategy |
| breadth (int): Breadth for DVTS strategy |
| |
| Returns: |
| str: Generated response |
| """ |
| prompt = self.prompt_builder.format( |
| context, |
| user_input, |
| chat_history, |
| max_history_turns |
| ) |
| return self.generate( |
| generator=self, |
| prompt=prompt, |
| model_kwargs=model_kwargs, |
| strategy=strategy, |
| num_samples=num_samples, |
| depth=depth, |
| breadth=breadth |
| ) |
|
|
|
|
| |
| def check_health(self) : |
| """Check the health status of the generator.""" |
| |
| return "All good? - Check not omplemented " |