Spaces:
Sleeping
Sleeping
Upload 16 files
Browse files
src/writing_studio/core/config.py
CHANGED
|
@@ -43,7 +43,6 @@ class Settings(BaseSettings):
|
|
| 43 |
|
| 44 |
# Model Configuration
|
| 45 |
default_model: str = Field(default="distilgpt2", description="Default HuggingFace model")
|
| 46 |
-
model_cache_dir: str = Field(default="./models", description="Model cache directory")
|
| 47 |
max_model_length: int = Field(default=512, ge=1, description="Maximum model input length")
|
| 48 |
default_max_length: int = Field(default=300, ge=1, description="Default generation length")
|
| 49 |
default_num_sequences: int = Field(default=1, ge=1, description="Number of sequences")
|
|
@@ -91,7 +90,7 @@ class Settings(BaseSettings):
|
|
| 91 |
return [origin.strip() for origin in v.split(",") if origin.strip()]
|
| 92 |
return v
|
| 93 |
|
| 94 |
-
@field_validator("
|
| 95 |
@classmethod
|
| 96 |
def ensure_directory_exists(cls, v: str) -> str:
|
| 97 |
"""Ensure directory exists for file paths."""
|
|
|
|
| 43 |
|
| 44 |
# Model Configuration
|
| 45 |
default_model: str = Field(default="distilgpt2", description="Default HuggingFace model")
|
|
|
|
| 46 |
max_model_length: int = Field(default=512, ge=1, description="Maximum model input length")
|
| 47 |
default_max_length: int = Field(default=300, ge=1, description="Default generation length")
|
| 48 |
default_num_sequences: int = Field(default=1, ge=1, description="Number of sequences")
|
|
|
|
| 90 |
return [origin.strip() for origin in v.split(",") if origin.strip()]
|
| 91 |
return v
|
| 92 |
|
| 93 |
+
@field_validator("log_file_path")
|
| 94 |
@classmethod
|
| 95 |
def ensure_directory_exists(cls, v: str) -> str:
|
| 96 |
"""Ensure directory exists for file paths."""
|
src/writing_studio/services/model_service.py
CHANGED
|
@@ -58,10 +58,10 @@ class ModelService:
|
|
| 58 |
start_time = time.time()
|
| 59 |
|
| 60 |
# Load model with error handling
|
|
|
|
| 61 |
self._current_model = pipeline(
|
| 62 |
"text-generation",
|
| 63 |
model=model_name,
|
| 64 |
-
cache_dir=settings.model_cache_dir,
|
| 65 |
)
|
| 66 |
self._current_model_name = model_name
|
| 67 |
|
|
@@ -119,13 +119,14 @@ class ModelService:
|
|
| 119 |
logger.info(f"Generating text with model: {self._current_model_name}")
|
| 120 |
start_time = time.time()
|
| 121 |
|
| 122 |
-
# Generate text
|
| 123 |
result = self._current_model(
|
| 124 |
prompt,
|
| 125 |
max_length=params["max_length"],
|
| 126 |
num_return_sequences=params["num_sequences"],
|
| 127 |
do_sample=True,
|
| 128 |
temperature=params["temperature"],
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
generated_text = result[0]["generated_text"]
|
|
|
|
| 58 |
start_time = time.time()
|
| 59 |
|
| 60 |
# Load model with error handling
|
| 61 |
+
# Note: cache_dir is handled automatically by transformers
|
| 62 |
self._current_model = pipeline(
|
| 63 |
"text-generation",
|
| 64 |
model=model_name,
|
|
|
|
| 65 |
)
|
| 66 |
self._current_model_name = model_name
|
| 67 |
|
|
|
|
| 119 |
logger.info(f"Generating text with model: {self._current_model_name}")
|
| 120 |
start_time = time.time()
|
| 121 |
|
| 122 |
+
# Generate text with proper parameters
|
| 123 |
result = self._current_model(
|
| 124 |
prompt,
|
| 125 |
max_length=params["max_length"],
|
| 126 |
num_return_sequences=params["num_sequences"],
|
| 127 |
do_sample=True,
|
| 128 |
temperature=params["temperature"],
|
| 129 |
+
pad_token_id=self._current_model.tokenizer.eos_token_id, # Avoid warnings
|
| 130 |
)
|
| 131 |
|
| 132 |
generated_text = result[0]["generated_text"]
|