jmisak commited on
Commit
f277022
·
verified ·
1 Parent(s): 2e32647

Upload 19 files

Browse files
src/writing_studio/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (305 Bytes). View file
 
src/writing_studio/core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (210 Bytes). View file
 
src/writing_studio/core/__pycache__/config.cpython-312.pyc ADDED
Binary file (6.07 kB). View file
 
src/writing_studio/core/analyzer.py CHANGED
@@ -67,16 +67,30 @@ class WritingAnalyzer:
67
  logger.info(f"Loading new model: {model_name}")
68
  self.model_service.load_model(model_name)
69
 
70
- # Generate prompt
71
  prompt = self.prompt_service.generate_prompt(user_text, prompt_pack)
72
 
73
- # Generate revision
 
74
  with generation_duration.time():
75
- revision = self.model_service.generate_text(prompt)
76
-
77
- # Extract only the revised part (after "Revised Text:")
78
- if "Revised Text:" in revision:
79
- revision = revision.split("Revised Text:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Analyze with rubric
82
  rubric_results = self.rubric_service.analyze_text(user_text)
 
67
  logger.info(f"Loading new model: {model_name}")
68
  self.model_service.load_model(model_name)
69
 
70
+ # Generate prompt using selected pack
71
  prompt = self.prompt_service.generate_prompt(user_text, prompt_pack)
72
 
73
+ # Generate AI revision
74
+ logger.info("Generating AI revision...")
75
  with generation_duration.time():
76
+ revision = self.model_service.generate_text(
77
+ prompt,
78
+ max_length=min(len(user_text.split()) * 2 + 100, settings.max_model_length),
79
+ use_cache=True
80
+ )
81
+
82
+ # Clean up revision (remove any prompt artifacts)
83
+ if prompt_pack in revision:
84
+ revision = revision.split(prompt_pack)[-1].strip()
85
+ if "Revised text:" in revision:
86
+ revision = revision.split("Revised text:")[-1].strip()
87
+ if user_text in revision:
88
+ # Model might include original text, extract just the revision
89
+ revision = revision.replace(user_text, "").strip()
90
+
91
+ # If revision is empty or too similar, provide a note
92
+ if not revision or revision == user_text:
93
+ revision = user_text + "\n\n[Note: The AI model kept the text as-is, suggesting it's already well-written!]"
94
 
95
  # Analyze with rubric
96
  rubric_results = self.rubric_service.analyze_text(user_text)
src/writing_studio/core/config.py CHANGED
@@ -42,9 +42,12 @@ class Settings(BaseSettings):
42
  server_workers: int = Field(default=4, ge=1, description="Number of worker processes")
43
 
44
  # Model Configuration
45
- default_model: str = Field(default="distilgpt2", description="Default HuggingFace model")
 
 
 
46
  max_model_length: int = Field(default=512, ge=1, description="Maximum model input length")
47
- default_max_length: int = Field(default=300, ge=1, description="Default generation length")
48
  default_num_sequences: int = Field(default=1, ge=1, description="Number of sequences")
49
 
50
  # Security
 
42
  server_workers: int = Field(default=4, ge=1, description="Number of worker processes")
43
 
44
  # Model Configuration
45
+ default_model: str = Field(
46
+ default="google/flan-t5-base",
47
+ description="Default HuggingFace model (instruction-tuned for revision)"
48
+ )
49
  max_model_length: int = Field(default=512, ge=1, description="Maximum model input length")
50
+ default_max_length: int = Field(default=512, ge=1, description="Default generation length")
51
  default_num_sequences: int = Field(default=1, ge=1, description="Number of sequences")
52
 
53
  # Security
src/writing_studio/services/model_service.py CHANGED
@@ -5,7 +5,7 @@ import time
5
  from functools import lru_cache
6
  from typing import Any, Dict, Optional
7
 
8
- from transformers import pipeline
9
 
10
  from writing_studio.core.config import settings
11
  from writing_studio.core.exceptions import ModelLoadError, TextGenerationError
@@ -20,6 +20,7 @@ class ModelService:
20
  """Initialize the model service."""
21
  self._current_model: Optional[Any] = None
22
  self._current_model_name: Optional[str] = None
 
23
  self._cache: Dict[str, Any] = {}
24
  self._load_default_model()
25
 
@@ -57,13 +58,24 @@ class ModelService:
57
  logger.info(f"Loading model: {model_name}")
58
  start_time = time.time()
59
 
 
 
 
 
 
 
 
 
 
 
60
  # Load model with error handling
61
- # Note: cache_dir is handled automatically by transformers
62
  self._current_model = pipeline(
63
- "text-generation",
64
  model=model_name,
 
65
  )
66
  self._current_model_name = model_name
 
67
 
68
  load_time = time.time() - start_time
69
  logger.info(f"Model loaded successfully in {load_time:.2f}s: {model_name}")
@@ -119,19 +131,32 @@ class ModelService:
119
  logger.info(f"Generating text with model: {self._current_model_name}")
120
  start_time = time.time()
121
 
122
- # Generate text with proper parameters
123
- result = self._current_model(
124
- prompt,
125
- max_length=params["max_length"],
126
- num_return_sequences=params["num_sequences"],
127
- do_sample=True,
128
- temperature=params["temperature"],
129
- pad_token_id=self._current_model.tokenizer.eos_token_id, # Avoid warnings
130
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- generated_text = result[0]["generated_text"]
133
  generation_time = time.time() - start_time
134
-
135
  logger.info(f"Text generated in {generation_time:.2f}s")
136
 
137
  # Cache result if enabled
 
5
  from functools import lru_cache
6
  from typing import Any, Dict, Optional
7
 
8
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
9
 
10
  from writing_studio.core.config import settings
11
  from writing_studio.core.exceptions import ModelLoadError, TextGenerationError
 
20
  """Initialize the model service."""
21
  self._current_model: Optional[Any] = None
22
  self._current_model_name: Optional[str] = None
23
+ self._task_type: str = "text2text-generation" # Default for FLAN-T5
24
  self._cache: Dict[str, Any] = {}
25
  self._load_default_model()
26
 
 
58
  logger.info(f"Loading model: {model_name}")
59
  start_time = time.time()
60
 
61
+ # Detect model type and use appropriate pipeline
62
+ # FLAN-T5, T5 = text2text-generation
63
+ # GPT-2, GPT = text-generation
64
+ if any(x in model_name.lower() for x in ['t5', 'flan']):
65
+ task = "text2text-generation"
66
+ logger.info(f"Detected instruction-following model, using {task} pipeline")
67
+ else:
68
+ task = "text-generation"
69
+ logger.info(f"Detected text generation model, using {task} pipeline")
70
+
71
  # Load model with error handling
 
72
  self._current_model = pipeline(
73
+ task,
74
  model=model_name,
75
+ max_length=settings.max_model_length,
76
  )
77
  self._current_model_name = model_name
78
+ self._task_type = task
79
 
80
  load_time = time.time() - start_time
81
  logger.info(f"Model loaded successfully in {load_time:.2f}s: {model_name}")
 
131
  logger.info(f"Generating text with model: {self._current_model_name}")
132
  start_time = time.time()
133
 
134
+ # Generate text with parameters appropriate for model type
135
+ if self._task_type == "text2text-generation":
136
+ # T5/FLAN-T5 models
137
+ result = self._current_model(
138
+ prompt,
139
+ max_new_tokens=params["max_length"],
140
+ num_return_sequences=params["num_sequences"],
141
+ do_sample=True,
142
+ temperature=params["temperature"],
143
+ truncation=True,
144
+ )
145
+ # T5 models return generated_text directly
146
+ generated_text = result[0]["generated_text"]
147
+ else:
148
+ # GPT-2 style models
149
+ result = self._current_model(
150
+ prompt,
151
+ max_length=params["max_length"],
152
+ num_return_sequences=params["num_sequences"],
153
+ do_sample=True,
154
+ temperature=params["temperature"],
155
+ pad_token_id=self._current_model.tokenizer.eos_token_id,
156
+ )
157
+ generated_text = result[0]["generated_text"]
158
 
 
159
  generation_time = time.time() - start_time
 
160
  logger.info(f"Text generated in {generation_time:.2f}s")
161
 
162
  # Cache result if enabled
src/writing_studio/services/prompt_service.py CHANGED
@@ -9,27 +9,28 @@ class PromptService:
9
  """Service for managing and generating prompts."""
10
 
11
  def __init__(self):
12
- """Initialize the prompt service with templates."""
 
13
  self.prompt_packs = {
14
  "General": {
15
- "instruction": "Revise this text for clarity, conciseness, and audience fit",
16
- "context": "Focus on improving overall readability and effectiveness.",
17
  },
18
  "Literature": {
19
- "instruction": "Revise this literary analysis with attention to theme, style, and evidence",
20
- "context": "Ensure proper use of literary terminology and textual support.",
21
  },
22
  "Tech Comm": {
23
  "instruction": "Revise this technical document for precision, clarity, and professional tone",
24
- "context": "Emphasize accuracy, clear instructions, and appropriate technical level.",
25
  },
26
  "Academic": {
27
- "instruction": "Revise this academic writing for formal tone, organization, and scholarly support",
28
- "context": "Maintain formal register and ensure proper citation indicators.",
29
  },
30
  "Creative": {
31
- "instruction": "Revise this creative writing with focus on imagery, voice, and engagement",
32
- "context": "Enhance descriptive language and narrative flow.",
33
  },
34
  }
35
 
@@ -46,6 +47,8 @@ class PromptService:
46
  """
47
  Generate a complete prompt from user text and pack template.
48
 
 
 
49
  Args:
50
  user_text: User's input text
51
  pack_name: Name of the prompt pack to use
@@ -60,14 +63,8 @@ class PromptService:
60
  pack = self.prompt_packs[pack_name]
61
  logger.info(f"Generating prompt with pack: {pack_name}")
62
 
63
- prompt = f"""{pack['instruction']}.
64
-
65
- Context: {pack['context']}
66
-
67
- Original Text:
68
- {user_text}
69
-
70
- Revised Text:"""
71
 
72
  return prompt
73
 
 
9
  """Service for managing and generating prompts."""
10
 
11
  def __init__(self):
12
+ """Initialize the prompt service with templates for instruction-following models."""
13
+ # Optimized for FLAN-T5 and other instruction-tuned models
14
  self.prompt_packs = {
15
  "General": {
16
+ "instruction": "Revise the following text to improve clarity, conciseness, and readability",
17
+ "context": "Make it clear and easy to understand while maintaining the original meaning.",
18
  },
19
  "Literature": {
20
+ "instruction": "Revise this literary analysis to strengthen the argument with better evidence and literary terminology",
21
+ "context": "Enhance academic rigor and use of textual support.",
22
  },
23
  "Tech Comm": {
24
  "instruction": "Revise this technical document for precision, clarity, and professional tone",
25
+ "context": "Make it accurate, clear, and appropriate for technical audiences.",
26
  },
27
  "Academic": {
28
+ "instruction": "Revise this academic writing to improve formal tone, organization, and scholarly voice",
29
+ "context": "Ensure formal register and proper academic style.",
30
  },
31
  "Creative": {
32
+ "instruction": "Revise this creative writing to enhance imagery, voice, and reader engagement",
33
+ "context": "Improve descriptive language and narrative flow.",
34
  },
35
  }
36
 
 
47
  """
48
  Generate a complete prompt from user text and pack template.
49
 
50
+ Optimized for instruction-following models like FLAN-T5.
51
+
52
  Args:
53
  user_text: User's input text
54
  pack_name: Name of the prompt pack to use
 
63
  pack = self.prompt_packs[pack_name]
64
  logger.info(f"Generating prompt with pack: {pack_name}")
65
 
66
+ # Format optimized for FLAN-T5 and similar instruction-tuned models
67
+ prompt = f"{pack['instruction']}. {pack['context']}\n\nText: {user_text}\n\nRevised text:"
 
 
 
 
 
 
68
 
69
  return prompt
70