Spaces:
Sleeping
Sleeping
| import re | |
| import json | |
| from huggingface_hub import InferenceClient | |
| import gradio as gr | |
| # Initialize HuggingFace client | |
| client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
| # Function to format the input into a strict JSON-based prompt | |
| def format_prompt(topic, description, difficulty): | |
| prompt = ( | |
| f"You are an expert educator. Generate a structured, highly engaging, and educational JSON object on the topic '{topic}'. " | |
| f"Use the following description as context: '{description}'. " | |
| f"The content must be suitable for a '{difficulty}' difficulty level and strictly adhere to the following JSON structure:\n\n" | |
| f"{{\n" | |
| f" \"title\": \"[A descriptive and concise title for the topic]\",\n" | |
| f" \"sections\": [\n" | |
| f" {{\n" | |
| f" \"subheading\": \"[A clear and concise subheading summarizing the section]\",\n" | |
| f" \"content\": \"[A detailed, engaging explanation of the section content written in clear, accessible language.]\"\n" | |
| f" }}\n" | |
| f" ]\n" | |
| f"}}\n\n" | |
| f"### Strict Output Rules:\n" | |
| f"1. The output **must be a valid JSON object** and nothing else.\n" | |
| f"2. All keys and string values must be enclosed in double quotes (\"\").\n" | |
| f"3. The `sections` field must be a non-empty list of objects, each containing `subheading` and `content`.\n" | |
| f"4. Avoid extra characters, trailing commas, or malformed syntax.\n" | |
| f"5. Close all brackets and braces properly.\n" | |
| f"6. If there is insufficient information, return a JSON object with empty placeholders, e.g.,\n" | |
| f"{{\n" | |
| f" \"title\": \"\",\n" | |
| f" \"sections\": []\n" | |
| f"}}\n" | |
| f"7. Validate the output to ensure it complies with the required JSON structure.\n" | |
| ) | |
| return prompt | |
| # Function to clean and format the AI output | |
| def clean_and_format_learning_content(output): | |
| """ | |
| Cleans, validates, and repairs JSON output for learning content. | |
| """ | |
| try: | |
| # Step 1: Clean raw output | |
| cleaned_output = re.sub(r'[^\x00-\x7F]+', '', output) # Remove non-ASCII characters | |
| cleaned_output = re.sub(r'`|<s>|</s>|◀|▶', '', cleaned_output) # Remove extraneous symbols | |
| cleaned_output = re.sub(r'^[^{]*', '', cleaned_output) # Remove text before the first '{' | |
| cleaned_output = re.sub(r'[^}]*$', '', cleaned_output) # Remove text after the last '}' | |
| cleaned_output = re.sub(r'\s+', ' ', cleaned_output).strip() # Normalize whitespace | |
| cleaned_output = cleaned_output.replace('\\"', '"') # Fix improperly escaped quotes | |
| cleaned_output = re.sub(r',\s*(\}|\])', r'\1', cleaned_output) # Remove trailing commas | |
| # Step 2: Fix invalid 'sections' fields | |
| # Replace invalid sections (e.g., sections:) with an empty array | |
| if re.search(r'"sections":\s*,', cleaned_output): | |
| cleaned_output = re.sub(r'"sections":\s*,', '"sections": []', cleaned_output) | |
| # Fix unbalanced brackets or braces | |
| open_braces = cleaned_output.count('{') | |
| close_braces = cleaned_output.count('}') | |
| open_brackets = cleaned_output.count('[') | |
| close_brackets = cleaned_output.count(']') | |
| if open_braces > close_braces: | |
| cleaned_output += '}' * (open_braces - close_braces) | |
| if open_brackets > close_brackets: | |
| cleaned_output += ']' * (open_brackets - close_brackets) | |
| # Fix commas between objects in arrays | |
| cleaned_output = re.sub(r'(\})(\s*{)', r'\1,\2', cleaned_output) | |
| # Step 3: Attempt to parse JSON | |
| json_output = json.loads(cleaned_output) | |
| # Step 4: Validate JSON structure | |
| required_keys = ["title", "sections"] | |
| if "title" not in json_output or "sections" not in json_output: | |
| raise ValueError("Missing required keys: 'title' or 'sections'.") | |
| if not isinstance(json_output["sections"], list): | |
| # If 'sections' is not a list, replace it with an empty list | |
| json_output["sections"] = [] | |
| else: | |
| for section in json_output["sections"]: | |
| if "subheading" not in section or "content" not in section: | |
| raise ValueError("Each section must contain 'subheading' and 'content'.") | |
| return json_output | |
| except (json.JSONDecodeError, ValueError) as e: | |
| # Provide detailed error information for debugging | |
| return { | |
| "error": "Failed to parse or validate output as JSON", | |
| "details": str(e), | |
| "output": cleaned_output | |
| } | |
| # Function to generate learning content | |
| def generate_learning_content(topic, description, difficulty, temperature=0.9, max_new_tokens=2000, top_p=0.95, repetition_penalty=1.2): | |
| """ | |
| Generates learning content and validates the output. | |
| """ | |
| temperature = max(float(temperature), 1e-2) # Ensure minimum temperature | |
| top_p = float(top_p) | |
| generate_kwargs = dict( | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=True, | |
| seed=42, | |
| ) | |
| # Format the prompt | |
| formatted_prompt = format_prompt(topic, description, difficulty) | |
| # Stream the output from the model | |
| stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) | |
| raw_output = "" | |
| for response in stream: | |
| raw_output += response.token.text | |
| # Clean and validate the raw output | |
| return clean_and_format_learning_content(raw_output) | |
| # Define the Gradio interface | |
| with gr.Blocks(theme="ocean") as demo: | |
| gr.HTML("<h1><center>Learning Content Generator</center></h1>") | |
| # Input fields for topic, description, and difficulty | |
| topic_input = gr.Textbox(label="Topic", placeholder="Enter the topic for learning content.") | |
| description_input = gr.Textbox(label="Description", placeholder="Enter a brief description of the topic.") | |
| difficulty_input = gr.Dropdown( | |
| label="Difficulty Level", | |
| choices=["High", "Medium", "Low"], | |
| value="Medium", | |
| interactive=True | |
| ) | |
| # Sliders for model parameters | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.9, label="Temperature") | |
| tokens_slider = gr.Slider(minimum=128, maximum=1048, step=64, value=512, label="Max new tokens") | |
| top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.95, label="Top-p (nucleus sampling)") | |
| repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.2, label="Repetition penalty") | |
| # Output field for generated learning content | |
| output = gr.Textbox(label="Generated Learning Content", lines=15) | |
| # Button to generate content | |
| submit_button = gr.Button("Generate Learning Content") | |
| # Define the click event to call the generate function | |
| submit_button.click( | |
| fn=generate_learning_content, | |
| inputs=[topic_input, description_input, difficulty_input, temperature_slider, tokens_slider, top_p_slider, repetition_penalty_slider], | |
| outputs=output, | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |