mwtuni commited on
Commit
92d7d12
·
1 Parent(s): 69d4745

split api and app

Browse files
Files changed (3) hide show
  1. api.py +134 -0
  2. app.py +3 -109
  3. requirements.txt +1 -0
api.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ from fastapi import FastAPI, HTTPException
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+
8
+ from avatar_store import avatar_generated_path
9
+ from generate_image_with_nano import build_prompt, run_edit
10
+ from mcp_tools import (
11
+ create_avatar,
12
+ delete_avatar_memory,
13
+ delete_avatar_portrait,
14
+ delete_generated_images,
15
+ ensure_public_avatar,
16
+ generate_as_avatar,
17
+ get_avatar,
18
+ get_avatar_context,
19
+ record_generated_image,
20
+ retrieve_avatar_snippets,
21
+ set_avatar_portrait,
22
+ store_avatar_memory,
23
+ summarize_avatar_context,
24
+ )
25
+
26
+ fastapi_app = FastAPI()
27
+ fastapi_app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+
36
+ def _handle(func, payload):
37
+ try:
38
+ return func(payload or {})
39
+ except ValueError as exc:
40
+ raise HTTPException(status_code=400, detail=str(exc))
41
+
42
+
43
+ def _resolve_path(path_str):
44
+ if not path_str:
45
+ return None
46
+ path = Path(path_str)
47
+ if not path.is_absolute():
48
+ path = Path(__file__).parent / path
49
+ return str(path)
50
+
51
+
52
+ @fastapi_app.post("/mcp/create_avatar")
53
+ def api_create_avatar(payload: dict):
54
+ return _handle(create_avatar, payload)
55
+
56
+
57
+ @fastapi_app.post("/mcp/get_avatar")
58
+ def api_get_avatar(payload: dict):
59
+ return _handle(get_avatar, payload)
60
+
61
+
62
+ @fastapi_app.post("/mcp/generate_as_avatar")
63
+ def api_generate_as_avatar(payload: dict):
64
+ return _handle(generate_as_avatar, payload)
65
+
66
+
67
+ @fastapi_app.post("/mcp/generate_image")
68
+ def api_generate_image(payload: dict):
69
+ payload = payload or {}
70
+ avatar_id = payload.get("avatar_id")
71
+ if not avatar_id:
72
+ raise HTTPException(status_code=400, detail="avatar_id required")
73
+ message = (payload.get("message") or "").strip()
74
+ reply = (payload.get("reply") or "").strip()
75
+ context = message or reply or "scene with the avatar"
76
+ try:
77
+ avatar = ensure_public_avatar(avatar_id)
78
+ except ValueError as exc:
79
+ raise HTTPException(status_code=400, detail=str(exc))
80
+ portrait_path = _resolve_path(avatar.get("portrait"))
81
+ if not portrait_path or not Path(portrait_path).exists():
82
+ raise HTTPException(status_code=400, detail="portrait not found for avatar")
83
+ prompt = build_prompt(avatar.get("persona", "Avatar"), context)
84
+ timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
85
+ out_path = avatar_generated_path(avatar_id, timestamp)
86
+ out_path.parent.mkdir(parents=True, exist_ok=True)
87
+ try:
88
+ rc = run_edit(Path(portrait_path), prompt, out_path)
89
+ except SystemExit:
90
+ rc = 1
91
+ if rc != 0:
92
+ raise HTTPException(status_code=500, detail="image generation failed")
93
+ record_generated_image(avatar_id, out_path, prompt, timestamp)
94
+ return {
95
+ "status": "generated",
96
+ "path": str(out_path),
97
+ "prompt": prompt,
98
+ "timestamp": timestamp,
99
+ }
100
+
101
+
102
+ @fastapi_app.post("/mcp/store_avatar_memory")
103
+ def api_store_avatar_memory(payload: dict):
104
+ return _handle(store_avatar_memory, payload)
105
+
106
+
107
+ @fastapi_app.post("/mcp/delete_avatar_memory")
108
+ def api_delete_avatar_memory(payload: dict):
109
+ return _handle(delete_avatar_memory, payload)
110
+
111
+
112
+ @fastapi_app.post("/mcp/get_avatar_context")
113
+ def api_get_avatar_context(payload: dict):
114
+ return _handle(get_avatar_context, payload)
115
+
116
+
117
+ @fastapi_app.post("/mcp/delete_generated_images")
118
+ def api_delete_generated_images(payload: dict):
119
+ return _handle(delete_generated_images, payload)
120
+
121
+
122
+ @fastapi_app.post("/mcp/retrieve_snippets")
123
+ def api_retrieve_snippets(payload: dict):
124
+ return _handle(retrieve_avatar_snippets, payload)
125
+
126
+
127
+ @fastapi_app.post("/mcp/summarize_avatar")
128
+ def api_summarize_avatar(payload: dict):
129
+ return _handle(summarize_avatar_context, payload)
130
+
131
+
132
+ @fastapi_app.get("/")
133
+ def root():
134
+ return {"status": "ok"}
app.py CHANGED
@@ -4,44 +4,23 @@ from datetime import datetime
4
  from pathlib import Path
