jmisak commited on
Commit
aec570d
·
verified ·
1 Parent(s): f6e7587

Upload 16 files

Browse files
src/writing_studio/core/config.py CHANGED
@@ -43,7 +43,6 @@ class Settings(BaseSettings):
43
 
44
  # Model Configuration
45
  default_model: str = Field(default="distilgpt2", description="Default HuggingFace model")
46
- model_cache_dir: str = Field(default="./models", description="Model cache directory")
47
  max_model_length: int = Field(default=512, ge=1, description="Maximum model input length")
48
  default_max_length: int = Field(default=300, ge=1, description="Default generation length")
49
  default_num_sequences: int = Field(default=1, ge=1, description="Number of sequences")
@@ -91,7 +90,7 @@ class Settings(BaseSettings):
91
  return [origin.strip() for origin in v.split(",") if origin.strip()]
92
  return v
93
 
94
- @field_validator("model_cache_dir", "log_file_path")
95
  @classmethod
96
  def ensure_directory_exists(cls, v: str) -> str:
97
  """Ensure directory exists for file paths."""
 
43
 
44
  # Model Configuration
45
  default_model: str = Field(default="distilgpt2", description="Default HuggingFace model")
 
46
  max_model_length: int = Field(default=512, ge=1, description="Maximum model input length")
47
  default_max_length: int = Field(default=300, ge=1, description="Default generation length")
48
  default_num_sequences: int = Field(default=1, ge=1, description="Number of sequences")
 
90
  return [origin.strip() for origin in v.split(",") if origin.strip()]
91
  return v
92
 
93
+ @field_validator("log_file_path")
94
  @classmethod
95
  def ensure_directory_exists(cls, v: str) -> str:
96
  """Ensure directory exists for file paths."""
src/writing_studio/services/model_service.py CHANGED
@@ -58,10 +58,10 @@ class ModelService:
58
  start_time = time.time()
59
 
60
  # Load model with error handling
 
61
  self._current_model = pipeline(
62
  "text-generation",
63
  model=model_name,
64
- cache_dir=settings.model_cache_dir,
65
  )
66
  self._current_model_name = model_name
67
 
@@ -119,13 +119,14 @@ class ModelService:
119
  logger.info(f"Generating text with model: {self._current_model_name}")
120
  start_time = time.time()
121
 
122
- # Generate text
123
  result = self._current_model(
124
  prompt,
125
  max_length=params["max_length"],
126
  num_return_sequences=params["num_sequences"],
127
  do_sample=True,
128
  temperature=params["temperature"],
 
129
  )
130
 
131
  generated_text = result[0]["generated_text"]
 
58
  start_time = time.time()
59
 
60
  # Load model with error handling
61
+ # Note: cache_dir is handled automatically by transformers
62
  self._current_model = pipeline(
63
  "text-generation",
64
  model=model_name,
 
65
  )
66
  self._current_model_name = model_name
67
 
 
119
  logger.info(f"Generating text with model: {self._current_model_name}")
120
  start_time = time.time()
121
 
122
+ # Generate text with proper parameters
123
  result = self._current_model(
124
  prompt,
125
  max_length=params["max_length"],
126
  num_return_sequences=params["num_sequences"],
127
  do_sample=True,
128
  temperature=params["temperature"],
129
+ pad_token_id=self._current_model.tokenizer.eos_token_id, # Avoid warnings
130
  )
131
 
132
  generated_text = result[0]["generated_text"]