Spaces:
Paused
Paused
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| import json | |
| import random | |
| import os | |
| # Initialize the client (uses HUGGING_FACE_HUB_TOKEN from environment) | |
| token = os.getenv("HUGGING_FACE_HUB_TOKEN") | |
| client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token) | |
| SYSTEM_PROMPT = """You are ModelForge, an expert AI architecture assistant. Your goal is to analyze machine learning problems and generate detailed, deployable solutions in strict JSON format. | |
| You must analyze the user's request and return a JSON object with the following structure: | |
| { | |
| "analysis": { | |
| "dataType": "image" | "text" | "tabular" | "audio" | "video" | "time_series" | "multimodal", | |
| "taskType": "classification" | "regression" | "nlp" | "vision" | "forecasting" | "multimodal_reasoning", | |
| "complexity": "low" | "medium" | "high" | "research", | |
| "domain": "string (e.g., medical, finance, etc.)" | |
| }, | |
| "recommendations": [ | |
| { | |
| "name": "Model Name", | |
| "description": "Detailed technical description...", | |
| "pros": ["pro1", "pro2", "pro3"], | |
| "cons": ["con1", "con2"], | |
| "architectureDiagram": "Mermaid graph definition...", | |
| "mlopsBestPractices": ["step 1", "step 2", ...], | |
| "trainingCode": "Python code snippet..." | |
| } | |
| ] | |
| } | |
| Provide 2-3 distinct recommendations. For research-level problems, propose novel architectures. | |
| Ensure the Mermaid diagram uses valid syntax (no curly braces for nodes, use square brackets []). | |
| """ | |
| FEW_SHOT_EXAMPLES = """ | |
| Example Input: "Detect fraud in credit card transactions" | |
| Example Output: | |
| { | |
| "analysis": { "dataType": "tabular", "taskType": "classification", "complexity": "medium", "domain": "finance" }, | |
| "recommendations": [ | |
| { | |
| "name": "XGBoost Fraud Detector", | |
| "description": "Gradient boosting ensemble optimized for imbalanced tabular data...", | |
| "pros": ["High interpretability", "Handles missing data"], | |
| "cons": ["Feature engineering required"], | |
| "architectureDiagram": "graph TD\\nA[Raw Data] --> B[Preprocessing]\\nB --> C[XGBoost]", | |
| "mlopsBestPractices": ["Use DVC for data", "Monitor drift"], | |
| "trainingCode": "import xgboost as xgb..." | |
| } | |
| ] | |
| } | |
| """ | |
| def generate_solution(description): | |
| prompt = f"{SYSTEM_PROMPT}\n\n{FEW_SHOT_EXAMPLES}\n\nUser Input: \"{description}\"\n\nJSON Response:" | |
| try: | |
| response = client.text_generation( | |
| prompt, | |
| max_new_tokens=2048, | |
| temperature=0.7, | |
| top_p=0.95, | |
| return_full_text=False | |
| ) | |
| # clean up response to ensure it's valid JSON | |
| json_str = response.strip() | |
| if json_str.startswith("```json"): | |
| json_str = json_str.split("```json")[1].split("```")[0].strip() | |
| elif json_str.startswith("```"): | |
| json_str = json_str.split("```")[1].split("```")[0].strip() | |
| return json_str | |
| except Exception as e: | |
| return json.dumps({ | |
| "error": str(e), | |
| "analysis": {"dataType": "text", "taskType": "nlp", "complexity": "low", "domain": "general"}, | |
| "recommendations": [] | |
| }) | |
| demo = gr.Interface( | |
| fn=generate_solution, | |
| inputs=gr.Textbox(lines=5, placeholder="Describe your ML problem..."), | |
| outputs=gr.JSON(label="Recommendations"), | |
| title="ModelForge AI Backend", | |
| description="Generates ML recommendations via LLM." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |