tcmmichaelb139 commited on
Commit
45d9469
·
1 Parent(s): 89493ee

small updates + api cookies

Browse files
evolutiontransformer/api.py CHANGED
@@ -5,6 +5,7 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
  import uuid
6
  from typing import List, Tuple
7
  from fastapi import FastAPI, Depends, HTTPException, Request, Response
 
8
  from pydantic import BaseModel
9
  from celery import Celery
10
  from dotenv import load_dotenv
@@ -19,6 +20,14 @@ celery_app = Celery("tasks", broker=REDIS_URL, backend=REDIS_URL)
19
 
20
  app = FastAPI()
21
 
 
 
 
 
 
 
 
 
22
 
23
  class GenerateRequest(BaseModel):
24
  model_name: str
 
5
  import uuid
6
  from typing import List, Tuple
7
  from fastapi import FastAPI, Depends, HTTPException, Request, Response
8
+ from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel
10
  from celery import Celery
11
  from dotenv import load_dotenv
 
20
 
21
  app = FastAPI()
22
 
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["http://localhost:5173"], # todo add website url
26
+ allow_credentials=True,
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
 
32
  class GenerateRequest(BaseModel):
33
  model_name: str
frontend/index.html CHANGED
@@ -1,13 +1,16 @@
1
  <!doctype html>
2
  <html lang="en">
3
- <head>
4
- <meta charset="UTF-8" />
5
- <link rel="icon" type="image/svg+xml" href="/vite.svg" />
6
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
7
- <title>Vite + React</title>
8
- </head>
9
- <body>
10
- <div id="root"></div>
11
- <script type="module" src="/src/main.jsx"></script>
12
- </body>
13
- </html>
 
 
 
 
1
  <!doctype html>
2
  <html lang="en">
3
+
4
+ <head>
5
+ <meta charset="UTF-8" />
6
+ <link rel="icon" type="image/svg+xml" href="/favicon.svg" />
7
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
8
+ <title>Evolution Transformer</title>
9
+ </head>
10
+
11
+ <body>
12
+ <div id="root"></div>
13
+ <script type="module" src="/src/main.jsx"></script>
14
+ </body>
15
+
16
+ </html>
frontend/public/favicon.svg ADDED
frontend/public/vite.svg DELETED
frontend/src/App.jsx CHANGED
@@ -3,6 +3,7 @@ import "./App.css";
3
  import Options from "./components/Options";
4
  import Recipe from "./components/Recipe";
5
  import { setModelLayers } from "./utils/modelCookies";
 
6
 
