Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import time | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
| import torch | |
| from threading import Thread | |
| import logging | |
| import spaces | |
| from functools import lru_cache | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # True | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set an environment variable | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| DESCRIPTION = ''' | |
| <div> | |
| <h1 style="text-align: center;">ContenteaseAI custom trained model</h1> | |
| </div> | |
| ''' | |
| LICENSE = """ | |
| <p/> | |
| --- | |
| For more information, visit our [website](https://contentease.ai). | |
| """ | |
| PLACEHOLDER = """ | |
| <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;"> | |
| <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">ContenteaseAI Custom AI trained model</h1> | |
| <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Enter the text extracted from the PDF:</p> | |
| </div> | |
| """ | |
| css = """ | |
| h1 { | |
| text-align: center; | |
| display: block; | |
| } | |
| """ | |
| # Load the tokenizer and model with quantization | |
| model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16 | |
| ) | |
| def load_model_and_tokenizer(): | |
| try: | |
| start_time = time.time() | |
| logger.info("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| logger.info("Loading model...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| quantization_config=bnb_config, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| end_time = time.time() | |
| logger.info(f"Model and tokenizer loaded successfully in {end_time - start_time} seconds.") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error loading model or tokenizer: {e}") | |
| raise | |
| try: | |
| model, tokenizer = load_model_and_tokenizer() | |
| except Exception as e: | |
| logger.error(f"Failed to load model and tokenizer: {e}") | |
| raise | |
| terminators = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| ] | |
| SYS_PROMPT = """ | |
| Given the text of a hotel property improvement plan, extract the items to be replaced for only the Guest Rooms/ Suites, Guest Bathrooms/Suite Bathrooms. | |
| First, find the section of the pdf which describes improvements to be done on the Guest Rooms and Guest Bathrooms, then find the items to be replaced. | |
| Ignore items from other sections of the hotel. | |
| Items to be replaced are usually preceded by the words replace, install, or provide. | |
| Return the results as a JSON with "Guest Room" and "Guest Bathroom" as keys and each value the list of unique items to be replaced. | |
| Return only the JSON with no extra text. | |
| Example Text: | |
| " | |
| Site & Building Exterior | |
| Replace all exterior decorative lighting | |
| ... | |
| Guestrooms | |
| Replace [ORG] C-Table. | |
| Provide full length mirror. | |
| Replace cabinets - Kitchen. | |
| at doors where brass hardware finishes exist – replace with stainless | |
| ... | |
| Guest Bathrooms - (FRCM) Replace mirrors. Install a vanity mirror that has integrated lighting | |
| Guest Bathrooms - (FRCM) Replace artwork and decorative accessories. | |
| ... | |
| Suites - Replace microwave, refrigerator, and associated casegood cabinet. | |
| " | |
| Example Response: | |
| { | |
| "Guest Room": [ | |
| "C-Table", | |
| "full length mirror", | |
| "kitchen cabinets", | |
| "stainless steel door hardware", | |
| "microwave", | |
| "refrigerator", | |
| "casegood cabinet",], | |
| "Guest Bathroom": [ | |
| "vanity mirror with integrated lighting", | |
| "artwork", | |
| "decorative accessories",], | |
| } | |
| """ | |
| def chunk_text(text, chunk_size=5000): | |
| """ | |
| Splits the input text into chunks of specified size. | |
| Args: | |
| text (str): The input text to be chunked. | |
| chunk_size (int): The size of each chunk in tokens. | |
| Returns: | |
| list: A list of text chunks. | |
| """ | |
| words = text.split() | |
| chunks = [' '.join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
| logger.info(f"Total chunks created: {len(chunks)}") | |
| return chunks | |
| def combine_responses(responses): | |
| """ | |
| Combines the responses from all chunks into a final output string. | |
| Args: | |
| responses (list): A list of responses from each chunk. | |
| Returns: | |
| str: The combined output string. | |
| """ | |
| combined_output = " ".join(responses) | |
| return combined_output | |
| def generate_response_for_chunk(chunk, history, temperature, max_new_tokens): | |
| start_time = time.time() | |
| if len(history) == 0: | |
| pass | |
| else: | |
| history.pop() | |
| conversation = [{"role": "system", "content": SYS_PROMPT}] | |
| for user, assistant in history: | |
| conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) | |
| conversation.append({"role": "user", "content": chunk}) | |
| input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device) | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| generate_kwargs = dict( | |
| input_ids=input_ids, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| eos_token_id=terminators, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| if temperature == 0: | |
| generate_kwargs['do_sample'] = False | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| end_time = time.time() | |
| logger.info(f"Time taken for generating response for a chunk: {end_time - start_time} seconds") | |
| return "".join(outputs) | |
| def chat_llama3_8b(message: str, history: list, temperature: float, max_new_tokens: int): | |
| """ | |
| Generate a streaming response using the llama3-8b model with chunking. | |
| Args: | |
| message (str): The input message. | |
| history (list): The conversation history used by ChatInterface. | |
| temperature (float): The temperature for generating the response. | |
| max_new_tokens (int): The maximum number of new tokens to generate. | |
| Returns: | |
| str: The generated response. | |
| """ | |
| try: | |
| start_time = time.time() | |
| chunks = chunk_text(message) | |
| responses = [] | |
| count=0 | |
| for chunk in chunks: | |
| logger.info(f"Processing chunk {count+1}/{len(chunks)}") | |
| response = generate_response_for_chunk(chunk, history, temperature, max_new_tokens) | |
| responses.append(response) | |
| count+=1 | |
| final_output = combine_responses(responses) | |
| end_time = time.time() | |
| logger.info(f"Total time taken for generating response: {end_time - start_time} seconds") | |
| yield final_output | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| yield "An error occurred while generating the response. Please try again." | |
| # Gradio block | |
| chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface') | |
| with gr.Blocks(fill_height=True, css=css) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.ChatInterface( | |
| fn=chat_llama3_8b, | |
| chatbot=chatbot, | |
| fill_height=True, | |
| additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), | |
| additional_inputs=[ | |
| gr.Slider(minimum=0, maximum=1, step=0.1, value=0.95, label="Temperature", render=False), | |
| gr.Slider(minimum=128, maximum=2000, step=1, value=700, label="Max new tokens", render=False), | |
| ] | |
| ) | |
| gr.Markdown(LICENSE) | |
| if __name__ == "__main__": | |
| try: | |
| demo.launch(show_error=True) | |
| except Exception as e: | |
| logger.error(f"Error launching Gradio demo: {e}") | |