Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import time | |
| import json | |
| import requests | |
| import threading | |
| """ | |
| This app uses the Hugging Face Inference API to generate responses from the | |
| Trinoid/Data_Management_Mistral model. | |
| """ | |
| # Get token from environment | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| print(f"HF_TOKEN is {'available' if HF_TOKEN else 'not available'}") | |
| # Setup API for the Hugging Face Inference API | |
| MODEL_ID = "Trinoid/Data_Management_Mistral" | |
| API_URL = f"https://api-inference.huggingface.co/models/{MODEL_ID}" | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else {} | |
| # Check if model exists | |
| try: | |
| print(f"Checking if model {MODEL_ID} exists...") | |
| response = requests.get(API_URL, headers=headers) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| print("Model exists and is accessible") | |
| print(f"Response: {response.text[:200]}...") | |
| else: | |
| print(f"Response: {response.text}") | |
| except Exception as e: | |
| print(f"Error checking model: {str(e)}") | |
| # Global variable to track model status | |
| model_loaded = False | |
| estimated_time = None | |
| use_simple_format = True # Toggle to use simpler format instead of chat format | |
| def format_prompt(messages): | |
| """Format chat messages into a text prompt that Mistral models can understand""" | |
| if use_simple_format: | |
| # Simple format - just extract the message content | |
| system = next((m["content"] for m in messages if m["role"] == "system"), "") | |
| last_user_msg = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "") | |
| if system: | |
| return f"{system}\n\nQuestion: {last_user_msg}\n\nAnswer:" | |
| else: | |
| return f"Question: {last_user_msg}\n\nAnswer:" | |
| else: | |
| # Chat format for Mistral models | |
| formatted = "" | |
| for msg in messages: | |
| if msg["role"] == "system": | |
| formatted += f"<s>[INST] {msg['content']} [/INST]</s>\n" | |
| elif msg["role"] == "user": | |
| formatted += f"<s>[INST] {msg['content']} [/INST]" | |
| elif msg["role"] == "assistant": | |
| formatted += f" {msg['content']} </s>\n" | |
| return formatted | |
| def query_model_text_generation(prompt, parameters=None): | |
| """Query the model using the text generation API endpoint""" | |
| payload = { | |
| "inputs": prompt, | |
| } | |
| if parameters: | |
| payload["parameters"] = parameters | |
| print(f"Sending text generation query to API...") | |
| print(f"Prompt: {prompt[:100]}...") | |
| try: | |
| # Try with longer timeout | |
| response = requests.post( | |
| API_URL, | |
| headers=headers, | |
| json=payload, | |
| timeout=180 # 3 minute timeout | |
| ) | |
| print(f"API response status: {response.status_code}") | |
| # If successful, return the response | |
| if response.status_code == 200: | |
| print(f"Success! Response: {str(response.text)[:200]}...") | |
| return response.json() | |
| # If model is loading, handle it | |
| elif response.status_code == 503 and "estimated_time" in response.json(): | |
| est_time = response.json()["estimated_time"] | |
| global estimated_time | |
| estimated_time = est_time | |
| print(f"Model is loading. Estimated time: {est_time:.2f} seconds") | |
| return None | |
| # For other errors | |
| else: | |
| print(f"API error: {response.text}") | |
| return None | |
| except Exception as e: | |
| print(f"Request exception: {str(e)}") | |
| return None | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| """Respond to user messages""" | |
| # Create the messages list | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| # Format the prompt | |
| prompt = format_prompt(messages) | |
| # Set up the generation parameters | |
| parameters = { | |
| "max_new_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": True, | |
| "return_full_text": False # Only return the generated text, not the prompt | |
| } | |
| # Initial message about model status | |
| global estimated_time | |
| if estimated_time: | |
| initial_msg = f"⌛ The model is loading... estimated time: {estimated_time:.0f} seconds. Please be patient." | |
| else: | |
| initial_msg = "⌛ Working on your request..." | |
| yield initial_msg | |
| # Try multiple times with increasing waits | |
| max_retries = 6 | |
| for attempt in range(max_retries): | |
| # Check if this is a retry | |
| if attempt > 0: | |
| wait_time = min(60, 10 * attempt) | |
| yield f"⌛ Still working on your request... (attempt {attempt+1}/{max_retries})" | |
| time.sleep(wait_time) | |
| try: | |
| # Query the model using text generation | |
| result = query_model_text_generation(prompt, parameters) | |
| if result: | |
| # Handle different response formats | |
| if isinstance(result, list) and len(result) > 0: | |
| if isinstance(result[0], dict) and "generated_text" in result[0]: | |
| yield result[0]["generated_text"] | |
| return | |
| if isinstance(result, dict) and "generated_text" in result: | |
| yield result["generated_text"] | |
| return | |
| # String or other format | |
| yield str(result) | |
| return | |
| # If model is still loading, get the latest estimate | |
| if estimated_time and attempt < max_retries - 1: | |
| try: | |
| response = requests.get(API_URL, headers=headers) | |
| if response.status_code == 503 and "estimated_time" in response.json(): | |
| estimated_time = response.json()["estimated_time"] | |
| print(f"Updated loading time: {estimated_time:.0f} seconds") | |
| except: | |
| pass | |
| except Exception as e: | |
| print(f"Error in attempt {attempt+1}: {str(e)}") | |
| if attempt == max_retries - 1: | |
| yield f"""❌ Sorry, I couldn't generate a response after multiple attempts. | |
| Error details: {str(e)} | |
| Please try again later or contact support if this persists.""" | |
| # If all retries failed | |
| yield """❌ The model couldn't be accessed after multiple attempts. | |
| This could be due to: | |
| 1. Heavy server load | |
| 2. The model being too large for the current hardware | |
| 3. Temporary service issues | |
| Please try again later. For best results with large models like Mistral-7B, consider: | |
| - Using a smaller model | |
| - Creating a 4-bit quantized version | |
| - Using Hugging Face Inference Endpoints instead of Spaces""" | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a data management expert specializing in Microsoft 365 services.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)", | |
| ), | |
| ], | |
| description="""This interface uses a fine-tuned Mistral model for Microsoft 365 data management. | |
| First requests may take 2-3 minutes as the model loads.""" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |