import gradio as gr import argparse import json import logging from typing import List from scene_gen import * from pydantic import BaseModel, RootModel, ValidationError from ollama import chat # ----------------------------- # Models # ----------------------------- class QAItem(BaseModel): Question: str Answer: str Voice_Over: str include_audio: bool class QAList(RootModel[List[QAItem]]): """RootModel wrapping a list of QAItem""" root: List[QAItem] # ----------------------------- # Configuration # ----------------------------- MODEL = "qwen3:0.6b" logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") # ----------------------------- # Core logic # ----------------------------- def generate_qa(topic: str, count: int = 10) -> List[QAItem]: """ Call the Ollama model to generate `count` QA items for a given topic. Returns a list of QAItem instances. """ schema = QAList.model_json_schema() prompt = ( f'Given the topic "{topic}", generate {count} entries in JSON format, ' "each with keys Question, Answer, Voice_Over, and include_audio (true/false)." ) response = chat( model=MODEL, think=False, messages=[{"role": "user", "content": prompt}], format=schema, options={"temperature": 0}, ) try: qa_list = QAList.model_validate_json(response.message.content) return qa_list.root except ValidationError as e: logging.error("Response validation failed:\n%s", e) raise # ----------------------------- # CLI entrypoint # ----------------------------- def cli_main(): parser = argparse.ArgumentParser(description="Generate QA JSON via Ollama") parser.add_argument("topic", type=str, help="Topic to generate Q&A for") parser.add_argument( "-n", "--count", type=int, default=10, help="Number of QA items to generate (default: 10)" ) args = parser.parse_args() logging.info("Generating %d QA items for topic: %s", args.count, args.topic) try: items = generate_qa(args.topic, args.count) except Exception: logging.critical("Aborting due to errors") return # Convert to plain data output = [item.model_dump() for item in items] # 1) Pretty-print to stdout print(json.dumps(output, indent=2, ensure_ascii=False)) # 2) Save to file filename = f"{args.topic}.json" with open(filename, "w", encoding="utf-8") as f: json.dump(output, f, indent=2, ensure_ascii=False) logging.info("Saved output to %s", filename) # ----------------------------- # Gradio entrypoint # ----------------------------- def gradio_generate(topic: str, count: int = 10) -> str: """ Wrapper for Gradio: returns the JSON string. """ items = generate_qa(topic, count) output = [item.model_dump() for item in items] with open("questions.json", "w", encoding="utf-8") as f: json.dump(output, f, indent=2, ensure_ascii=False) return json.dumps(output, indent=2, ensure_ascii=False) def app(): demo = gr.Interface( fn=gradio_generate, inputs=[ gr.Textbox(label="Topic", placeholder="Enter your topic here"), gr.Slider(minimum=1, maximum=50, step=1, label="Number of Q&A items", value=10) ], outputs=gr.Textbox(label="Generated JSON"), title="Transcript Generator for Manim Scene", description="Generates JSON transcript for Manim Scene, with voiceover." ) demo.launch(share=True, mcp_server=True) # ----------------------------- # Bootstrap # ----------------------------- if __name__ == "__main__": # Decide between CLI and UI based on presence of command-line args import sys if len(sys.argv) > 1: cli_main() else: app()