Spaces:
Running on Zero
Running on Zero
Codex commited on
Commit ·
8e1f5bc
1
Parent(s): f1d5b31
Refactor codebase to enforce top-level imports, clean comments, remove inline ignores, and expand UI examples
Browse files- analyzer.py +8 -17
- app.py +6 -6
- core.py +2 -12
- inference.py +6 -15
- parser.py +2 -4
- pyrightconfig.json +3 -0
- runtime.py +7 -7
- tune_journal.py +25 -25
- ui.py +24 -4
analyzer.py
CHANGED
|
@@ -1,19 +1,16 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
Brings together inference, file extraction, response parsing, and heuristic fallbacks.
|
| 4 |
-
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
-
from dataclasses import dataclass
|
| 9 |
from collections.abc import Callable
|
|
|
|
| 10 |
from typing import Any
|
| 11 |
|
| 12 |
-
# Dynamic import fallback for ZeroGPU runtime environment compatibility
|
| 13 |
try:
|
| 14 |
-
import spaces
|
| 15 |
except ImportError:
|
| 16 |
-
|
| 17 |
class _LocalSpacesFallback:
|
| 18 |
@staticmethod
|
| 19 |
def GPU(
|
|
@@ -27,7 +24,7 @@ except ImportError:
|
|
| 27 |
spaces = _LocalSpacesFallback()
|
| 28 |
|
| 29 |
from config import ENTRY_LIMIT, MODEL_ID, PARAMETER_COUNT
|
| 30 |
-
from inference import
|
| 31 |
from parser import extract_journal_text, parse_sections
|
| 32 |
|
| 33 |
|
|
@@ -125,6 +122,7 @@ def analyze_journal_ui(
|
|
| 125 |
) -> tuple[str, str, str, str, str, list[dict[str, str]], str]:
|
| 126 |
"""Gradio-compatible entry point decorated for Hugging Face ZeroGPU compatibility."""
|
| 127 |
report = analyze_journal(file_path, raw_text)
|
|
|
|
| 128 |
return (
|
| 129 |
report.entry_text,
|
| 130 |
report.model_path,
|
|
@@ -142,14 +140,7 @@ def chat_respond_ui(
|
|
| 142 |
user_message: str,
|
| 143 |
journal_context: str,
|
| 144 |
) -> tuple[list[dict[str, str]], str, str]:
|
| 145 |
-
"""Gradio-compatible chat handler decorated for Hugging Face ZeroGPU compatibility.
|
| 146 |
-
|
| 147 |
-
Returns:
|
| 148 |
-
tuple containing:
|
| 149 |
-
- updated history list of dicts
|
| 150 |
-
- cleared user message textbox string ("")
|
| 151 |
-
- updated system logs string
|
| 152 |
-
"""
|
| 153 |
updated_history = list(history) if history else []
|
| 154 |
if not user_message.strip():
|
| 155 |
return updated_history, "", "Empty user message. No inference run."
|
|
|
|
| 1 |
+
# Module responsible for orchestrating the overall journal entry analysis.
|
| 2 |
+
# Brings together inference, file extraction, response parsing, and fallback flows.
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from __future__ import annotations
|
| 5 |
|
|
|
|
| 6 |
from collections.abc import Callable
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
from typing import Any
|
| 9 |
|
|
|
|
| 10 |
try:
|
| 11 |
+
import spaces
|
| 12 |
except ImportError:
|
| 13 |
+
# Dummy decorator used when spaces package is unavailable locally
|
| 14 |
class _LocalSpacesFallback:
|
| 15 |
@staticmethod
|
| 16 |
def GPU(
|
|
|
|
| 24 |
spaces = _LocalSpacesFallback()
|
| 25 |
|
| 26 |
from config import ENTRY_LIMIT, MODEL_ID, PARAMETER_COUNT
|
| 27 |
+
from inference import run_chat_inference, run_model_inference
|
| 28 |
from parser import extract_journal_text, parse_sections
|
| 29 |
|
| 30 |
|
|
|
|
| 122 |
) -> tuple[str, str, str, str, str, list[dict[str, str]], str]:
|
| 123 |
"""Gradio-compatible entry point decorated for Hugging Face ZeroGPU compatibility."""
|
| 124 |
report = analyze_journal(file_path, raw_text)
|
| 125 |
+
# The last element returned updates the hidden journal_context state variable
|
| 126 |
return (
|
| 127 |
report.entry_text,
|
| 128 |
report.model_path,
|
|
|
|
| 140 |
user_message: str,
|
| 141 |
journal_context: str,
|
| 142 |
) -> tuple[list[dict[str, str]], str, str]:
|
| 143 |
+
"""Gradio-compatible chat handler decorated for Hugging Face ZeroGPU compatibility."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
updated_history = list(history) if history else []
|
| 145 |
if not user_message.strip():
|
| 146 |
return updated_history, "", "Empty user message. No inference run."
|
app.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
# Disable Gradio Server-Side Rendering
|
| 6 |
os.environ.setdefault("GRADIO_SSR_MODE", "false")
|
| 7 |
|
| 8 |
# Patch asyncio to ignore minor event loop warnings on teardown
|
| 9 |
-
from runtime import patch_asyncio_cleanup_warning # noqa: E402
|
| 10 |
-
|
| 11 |
patch_asyncio_cleanup_warning()
|
| 12 |
|
| 13 |
-
# Import UI components and CSS styling
|
| 14 |
-
from styles import CUSTOM_CSS # noqa: E402
|
| 15 |
-
from ui import create_app, get_theme # noqa: E402
|
| 16 |
-
|
| 17 |
# Build Gradio app block
|
| 18 |
demo = create_app()
|
| 19 |
theme = get_theme()
|
|
|
|
| 1 |
+
# Entry point for the InnerSpace Gradio application.
|
| 2 |
+
# Configures environment variables, patches warnings, and launches the interface.
|
| 3 |
+
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import os
|
| 7 |
+
from runtime import patch_asyncio_cleanup_warning
|
| 8 |
+
from styles import CUSTOM_CSS
|
| 9 |
+
from ui import create_app, get_theme
|
| 10 |
|
| 11 |
# Disable Gradio Server-Side Rendering
|
| 12 |
os.environ.setdefault("GRADIO_SSR_MODE", "false")
|
| 13 |
|
| 14 |
# Patch asyncio to ignore minor event loop warnings on teardown
|
|
|
|
|
|
|
| 15 |
patch_asyncio_cleanup_warning()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
# Build Gradio app block
|
| 18 |
demo = create_app()
|
| 19 |
theme = get_theme()
|
core.py
CHANGED
|
@@ -1,18 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
Provides a unified entry point for Gradio UI interactions.
|
| 4 |
-
|
| 5 |
-
This file serves as a facade to maintain backward compatibility while delegating
|
| 6 |
-
responsibilities to specialized modules according to SOLID principles:
|
| 7 |
-
- `inference.py` handles model lazy-loading, caching, and inference.
|
| 8 |
-
- `parser.py` handles file text extraction and output segment splitting.
|
| 9 |
-
- `heuristics.py` handles keyword-based offline backup interpretations.
|
| 10 |
-
- `analyzer.py` handles prompt formatting and orchestrates the pipeline.
|
| 11 |
-
"""
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
-
# Re-export key analytical components to maintain interface contracts
|
| 16 |
from analyzer import (
|
| 17 |
JournalReport,
|
| 18 |
analyze_journal,
|
|
|
|
| 1 |
+
# InnerSpace Core API Facade.
|
| 2 |
+
# Provides a unified entry point for Gradio UI interactions.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from __future__ import annotations
|
| 5 |
|
|
|
|
| 6 |
from analyzer import (
|
| 7 |
JournalReport,
|
| 8 |
analyze_journal,
|
inference.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
Handles local GPU/CPU execution and fallback to Hugging Face Serverless Inference API.
|
| 4 |
-
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
| 8 |
import os
|
|
|
|
| 9 |
from typing import Any
|
| 10 |
import torch
|
|
|
|
|
|
|
| 11 |
|
| 12 |
from config import MODEL_ID
|
| 13 |
|
| 14 |
-
#
|
| 15 |
_model: Any = None
|
| 16 |
_tokenizer: Any = None
|
| 17 |
|
|
@@ -20,8 +21,6 @@ def get_model_and_tokenizer() -> tuple[Any, Any]:
|
|
| 20 |
"""Loads and caches the Hugging Face model and tokenizer lazily."""
|
| 21 |
global _model, _tokenizer
|
| 22 |
if _model is None:
|
| 23 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 24 |
-
|
| 25 |
print(f"Loading tokenizer for {MODEL_ID}...")
|
| 26 |
_tokenizer = AutoTokenizer.from_pretrained(
|
| 27 |
MODEL_ID,
|
|
@@ -83,8 +82,6 @@ def run_model_inference(prompt: str) -> tuple[str, str]:
|
|
| 83 |
return response, "\n".join(log_lines)
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
-
import traceback
|
| 87 |
-
|
| 88 |
traceback.print_exc()
|
| 89 |
log_lines.append(
|
| 90 |
f"Local model execution failed: {e}. Falling back to serverless API..."
|
|
@@ -95,8 +92,6 @@ def run_model_inference(prompt: str) -> tuple[str, str]:
|
|
| 95 |
f"Initiating Hugging Face Serverless Inference API ({MODEL_ID})..."
|
| 96 |
)
|
| 97 |
try:
|
| 98 |
-
from huggingface_hub import InferenceClient
|
| 99 |
-
|
| 100 |
client = InferenceClient(MODEL_ID, token=os.environ.get("HF_TOKEN"))
|
| 101 |
messages = [{"role": "user", "content": prompt}]
|
| 102 |
completion = client.chat_completion(messages=messages, max_tokens=512)
|
|
@@ -151,8 +146,6 @@ def run_chat_inference(
|
|
| 151 |
return response, "\n".join(log_lines)
|
| 152 |
|
| 153 |
except Exception as e:
|
| 154 |
-
import traceback
|
| 155 |
-
|
| 156 |
traceback.print_exc()
|
| 157 |
log_lines.append(
|
| 158 |
f"Local chat execution failed: {e}. Falling back to serverless API..."
|
|
@@ -162,8 +155,6 @@ def run_chat_inference(
|
|
| 162 |
f"Initiating Hugging Face Serverless Inference API for chat ({MODEL_ID})..."
|
| 163 |
)
|
| 164 |
try:
|
| 165 |
-
from huggingface_hub import InferenceClient
|
| 166 |
-
|
| 167 |
client = InferenceClient(MODEL_ID, token=os.environ.get("HF_TOKEN"))
|
| 168 |
messages = [{"role": "system", "content": system_prompt}] + history
|
| 169 |
completion = client.chat_completion(messages=messages, max_tokens=256)
|
|
|
|
| 1 |
+
# Module responsible for model loading and text generation.
|
| 2 |
+
# Handles local GPU/CPU execution and fallback to Hugging Face Serverless Inference API.
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from __future__ import annotations
|
| 5 |
|
| 6 |
import os
|
| 7 |
+
import traceback
|
| 8 |
from typing import Any
|
| 9 |
import torch
|
| 10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 11 |
+
from huggingface_hub import InferenceClient
|
| 12 |
|
| 13 |
from config import MODEL_ID
|
| 14 |
|
| 15 |
+
# Cache model and tokenizer to prevent reloading on subsequent runs
|
| 16 |
_model: Any = None
|
| 17 |
_tokenizer: Any = None
|
| 18 |
|
|
|
|
| 21 |
"""Loads and caches the Hugging Face model and tokenizer lazily."""
|
| 22 |
global _model, _tokenizer
|
| 23 |
if _model is None:
|
|
|
|
|
|
|
| 24 |
print(f"Loading tokenizer for {MODEL_ID}...")
|
| 25 |
_tokenizer = AutoTokenizer.from_pretrained(
|
| 26 |
MODEL_ID,
|
|
|
|
| 82 |
return response, "\n".join(log_lines)
|
| 83 |
|
| 84 |
except Exception as e:
|
|
|
|
|
|
|
| 85 |
traceback.print_exc()
|
| 86 |
log_lines.append(
|
| 87 |
f"Local model execution failed: {e}. Falling back to serverless API..."
|
|
|
|
| 92 |
f"Initiating Hugging Face Serverless Inference API ({MODEL_ID})..."
|
| 93 |
)
|
| 94 |
try:
|
|
|
|
|
|
|
| 95 |
client = InferenceClient(MODEL_ID, token=os.environ.get("HF_TOKEN"))
|
| 96 |
messages = [{"role": "user", "content": prompt}]
|
| 97 |
completion = client.chat_completion(messages=messages, max_tokens=512)
|
|
|
|
| 146 |
return response, "\n".join(log_lines)
|
| 147 |
|
| 148 |
except Exception as e:
|
|
|
|
|
|
|
| 149 |
traceback.print_exc()
|
| 150 |
log_lines.append(
|
| 151 |
f"Local chat execution failed: {e}. Falling back to serverless API..."
|
|
|
|
| 155 |
f"Initiating Hugging Face Serverless Inference API for chat ({MODEL_ID})..."
|
| 156 |
)
|
| 157 |
try:
|
|
|
|
|
|
|
| 158 |
client = InferenceClient(MODEL_ID, token=os.environ.get("HF_TOKEN"))
|
| 159 |
messages = [{"role": "system", "content": system_prompt}] + history
|
| 160 |
completion = client.chat_completion(messages=messages, max_tokens=256)
|
parser.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
Extracts raw text from inputs and parses structured response blocks.
|
| 4 |
-
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
| 7 |
|
|
|
|
| 1 |
+
# Module responsible for diary text file parsing and model output parsing.
|
| 2 |
+
# Extracts raw text from inputs and parses structured response blocks.
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from __future__ import annotations
|
| 5 |
|
pyrightconfig.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"reportMissingImports": "none"
|
| 3 |
+
}
|
runtime.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import asyncio.base_events as base_events
|
| 4 |
from typing import Any
|
| 5 |
|
|
@@ -7,7 +9,7 @@ from typing import Any
|
|
| 7 |
def patch_asyncio_cleanup_warning() -> None:
|
| 8 |
"""Patches asyncio EventLoop __del__ method to ignore harmless file descriptor cleanup warnings in notebook/interactive runs."""
|
| 9 |
original_del = getattr(base_events.BaseEventLoop, "__del__", None)
|
| 10 |
-
if original_del is None or getattr(original_del, "
|
| 11 |
return
|
| 12 |
|
| 13 |
def patched_del(self: Any) -> None:
|
|
@@ -17,15 +19,12 @@ def patch_asyncio_cleanup_warning() -> None:
|
|
| 17 |
if str(exc) != "Invalid file descriptor: -1":
|
| 18 |
raise
|
| 19 |
|
| 20 |
-
patched_del
|
| 21 |
-
base_events.BaseEventLoop
|
| 22 |
|
| 23 |
|
| 24 |
def load_env() -> None:
|
| 25 |
"""Loads environment variables from .env if present in the current or parent directory."""
|
| 26 |
-
import os
|
| 27 |
-
from pathlib import Path
|
| 28 |
-
|
| 29 |
for path in [Path(".env"), Path("../.env")]:
|
| 30 |
if path.is_file():
|
| 31 |
try:
|
|
@@ -35,7 +34,8 @@ def load_env() -> None:
|
|
| 35 |
if line and not line.startswith("#") and "=" in line:
|
| 36 |
k, v = line.split("=", 1)
|
| 37 |
os.environ.setdefault(k.strip(), v.strip().strip("'\""))
|
| 38 |
-
|
|
|
|
| 39 |
except Exception:
|
| 40 |
pass
|
| 41 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
import asyncio.base_events as base_events
|
| 6 |
from typing import Any
|
| 7 |
|
|
|
|
| 9 |
def patch_asyncio_cleanup_warning() -> None:
|
| 10 |
"""Patches asyncio EventLoop __del__ method to ignore harmless file descriptor cleanup warnings in notebook/interactive runs."""
|
| 11 |
original_del = getattr(base_events.BaseEventLoop, "__del__", None)
|
| 12 |
+
if original_del is None or getattr(original_del, "_innerspace_patched", False):
|
| 13 |
return
|
| 14 |
|
| 15 |
def patched_del(self: Any) -> None:
|
|
|
|
| 19 |
if str(exc) != "Invalid file descriptor: -1":
|
| 20 |
raise
|
| 21 |
|
| 22 |
+
setattr(patched_del, "_innerspace_patched", True)
|
| 23 |
+
setattr(base_events.BaseEventLoop, "__del__", patched_del)
|
| 24 |
|
| 25 |
|
| 26 |
def load_env() -> None:
|
| 27 |
"""Loads environment variables from .env if present in the current or parent directory."""
|
|
|
|
|
|
|
|
|
|
| 28 |
for path in [Path(".env"), Path("../.env")]:
|
| 29 |
if path.is_file():
|
| 30 |
try:
|
|
|
|
| 34 |
if line and not line.startswith("#") and "=" in line:
|
| 35 |
k, v = line.split("=", 1)
|
| 36 |
os.environ.setdefault(k.strip(), v.strip().strip("'\""))
|
| 37 |
+
# Stop after loading the first found .env file
|
| 38 |
+
break
|
| 39 |
except Exception:
|
| 40 |
pass
|
| 41 |
|
tune_journal.py
CHANGED
|
@@ -1,7 +1,21 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import os
|
| 4 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Define the Modal App
|
| 7 |
app = modal.App("inner-space-tuner")
|
|
@@ -24,35 +38,21 @@ volume = modal.Volume.from_name("inner-space-checkpoints", create_if_missing=Tru
|
|
| 24 |
MODEL_ID = "openbmb/MiniCPM5-1B-SFT"
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
| 27 |
@app.function(
|
| 28 |
image=image,
|
| 29 |
-
gpu="A10G",
|
| 30 |
-
timeout=7200,
|
| 31 |
volumes={"/checkpoints": volume},
|
| 32 |
)
|
| 33 |
def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
| 34 |
"""Fine-tunes openbmb/MiniCPM5-1B-SFT on cognitive behavioral reflections using QLoRA."""
|
| 35 |
-
import torch
|
| 36 |
-
from datasets import Dataset # type: ignore
|
| 37 |
-
from peft import ( # type: ignore
|
| 38 |
-
LoraConfig,
|
| 39 |
-
get_peft_model,
|
| 40 |
-
prepare_model_for_kbit_training,
|
| 41 |
-
)
|
| 42 |
-
from transformers import (
|
| 43 |
-
AutoModelForCausalLM,
|
| 44 |
-
AutoTokenizer,
|
| 45 |
-
BitsAndBytesConfig,
|
| 46 |
-
TrainingArguments,
|
| 47 |
-
)
|
| 48 |
-
from trl import SFTTrainer # type: ignore
|
| 49 |
-
|
| 50 |
print(f"Loading tokenizer for {MODEL_ID}...")
|
| 51 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 52 |
tokenizer.pad_token = tokenizer.eos_token
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
# In a real-world scenario, you would load this from Hugging Face Hub (e.g., load_dataset("CBT-Reflections"))
|
| 56 |
print("Preparing training dataset...")
|
| 57 |
raw_data = [
|
| 58 |
{
|
|
@@ -108,7 +108,7 @@ def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
|
| 108 |
|
| 109 |
dataset = Dataset.from_list(formatted_dataset)
|
| 110 |
|
| 111 |
-
#
|
| 112 |
print("Configuring QLoRA...")
|
| 113 |
bnb_config = BitsAndBytesConfig(
|
| 114 |
load_in_4bit=True,
|
|
@@ -128,7 +128,7 @@ def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
|
| 128 |
# Prepare model for PEFT training
|
| 129 |
model = prepare_model_for_kbit_training(model)
|
| 130 |
|
| 131 |
-
#
|
| 132 |
peft_config = LoraConfig(
|
| 133 |
r=8,
|
| 134 |
lora_alpha=16,
|
|
@@ -140,7 +140,7 @@ def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
|
| 140 |
model = get_peft_model(model, peft_config)
|
| 141 |
model.print_trainable_parameters()
|
| 142 |
|
| 143 |
-
#
|
| 144 |
training_args = TrainingArguments(
|
| 145 |
output_dir="/checkpoints/inner-space-lora",
|
| 146 |
per_device_train_batch_size=1,
|
|
@@ -157,7 +157,7 @@ def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
|
| 157 |
report_to="none",
|
| 158 |
)
|
| 159 |
|
| 160 |
-
#
|
| 161 |
print("Starting fine-tuning job on Modal...")
|
| 162 |
trainer = SFTTrainer(
|
| 163 |
model=model,
|
|
@@ -169,7 +169,7 @@ def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
|
| 169 |
trainer.train()
|
| 170 |
print("Fine-tuning completed successfully!")
|
| 171 |
|
| 172 |
-
#
|
| 173 |
print("Saving fine-tuned adapter...")
|
| 174 |
model.save_pretrained("/checkpoints/inner-space-final")
|
| 175 |
tokenizer.save_pretrained("/checkpoints/inner-space-final")
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import os
|
| 4 |
+
import torch
|
| 5 |
+
import modal
|
| 6 |
+
from datasets import Dataset
|
| 7 |
+
from peft import (
|
| 8 |
+
LoraConfig,
|
| 9 |
+
get_peft_model,
|
| 10 |
+
prepare_model_for_kbit_training,
|
| 11 |
+
)
|
| 12 |
+
from transformers import (
|
| 13 |
+
AutoModelForCausalLM,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
BitsAndBytesConfig,
|
| 16 |
+
TrainingArguments,
|
| 17 |
+
)
|
| 18 |
+
from trl import SFTTrainer
|
| 19 |
|
| 20 |
# Define the Modal App
|
| 21 |
app = modal.App("inner-space-tuner")
|
|
|
|
| 38 |
MODEL_ID = "openbmb/MiniCPM5-1B-SFT"
|
| 39 |
|
| 40 |
|
| 41 |
+
# Targets single A10G GPU for cost-effective execution
|
| 42 |
+
# Two hours timeout
|
| 43 |
@app.function(
|
| 44 |
image=image,
|
| 45 |
+
gpu="A10G",
|
| 46 |
+
timeout=7200,
|
| 47 |
volumes={"/checkpoints": volume},
|
| 48 |
)
|
| 49 |
def train_lora(hf_token: str | None = None, repo_id: str | None = None):
|
| 50 |
"""Fine-tunes openbmb/MiniCPM5-1B-SFT on cognitive behavioral reflections using QLoRA."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
print(f"Loading tokenizer for {MODEL_ID}...")
|
| 52 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 53 |
tokenizer.pad_token = tokenizer.eos_token
|
| 54 |
|
| 55 |
+
# Prepare a synthetic CBT/Mindfulness journal dataset
|
|
|
|
| 56 |
print("Preparing training dataset...")
|
| 57 |
raw_data = [
|
| 58 |
{
|
|
|
|
| 108 |
|
| 109 |
dataset = Dataset.from_list(formatted_dataset)
|
| 110 |
|
| 111 |
+
# Configure 4-bit QLoRA quantization for resource efficiency
|
| 112 |
print("Configuring QLoRA...")
|
| 113 |
bnb_config = BitsAndBytesConfig(
|
| 114 |
load_in_4bit=True,
|
|
|
|
| 128 |
# Prepare model for PEFT training
|
| 129 |
model = prepare_model_for_kbit_training(model)
|
| 130 |
|
| 131 |
+
# Configure LoRA Adapter
|
| 132 |
peft_config = LoraConfig(
|
| 133 |
r=8,
|
| 134 |
lora_alpha=16,
|
|
|
|
| 140 |
model = get_peft_model(model, peft_config)
|
| 141 |
model.print_trainable_parameters()
|
| 142 |
|
| 143 |
+
# Configure Training Arguments
|
| 144 |
training_args = TrainingArguments(
|
| 145 |
output_dir="/checkpoints/inner-space-lora",
|
| 146 |
per_device_train_batch_size=1,
|
|
|
|
| 157 |
report_to="none",
|
| 158 |
)
|
| 159 |
|
| 160 |
+
# Start Training
|
| 161 |
print("Starting fine-tuning job on Modal...")
|
| 162 |
trainer = SFTTrainer(
|
| 163 |
model=model,
|
|
|
|
| 169 |
trainer.train()
|
| 170 |
print("Fine-tuning completed successfully!")
|
| 171 |
|
| 172 |
+
# Save and push adapter to Hugging Face Hub
|
| 173 |
print("Saving fine-tuned adapter...")
|
| 174 |
model.save_pretrained("/checkpoints/inner-space-final")
|
| 175 |
tokenizer.save_pretrained("/checkpoints/inner-space-final")
|
ui.py
CHANGED
|
@@ -1,7 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import gradio as gr
|
| 4 |
from typing import Any
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from config import (
|
| 7 |
APP_DESCRIPTION,
|
|
@@ -14,7 +18,7 @@ from core import analyze_journal_ui, chat_respond_ui
|
|
| 14 |
|
| 15 |
def get_theme() -> Any:
|
| 16 |
"""Returns the custom soft theme configured for dark slate violet styling."""
|
| 17 |
-
theme =
|
| 18 |
primary_hue="violet",
|
| 19 |
secondary_hue="slate",
|
| 20 |
neutral_hue="slate",
|
|
@@ -114,7 +118,7 @@ def create_app() -> gr.Blocks:
|
|
| 114 |
elem_classes=["nd-log-box"],
|
| 115 |
)
|
| 116 |
|
| 117 |
-
#
|
| 118 |
gr.Examples(
|
| 119 |
examples=[
|
| 120 |
[
|
|
@@ -123,7 +127,23 @@ def create_app() -> gr.Blocks:
|
|
| 123 |
],
|
| 124 |
[
|
| 125 |
None,
|
| 126 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
],
|
| 128 |
],
|
| 129 |
inputs=[file_input, notes_input],
|
|
|
|
| 1 |
+
# Module responsible for creating and laying out the Gradio interface.
|
| 2 |
+
# Connects UI input components to core logical workflows.
|
| 3 |
+
|
| 4 |
from __future__ import annotations
|
| 5 |
|
|
|
|
| 6 |
from typing import Any
|
| 7 |
+
import gradio as gr
|
| 8 |
+
from gradio.themes import Soft
|
| 9 |
|
| 10 |
from config import (
|
| 11 |
APP_DESCRIPTION,
|
|
|
|
| 18 |
|
| 19 |
def get_theme() -> Any:
|
| 20 |
"""Returns the custom soft theme configured for dark slate violet styling."""
|
| 21 |
+
theme = Soft(
|
| 22 |
primary_hue="violet",
|
| 23 |
secondary_hue="slate",
|
| 24 |
neutral_hue="slate",
|
|
|
|
| 118 |
elem_classes=["nd-log-box"],
|
| 119 |
)
|
| 120 |
|
| 121 |
+
# Preloaded examples for one-click test runs
|
| 122 |
gr.Examples(
|
| 123 |
examples=[
|
| 124 |
[
|
|
|
|
| 127 |
],
|
| 128 |
[
|
| 129 |
None,
|
| 130 |
+
"I've been working 12-hour days all week. I feel completely exhausted, but if I take a break, my team will fall behind and it'll be my fault. I just need to push through, but I can barely think straight.",
|
| 131 |
+
],
|
| 132 |
+
[
|
| 133 |
+
None,
|
| 134 |
+
"I got promoted to senior engineer, but I'm terrified. I only got it because they like me, not because I'm actually good at this. Soon they'll assign me a complex task, I'll fail, and everyone will realize I'm a fraud.",
|
| 135 |
+
],
|
| 136 |
+
[
|
| 137 |
+
None,
|
| 138 |
+
"My best friend forgot my birthday. They didn't even text me. I thought we were close, but clearly they don't value our friendship as much as I do. I should just stop talking to them entirely.",
|
| 139 |
+
],
|
| 140 |
+
[
|
| 141 |
+
None,
|
| 142 |
+
"I've had a headache for two days. I googled it and it says it could be a brain tumor. I'm terrified. I can't focus on anything else and I feel like my life is ending.",
|
| 143 |
+
],
|
| 144 |
+
[
|
| 145 |
+
None,
|
| 146 |
+
"Had an amazing weekend! Met up with an old high school friend. We talked for hours over coffee and reminisced. I felt so connected and energized.",
|
| 147 |
],
|
| 148 |
],
|
| 149 |
inputs=[file_input, notes_input],
|