File size: 5,449 Bytes
9a9ec03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from transformers import pipeline, AutoTokenizer
import torch
from fastapi import HTTPException
import asyncio 

class HuggingFaceTextGenerationService:
    def __init__(self, model_name_or_path: str, device: str = None, task: str = "text-generation"):
        self.model_name_or_path = model_name_or_path
        self.task = task
        self.pipeline = None
        self.tokenizer = None

        if device is None:
            self.device_index = 0 if torch.cuda.is_available() else -1
        elif device == "cuda" and torch.cuda.is_available():
            self.device_index = 0
        elif device == "cpu":
            self.device_index = -1
        else:
            self.device_index = -1
        
        if self.device_index == 0:
            print("CUDA (GPU) is available. Using GPU.")
        else:
            print(f"Device set to use {'cpu' if self.device_index == -1 else f'cuda:{self.device_index}'}")


    async def initialize(self):
        try:
            print(f"Initializing Hugging Face pipeline for model: {self.model_name_or_path} on device index: {self.device_index}")
            self.tokenizer = await asyncio.to_thread(
                AutoTokenizer.from_pretrained, self.model_name_or_path, trust_remote_code=True
            )
            self.pipeline = await asyncio.to_thread(
                pipeline,
                self.task,
                model=self.model_name_or_path,
                tokenizer=self.tokenizer,
                device=self.device_index,
                torch_dtype=torch.bfloat16 if self.device_index != -1 and torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32,
                trust_remote_code=True,
            )
            print(f"Pipeline for model {self.model_name_or_path} initialized successfully.")
        except Exception as e:
            print(f"Error initializing HuggingFace pipeline: {e}")
            raise HTTPException(status_code=503, detail=f"LLM (HuggingFace) model could not be loaded: {str(e)}")

    async def generate_text(self, prompt_text: str = None, chat_template_messages: list = None, max_new_tokens: int = 250, temperature: float = 0.7, top_p: float = 0.9, do_sample: bool = True, **kwargs) -> str:
        if not self.pipeline or not self.tokenizer:
            raise Exception("Pipeline is not initialized. Call initialize() first.")

        formatted_prompt_input = "" 
        if chat_template_messages:
            try:
                formatted_prompt_input = self.tokenizer.apply_chat_template(
                    chat_template_messages,
                    tokenize=False,
                    add_generation_prompt=True 
                )
            except Exception as e:
                print(f"Could not apply chat template, falling back to raw prompt if available. Error: {e}")
                if prompt_text:
                    formatted_prompt_input = prompt_text
                else:
                    raise ValueError("Cannot generate text without a valid prompt or chat_template_messages.")
        elif prompt_text:
            formatted_prompt_input = prompt_text
        else:
            raise ValueError("Either prompt_text or chat_template_messages must be provided.")

        try:
            generated_outputs = await asyncio.to_thread(
                self.pipeline,
                formatted_prompt_input, 
                max_new_tokens=max_new_tokens,
                do_sample=do_sample,
                temperature=temperature,
                top_p=top_p,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.eos_token_id if self.tokenizer.pad_token_id is None else self.tokenizer.pad_token_id, # Common setting
                **kwargs
            )
            
            if generated_outputs and isinstance(generated_outputs, list) and "generated_text" in generated_outputs[0]:
                full_generated_sequence = generated_outputs[0]["generated_text"]
                
                assistant_response = ""
                if full_generated_sequence.startswith(formatted_prompt_input):
                    assistant_response = full_generated_sequence[len(formatted_prompt_input):]
                else:
                    assistant_marker = "<|im_start|>assistant\n" 
                    last_marker_pos = full_generated_sequence.rfind(assistant_marker)
                    if last_marker_pos != -1:
                        assistant_response = full_generated_sequence[last_marker_pos + len(assistant_marker):]
                        print("Warning: Used fallback parsing for assistant response.")
                    else:
                        print("Error: Could not isolate assistant response from the full generated sequence.")
                        assistant_response = full_generated_sequence 

                if assistant_response.endswith("<|im_end|>"):
                    assistant_response = assistant_response[:-len("<|im_end|>")]
                
                return assistant_response.strip() 
            else:
                print(f"Unexpected output format from pipeline: {generated_outputs}")
                return "Error: Could not parse generated text from pipeline output."

        except Exception as e:
            print(f"Error during text generation with {self.model_name_or_path}: {e}")
            raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")