Spaces:
Runtime error
Runtime error
| import gc | |
| import sys | |
| from typing import Dict | |
| import torch | |
| def generate_stream_exllama( | |
| model, | |
| tokenizer, | |
| params: Dict, | |
| device: str, | |
| context_len: int, | |
| stream_interval: int = 2, | |
| judge_sent_end: bool = False, | |
| ): | |
| try: | |
| from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler | |
| except ImportError as e: | |
| print(f"Error: Failed to load Exllamav2. {e}") | |
| sys.exit(-1) | |
| prompt = params["prompt"] | |
| generator = ExLlamaV2StreamingGenerator(model.model, model.cache, tokenizer) | |
| settings = ExLlamaV2Sampler.Settings() | |
| settings.temperature = float(params.get("temperature", 0.85)) | |
| settings.top_k = int(params.get("top_k", 50)) | |
| settings.top_p = float(params.get("top_p", 0.8)) | |
| settings.token_repetition_penalty = float(params.get("repetition_penalty", 1.15)) | |
| settings.disallow_tokens(generator.tokenizer, [generator.tokenizer.eos_token_id]) | |
| max_new_tokens = int(params.get("max_new_tokens", 256)) | |
| generator.set_stop_conditions(params.get("stop_token_ids", None) or []) | |
| echo = bool(params.get("echo", True)) | |
| input_ids = generator.tokenizer.encode(prompt) | |
| prompt_tokens = input_ids.shape[-1] | |
| generator.begin_stream(input_ids, settings) | |
| generated_tokens = 0 | |
| if echo: | |
| output = prompt | |
| else: | |
| output = "" | |
| while True: | |
| chunk, eos, _ = generator.stream() | |
| output += chunk | |
| generated_tokens += 1 | |
| if generated_tokens == max_new_tokens: | |
| finish_reason = "length" | |
| break | |
| elif eos: | |
| finish_reason = "length" | |
| break | |
| yield { | |
| "text": output, | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": generated_tokens, | |
| "total_tokens": prompt_tokens + generated_tokens, | |
| }, | |
| "finish_reason": None, | |
| } | |
| yield { | |
| "text": output, | |
| "usage": { | |
| "prompt_tokens": prompt_tokens, | |
| "completion_tokens": generated_tokens, | |
| "total_tokens": prompt_tokens + generated_tokens, | |
| }, | |
| "finish_reason": finish_reason, | |
| } | |
| gc.collect() | |