Spaces:
Running
Running
| from transformers import pipeline, AutoTokenizer | |
| import torch | |
| from fastapi import HTTPException | |
| import asyncio | |
| class HuggingFaceTextGenerationService: | |
| def __init__(self, model_name_or_path: str, device: str = None, task: str = "text-generation"): | |
| self.model_name_or_path = model_name_or_path | |
| self.task = task | |
| self.pipeline = None | |
| self.tokenizer = None | |
| if device is None: | |
| self.device_index = 0 if torch.cuda.is_available() else -1 | |
| elif device == "cuda" and torch.cuda.is_available(): | |
| self.device_index = 0 | |
| elif device == "cpu": | |
| self.device_index = -1 | |
| else: | |
| self.device_index = -1 | |
| if self.device_index == 0: | |
| print("CUDA (GPU) is available. Using GPU.") | |
| else: | |
| print(f"Device set to use {'cpu' if self.device_index == -1 else f'cuda:{self.device_index}'}") | |
| async def initialize(self): | |
| try: | |
| print(f"Initializing Hugging Face pipeline for model: {self.model_name_or_path} on device index: {self.device_index}") | |
| self.tokenizer = await asyncio.to_thread( | |
| AutoTokenizer.from_pretrained, self.model_name_or_path, trust_remote_code=True | |
| ) | |
| self.pipeline = await asyncio.to_thread( | |
| pipeline, | |
| self.task, | |
| model=self.model_name_or_path, | |
| tokenizer=self.tokenizer, | |
| device=self.device_index, | |
| torch_dtype=torch.bfloat16 if self.device_index != -1 and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32, | |
| trust_remote_code=True, | |
| ) | |
| print(f"Pipeline for model {self.model_name_or_path} initialized successfully.") | |
| except Exception as e: | |
| print(f"Error initializing HuggingFace pipeline: {e}") | |
| raise HTTPException(status_code=503, detail=f"LLM (HuggingFace) model could not be loaded: {str(e)}") | |
| async def generate_text(self, prompt_text: str = None, chat_template_messages: list = None, max_new_tokens: int = 250, temperature: float = 0.7, top_p: float = 0.9, do_sample: bool = True, **kwargs) -> str: | |
| if not self.pipeline or not self.tokenizer: | |
| raise Exception("Pipeline is not initialized. Call initialize() first.") | |
| formatted_prompt_input = "" | |
| if chat_template_messages: | |
| try: | |
| formatted_prompt_input = self.tokenizer.apply_chat_template( | |
| chat_template_messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| except Exception as e: | |
| print(f"Could not apply chat template, falling back to raw prompt if available. Error: {e}") | |
| if prompt_text: | |
| formatted_prompt_input = prompt_text | |
| else: | |
| raise ValueError("Cannot generate text without a valid prompt or chat_template_messages.") | |
| elif prompt_text: | |
| formatted_prompt_input = prompt_text | |
| else: | |
| raise ValueError("Either prompt_text or chat_template_messages must be provided.") | |
| try: | |
| generated_outputs = await asyncio.to_thread( | |
| self.pipeline, | |
| formatted_prompt_input, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_p=top_p, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is None else self.tokenizer.pad_token_id, # Common setting | |
| **kwargs | |
| ) | |
| if generated_outputs and isinstance(generated_outputs, list) and "generated_text" in generated_outputs[0]: | |
| full_generated_sequence = generated_outputs[0]["generated_text"] | |
| assistant_response = "" | |
| if full_generated_sequence.startswith(formatted_prompt_input): | |
| assistant_response = full_generated_sequence[len(formatted_prompt_input):] | |
| else: | |
| assistant_marker = "<|im_start|>assistant\n" | |
| last_marker_pos = full_generated_sequence.rfind(assistant_marker) | |
| if last_marker_pos != -1: | |
| assistant_response = full_generated_sequence[last_marker_pos + len(assistant_marker):] | |
| print("Warning: Used fallback parsing for assistant response.") | |
| else: | |
| print("Error: Could not isolate assistant response from the full generated sequence.") | |
| assistant_response = full_generated_sequence | |
| if assistant_response.endswith("<|im_end|>"): | |
| assistant_response = assistant_response[:-len("<|im_end|>")] | |
| return assistant_response.strip() | |
| else: | |
| print(f"Unexpected output format from pipeline: {generated_outputs}") | |
| return "Error: Could not parse generated text from pipeline output." | |
| except Exception as e: | |
| print(f"Error during text generation with {self.model_name_or_path}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") | |