5
 
6
  import gradio as gr
7
- from fastapi import FastAPI, HTTPException
8
- from fastapi.middleware.cors import CORSMiddleware
9
 
 
10
  from avatar_store import avatar_generated_path
 
11
  from mcp_tools import (
12
  create_avatar,
13
  delete_avatar_memory,
14
  delete_avatar_portrait,
15
- ensure_public_avatar,
16
  generate_as_avatar,
17
  get_avatar,
18
- get_avatar_context,
19
  record_generated_image,
20
- delete_generated_images,
21
  set_avatar_portrait,
22
  store_avatar_memory,
23
- retrieve_avatar_snippets,
24
- summarize_avatar_context,
25
- )
26
- from generate_image_with_nano import build_prompt, run_edit
27
-
28
- fastapi_app = FastAPI()
29
- fastapi_app.add_middleware(
30
- CORSMiddleware,
31
- allow_origins=["*"],
32
- allow_credentials=True,
33
- allow_methods=["*"],
34
- allow_headers=["*"],
35
  )
36
 
37
 
38
- def _handle(func, payload):
39
- try:
40
- return func(payload or {})
41
- except ValueError as exc:
42
- raise HTTPException(status_code=400, detail=str(exc))
43
-
44
-
45
  def _resolve_path(path_str):
46
  if not path_str:
47
  return None
@@ -51,86 +30,6 @@ def _resolve_path(path_str):
51
  return str(path)
52
 
53
 
