| |
| |
| from abc import ABC, abstractmethod |
| from typing import List, Tuple, Dict, Any |
|
|
|
|
|
|
| from langfuse.decorators import observe, langfuse_context |
| import os |
|
|
| |
| os.environ["LANGFUSE_PUBLIC_KEY"] = "pk-lf-04d2302a-aa5c-4870-9703-58ab64c3bcae" |
| os.environ["LANGFUSE_SECRET_KEY"] = "sk-lf-d34ea200-feec-428e-a621-784fce93a5af" |
| os.environ["LANGFUSE_HOST"] = "https://chris4k-langfuse-template-space.hf.space" |
|
|
| try: |
| langfuse = Langfuse() |
| except Exception as e: |
| print("Langfuse Offline") |
|
|
|
|
| class GenerationStrategy(ABC): |
| """Base class for generation strategies.""" |
| |
| @abstractmethod |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
| pass |
|
|
|
|
| class DefaultStrategy(GenerationStrategy): |
| @observe() |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: |
| |
| tokenizer = generator.tokenizers["llama"] |
| model = generator.models["llama"].generate |
| |
| input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device) |
| output = generator.models["llama"].generate(input_ids, **model_kwargs) |
| return generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True) |
|
|
|
|
| class MajorityVotingStrategy(GenerationStrategy): |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| outputs = [] |
| for _ in range(num_samples): |
| input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device) |
| output = generator.models["llama"].generate(input_ids, **model_kwargs) |
| outputs.append(generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True)) |
| return max(set(outputs), key=outputs.count) |
|
|
|
|
| class BestOfN(GenerationStrategy): |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| scored_outputs = [] |
| for _ in range(num_samples): |
| input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device) |
| output = generator.models["llama"].generate(input_ids, **model_kwargs) |
| response =generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True) |
| score = generator.prm_model(**generator.tokenizers["llama"](response, return_tensors="pt").to(generator.device)).logits.mean().item() |
| scored_outputs.append((response, score)) |
| return max(scored_outputs, key=lambda x: x[1])[0] |
|
|
|
|
| class BeamSearch(GenerationStrategy): |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device) |
| outputs = generator.models["llama"].generate( |
| input_ids, |
| num_beams=num_samples, |
| num_return_sequences=num_samples, |
| **model_kwargs |
| ) |
| return [generator.tokenizers["llama"].decode(output, skip_special_tokens=True) for output in outputs] |
|
|
|
|
| class DVT(GenerationStrategy): |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| results = [] |
| for _ in range(breadth): |
| input_ids = generator.tokenizers["llama"](prompt, return_tensors="pt").input_ids.to(generator.device) |
| output = generator.models["llama"].generate(input_ids, **model_kwargs) |
| response = generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True) |
| score = generator.prm_model(**generator.tokenizers["llama"](response, return_tensors="pt").to(generator.device)).logits.mean().item() |
| results.append((response, score)) |
| |
| for _ in range(depth - 1): |
| best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] |
| for response, _ in best_responses: |
| input_ids = generator.tokenizers["llama"](response, return_tensors="pt").input_ids.to(generator.device) |
| output = generator.models["llama"].generate(input_ids, **model_kwargs) |
| extended_response = generator.tokenizers["llama"].decode(output[0], skip_special_tokens=True) |
| score = generator.prm_model(**generator.tokenizers["llama"](extended_response, return_tensors="pt").to(generator.device)).logits.mean().item() |
| results.append((extended_response, score)) |
| return max(results, key=lambda x: x[1])[0] |
| |
|
|
| class COT(GenerationStrategy): |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| |
| |
| return "Not implemented yet" |
|
|
|
|
| class ReAct(GenerationStrategy): |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: |
| |
| return "Not implemented yet" |
| |
|
|