ModelForge / app.py
Ali Mohsin
more fixes
ec83d0c
raw
history blame
3.46 kB
import gradio as gr
from huggingface_hub import InferenceClient
import json
import random
import os
# Initialize the client (uses HUGGING_FACE_HUB_TOKEN from environment)
token = os.getenv("HUGGING_FACE_HUB_TOKEN")
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=token)
SYSTEM_PROMPT = """You are ModelForge, an expert AI architecture assistant. Your goal is to analyze machine learning problems and generate detailed, deployable solutions in strict JSON format.
You must analyze the user's request and return a JSON object with the following structure:
{
"analysis": {
"dataType": "image" | "text" | "tabular" | "audio" | "video" | "time_series" | "multimodal",
"taskType": "classification" | "regression" | "nlp" | "vision" | "forecasting" | "multimodal_reasoning",
"complexity": "low" | "medium" | "high" | "research",
"domain": "string (e.g., medical, finance, etc.)"
},
"recommendations": [
{
"name": "Model Name",
"description": "Detailed technical description...",
"pros": ["pro1", "pro2", "pro3"],
"cons": ["con1", "con2"],
"architectureDiagram": "Mermaid graph definition...",
"mlopsBestPractices": ["step 1", "step 2", ...],
"trainingCode": "Python code snippet..."
}
]
}
Provide 2-3 distinct recommendations. For research-level problems, propose novel architectures.
Ensure the Mermaid diagram uses valid syntax (no curly braces for nodes, use square brackets []).
"""
FEW_SHOT_EXAMPLES = """
Example Input: "Detect fraud in credit card transactions"
Example Output:
{
"analysis": { "dataType": "tabular", "taskType": "classification", "complexity": "medium", "domain": "finance" },
"recommendations": [
{
"name": "XGBoost Fraud Detector",
"description": "Gradient boosting ensemble optimized for imbalanced tabular data...",
"pros": ["High interpretability", "Handles missing data"],
"cons": ["Feature engineering required"],
"architectureDiagram": "graph TD\\nA[Raw Data] --> B[Preprocessing]\\nB --> C[XGBoost]",
"mlopsBestPractices": ["Use DVC for data", "Monitor drift"],
"trainingCode": "import xgboost as xgb..."
}
]
}
"""
def generate_solution(description):
prompt = f"{SYSTEM_PROMPT}\n\n{FEW_SHOT_EXAMPLES}\n\nUser Input: \"{description}\"\n\nJSON Response:"
try:
response = client.text_generation(
prompt,
max_new_tokens=2048,
temperature=0.7,
top_p=0.95,
return_full_text=False
)
# clean up response to ensure it's valid JSON
json_str = response.strip()
if json_str.startswith("```json"):
json_str = json_str.split("```json")[1].split("```")[0].strip()
elif json_str.startswith("```"):
json_str = json_str.split("```")[1].split("```")[0].strip()
return json_str
except Exception as e:
return json.dumps({
"error": str(e),
"analysis": {"dataType": "text", "taskType": "nlp", "complexity": "low", "domain": "general"},
"recommendations": []
})
demo = gr.Interface(
fn=generate_solution,
inputs=gr.Textbox(lines=5, placeholder="Describe your ML problem..."),
outputs=gr.JSON(label="Recommendations"),
title="ModelForge AI Backend",
description="Generates ML recommendations via LLM."
)
if __name__ == "__main__":
demo.launch()