Spaces:
Build error
Build error
| # strategy.py | |
| #TODO UPDATE Paths | |
| from abc import ABC, abstractmethod | |
| from typing import List, Tuple | |
| class GenerationStrategy(ABC): | |
| """Base class for generation strategies.""" | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
| pass | |
| class DefaultStrategy(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], **kwargs) -> str: | |
| input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
| output = generator.model.generate(input_ids, **model_kwargs) | |
| return generator.tokenizer.decode(output[0], skip_special_tokens=True) | |
| class MajorityVotingStrategy(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| outputs = [] | |
| for _ in range(num_samples): | |
| input_ids = generator.tokenizer(prompt, return_tensors="pt").input_ids.to(generator.device) | |
| output = generator.model.generate(input_ids, **model_kwargs) | |
| outputs.append(generator.tokenizer.decode(output[0], skip_special_tokens=True)) | |
| return max(set(outputs), key=outputs.count) | |
| class BestOfN(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| scored_outputs = [] | |
| for _ in range(num_samples): | |
| input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| output = self.llama_model.generate(input_ids, **model_kwargs) | |
| response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
| scored_outputs.append((response, score)) | |
| return max(scored_outputs, key=lambda x: x[1])[0] | |
| class BeamSearch(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| outputs = self.llama_model.generate( | |
| input_ids, | |
| num_beams=num_samples, | |
| num_return_sequences=num_samples, | |
| **model_kwargs | |
| ) | |
| return [self.llama_tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| class DVT(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| results = [] | |
| for _ in range(breadth): | |
| input_ids = self.llama_tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) | |
| output = self.llama_model.generate(input_ids, **model_kwargs) | |
| response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = self.prm_model(**self.llama_tokenizer(response, return_tensors="pt").to(self.device)).logits.mean().item() | |
| results.append((response, score)) | |
| for _ in range(depth - 1): | |
| best_responses = sorted(results, key=lambda x: x[1], reverse=True)[:breadth] | |
| for response, _ in best_responses: | |
| input_ids = self.llama_tokenizer(response, return_tensors="pt").input_ids.to(self.device) | |
| output = self.llama_model.generate(input_ids, **model_kwargs) | |
| extended_response = self.llama_tokenizer.decode(output[0], skip_special_tokens=True) | |
| score = self.prm_model(**self.llama_tokenizer(extended_response, return_tensors="pt").to(self.device)).logits.mean().item() | |
| results.append((extended_response, score)) | |
| return max(results, key=lambda x: x[1])[0] | |
| class COT(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| #TODO implement the chain of thought strategy | |
| return "Not implemented yet" | |
| class ReAct(GenerationStrategy): | |
| def generate(self, generator: 'BaseGenerator', prompt: str, model_kwargs: Dict[str, Any], num_samples: int = 5, **kwargs) -> str: | |
| #TODO implement the ReAct framework | |
| return "Not implemented yet" | |
| # Add other strategy implementations... | |