54
- @fastapi_app.post("/mcp/create_avatar")
55
- def api_create_avatar(payload: dict):
56
- return _handle(create_avatar, payload)
57
-
58
-
59
- @fastapi_app.post("/mcp/get_avatar")
60
- def api_get_avatar(payload: dict):
61
- return _handle(get_avatar, payload)
62
-
63
-
64
- @fastapi_app.post("/mcp/generate_as_avatar")
65
- def api_generate_as_avatar(payload: dict):
66
- return _handle(generate_as_avatar, payload)
67
-
68
-
69
- @fastapi_app.post("/mcp/generate_image")
70
- def api_generate_image(payload: dict):
71
- payload = payload or {}
72
- avatar_id = payload.get("avatar_id")
73
- if not avatar_id:
74
- raise HTTPException(status_code=400, detail="avatar_id required")
75
- message = (payload.get("message") or "").strip()
76
- reply = (payload.get("reply") or "").strip()
77
- context = message or reply or "scene with the avatar"
78
- try:
79
- avatar = ensure_public_avatar(avatar_id)
80
- except ValueError as exc:
81
- raise HTTPException(status_code=400, detail=str(exc))
82
- portrait_path = _resolve_path(avatar.get("portrait"))
83
- if not portrait_path or not Path(portrait_path).exists():
84
- raise HTTPException(status_code=400, detail="portrait not found for avatar")
85
- prompt = build_prompt(avatar.get("persona", "Avatar"), context)
86
- timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
87
- out_path = avatar_generated_path(avatar_id, timestamp)
88
- out_path.parent.mkdir(parents=True, exist_ok=True)
89
- try:
90
- rc = run_edit(Path(portrait_path), prompt, out_path)
91
- except SystemExit:
92
- rc = 1
93
- if rc != 0:
94
- raise HTTPException(status_code=500, detail="image generation failed")
95
- record_generated_image(avatar_id, out_path, prompt, timestamp)
96
- return {
97
- "status": "generated",
98
- "path": str(out_path),
99
- "prompt": prompt,
100
- "timestamp": timestamp,
101
- }
102
-
103
-
104
- @fastapi_app.post("/mcp/store_avatar_memory")
105
- def api_store_avatar_memory(payload: dict):
106
- return _handle(store_avatar_memory, payload)
107
-
108
-
109
- @fastapi_app.post("/mcp/delete_avatar_memory")
110
- def api_delete_avatar_memory(payload: dict):
111
- return _handle(delete_avatar_memory, payload)
112
-
113
-
114
- @fastapi_app.post("/mcp/get_avatar_context")
115
- def api_get_avatar_context(payload: dict):
116
- return _handle(get_avatar_context, payload)
117
-
118
-
119
- @fastapi_app.post("/mcp/delete_generated_images")
120
- def api_delete_generated_images(payload: dict):
121
- return _handle(delete_generated_images, payload)
122
-
123
-
124
- @fastapi_app.post("/mcp/retrieve_snippets")
125
- def api_retrieve_snippets(payload: dict):
126
- return _handle(retrieve_avatar_snippets, payload)
127
-
128
-
129
- @fastapi_app.post("/mcp/summarize_avatar")
130
- def api_summarize_avatar(payload: dict):
131
- return _handle(summarize_avatar_context, payload)
132
-
133
-
134
  def decide_tool(message: str) -> str:
135
  text = (message or "").lower()
136
  if "remember" in text:
@@ -535,11 +434,6 @@ with gr.Blocks() as ui:
535
  app = gr.mount_gradio_app(fastapi_app, ui, path="/")
536
 
537
 
538
- @fastapi_app.get("/")
539
- def root():
540
- return {"status": "ok"}
541
-
542
-
543
  if __name__ == "__main__" and not os.getenv("SPACE_ID"):
544
  import uvicorn
545
 
 
4
  from pathlib import Path
5
 
6
  import gradio as gr
 
 
7
 
8
+ from api import fastapi_app
9
  from avatar_store import avatar_generated_path
10
+ from generate_image_with_nano import build_prompt, run_edit
11
  from mcp_tools import (
12
  create_avatar,
13
  delete_avatar_memory,
14
  delete_avatar_portrait,
15
+ delete_generated_images,
16
  generate_as_avatar,
17
  get_avatar,
 
18
  record_generated_image,
 
19
  set_avatar_portrait,
20
  store_avatar_memory,
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
 
23
 
 
 
 
 
 
 
 
24
  def _resolve_path(path_str):
25
  if not path_str:
26
  return None
 
30
  return str(path)
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def decide_tool(message: str) -> str:
34
  text = (message or "").lower()
35
  if "remember" in text:
 
434
  app = gr.mount_gradio_app(fastapi_app, ui, path="/")
435
 
436
 
 
 
 
 
 
437
  if __name__ == "__main__" and not os.getenv("SPACE_ID"):
438
  import uvicorn
439
 
requirements.txt CHANGED
@@ -4,3 +4,4 @@ uvicorn
4
  httpx
5
  google-genai
6
  Pillow
 
 
4
  httpx
5
  google-genai
6
  Pillow
7
+ pydantic>=2,<3