7
  function App() {
8
  const [models, setModels] = useState([]);
@@ -14,11 +15,39 @@ function App() {
14
  const [mergedName, setMergedName] = useState("merged");
15
  const [numLayers, setNumLayers] = useState(12);
16
 
 
 
17
  useEffect(() => {
18
- setModels(["svamp", "tinystories"]);
19
  setModelLayers("svamp", 24);
20
  setModelLayers("tinystories", 24);
21
- }, []);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  return (
24
  <div className="h-screen bg-gradient-to-br from-primary-50 to-secondary-50 overflow-hidden">
 
3
  import Options from "./components/Options";
4
  import Recipe from "./components/Recipe";
5
  import { setModelLayers } from "./utils/modelCookies";
6
+ import { useAPI } from "./hooks/useAPI";
7
 
8
  function App() {
9
  const [models, setModels] = useState([]);
 
15
  const [mergedName, setMergedName] = useState("merged");
16
  const [numLayers, setNumLayers] = useState(12);
17
 
18
+ const { fetchModels, checkTaskStatus } = useAPI();
19
+
20
  useEffect(() => {
 
21
  setModelLayers("svamp", 24);
22
  setModelLayers("tinystories", 24);
23
+
24
+ const loadModels = async () => {
25
+ try {
26
+ console.log("Loading models...");
27
+ const taskId = await fetchModels();
28
+ console.log("Got task ID:", taskId);
29
+
30
+ if (taskId) {
31
+ checkTaskStatus(
32
+ taskId,
33
+ (result) => {
34
+ console.log("Models loaded successfully:", result);
35
+ if (result && Array.isArray(result.response)) {
36
+ setModels(result.response);
37
+ }
38
+ },
39
+ (error) => {
40
+ console.error("Failed to load models:", error);
41
+ }
42
+ );
43
+ }
44
+ } catch (error) {
45
+ console.error("Error fetching models:", error);
46
+ }
47
+ };
48
+
49
+ loadModels();
50
+ }, [fetchModels, checkTaskStatus]);
51
 
52
  return (
53
  <div className="h-screen bg-gradient-to-br from-primary-50 to-secondary-50 overflow-hidden">
frontend/src/components/InferencePopup.jsx CHANGED
@@ -9,7 +9,7 @@ const InferencePopup = ({ isOpen, onClose, models }) => {
9
  const [isLoading, setIsLoading] = useState(false);
10
  const [error, setError] = useState("");
11
 
12
- const { inference } = useAPI();
13
 
14
  const handleInference = async () => {
15
  if (!selectedModel || !prompt.trim()) {
@@ -23,23 +23,51 @@ const InferencePopup = ({ isOpen, onClose, models }) => {
23
 
24
  try {
25
  const inferenceData = {
26
- model: selectedModel,
27
  prompt: prompt,
28
- max_length: 100, // You can make this configurable if needed
 
29
  };
30
 
 
31
  const result = await inference(inferenceData);
32
-
33
- if (result && result.generated_text) {
34
- setResponse(result.generated_text);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  } else if (result && result.error) {
36
- setError(`Inference failed: ${result.error}`);
 
 
 
 
37
  } else {
38
- setError("No response received from the model");
 
39
  }
40
  } catch (err) {
 
41
  setError(`Error: ${err.message}`);
42
- } finally {
43
  setIsLoading(false);
44
  }
45
  };
 
9
  const [isLoading, setIsLoading] = useState(false);
10
  const [error, setError] = useState("");
11
 
12
+ const { inference, checkTaskStatus } = useAPI();
13
 
14
  const handleInference = async () => {
15
  if (!selectedModel || !prompt.trim()) {
 
23
 
24
  try {
25
  const inferenceData = {
26
+ model_name: selectedModel,
27
  prompt: prompt,
28
+ max_new_tokens: 100, // You can make this configurable if needed
29
+ temperature: 0.7, // Add temperature field
30
  };
31
 
32
+ console.log("Starting inference with data:", inferenceData);
33
  const result = await inference(inferenceData);
34
+ console.log("Got inference result:", result);
35
+
36
+ if (result && result.task_id) {
37
+ // Check task status for inference result
38
+ checkTaskStatus(
39
+ result.task_id,
40
+ (taskResult) => {
41
+ console.log("Inference task result:", taskResult);
42
+ if (taskResult && taskResult.generated_text) {
43
+ setResponse(taskResult.generated_text);
44
+ } else if (taskResult && taskResult.error) {
45
+ setError(`Inference failed: ${taskResult.error}`);
46
+ } else {
47
+ setError("No response received from the model");
48
+ }
49
+ setIsLoading(false);
50
+ },
51
+ (errorMessage) => {
52
+ // Error callback for task status check
53
+ console.error("Inference task failed:", errorMessage);
54
+ setError(`Task failed: ${errorMessage}`);
55
+ setIsLoading(false);
56
+ }
57
+ );
58
  } else if (result && result.error) {
59
+ // Check if it's a server error
60
+ const isServerError = result.error.includes("HTTP 5");
61
+ const errorPrefix = isServerError ? "🔴 Server Error: " : "Error: ";
62
+ setError(`${errorPrefix}${result.error}`);
63
+ setIsLoading(false);
64
  } else {
65
+ setError("No task ID received");
66
+ setIsLoading(false);
67
  }
68
  } catch (err) {
69
+ console.error("Inference error:", err);
70
  setError(`Error: ${err.message}`);
 
71
  setIsLoading(false);
72
  }
73
  };
frontend/src/components/Options.jsx CHANGED
@@ -43,36 +43,44 @@ const Options = ({
43
 
44
  try {
45
  const mergeData = {
46
- model1: selectedModel1,
47
- model2: selectedModel2,
48
  layer_recipe: layerRecipe,
49
  embedding_lambdas: embeddingLambdas,
50
  linear_lambdas: linearLambdas,
51
- num_layers: numLayers,
52
  merged_name: mergedName,
53
  };
54
 
 
55
  const taskId = await mergeModels(mergeData);
 
56
 
57
  if (taskId) {
58
- checkTaskStatus(taskId, (result) => {
59
- if (result.success) {
60
- setMergeStatus("Merge successful!");
61
-
62
- const newModelName = result.model_name || mergedName;
63
- setModels((prev) => [...prev, newModelName]);
64
-
65
- setModelLayers(newModelName, numLayers);
66
- } else {
67
- setMergeStatus(`Merge failed: ${result.error || "Unknown error"}`);
 
 
 
 
 
 
 
 
 
 
68
  }
69
- setIsLoading(false);
70
- });
71
- } else {
72
- setMergeStatus("Failed to start merge process");
73
- setIsLoading(false);
74
  }
75
  } catch (error) {
 
76
  setMergeStatus(`Error: ${error.message}`);
77
  setIsLoading(false);
78
  }
 
43
 
44
  try {
45
  const mergeData = {
46
+ model1_name: selectedModel1,
47
+ model2_name: selectedModel2,
48
  layer_recipe: layerRecipe,
49
  embedding_lambdas: embeddingLambdas,
50
  linear_lambdas: linearLambdas,
 
51
  merged_name: mergedName,
52
  };
53
 
54
+ console.log("Starting merge with data:", mergeData);
55
  const taskId = await mergeModels(mergeData);
56
+ console.log("Got merge task ID:", taskId);
57
 
58
  if (taskId) {
59
+ checkTaskStatus(
60
+ taskId,
61
+ (taskResult) => {
62
+ console.log("Merge result:", taskResult);
63
+ if (taskResult.response) {
64
+ setMergeStatus("Merge successful!");
65
+ const newModelName = taskResult.response || mergedName;
66
+ setModels((prev) => [...prev, newModelName]);
67
+ setModelLayers(newModelName, numLayers);
68
+ } else {
69
+ setMergeStatus(
70
+ `Merge failed: ${taskResult.error || "Unknown error"}`
71
+ );
72
+ }
73
+ setIsLoading(false);
74
+ },
75
+ (error) => {
76
+ console.error("Merge task failed:", error);
77
+ setMergeStatus(`Merge failed: ${error}`);
78
+ setIsLoading(false);
79
  }
80
+ );
 
 
 
 
81
  }
82
  } catch (error) {
83
+ console.error("Merge error:", error);
84
  setMergeStatus(`Error: ${error.message}`);
85
  setIsLoading(false);
86
  }
frontend/src/hooks/useAPI.js CHANGED
@@ -3,62 +3,102 @@ import { useCallback } from "react";
3
  const API_BASE = "https://tcmmichaelb139-evolutiontransformer.hf.space";
4
 
5
  export const useAPI = () => {
6
- const checkTaskStatus = useCallback(async (taskId, callback) => {
7
- try {
8
- const response = await fetch(`${API_BASE}/tasks/${taskId}`);
9
- const data = await response.json();
 
 
 
 
 
 
 
 
 
10
 
11
- if (data.status === "SUCCESS") {
12
- callback(data.result);
13
- } else if (data.status === "PENDING") {
14
- setTimeout(() => checkTaskStatus(taskId, callback), 1000);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  }
16
- } catch (error) {
17
- console.error("Task check failed:", error);
18
- }
19
- }, []);
20
 
21
  const fetchModels = useCallback(async () => {
22
  try {
 
23
  const response = await fetch(`${API_BASE}/list_models`, {
24
  method: "POST",
25
  headers: { "Content-Type": "application/json" },
 
26
  });
 
27
  const data = await response.json();
 
28
  return data.task_id;
29
  } catch (error) {
30
- console.error("Failed to fetch models:", error);
31
- return null;
32
  }
33
  }, []);
34
 
35
  const mergeModels = useCallback(async (mergeData) => {
36
  try {
 
37
  const response = await fetch(`${API_BASE}/merge`, {
38
  method: "POST",
39
  headers: { "Content-Type": "application/json" },
40
  body: JSON.stringify(mergeData),
 
41
  });
 
42
  const data = await response.json();
 
43
  return data.task_id;
44
  } catch (error) {
45
- console.error("Merge failed:", error);
46
- return null;
47
  }
48
  }, []);
49
 
50
  const inference = useCallback(async (inferenceData) => {
51
  try {
 
52
  const response = await fetch(`${API_BASE}/generate`, {
53
  method: "POST",
54
  headers: { "Content-Type": "application/json" },
55
  body: JSON.stringify(inferenceData),
 
56
  });
 
 
 
 
 
 
 
57
  const data = await response.json();
 
58
  return data;
59
  } catch (error) {
60
- console.error("Inference failed:", error);
61
- return null;
62
  }
63
  }, []);
64
 
 
3
  const API_BASE = "https://tcmmichaelb139-evolutiontransformer.hf.space";
4
 
5
  export const useAPI = () => {
6
+ const checkTaskStatus = useCallback(
7
+ async (taskId, successCallback, errorCallback) => {
8
+ try {
9
+ const response = await fetch(`${API_BASE}/tasks/${taskId}`, {
10
+ credentials: "include",
11
+ });
12
+
13
+ if (!response.ok) {
14
+ const error = `HTTP ${response.status}: ${response.statusText}`;
15
+ console.error("Task check failed:", error);
16
+ if (errorCallback) errorCallback(error);
17
+ return;
18
+ }
19
 
20
+ const data = await response.json();
21
+ console.log("Task status:", data.status);
22
+
23
+ if (data.status === "SUCCESS") {
24
+ successCallback(data.result);
25
+ } else if (data.status === "PENDING") {
26
+ setTimeout(
27
+ () => checkTaskStatus(taskId, successCallback, errorCallback),
28
+ 1000
29
+ );
30
+ } else if (data.status === "FAILURE") {
31
+ const error = data.result || "Task failed";
32
+ console.error("Task failed:", error);
33
+ if (errorCallback) errorCallback(error);
34
+ }
35
+ } catch (error) {
36
+ console.error("Task check error:", error);
37
+ if (errorCallback) errorCallback(error.message);
38
  }
39
+ },
40
+ []
41
+ );
 
42
 
43
  const fetchModels = useCallback(async () => {
44
  try {
45
+ console.log("Fetching models...");
46
  const response = await fetch(`${API_BASE}/list_models`, {
47
  method: "POST",
48
  headers: { "Content-Type": "application/json" },
49
+ credentials: "include",
50
  });
51
+
52
  const data = await response.json();
53
+ console.log("Fetch models response:", data);
54
  return data.task_id;
55
  } catch (error) {
56
+ console.error("Fetch models error:", error);
57
+ throw error;
58
  }
59
  }, []);
60
 
61
  const mergeModels = useCallback(async (mergeData) => {
62
  try {
63
+ console.log("Merging models with data:", mergeData);
64
  const response = await fetch(`${API_BASE}/merge`, {
65
  method: "POST",
66
  headers: { "Content-Type": "application/json" },
67
  body: JSON.stringify(mergeData),
68
+ credentials: "include",
69
  });
70
+
71
  const data = await response.json();
72
+ console.log("Merge response:", data);
73
  return data.task_id;
74
  } catch (error) {
75
+ console.error("Merge error:", error);
76
+ throw error;
77
  }
78
  }, []);
79
 
80
  const inference = useCallback(async (inferenceData) => {
81
  try {
82
+ console.log("Running inference with data:", inferenceData);
83
  const response = await fetch(`${API_BASE}/generate`, {
84
  method: "POST",
85
  headers: { "Content-Type": "application/json" },
86
  body: JSON.stringify(inferenceData),
87
+ credentials: "include",
88
  });
89
+
90
+ if (!response.ok) {
91
+ const error = `HTTP ${response.status}: ${response.statusText}`;
92
+ console.error("Inference failed:", error);
93
+ throw new Error(error);
94
+ }
95
+
96
  const data = await response.json();
97
+ console.log("Inference response:", data);
98
  return data;
99
  } catch (error) {
100
+ console.error("Inference error:", error);
101
+ throw error;
102
  }
103
  }, []);
104
 
frontend/src/index.css CHANGED
@@ -64,13 +64,13 @@ body {
64
  }
65
 
66
  ::-webkit-scrollbar-thumb {
67
- background: linear-gradient(135deg, #0ea5e9, #22c55e);
68
  border-radius: 4px;
69
  transition: background 0.2s ease;
70
  }
71
 
72
  ::-webkit-scrollbar-thumb:hover {
73
- background: linear-gradient(135deg, #0284c7, #16a34a);
74
  }
75
 
76
  ::-webkit-scrollbar-corner {
 
64
  }
65
 
66
  ::-webkit-scrollbar-thumb {
67
+ background: #0ea5e9;
68
  border-radius: 4px;
69
  transition: background 0.2s ease;
70
  }
71
 
72
  ::-webkit-scrollbar-thumb:hover {
73
+ background: #0284c7;
74
  }
75
 
76
  ::-webkit-scrollbar-corner {