Spaces:
Sleeping
Sleeping
GitHub Actions commited on
Commit ·
0d87629
1
Parent(s): 49cd8d6
🚀 Auto-deploy from GitHub
Browse files- .gitignore +1 -1
- NOTES.md +6 -0
- app/api/v1/endpoints/generate.py +32 -23
- app/api/v1/schemas/card_schemas.py +2 -2
- app/core/card_renderer.py +7 -7
- app/core/config.py +1 -1
- app/core/constraints.py +6 -16
- app/core/model_loader.py +50 -17
- app/main.py +15 -13
- app/utils/qr_utils.py +33 -5
- scripts/cleanup_large_files.sh +0 -0
- scripts/quick_test.sh +0 -58
- scripts/start.sh +0 -24
- scripts/setup_training.sh → setup_training.sh +27 -10
- static/images/base/base0.png +0 -0
- static/images/base/base1.png +0 -0
- static/images/base/base10.png +0 -0
- static/images/base/base11.png +0 -0
- static/images/base/base2.png +0 -0
- static/images/base/base3.png +0 -0
- static/images/base/base4.png +0 -0
- static/images/base/base5.png +0 -0
- static/images/base/base6.png +0 -0
- static/images/base/base7.png +0 -0
- static/images/base/base8.png +0 -0
- static/images/base/base9.png +0 -0
- static/images/symbols/symbol0.png +0 -0
- static/images/symbols/symbol1.png +0 -0
- static/images/symbols/symbol2.png +0 -0
- static/images/symbols/symbol3.png +0 -0
- static/images/symbols/symbol4.png +0 -0
- static/images/symbols/symbol5.png +0 -0
- static/images/symbols/symbol6.png +0 -0
- static/images/symbols/symbol7.png +0 -0
.gitignore
CHANGED
|
@@ -161,7 +161,7 @@ models/*/
|
|
| 161 |
training/models/
|
| 162 |
training/checkpoints/
|
| 163 |
# training/outputs/ # Moved to logs/outputs as it's often mixed
|
| 164 |
-
|
| 165 |
# Common model file extensions
|
| 166 |
*.bin
|
| 167 |
*.safetensors
|
|
|
|
| 161 |
training/models/
|
| 162 |
training/checkpoints/
|
| 163 |
# training/outputs/ # Moved to logs/outputs as it's often mixed
|
| 164 |
+
/models/lora-checkpoint/
|
| 165 |
# Common model file extensions
|
| 166 |
*.bin
|
| 167 |
*.safetensors
|
NOTES.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
https://huggingface.co/docs/peft/en/task_guides/lora_based_methods
|
| 4 |
+
https://chatgpt.com/share/6841f8f1-3164-800e-99ad-3ef13c4400e9
|
| 5 |
+
|
| 6 |
+
https://git-lfs.com/
|
app/api/v1/endpoints/generate.py
CHANGED
|
@@ -3,17 +3,26 @@ from dotenv import load_dotenv
|
|
| 3 |
from supabase import Client
|
| 4 |
import uuid
|
| 5 |
from ..schemas.card_schemas import CardGenerateRequest, CardGenerateResponse
|
| 6 |
-
from ....core.generator import build_prompt
|
| 7 |
-
from ....core.card_renderer import generate_card as
|
| 8 |
-
from ....utils.qr_utils import
|
| 9 |
from ....services.database import get_supabase_client, save_card
|
| 10 |
from ....core.config import settings
|
| 11 |
from ....core.model_loader import get_generator
|
| 12 |
from ....core.constraints import generate_with_retry, check_constraints
|
|
|
|
| 13 |
|
| 14 |
load_dotenv()
|
| 15 |
router = APIRouter()
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
@router.post("/generate", response_model=CardGenerateResponse)
|
| 18 |
async def generate_endpoint(
|
| 19 |
request: CardGenerateRequest,
|
|
@@ -23,13 +32,14 @@ async def generate_endpoint(
|
|
| 23 |
lang = request.lang or "de"
|
| 24 |
input_date_str = request.card_date.isoformat()
|
| 25 |
|
|
|
|
| 26 |
card_prompt = build_prompt(
|
| 27 |
lang=lang,
|
| 28 |
card_date=input_date_str,
|
| 29 |
terms=request.terms
|
| 30 |
)
|
| 31 |
|
| 32 |
-
llm_pipeline = get_generator()
|
| 33 |
|
| 34 |
generation_params = {
|
| 35 |
"max_new_tokens": settings.GENERATION_MAX_NEW_TOKENS,
|
|
@@ -40,8 +50,8 @@ async def generate_endpoint(
|
|
| 40 |
"return_full_text": False
|
| 41 |
}
|
| 42 |
|
| 43 |
-
#
|
| 44 |
-
card_text = generate_with_retry(
|
| 45 |
prompt=card_prompt,
|
| 46 |
generator=llm_pipeline,
|
| 47 |
terms=request.terms,
|
|
@@ -55,22 +65,22 @@ async def generate_endpoint(
|
|
| 55 |
detail="Kartentext konnte nicht generiert werden oder erfüllt nicht die Bedingungen."
|
| 56 |
)
|
| 57 |
|
| 58 |
-
# 4. QR-Code generieren
|
| 59 |
card_id_for_url = str(uuid.uuid4())
|
| 60 |
qr_content_url = f"{settings.FRONTEND_BASE_URL}/card/{card_id_for_url}"
|
| 61 |
|
| 62 |
-
|
|
|
|
| 63 |
data=qr_content_url,
|
| 64 |
output_path=settings.resolved_qr_code_path,
|
| 65 |
size=settings.QR_CODE_SIZE
|
| 66 |
)
|
| 67 |
qr_code_url = f"{settings.API_PREFIX}/static/images/qr/{qr_code_file_id}.png"
|
| 68 |
|
| 69 |
-
# 5. Karte rendern
|
| 70 |
card_design_id_to_render = request.card_design_id_override or 1
|
| 71 |
symbol_ids_to_render = request.symbol_ids_override or [1, 2]
|
| 72 |
|
| 73 |
-
|
|
|
|
| 74 |
card_design_id=card_design_id_to_render,
|
| 75 |
symbol_ids=symbol_ids_to_render,
|
| 76 |
text=card_text,
|
|
@@ -90,28 +100,27 @@ async def generate_endpoint(
|
|
| 90 |
"session_id": uuid.UUID(card_id_for_url),
|
| 91 |
"lang": lang,
|
| 92 |
"prompt_text": card_prompt,
|
|
|
|
|
|
|
| 93 |
"ml_model_info": llm_pipeline.model.config.to_dict() if hasattr(llm_pipeline, 'model') and hasattr(llm_pipeline.model, 'config') else {"name": str(type(llm_pipeline.model).__name__)},
|
| 94 |
"generation_params": generation_params
|
| 95 |
}
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
except Exception as e:
|
| 106 |
-
print(f"Fehler beim Speichern der Karte in Supabase: {e}")
|
| 107 |
|
| 108 |
return CardGenerateResponse(
|
| 109 |
-
message="
|
| 110 |
-
# Der card_id in der Response sollte nun auch die neue UUID sein, wenn db_id nicht verfügbar ist
|
| 111 |
card_id=db_id if db_id else card_id_for_url,
|
| 112 |
qr_code_image_url=qr_code_url
|
| 113 |
)
|
| 114 |
-
|
| 115 |
except FileNotFoundError as e:
|
| 116 |
print(f"FileNotFoundError in generate_endpoint: {e}")
|
| 117 |
raise HTTPException(status_code=500, detail=f"Ein benötigtes Template oder eine Datei wurde nicht gefunden: {e.filename}")
|
|
|
|
| 3 |
from supabase import Client
|
| 4 |
import uuid
|
| 5 |
from ..schemas.card_schemas import CardGenerateRequest, CardGenerateResponse
|
| 6 |
+
from ....core.generator import build_prompt # get_constellation wird hier nicht direkt verwendet
|
| 7 |
+
from ....core.card_renderer import generate_card as render_card_sync # Umbenennen für Klarheit
|
| 8 |
+
from ....utils.qr_utils import generate_qr_code_sync # Umbenennen für Klarheit
|
| 9 |
from ....services.database import get_supabase_client, save_card
|
| 10 |
from ....core.config import settings
|
| 11 |
from ....core.model_loader import get_generator
|
| 12 |
from ....core.constraints import generate_with_retry, check_constraints
|
| 13 |
+
from fastapi.concurrency import run_in_threadpool # Importieren
|
| 14 |
|
| 15 |
load_dotenv()
|
| 16 |
router = APIRouter()
|
| 17 |
|
| 18 |
+
# Asynchrone Wrapper für blockierende Funktionen
|
| 19 |
+
async def render_card_async(*args, **kwargs):
|
| 20 |
+
return await run_in_threadpool(render_card_sync, *args, **kwargs)
|
| 21 |
+
|
| 22 |
+
async def generate_qr_code_async(*args, **kwargs):
|
| 23 |
+
return await run_in_threadpool(generate_qr_code_sync, *args, **kwargs)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
@router.post("/generate", response_model=CardGenerateResponse)
|
| 27 |
async def generate_endpoint(
|
| 28 |
request: CardGenerateRequest,
|
|
|
|
| 32 |
lang = request.lang or "de"
|
| 33 |
input_date_str = request.card_date.isoformat()
|
| 34 |
|
| 35 |
+
# build_prompt ist schnell und CPU-gebunden, kann synchron bleiben
|
| 36 |
card_prompt = build_prompt(
|
| 37 |
lang=lang,
|
| 38 |
card_date=input_date_str,
|
| 39 |
terms=request.terms
|
| 40 |
)
|
| 41 |
|
| 42 |
+
llm_pipeline = get_generator() # Bleibt synchron, da es gecacht ist und schnell sein sollte nach dem ersten Mal
|
| 43 |
|
| 44 |
generation_params = {
|
| 45 |
"max_new_tokens": settings.GENERATION_MAX_NEW_TOKENS,
|
|
|
|
| 50 |
"return_full_text": False
|
| 51 |
}
|
| 52 |
|
| 53 |
+
# generate_with_retry ist jetzt asynchron
|
| 54 |
+
card_text = await generate_with_retry(
|
| 55 |
prompt=card_prompt,
|
| 56 |
generator=llm_pipeline,
|
| 57 |
terms=request.terms,
|
|
|
|
| 65 |
detail="Kartentext konnte nicht generiert werden oder erfüllt nicht die Bedingungen."
|
| 66 |
)
|
| 67 |
|
|
|
|
| 68 |
card_id_for_url = str(uuid.uuid4())
|
| 69 |
qr_content_url = f"{settings.FRONTEND_BASE_URL}/card/{card_id_for_url}"
|
| 70 |
|
| 71 |
+
# QR-Code Generierung asynchron
|
| 72 |
+
qr_code_file_id = await generate_qr_code_async(
|
| 73 |
data=qr_content_url,
|
| 74 |
output_path=settings.resolved_qr_code_path,
|
| 75 |
size=settings.QR_CODE_SIZE
|
| 76 |
)
|
| 77 |
qr_code_url = f"{settings.API_PREFIX}/static/images/qr/{qr_code_file_id}.png"
|
| 78 |
|
|
|
|
| 79 |
card_design_id_to_render = request.card_design_id_override or 1
|
| 80 |
symbol_ids_to_render = request.symbol_ids_override or [1, 2]
|
| 81 |
|
| 82 |
+
# Karten-Rendering asynchron
|
| 83 |
+
card_file_id = await render_card_async(
|
| 84 |
card_design_id=card_design_id_to_render,
|
| 85 |
symbol_ids=symbol_ids_to_render,
|
| 86 |
text=card_text,
|
|
|
|
| 100 |
"session_id": uuid.UUID(card_id_for_url),
|
| 101 |
"lang": lang,
|
| 102 |
"prompt_text": card_prompt,
|
| 103 |
+
# llm_pipeline.model.config kann potenziell blockierend sein, wenn es I/O macht.
|
| 104 |
+
# Für den Moment belassen wir es, aber es könnte auch in einen Threadpool, falls nötig.
|
| 105 |
"ml_model_info": llm_pipeline.model.config.to_dict() if hasattr(llm_pipeline, 'model') and hasattr(llm_pipeline.model, 'config') else {"name": str(type(llm_pipeline.model).__name__)},
|
| 106 |
"generation_params": generation_params
|
| 107 |
}
|
| 108 |
|
| 109 |
+
# save_card ist bereits asynchron (await)
|
| 110 |
+
db_response = await save_card(supabase, card_data_for_db)
|
| 111 |
+
db_id = None
|
| 112 |
+
# ... existing database response handling ...
|
| 113 |
+
if db_response and hasattr(db_response, 'data') and db_response.data and len(db_response.data) > 0:
|
| 114 |
+
db_id = str(db_response.data[0].get('id'))
|
| 115 |
+
elif isinstance(db_response, list) and db_response and isinstance(db_response[0], dict):
|
| 116 |
+
db_id = str(db_response[0].get('id'))
|
|
|
|
|
|
|
| 117 |
|
| 118 |
return CardGenerateResponse(
|
| 119 |
+
message="Karte erfolgreich generiert.",
|
|
|
|
| 120 |
card_id=db_id if db_id else card_id_for_url,
|
| 121 |
qr_code_image_url=qr_code_url
|
| 122 |
)
|
| 123 |
+
# ... existing error handling ...
|
| 124 |
except FileNotFoundError as e:
|
| 125 |
print(f"FileNotFoundError in generate_endpoint: {e}")
|
| 126 |
raise HTTPException(status_code=500, detail=f"Ein benötigtes Template oder eine Datei wurde nicht gefunden: {e.filename}")
|
app/api/v1/schemas/card_schemas.py
CHANGED
|
@@ -79,7 +79,7 @@ class CardData(BaseModel):
|
|
| 79 |
session_id: Optional[uuid.UUID] = Field(None, description="Session ID used for the request.")
|
| 80 |
lang: Optional[str] = Field(None, description="Language used for generation.")
|
| 81 |
prompt_text: Optional[str] = Field(None, description="The prompt text used for generation.")
|
| 82 |
-
|
| 83 |
generation_params: Optional[dict] = Field(None, description="Parameters used for text generation.")
|
| 84 |
|
| 85 |
class CardDBSchema(BaseModel):
|
|
@@ -92,7 +92,7 @@ class CardDBSchema(BaseModel):
|
|
| 92 |
session_id: UUID4 # Corresponds to uuid.UUID(card_id_for_url)
|
| 93 |
lang: str
|
| 94 |
prompt_text: str
|
| 95 |
-
|
| 96 |
generation_params: Dict[str, Any]
|
| 97 |
|
| 98 |
# Add this model_config to resolve the warning
|
|
|
|
| 79 |
session_id: Optional[uuid.UUID] = Field(None, description="Session ID used for the request.")
|
| 80 |
lang: Optional[str] = Field(None, description="Language used for generation.")
|
| 81 |
prompt_text: Optional[str] = Field(None, description="The prompt text used for generation.")
|
| 82 |
+
ml_model_info: Optional[dict] = Field(None, description="Information about the model used.")
|
| 83 |
generation_params: Optional[dict] = Field(None, description="Parameters used for text generation.")
|
| 84 |
|
| 85 |
class CardDBSchema(BaseModel):
|
|
|
|
| 92 |
session_id: UUID4 # Corresponds to uuid.UUID(card_id_for_url)
|
| 93 |
lang: str
|
| 94 |
prompt_text: str
|
| 95 |
+
ml_model_info: Dict[str, Any] # The field causing the warning
|
| 96 |
generation_params: Dict[str, Any]
|
| 97 |
|
| 98 |
# Add this model_config to resolve the warning
|
app/core/card_renderer.py
CHANGED
|
@@ -12,28 +12,28 @@ def generate_card(
|
|
| 12 |
output_path: Path
|
| 13 |
) -> str:
|
| 14 |
"""
|
| 15 |
-
Generiert eine
|
| 16 |
Verwendet jetzt übergebene Pfade für mehr Flexibilität und Testbarkeit.
|
| 17 |
Gibt die UUID der generierten Datei (ohne Erweiterung) zurück.
|
| 18 |
"""
|
| 19 |
try:
|
| 20 |
# Basiskarte laden
|
| 21 |
-
base_image_file = base_images_path / f"{card_design_id}.png"
|
| 22 |
if not base_image_file.exists():
|
| 23 |
-
raise FileNotFoundError(f"Basiskartenbild nicht gefunden: {base_image_file}")
|
| 24 |
card_design_img = Image.open(base_image_file).convert("RGBA")
|
| 25 |
|
| 26 |
# Symbole hinzufügen
|
| 27 |
# Die Positionierung hier ist ein Beispiel und muss ggf. angepasst werden
|
| 28 |
symbol_x_start = 50
|
| 29 |
-
symbol_y_start = 400
|
| 30 |
-
symbol_spacing = 10
|
| 31 |
current_x = symbol_x_start
|
| 32 |
|
| 33 |
for i, sid in enumerate(symbol_ids):
|
| 34 |
-
symbol_file = symbols_images_path / f"{sid}.png"
|
| 35 |
if not symbol_file.exists():
|
| 36 |
-
print(f"Warnung: Symbolbild nicht gefunden: {symbol_file}, wird übersprungen.")
|
| 37 |
continue
|
| 38 |
symbol_img = Image.open(symbol_file).convert("RGBA")
|
| 39 |
|
|
|
|
| 12 |
output_path: Path
|
| 13 |
) -> str:
|
| 14 |
"""
|
| 15 |
+
Generiert eine Karte und speichert sie.
|
| 16 |
Verwendet jetzt übergebene Pfade für mehr Flexibilität und Testbarkeit.
|
| 17 |
Gibt die UUID der generierten Datei (ohne Erweiterung) zurück.
|
| 18 |
"""
|
| 19 |
try:
|
| 20 |
# Basiskarte laden
|
| 21 |
+
base_image_file = base_images_path / f"base{card_design_id}.png"
|
| 22 |
if not base_image_file.exists():
|
| 23 |
+
raise FileNotFoundError(f"Basiskartenbild nicht gefunden: base{base_image_file}")
|
| 24 |
card_design_img = Image.open(base_image_file).convert("RGBA")
|
| 25 |
|
| 26 |
# Symbole hinzufügen
|
| 27 |
# Die Positionierung hier ist ein Beispiel und muss ggf. angepasst werden
|
| 28 |
symbol_x_start = 50
|
| 29 |
+
symbol_y_start = 400
|
| 30 |
+
symbol_spacing = 10
|
| 31 |
current_x = symbol_x_start
|
| 32 |
|
| 33 |
for i, sid in enumerate(symbol_ids):
|
| 34 |
+
symbol_file = symbols_images_path / f"symbol{sid}.png"
|
| 35 |
if not symbol_file.exists():
|
| 36 |
+
print(f"Warnung: Symbolbild nicht gefunden: symbol{symbol_file}, wird übersprungen.")
|
| 37 |
continue
|
| 38 |
symbol_img = Image.open(symbol_file).convert("RGBA")
|
| 39 |
|
app/core/config.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 3 |
from typing import ClassVar
|
| 4 |
-
import os
|
| 5 |
|
| 6 |
# Determine base directory based on environment
|
| 7 |
current_file = Path(__file__).resolve()
|
|
@@ -19,6 +18,7 @@ class Settings(BaseSettings):
|
|
| 19 |
|
| 20 |
MODEL_PATH: str = str(_CARDSERVER_DIR_CLS / "models" / "lora-checkpoint")
|
| 21 |
DEFAULT_MODEL_ID: str = "teknium/OpenHermes-2.5-Mistral-7B"
|
|
|
|
| 22 |
|
| 23 |
GENERATED_PATH: str = str(_APP_DEFAULT_STATIC_DIR_CLS / "images" / "generated")
|
| 24 |
BASE_PATH: str = str(_APP_DEFAULT_STATIC_DIR_CLS / "images" / "base")
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 3 |
from typing import ClassVar
|
|
|
|
| 4 |
|
| 5 |
# Determine base directory based on environment
|
| 6 |
current_file = Path(__file__).resolve()
|
|
|
|
| 18 |
|
| 19 |
MODEL_PATH: str = str(_CARDSERVER_DIR_CLS / "models" / "lora-checkpoint")
|
| 20 |
DEFAULT_MODEL_ID: str = "teknium/OpenHermes-2.5-Mistral-7B"
|
| 21 |
+
MODEL_LOAD_IN_4BIT: bool = True # Default to True for 4-bit loading
|
| 22 |
|
| 23 |
GENERATED_PATH: str = str(_APP_DEFAULT_STATIC_DIR_CLS / "images" / "generated")
|
| 24 |
BASE_PATH: str = str(_APP_DEFAULT_STATIC_DIR_CLS / "images" / "base")
|
app/core/constraints.py
CHANGED
|
@@ -1,34 +1,24 @@
|
|
| 1 |
# constraints.py
|
| 2 |
# Constraints: Begriffe prüfen, Filter
|
|
|
|
| 3 |
|
| 4 |
def check_constraints(output: str, terms: list[str]) -> bool:
|
| 5 |
if not terms: # Wenn keine Begriffe vorgegeben sind, ist die Bedingung immer erfüllt
|
| 6 |
return True
|
| 7 |
return all(term.lower() in output.lower() for term in terms)
|
| 8 |
|
| 9 |
-
def generate_with_retry(prompt: str, generator, terms: list[str], max_retries: int = 3, generation_params: dict | None = None):
|
| 10 |
"""
|
| 11 |
Generiert Text mit Wiederholungsversuchen, bis die Bedingungen (Constraints) erfüllt sind.
|
| 12 |
-
|
| 13 |
-
Args:
|
| 14 |
-
prompt (str): Der Eingabe-Prompt für den Generator.
|
| 15 |
-
generator: Die Text-Generierungs-Pipeline oder -Funktion.
|
| 16 |
-
terms (list[str]): Eine Liste von Begriffen, die im generierten Text enthalten sein müssen.
|
| 17 |
-
max_retries (int): Maximale Anzahl von Wiederholungsversuchen.
|
| 18 |
-
generation_params (dict | None): Zusätzliche Parameter für den Generator.
|
| 19 |
-
|
| 20 |
-
Returns:
|
| 21 |
-
str: Der generierte Text, der die Bedingungen erfüllt, oder eine Fehlermeldung.
|
| 22 |
"""
|
| 23 |
if generation_params is None:
|
| 24 |
generation_params = {}
|
| 25 |
|
| 26 |
for attempt in range(max_retries):
|
| 27 |
try:
|
| 28 |
-
#
|
| 29 |
-
|
| 30 |
-
# Die meisten Hugging Face Pipelines erwarten den Prompt als positional argument.
|
| 31 |
-
responses = generator(prompt, **generation_params)
|
| 32 |
|
| 33 |
# Die Struktur der Antwort kann variieren. Üblich ist eine Liste von Diktionären.
|
| 34 |
if responses and isinstance(responses, list) and responses[0].get("generated_text"):
|
|
@@ -44,6 +34,6 @@ def generate_with_retry(prompt: str, generator, terms: list[str], max_retries: i
|
|
| 44 |
print(f"Fehler bei der Textgenerierung (Versuch {attempt + 1}/{max_retries}): {e}")
|
| 45 |
# Optional: Kurze Pause vor dem nächsten Versuch
|
| 46 |
# import time
|
| 47 |
-
#
|
| 48 |
|
| 49 |
return "Leider konnte kein gültiger Text erzeugt werden."
|
|
|
|
| 1 |
# constraints.py
|
| 2 |
# Constraints: Begriffe prüfen, Filter
|
| 3 |
+
from fastapi.concurrency import run_in_threadpool # Importieren
|
| 4 |
|
| 5 |
def check_constraints(output: str, terms: list[str]) -> bool:
|
| 6 |
if not terms: # Wenn keine Begriffe vorgegeben sind, ist die Bedingung immer erfüllt
|
| 7 |
return True
|
| 8 |
return all(term.lower() in output.lower() for term in terms)
|
| 9 |
|
| 10 |
+
async def generate_with_retry(prompt: str, generator, terms: list[str], max_retries: int = 3, generation_params: dict | None = None): # async def
|
| 11 |
"""
|
| 12 |
Generiert Text mit Wiederholungsversuchen, bis die Bedingungen (Constraints) erfüllt sind.
|
| 13 |
+
Führt die eigentliche Generierung in einem Threadpool aus.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
if generation_params is None:
|
| 16 |
generation_params = {}
|
| 17 |
|
| 18 |
for attempt in range(max_retries):
|
| 19 |
try:
|
| 20 |
+
# Führe die blockierende Generator-Funktion im Threadpool aus
|
| 21 |
+
responses = await run_in_threadpool(generator, prompt, **generation_params)
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Die Struktur der Antwort kann variieren. Üblich ist eine Liste von Diktionären.
|
| 24 |
if responses and isinstance(responses, list) and responses[0].get("generated_text"):
|
|
|
|
| 34 |
print(f"Fehler bei der Textgenerierung (Versuch {attempt + 1}/{max_retries}): {e}")
|
| 35 |
# Optional: Kurze Pause vor dem nächsten Versuch
|
| 36 |
# import time
|
| 37 |
+
# await asyncio.sleep(0.5) # Wenn async, dann asyncio.sleep
|
| 38 |
|
| 39 |
return "Leider konnte kein gültiger Text erzeugt werden."
|
app/core/model_loader.py
CHANGED
|
@@ -6,20 +6,25 @@ from pathlib import Path
|
|
| 6 |
import os
|
| 7 |
from .config import settings
|
| 8 |
from .hf_api import HuggingFaceWrapper
|
|
|
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
Optimierter Model Loader mit LoRA-Support.
|
|
|
|
| 16 |
Kann LoRA-Adapter von Hugging Face Hub herunterladen.
|
| 17 |
Automatische Konfiguration basierend auf verfügbaren Ressourcen.
|
| 18 |
"""
|
| 19 |
base_model_id = settings.DEFAULT_MODEL_ID
|
| 20 |
hf_token = os.getenv("HF_API_KEY")
|
| 21 |
|
| 22 |
-
logger.info(f"Lade Basismodell: {base_model_id}")
|
| 23 |
|
| 24 |
try:
|
| 25 |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token)
|
|
@@ -32,9 +37,9 @@ def load_model():
|
|
| 32 |
tokenizer.pad_token = tokenizer.eos_token
|
| 33 |
|
| 34 |
model_kwargs = {
|
| 35 |
-
"torch_dtype": torch.float16,
|
| 36 |
"device_map": "auto",
|
| 37 |
-
"trust_remote_code": True,
|
| 38 |
"token": hf_token
|
| 39 |
}
|
| 40 |
|
|
@@ -120,15 +125,46 @@ def load_model():
|
|
| 120 |
else:
|
| 121 |
logger.info("Keine LoRA-Gewichte zum Laden spezifiziert oder gefunden. Verwende Basismodell.")
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
"text-generation",
|
| 125 |
model=model,
|
| 126 |
tokenizer=tokenizer,
|
| 127 |
-
|
| 128 |
)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
| 132 |
|
| 133 |
|
| 134 |
def get_model_info():
|
|
@@ -142,11 +178,8 @@ def get_model_info():
|
|
| 142 |
"gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 143 |
}
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
if _generator is None:
|
| 151 |
-
_generator = load_model()
|
| 152 |
-
return _generator
|
|
|
|
| 6 |
import os
|
| 7 |
from .config import settings
|
| 8 |
from .hf_api import HuggingFaceWrapper
|
| 9 |
+
from functools import lru_cache # Import lru_cache
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
|
| 14 |
+
# Globale Variable für die Pipeline, um sie zwischenzuspeichern
|
| 15 |
+
# _cached_generator_pipeline = None # Entfernt, da wir lru_cache verwenden
|
| 16 |
+
|
| 17 |
+
def load_model_and_tokenizer(): # Umbenannt und gibt jetzt model und tokenizer zurück
|
| 18 |
"""
|
| 19 |
Optimierter Model Loader mit LoRA-Support.
|
| 20 |
+
Lädt Basismodell und Tokenizer.
|
| 21 |
Kann LoRA-Adapter von Hugging Face Hub herunterladen.
|
| 22 |
Automatische Konfiguration basierend auf verfügbaren Ressourcen.
|
| 23 |
"""
|
| 24 |
base_model_id = settings.DEFAULT_MODEL_ID
|
| 25 |
hf_token = os.getenv("HF_API_KEY")
|
| 26 |
|
| 27 |
+
logger.info(f"Lade Basismodell und Tokenizer: {base_model_id}")
|
| 28 |
|
| 29 |
try:
|
| 30 |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token)
|
|
|
|
| 37 |
tokenizer.pad_token = tokenizer.eos_token
|
| 38 |
|
| 39 |
model_kwargs = {
|
| 40 |
+
"torch_dtype": torch.float16,
|
| 41 |
"device_map": "auto",
|
| 42 |
+
"trust_remote_code": True,
|
| 43 |
"token": hf_token
|
| 44 |
}
|
| 45 |
|
|
|
|
| 125 |
else:
|
| 126 |
logger.info("Keine LoRA-Gewichte zum Laden spezifiziert oder gefunden. Verwende Basismodell.")
|
| 127 |
|
| 128 |
+
return model, tokenizer
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@lru_cache(maxsize=None) # Cache die Pipeline-Erstellung
|
| 132 |
+
def get_generator():
|
| 133 |
+
"""
|
| 134 |
+
Lädt das Modell und den Tokenizer (beim ersten Aufruf)
|
| 135 |
+
und erstellt eine Textgenerierungs-Pipeline.
|
| 136 |
+
Die Pipeline wird gecacht.
|
| 137 |
+
"""
|
| 138 |
+
# global _cached_generator_pipeline # Entfernt
|
| 139 |
+
# if _cached_generator_pipeline is None: # Entfernt
|
| 140 |
+
logger.info("Initialisiere Textgenerierungs-Pipeline...")
|
| 141 |
+
model, tokenizer = load_model_and_tokenizer() # Ruft die geänderte Funktion auf
|
| 142 |
+
|
| 143 |
+
# Sicherstellen, dass pad_token_id gesetzt ist, wenn es im Tokenizer existiert
|
| 144 |
+
# Dies ist wichtig für einige Modelle, um Warnungen oder Fehler zu vermeiden
|
| 145 |
+
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
| 146 |
+
logger.info(f"pad_token_id nicht im Tokenizer gefunden. Setze pad_token_id auf eos_token_id ({tokenizer.eos_token_id}).")
|
| 147 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 148 |
+
# Das Modell muss möglicherweise auch aktualisiert werden, wenn pad_token_id zur Laufzeit geändert wird
|
| 149 |
+
# Dies ist jedoch oft nicht notwendig, wenn das Modell bereits mit einem eos_token trainiert wurde.
|
| 150 |
+
# model.config.pad_token_id = tokenizer.pad_token_id
|
| 151 |
+
|
| 152 |
+
# Device für die Pipeline explizit setzen, falls nicht automatisch korrekt erkannt
|
| 153 |
+
# device = 0 if torch.cuda.is_available() else -1 # 0 für erste GPU, -1 für CPU
|
| 154 |
+
# Wenn device_map="auto" im Modell verwendet wird, sollte die Pipeline dies respektieren.
|
| 155 |
+
# Für explizite Kontrolle:
|
| 156 |
+
device = model.device # Das Gerät des Modells verwenden
|
| 157 |
+
|
| 158 |
+
_cached_generator_pipeline = pipeline(
|
| 159 |
"text-generation",
|
| 160 |
model=model,
|
| 161 |
tokenizer=tokenizer,
|
| 162 |
+
device=device # Gerät explizit übergeben
|
| 163 |
)
|
| 164 |
+
logger.info(f"Textgenerierungs-Pipeline erfolgreich initialisiert und auf Gerät {device} geladen.")
|
| 165 |
+
# else: # Entfernt
|
| 166 |
+
# logger.debug("Verwende gecachte Textgenerierungs-Pipeline.") # Entfernt
|
| 167 |
+
return _cached_generator_pipeline
|
| 168 |
|
| 169 |
|
| 170 |
def get_model_info():
|
|
|
|
| 178 |
"gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 179 |
}
|
| 180 |
|
| 181 |
+
# Optional: Pre-load model at startup if desired (in main.py or similar)
|
| 182 |
+
# def preload_model():
|
| 183 |
+
# logger.info("Starte Pre-Loading des Modells...")
|
| 184 |
+
# get_generator()
|
| 185 |
+
# logger.info("Modell erfolgreich vorab geladen.")
|
|
|
|
|
|
|
|
|
app/main.py
CHANGED
|
@@ -3,6 +3,7 @@ from fastapi.staticfiles import StaticFiles
|
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from .api.v1.endpoints import generate, download, health
|
| 5 |
from .core.config import settings
|
|
|
|
| 6 |
from contextlib import asynccontextmanager
|
| 7 |
from pathlib import Path
|
| 8 |
import logging
|
|
@@ -20,20 +21,21 @@ settings.resolved_qr_code_path.mkdir(parents=True, exist_ok=True)
|
|
| 20 |
# Ensure the static mount directory exists
|
| 21 |
settings.resolved_static_files_mount_dir.mkdir(parents=True, exist_ok=True)
|
| 22 |
|
|
|
|
| 23 |
@asynccontextmanager
|
| 24 |
async def lifespan(app: FastAPI):
|
| 25 |
-
|
| 26 |
-
logger.info(
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
logger.info(f" Default Font Path: {settings.resolved_default_font_path}")
|
| 35 |
yield
|
| 36 |
-
|
|
|
|
| 37 |
|
| 38 |
app = FastAPI(
|
| 39 |
title=settings.PROJECT_NAME,
|
|
@@ -48,7 +50,7 @@ try:
|
|
| 48 |
logger.info(f"Attempting to mount static directory: {static_dir}")
|
| 49 |
logger.info(f"Static directory exists: {static_dir.exists()}")
|
| 50 |
if static_dir.exists():
|
| 51 |
-
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
| 52 |
logger.info("Static files mounted successfully")
|
| 53 |
else:
|
| 54 |
logger.warning(f"Static directory does not exist: {static_dir}")
|
|
@@ -57,7 +59,7 @@ try:
|
|
| 57 |
# Create basic subdirectories
|
| 58 |
(static_dir / "images").mkdir(exist_ok=True)
|
| 59 |
(static_dir / "fonts").mkdir(exist_ok=True)
|
| 60 |
-
app.mount("/static", StaticFiles(directory=static_dir), name="static")
|
| 61 |
logger.info("Static files mounted with created directory")
|
| 62 |
except Exception as e:
|
| 63 |
logger.error(f"Failed to mount static files: {e}")
|
|
|
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
from .api.v1.endpoints import generate, download, health
|
| 5 |
from .core.config import settings
|
| 6 |
+
from .core.model_loader import get_generator # Import get_generator
|
| 7 |
from contextlib import asynccontextmanager
|
| 8 |
from pathlib import Path
|
| 9 |
import logging
|
|
|
|
| 21 |
# Ensure the static mount directory exists
|
| 22 |
settings.resolved_static_files_mount_dir.mkdir(parents=True, exist_ok=True)
|
| 23 |
|
| 24 |
+
# Lifecycle management for the model
|
| 25 |
@asynccontextmanager
|
| 26 |
async def lifespan(app: FastAPI):
|
| 27 |
+
# Startup: Preload the model
|
| 28 |
+
logger.info("Anwendung startet... Lade das LLM-Modell vorab.")
|
| 29 |
+
try:
|
| 30 |
+
get_generator() # Calls get_generator to load and cache the model
|
| 31 |
+
logger.info("LLM-Modell erfolgreich vorab geladen und Pipeline initialisiert.")
|
| 32 |
+
except Exception as e:
|
| 33 |
+
logger.error(f"Fehler beim Vorabladen des LLM-Modells: {e}", exc_info=True)
|
| 34 |
+
# Decide whether to prevent the application from starting
|
| 35 |
+
# raise # Uncomment to prevent startup on error
|
|
|
|
| 36 |
yield
|
| 37 |
+
# Shutdown: Cleanup actions could go here (not currently needed for the model)
|
| 38 |
+
logger.info("Anwendung wird heruntergefahren.")
|
| 39 |
|
| 40 |
app = FastAPI(
|
| 41 |
title=settings.PROJECT_NAME,
|
|
|
|
| 50 |
logger.info(f"Attempting to mount static directory: {static_dir}")
|
| 51 |
logger.info(f"Static directory exists: {static_dir.exists()}")
|
| 52 |
if static_dir.exists():
|
| 53 |
+
app.mount(f"{settings.API_PREFIX}/static", StaticFiles(directory=static_dir), name="static")
|
| 54 |
logger.info("Static files mounted successfully")
|
| 55 |
else:
|
| 56 |
logger.warning(f"Static directory does not exist: {static_dir}")
|
|
|
|
| 59 |
# Create basic subdirectories
|
| 60 |
(static_dir / "images").mkdir(exist_ok=True)
|
| 61 |
(static_dir / "fonts").mkdir(exist_ok=True)
|
| 62 |
+
app.mount(f"{settings.API_PREFIX}/static", StaticFiles(directory=static_dir), name="static")
|
| 63 |
logger.info("Static files mounted with created directory")
|
| 64 |
except Exception as e:
|
| 65 |
logger.error(f"Failed to mount static files: {e}")
|
app/utils/qr_utils.py
CHANGED
|
@@ -1,8 +1,36 @@
|
|
| 1 |
import qrcode
|
| 2 |
from pathlib import Path
|
|
|
|
| 3 |
|
| 4 |
-
def
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import qrcode
|
| 2 |
from pathlib import Path
|
| 3 |
+
import uuid # Import uuid
|
| 4 |
|
| 5 |
+
def generate_qr_code_sync(data: str, output_path: Path, size: int) -> str:
|
| 6 |
+
"""
|
| 7 |
+
Generiert einen QR-Code und speichert ihn.
|
| 8 |
+
Gibt die UUID der generierten Datei (ohne Erweiterung) zurück.
|
| 9 |
+
Der Parameter 'size' wird hier nicht direkt von qrcode.make verwendet,
|
| 10 |
+
aber die Standardeinstellungen sind oft ausreichend. Für eine exakte Pixelgröße
|
| 11 |
+
müsste man qrcode.QRCode mit box_size und border verwenden.
|
| 12 |
+
"""
|
| 13 |
+
qr = qrcode.QRCode(
|
| 14 |
+
version=1, # Standard
|
| 15 |
+
error_correction=qrcode.constants.ERROR_CORRECT_L, # Standard
|
| 16 |
+
# box_size steuert die Pixel pro "Box" des QR-Codes.
|
| 17 |
+
# Um die Gesamtgröße (size) zu erreichen, müsste man box_size berechnen.
|
| 18 |
+
# Beispiel: box_size=size // (anzahl_module + 2 * border_module)
|
| 19 |
+
# Für Einfachheit lassen wir es bei den Defaults oder einem festen Wert.
|
| 20 |
+
box_size=10, # Kann angepasst werden, um die 'size' besser zu treffen
|
| 21 |
+
border=4, # Standard
|
| 22 |
+
)
|
| 23 |
+
qr.add_data(data)
|
| 24 |
+
qr.make(fit=True)
|
| 25 |
+
|
| 26 |
+
img = qr.make_image(fill_color="black", back_color="white")
|
| 27 |
+
|
| 28 |
+
# Sicherstellen, dass das Ausgabeverzeichnis existiert
|
| 29 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
file_id = str(uuid.uuid4()) # Eindeutige ID für die Datei
|
| 32 |
+
file_path_with_id = output_path / f"{file_id}.png"
|
| 33 |
+
|
| 34 |
+
img.save(file_path_with_id)
|
| 35 |
+
|
| 36 |
+
return file_id # Gibt die UUID (ohne .png) zurück
|
scripts/cleanup_large_files.sh
DELETED
|
File without changes
|
scripts/quick_test.sh
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
# Quick test script for Museum Sexoskop App
|
| 3 |
-
|
| 4 |
-
BASE_URL="https://ch404-cardserver.hf.space"
|
| 5 |
-
echo "🧪 Testing Museum Sexoskop App at: $BASE_URL"
|
| 6 |
-
echo "=============================================="
|
| 7 |
-
|
| 8 |
-
# Function to test an endpoint
|
| 9 |
-
test_endpoint() {
|
| 10 |
-
local name="$1"
|
| 11 |
-
local url="$2"
|
| 12 |
-
local method="$3"
|
| 13 |
-
local data="$4"
|
| 14 |
-
|
| 15 |
-
echo -e "\n🔍 Testing: $name"
|
| 16 |
-
echo " URL: $url"
|
| 17 |
-
|
| 18 |
-
if [ "$method" = "POST" ]; then
|
| 19 |
-
response=$(curl -s -w "\n%{http_code}" -X POST "$url" \
|
| 20 |
-
-H "Content-Type: application/json" \
|
| 21 |
-
-d "$data")
|
| 22 |
-
else
|
| 23 |
-
response=$(curl -s -w "\n%{http_code}" "$url")
|
| 24 |
-
fi
|
| 25 |
-
|
| 26 |
-
# Split response and status code
|
| 27 |
-
status_code=$(echo "$response" | tail -1)
|
| 28 |
-
body=$(echo "$response" | sed '$d')
|
| 29 |
-
|
| 30 |
-
if [ "$status_code" = "200" ]; then
|
| 31 |
-
echo " ✅ SUCCESS (200 OK)"
|
| 32 |
-
if command -v jq &> /dev/null; then
|
| 33 |
-
echo "$body" | jq . 2>/dev/null || echo " 📝 Response: $body"
|
| 34 |
-
else
|
| 35 |
-
echo " 📝 Response: $body"
|
| 36 |
-
fi
|
| 37 |
-
elif [ "$status_code" = "404" ]; then
|
| 38 |
-
echo " ❌ NOT FOUND (404) - Space may not be running yet"
|
| 39 |
-
elif [ "$status_code" = "500" ]; then
|
| 40 |
-
echo " ❌ SERVER ERROR (500) - Check space logs"
|
| 41 |
-
else
|
| 42 |
-
echo " ⚠️ Status Code: $status_code"
|
| 43 |
-
echo " 📝 Response: $body"
|
| 44 |
-
fi
|
| 45 |
-
}
|
| 46 |
-
|
| 47 |
-
# Test endpoints
|
| 48 |
-
test_endpoint "Health Check" "$BASE_URL/api/v1/health" "GET"
|
| 49 |
-
|
| 50 |
-
test_endpoint "Root Endpoint" "$BASE_URL/" "GET"
|
| 51 |
-
|
| 52 |
-
test_endpoint "Generate Horoscope" "$BASE_URL/api/v1/generate-horoscope" "POST" \
|
| 53 |
-
'{"terms": ["Test","Deploy","Success","Working","Happy"], "date_of_birth": "1990-01-01"}'
|
| 54 |
-
|
| 55 |
-
echo -e "\n🏁 Testing complete!"
|
| 56 |
-
echo "💡 If you see 404 errors, the space may still be deploying."
|
| 57 |
-
echo "💡 If you see 500 errors, check the space logs on HuggingFace."
|
| 58 |
-
echo "💡 Space URL: https://huggingface.co/spaces/ch404/cardserver"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/start.sh
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
#!/bin/bash
|
| 2 |
-
# Startup script for debugging HF Space deployment
|
| 3 |
-
|
| 4 |
-
echo "🔍 DEBUG: Starting Museum Sexoskop App"
|
| 5 |
-
echo "📁 Current directory: $(pwd)"
|
| 6 |
-
echo "📁 Directory contents:"
|
| 7 |
-
ls -la
|
| 8 |
-
|
| 9 |
-
echo "🐍 Python version: $(python --version)"
|
| 10 |
-
echo "📦 Installed packages:"
|
| 11 |
-
pip list | grep -E "(fastapi|uvicorn|pydantic|pillow|qrcode|transformers|torch)"
|
| 12 |
-
|
| 13 |
-
echo "📁 App directory structure:"
|
| 14 |
-
find /app -type d -name "app" -o -name "static" -o -name "templates" | head -20
|
| 15 |
-
|
| 16 |
-
echo "🔧 Testing configuration..."
|
| 17 |
-
if [ -f "/app/tests/test_config.py" ]; then
|
| 18 |
-
python tests/test_config.py
|
| 19 |
-
else
|
| 20 |
-
echo "❌ tests/test_config.py not found"
|
| 21 |
-
fi
|
| 22 |
-
|
| 23 |
-
echo "🚀 Starting FastAPI server..."
|
| 24 |
-
exec uvicorn app.main:app --host 0.0.0.0 --port 7860 --log-level debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/setup_training.sh → setup_training.sh
RENAMED
|
@@ -1,23 +1,33 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
# Schnelles Setup-Script für LORA-Training
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
echo "🚀 Setup für kostengünstiges LORA-Training"
|
| 5 |
echo "=========================================="
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# 1. Virtuelle Umgebung erstellen (optional)
|
| 8 |
-
echo "📦 Erstelle virtuelle Umgebung..."
|
| 9 |
-
python3 -m venv venv_training
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# 2. Requirements installieren
|
| 13 |
echo "📥 Installiere Training-Dependencies..."
|
| 14 |
-
pip install -r training_requirements.txt
|
| 15 |
-
pip install -r
|
| 16 |
|
| 17 |
# 3. Training-Ordner vorbereiten
|
| 18 |
echo "📁 Erstelle Training-Struktur..."
|
| 19 |
-
mkdir -p
|
| 20 |
-
mkdir -p data
|
| 21 |
|
| 22 |
# 4. GPU-Check
|
| 23 |
echo "🔍 GPU-Verfügbarkeit prüfen..."
|
|
@@ -27,11 +37,18 @@ echo ""
|
|
| 27 |
echo "✅ Setup abgeschlossen!"
|
| 28 |
echo ""
|
| 29 |
echo "🎯 Nächste Schritte:"
|
| 30 |
-
echo "1. Trainingsdaten in data/ ablegen (JSON-Format)"
|
| 31 |
-
echo "2.
|
| 32 |
-
echo "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
echo ""
|
| 34 |
echo "💰 Kostenoptimierung:"
|
| 35 |
echo "- Lokales Training: 0€ (nur Stromkosten)"
|
| 36 |
echo "- Cloud-Alternative: Google Colab Pro (~10€/Monat)"
|
| 37 |
echo "- Training-Zeit: ~2-4 Stunden je nach Datenmenge"
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
# Schnelles Setup-Script für LORA-Training
|
| 3 |
|
| 4 |
+
set -e # Exit immediately if a command exits with a non-zero status.
|
| 5 |
+
|
| 6 |
+
# Determine the absolute path of the script (which is also the project root)
|
| 7 |
+
PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 8 |
+
|
| 9 |
echo "🚀 Setup für kostengünstiges LORA-Training"
|
| 10 |
echo "=========================================="
|
| 11 |
+
echo "🔧 Script wird ausgeführt von: $PROJECT_ROOT"
|
| 12 |
+
echo "📂 Projekt-Root-Verzeichnis: $PROJECT_ROOT"
|
| 13 |
|
| 14 |
# 1. Virtuelle Umgebung erstellen (optional)
|
| 15 |
+
echo "📦 Erstelle virtuelle Umgebung in '$PROJECT_ROOT/venv_training'..."
|
| 16 |
+
python3 -m venv "$PROJECT_ROOT/venv_training"
|
| 17 |
+
echo "✨ Virtuelle Umgebung '$PROJECT_ROOT/venv_training' erstellt."
|
| 18 |
+
echo "👉 Zum manuellen Aktivieren für spätere Sitzungen: source '$PROJECT_ROOT/venv_training/bin/activate'"
|
| 19 |
+
source "$PROJECT_ROOT/venv_training/bin/activate"
|
| 20 |
+
echo "✅ Virtuelle Umgebung für diese Skript-Sitzung aktiviert."
|
| 21 |
|
| 22 |
# 2. Requirements installieren
|
| 23 |
echo "📥 Installiere Training-Dependencies..."
|
| 24 |
+
pip install -r "$PROJECT_ROOT/training/training_requirements.txt"
|
| 25 |
+
pip install -r "$PROJECT_ROOT/requirements.txt" # General project requirements from root
|
| 26 |
|
| 27 |
# 3. Training-Ordner vorbereiten
|
| 28 |
echo "📁 Erstelle Training-Struktur..."
|
| 29 |
+
mkdir -p "$PROJECT_ROOT/models/lora-checkpoint"
|
| 30 |
+
mkdir -p "$PROJECT_ROOT/training/data" # For training data
|
| 31 |
|
| 32 |
# 4. GPU-Check
|
| 33 |
echo "🔍 GPU-Verfügbarkeit prüfen..."
|
|
|
|
| 37 |
echo "✅ Setup abgeschlossen!"
|
| 38 |
echo ""
|
| 39 |
echo "🎯 Nächste Schritte:"
|
| 40 |
+
echo "1. Trainingsdaten in '$PROJECT_ROOT/training/data/' ablegen (JSON-Format, z.B. cards_training_data.json)."
|
| 41 |
+
echo "2. Sicherstellen, dass die virtuelle Umgebung aktiv ist. Falls nicht, aktivieren mit:"
|
| 42 |
+
echo " source '$PROJECT_ROOT/venv_training/bin/activate'"
|
| 43 |
+
echo "3. Zum Training-Verzeichnis wechseln und Training starten:"
|
| 44 |
+
echo " cd '$PROJECT_ROOT/training/'"
|
| 45 |
+
echo " python train_lora.py"
|
| 46 |
+
echo "4. Geschätzter Speicherbedarf: ~8-12GB RAM + ~4GB VRAM"
|
| 47 |
echo ""
|
| 48 |
echo "💰 Kostenoptimierung:"
|
| 49 |
echo "- Lokales Training: 0€ (nur Stromkosten)"
|
| 50 |
echo "- Cloud-Alternative: Google Colab Pro (~10€/Monat)"
|
| 51 |
echo "- Training-Zeit: ~2-4 Stunden je nach Datenmenge"
|
| 52 |
+
echo ""
|
| 53 |
+
echo "💡 Die virtuelle Umgebung '$PROJECT_ROOT/venv_training' ist derzeit in dieser Shell-Sitzung aktiv."
|
| 54 |
+
echo " Um sie zu verlassen (deaktivieren), tippe: deactivate"
|
static/images/base/base0.png
ADDED
|
static/images/base/base1.png
ADDED
|
static/images/base/base10.png
ADDED
|
static/images/base/base11.png
ADDED
|
static/images/base/base2.png
ADDED
|
static/images/base/base3.png
ADDED
|
static/images/base/base4.png
ADDED
|
static/images/base/base5.png
ADDED
|
static/images/base/base6.png
ADDED
|
static/images/base/base7.png
ADDED
|
static/images/base/base8.png
ADDED
|
static/images/base/base9.png
ADDED
|
static/images/symbols/symbol0.png
ADDED
|
static/images/symbols/symbol1.png
ADDED
|
static/images/symbols/symbol2.png
ADDED
|
static/images/symbols/symbol3.png
ADDED
|
static/images/symbols/symbol4.png
ADDED
|
static/images/symbols/symbol5.png
ADDED
|
static/images/symbols/symbol6.png
ADDED
|
static/images/symbols/symbol7.png
ADDED
|