Spaces:
Sleeping
Sleeping
Commit
·
45d9469
1
Parent(s):
89493ee
small updates + api cookies
Browse files- evolutiontransformer/api.py +9 -0
- frontend/index.html +14 -11
- frontend/public/favicon.svg +5 -0
- frontend/public/vite.svg +0 -1
- frontend/src/App.jsx +31 -2
- frontend/src/components/InferencePopup.jsx +37 -9
- frontend/src/components/Options.jsx +26 -18
- frontend/src/hooks/useAPI.js +58 -18
- frontend/src/index.css +2 -2
evolutiontransformer/api.py
CHANGED
|
@@ -5,6 +5,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
| 5 |
import uuid
|
| 6 |
from typing import List, Tuple
|
| 7 |
from fastapi import FastAPI, Depends, HTTPException, Request, Response
|
|
|
|
| 8 |
from pydantic import BaseModel
|
| 9 |
from celery import Celery
|
| 10 |
from dotenv import load_dotenv
|
|
@@ -19,6 +20,14 @@ celery_app = Celery("tasks", broker=REDIS_URL, backend=REDIS_URL)
|
|
| 19 |
|
| 20 |
app = FastAPI()
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
class GenerateRequest(BaseModel):
|
| 24 |
model_name: str
|
|
|
|
| 5 |
import uuid
|
| 6 |
from typing import List, Tuple
|
| 7 |
from fastapi import FastAPI, Depends, HTTPException, Request, Response
|
| 8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
from pydantic import BaseModel
|
| 10 |
from celery import Celery
|
| 11 |
from dotenv import load_dotenv
|
|
|
|
| 20 |
|
| 21 |
app = FastAPI()
|
| 22 |
|
| 23 |
+
app.add_middleware(
|
| 24 |
+
CORSMiddleware,
|
| 25 |
+
allow_origins=["http://localhost:5173"], # todo add website url
|
| 26 |
+
allow_credentials=True,
|
| 27 |
+
allow_methods=["*"],
|
| 28 |
+
allow_headers=["*"],
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
|
| 32 |
class GenerateRequest(BaseModel):
|
| 33 |
model_name: str
|
frontend/index.html
CHANGED
|
@@ -1,13 +1,16 @@
|
|
| 1 |
<!doctype html>
|
| 2 |
<html lang="en">
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
</
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
<!doctype html>
|
| 2 |
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8" />
|
| 6 |
+
<link rel="icon" type="image/svg+xml" href="/favicon.svg" />
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 8 |
+
<title>Evolution Transformer</title>
|
| 9 |
+
</head>
|
| 10 |
+
|
| 11 |
+
<body>
|
| 12 |
+
<div id="root"></div>
|
| 13 |
+
<script type="module" src="/src/main.jsx"></script>
|
| 14 |
+
</body>
|
| 15 |
+
|
| 16 |
+
</html>
|
frontend/public/favicon.svg
ADDED
|
|
frontend/public/vite.svg
DELETED
frontend/src/App.jsx
CHANGED
|
@@ -3,6 +3,7 @@ import "./App.css";
|
|
| 3 |
import Options from "./components/Options";
|
| 4 |
import Recipe from "./components/Recipe";
|
| 5 |
import { setModelLayers } from "./utils/modelCookies";
|
|
|
|
| 6 |
|
| 7 |
function App() {
|
| 8 |
const [models, setModels] = useState([]);
|
|
@@ -14,11 +15,39 @@ function App() {
|
|
| 14 |
const [mergedName, setMergedName] = useState("merged");
|
| 15 |
const [numLayers, setNumLayers] = useState(12);
|
| 16 |
|
|
|
|
|
|
|
| 17 |
useEffect(() => {
|
| 18 |
-
setModels(["svamp", "tinystories"]);
|
| 19 |
setModelLayers("svamp", 24);
|
| 20 |
setModelLayers("tinystories", 24);
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
return (
|
| 24 |
<div className="h-screen bg-gradient-to-br from-primary-50 to-secondary-50 overflow-hidden">
|
|
|
|
| 3 |
import Options from "./components/Options";
|
| 4 |
import Recipe from "./components/Recipe";
|
| 5 |
import { setModelLayers } from "./utils/modelCookies";
|
| 6 |
+
import { useAPI } from "./hooks/useAPI";
|
| 7 |
|
| 8 |
function App() {
|
| 9 |
const [models, setModels] = useState([]);
|
|
|
|
| 15 |
const [mergedName, setMergedName] = useState("merged");
|
| 16 |
const [numLayers, setNumLayers] = useState(12);
|
| 17 |
|
| 18 |
+
const { fetchModels, checkTaskStatus } = useAPI();
|
| 19 |
+
|
| 20 |
useEffect(() => {
|
|
|
|
| 21 |
setModelLayers("svamp", 24);
|
| 22 |
setModelLayers("tinystories", 24);
|
| 23 |
+
|
| 24 |
+
const loadModels = async () => {
|
| 25 |
+
try {
|
| 26 |
+
console.log("Loading models...");
|
| 27 |
+
const taskId = await fetchModels();
|
| 28 |
+
console.log("Got task ID:", taskId);
|
| 29 |
+
|
| 30 |
+
if (taskId) {
|
| 31 |
+
checkTaskStatus(
|
| 32 |
+
taskId,
|
| 33 |
+
(result) => {
|
| 34 |
+
console.log("Models loaded successfully:", result);
|
| 35 |
+
if (result && Array.isArray(result.response)) {
|
| 36 |
+
setModels(result.response);
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
(error) => {
|
| 40 |
+
console.error("Failed to load models:", error);
|
| 41 |
+
}
|
| 42 |
+
);
|
| 43 |
+
}
|
| 44 |
+
} catch (error) {
|
| 45 |
+
console.error("Error fetching models:", error);
|
| 46 |
+
}
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
loadModels();
|
| 50 |
+
}, [fetchModels, checkTaskStatus]);
|
| 51 |
|
| 52 |
return (
|
| 53 |
<div className="h-screen bg-gradient-to-br from-primary-50 to-secondary-50 overflow-hidden">
|
frontend/src/components/InferencePopup.jsx
CHANGED
|
@@ -9,7 +9,7 @@ const InferencePopup = ({ isOpen, onClose, models }) => {
|
|
| 9 |
const [isLoading, setIsLoading] = useState(false);
|
| 10 |
const [error, setError] = useState("");
|
| 11 |
|
| 12 |
-
const { inference } = useAPI();
|
| 13 |
|
| 14 |
const handleInference = async () => {
|
| 15 |
if (!selectedModel || !prompt.trim()) {
|
|
@@ -23,23 +23,51 @@ const InferencePopup = ({ isOpen, onClose, models }) => {
|
|
| 23 |
|
| 24 |
try {
|
| 25 |
const inferenceData = {
|
| 26 |
-
|
| 27 |
prompt: prompt,
|
| 28 |
-
|
|
|
|
| 29 |
};
|
| 30 |
|
|
|
|
| 31 |
const result = await inference(inferenceData);
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
} else if (result && result.error) {
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
} else {
|
| 38 |
-
setError("No
|
|
|
|
| 39 |
}
|
| 40 |
} catch (err) {
|
|
|
|
| 41 |
setError(`Error: ${err.message}`);
|
| 42 |
-
} finally {
|
| 43 |
setIsLoading(false);
|
| 44 |
}
|
| 45 |
};
|
|
|
|
| 9 |
const [isLoading, setIsLoading] = useState(false);
|
| 10 |
const [error, setError] = useState("");
|
| 11 |
|
| 12 |
+
const { inference, checkTaskStatus } = useAPI();
|
| 13 |
|
| 14 |
const handleInference = async () => {
|
| 15 |
if (!selectedModel || !prompt.trim()) {
|
|
|
|
| 23 |
|
| 24 |
try {
|
| 25 |
const inferenceData = {
|
| 26 |
+
model_name: selectedModel,
|
| 27 |
prompt: prompt,
|
| 28 |
+
max_new_tokens: 100, // You can make this configurable if needed
|
| 29 |
+
temperature: 0.7, // Add temperature field
|
| 30 |
};
|
| 31 |
|
| 32 |
+
console.log("Starting inference with data:", inferenceData);
|
| 33 |
const result = await inference(inferenceData);
|
| 34 |
+
console.log("Got inference result:", result);
|
| 35 |
+
|
| 36 |
+
if (result && result.task_id) {
|
| 37 |
+
// Check task status for inference result
|
| 38 |
+
checkTaskStatus(
|
| 39 |
+
result.task_id,
|
| 40 |
+
(taskResult) => {
|
| 41 |
+
console.log("Inference task result:", taskResult);
|
| 42 |
+
if (taskResult && taskResult.generated_text) {
|
| 43 |
+
setResponse(taskResult.generated_text);
|
| 44 |
+
} else if (taskResult && taskResult.error) {
|
| 45 |
+
setError(`Inference failed: ${taskResult.error}`);
|
| 46 |
+
} else {
|
| 47 |
+
setError("No response received from the model");
|
| 48 |
+
}
|
| 49 |
+
setIsLoading(false);
|
| 50 |
+
},
|
| 51 |
+
(errorMessage) => {
|
| 52 |
+
// Error callback for task status check
|
| 53 |
+
console.error("Inference task failed:", errorMessage);
|
| 54 |
+
setError(`Task failed: ${errorMessage}`);
|
| 55 |
+
setIsLoading(false);
|
| 56 |
+
}
|
| 57 |
+
);
|
| 58 |
} else if (result && result.error) {
|
| 59 |
+
// Check if it's a server error
|
| 60 |
+
const isServerError = result.error.includes("HTTP 5");
|
| 61 |
+
const errorPrefix = isServerError ? "🔴 Server Error: " : "Error: ";
|
| 62 |
+
setError(`${errorPrefix}${result.error}`);
|
| 63 |
+
setIsLoading(false);
|
| 64 |
} else {
|
| 65 |
+
setError("No task ID received");
|
| 66 |
+
setIsLoading(false);
|
| 67 |
}
|
| 68 |
} catch (err) {
|
| 69 |
+
console.error("Inference error:", err);
|
| 70 |
setError(`Error: ${err.message}`);
|
|
|
|
| 71 |
setIsLoading(false);
|
| 72 |
}
|
| 73 |
};
|
frontend/src/components/Options.jsx
CHANGED
|
@@ -43,36 +43,44 @@ const Options = ({
|
|
| 43 |
|
| 44 |
try {
|
| 45 |
const mergeData = {
|
| 46 |
-
|
| 47 |
-
|
| 48 |
layer_recipe: layerRecipe,
|
| 49 |
embedding_lambdas: embeddingLambdas,
|
| 50 |
linear_lambdas: linearLambdas,
|
| 51 |
-
num_layers: numLayers,
|
| 52 |
merged_name: mergedName,
|
| 53 |
};
|
| 54 |
|
|
|
|
| 55 |
const taskId = await mergeModels(mergeData);
|
|
|
|
| 56 |
|
| 57 |
if (taskId) {
|
| 58 |
-
checkTaskStatus(
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
}
|
| 69 |
-
|
| 70 |
-
});
|
| 71 |
-
} else {
|
| 72 |
-
setMergeStatus("Failed to start merge process");
|
| 73 |
-
setIsLoading(false);
|
| 74 |
}
|
| 75 |
} catch (error) {
|
|
|
|
| 76 |
setMergeStatus(`Error: ${error.message}`);
|
| 77 |
setIsLoading(false);
|
| 78 |
}
|
|
|
|
| 43 |
|
| 44 |
try {
|
| 45 |
const mergeData = {
|
| 46 |
+
model1_name: selectedModel1,
|
| 47 |
+
model2_name: selectedModel2,
|
| 48 |
layer_recipe: layerRecipe,
|
| 49 |
embedding_lambdas: embeddingLambdas,
|
| 50 |
linear_lambdas: linearLambdas,
|
|
|
|
| 51 |
merged_name: mergedName,
|
| 52 |
};
|
| 53 |
|
| 54 |
+
console.log("Starting merge with data:", mergeData);
|
| 55 |
const taskId = await mergeModels(mergeData);
|
| 56 |
+
console.log("Got merge task ID:", taskId);
|
| 57 |
|
| 58 |
if (taskId) {
|
| 59 |
+
checkTaskStatus(
|
| 60 |
+
taskId,
|
| 61 |
+
(taskResult) => {
|
| 62 |
+
console.log("Merge result:", taskResult);
|
| 63 |
+
if (taskResult.response) {
|
| 64 |
+
setMergeStatus("Merge successful!");
|
| 65 |
+
const newModelName = taskResult.response || mergedName;
|
| 66 |
+
setModels((prev) => [...prev, newModelName]);
|
| 67 |
+
setModelLayers(newModelName, numLayers);
|
| 68 |
+
} else {
|
| 69 |
+
setMergeStatus(
|
| 70 |
+
`Merge failed: ${taskResult.error || "Unknown error"}`
|
| 71 |
+
);
|
| 72 |
+
}
|
| 73 |
+
setIsLoading(false);
|
| 74 |
+
},
|
| 75 |
+
(error) => {
|
| 76 |
+
console.error("Merge task failed:", error);
|
| 77 |
+
setMergeStatus(`Merge failed: ${error}`);
|
| 78 |
+
setIsLoading(false);
|
| 79 |
}
|
| 80 |
+
);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
}
|
| 82 |
} catch (error) {
|
| 83 |
+
console.error("Merge error:", error);
|
| 84 |
setMergeStatus(`Error: ${error.message}`);
|
| 85 |
setIsLoading(false);
|
| 86 |
}
|
frontend/src/hooks/useAPI.js
CHANGED
|
@@ -3,62 +3,102 @@ import { useCallback } from "react";
|
|
| 3 |
const API_BASE = "https://tcmmichaelb139-evolutiontransformer.hf.space";
|
| 4 |
|
| 5 |
export const useAPI = () => {
|
| 6 |
-
const checkTaskStatus = useCallback(
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
}
|
| 16 |
-
}
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
}, []);
|
| 20 |
|
| 21 |
const fetchModels = useCallback(async () => {
|
| 22 |
try {
|
|
|
|
| 23 |
const response = await fetch(`${API_BASE}/list_models`, {
|
| 24 |
method: "POST",
|
| 25 |
headers: { "Content-Type": "application/json" },
|
|
|
|
| 26 |
});
|
|
|
|
| 27 |
const data = await response.json();
|
|
|
|
| 28 |
return data.task_id;
|
| 29 |
} catch (error) {
|
| 30 |
-
console.error("
|
| 31 |
-
|
| 32 |
}
|
| 33 |
}, []);
|
| 34 |
|
| 35 |
const mergeModels = useCallback(async (mergeData) => {
|
| 36 |
try {
|
|
|
|
| 37 |
const response = await fetch(`${API_BASE}/merge`, {
|
| 38 |
method: "POST",
|
| 39 |
headers: { "Content-Type": "application/json" },
|
| 40 |
body: JSON.stringify(mergeData),
|
|
|
|
| 41 |
});
|
|
|
|
| 42 |
const data = await response.json();
|
|
|
|
| 43 |
return data.task_id;
|
| 44 |
} catch (error) {
|
| 45 |
-
console.error("Merge
|
| 46 |
-
|
| 47 |
}
|
| 48 |
}, []);
|
| 49 |
|
| 50 |
const inference = useCallback(async (inferenceData) => {
|
| 51 |
try {
|
|
|
|
| 52 |
const response = await fetch(`${API_BASE}/generate`, {
|
| 53 |
method: "POST",
|
| 54 |
headers: { "Content-Type": "application/json" },
|
| 55 |
body: JSON.stringify(inferenceData),
|
|
|
|
| 56 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
const data = await response.json();
|
|
|
|
| 58 |
return data;
|
| 59 |
} catch (error) {
|
| 60 |
-
console.error("Inference
|
| 61 |
-
|
| 62 |
}
|
| 63 |
}, []);
|
| 64 |
|
|
|
|
| 3 |
const API_BASE = "https://tcmmichaelb139-evolutiontransformer.hf.space";
|
| 4 |
|
| 5 |
export const useAPI = () => {
|
| 6 |
+
const checkTaskStatus = useCallback(
|
| 7 |
+
async (taskId, successCallback, errorCallback) => {
|
| 8 |
+
try {
|
| 9 |
+
const response = await fetch(`${API_BASE}/tasks/${taskId}`, {
|
| 10 |
+
credentials: "include",
|
| 11 |
+
});
|
| 12 |
+
|
| 13 |
+
if (!response.ok) {
|
| 14 |
+
const error = `HTTP ${response.status}: ${response.statusText}`;
|
| 15 |
+
console.error("Task check failed:", error);
|
| 16 |
+
if (errorCallback) errorCallback(error);
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
|
| 20 |
+
const data = await response.json();
|
| 21 |
+
console.log("Task status:", data.status);
|
| 22 |
+
|
| 23 |
+
if (data.status === "SUCCESS") {
|
| 24 |
+
successCallback(data.result);
|
| 25 |
+
} else if (data.status === "PENDING") {
|
| 26 |
+
setTimeout(
|
| 27 |
+
() => checkTaskStatus(taskId, successCallback, errorCallback),
|
| 28 |
+
1000
|
| 29 |
+
);
|
| 30 |
+
} else if (data.status === "FAILURE") {
|
| 31 |
+
const error = data.result || "Task failed";
|
| 32 |
+
console.error("Task failed:", error);
|
| 33 |
+
if (errorCallback) errorCallback(error);
|
| 34 |
+
}
|
| 35 |
+
} catch (error) {
|
| 36 |
+
console.error("Task check error:", error);
|
| 37 |
+
if (errorCallback) errorCallback(error.message);
|
| 38 |
}
|
| 39 |
+
},
|
| 40 |
+
[]
|
| 41 |
+
);
|
|
|
|
| 42 |
|
| 43 |
const fetchModels = useCallback(async () => {
|
| 44 |
try {
|
| 45 |
+
console.log("Fetching models...");
|
| 46 |
const response = await fetch(`${API_BASE}/list_models`, {
|
| 47 |
method: "POST",
|
| 48 |
headers: { "Content-Type": "application/json" },
|
| 49 |
+
credentials: "include",
|
| 50 |
});
|
| 51 |
+
|
| 52 |
const data = await response.json();
|
| 53 |
+
console.log("Fetch models response:", data);
|
| 54 |
return data.task_id;
|
| 55 |
} catch (error) {
|
| 56 |
+
console.error("Fetch models error:", error);
|
| 57 |
+
throw error;
|
| 58 |
}
|
| 59 |
}, []);
|
| 60 |
|
| 61 |
const mergeModels = useCallback(async (mergeData) => {
|
| 62 |
try {
|
| 63 |
+
console.log("Merging models with data:", mergeData);
|
| 64 |
const response = await fetch(`${API_BASE}/merge`, {
|
| 65 |
method: "POST",
|
| 66 |
headers: { "Content-Type": "application/json" },
|
| 67 |
body: JSON.stringify(mergeData),
|
| 68 |
+
credentials: "include",
|
| 69 |
});
|
| 70 |
+
|
| 71 |
const data = await response.json();
|
| 72 |
+
console.log("Merge response:", data);
|
| 73 |
return data.task_id;
|
| 74 |
} catch (error) {
|
| 75 |
+
console.error("Merge error:", error);
|
| 76 |
+
throw error;
|
| 77 |
}
|
| 78 |
}, []);
|
| 79 |
|
| 80 |
const inference = useCallback(async (inferenceData) => {
|
| 81 |
try {
|
| 82 |
+
console.log("Running inference with data:", inferenceData);
|
| 83 |
const response = await fetch(`${API_BASE}/generate`, {
|
| 84 |
method: "POST",
|
| 85 |
headers: { "Content-Type": "application/json" },
|
| 86 |
body: JSON.stringify(inferenceData),
|
| 87 |
+
credentials: "include",
|
| 88 |
});
|
| 89 |
+
|
| 90 |
+
if (!response.ok) {
|
| 91 |
+
const error = `HTTP ${response.status}: ${response.statusText}`;
|
| 92 |
+
console.error("Inference failed:", error);
|
| 93 |
+
throw new Error(error);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
const data = await response.json();
|
| 97 |
+
console.log("Inference response:", data);
|
| 98 |
return data;
|
| 99 |
} catch (error) {
|
| 100 |
+
console.error("Inference error:", error);
|
| 101 |
+
throw error;
|
| 102 |
}
|
| 103 |
}, []);
|
| 104 |
|
frontend/src/index.css
CHANGED
|
@@ -64,13 +64,13 @@ body {
|
|
| 64 |
}
|
| 65 |
|
| 66 |
::-webkit-scrollbar-thumb {
|
| 67 |
-
background:
|
| 68 |
border-radius: 4px;
|
| 69 |
transition: background 0.2s ease;
|
| 70 |
}
|
| 71 |
|
| 72 |
::-webkit-scrollbar-thumb:hover {
|
| 73 |
-
background:
|
| 74 |
}
|
| 75 |
|
| 76 |
::-webkit-scrollbar-corner {
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
::-webkit-scrollbar-thumb {
|
| 67 |
+
background: #0ea5e9;
|
| 68 |
border-radius: 4px;
|
| 69 |
transition: background 0.2s ease;
|
| 70 |
}
|
| 71 |
|
| 72 |
::-webkit-scrollbar-thumb:hover {
|
| 73 |
+
background: #0284c7;
|
| 74 |
}
|
| 75 |
|
| 76 |
::-webkit-scrollbar-corner {
|