Spaces:
Runtime error
Runtime error
Create static/main.py
Browse files- static/main.py +470 -0
static/main.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
+
from fastapi.staticfiles import StaticFiles
|
| 3 |
+
from fastapi.responses import HTMLResponse
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
import mysql.connector
|
| 6 |
+
import os
|
| 7 |
+
import requests
|
| 8 |
+
import json
|
| 9 |
+
from urllib.parse import urlparse
|
| 10 |
+
from fastapi.responses import StreamingResponse
|
| 11 |
+
from transformers import pipeline
|
| 12 |
+
from sentence_transformers import SentenceTransformer
|
| 13 |
+
import nltk
|
| 14 |
+
import numpy as np
|
| 15 |
+
import hashlib
|
| 16 |
+
|
| 17 |
+
# Suppress FutureWarnings for cleaner logs
|
| 18 |
+
import warnings
|
| 19 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 20 |
+
|
| 21 |
+
# Set NLTK data path to a writable directory
|
| 22 |
+
NLTK_DATA_PATH = os.getenv("NLTK_DATA", "/app/nltk_data")
|
| 23 |
+
os.makedirs(NLTK_DATA_PATH, exist_ok=True)
|
| 24 |
+
nltk.data.path.append(NLTK_DATA_PATH)
|
| 25 |
+
|
| 26 |
+
# Download NLTK punkt data if not already present
|
| 27 |
+
try:
|
| 28 |
+
nltk.data.find('tokenizers/punkt')
|
| 29 |
+
except LookupError:
|
| 30 |
+
try:
|
| 31 |
+
nltk.download('punkt', download_dir=NLTK_DATA_PATH)
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Failed to download NLTK punkt data: {e}")
|
| 34 |
+
|
| 35 |
+
# Set Hugging Face cache directory
|
| 36 |
+
os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "/app/hf_cache")
|
| 37 |
+
os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
|
| 38 |
+
|
| 39 |
+
app = FastAPI()
|
| 40 |
+
|
| 41 |
+
# Mount static files (frontend)
|
| 42 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
| 43 |
+
|
| 44 |
+
# Database configuration from DATABASE_URL
|
| 45 |
+
database_url = os.getenv("DATABASE_URL")
|
| 46 |
+
if database_url:
|
| 47 |
+
parsed = urlparse(database_url)
|
| 48 |
+
db_config = {
|
| 49 |
+
"host": parsed.hostname,
|
| 50 |
+
"user": parsed.username,
|
| 51 |
+
"password": parsed.password,
|
| 52 |
+
"database": parsed.path.lstrip("/"),
|
| 53 |
+
"port": parsed.port or 3306
|
| 54 |
+
}
|
| 55 |
+
else:
|
| 56 |
+
db_config = {
|
| 57 |
+
"host": "novelcrafter-novelcrafter.g.aivencloud.com",
|
| 58 |
+
"user": "avnadmin",
|
| 59 |
+
"password": "AVNS_6uHxC3wASDZVJ_PxQXJ",
|
| 60 |
+
"database": "defaultdb",
|
| 61 |
+
"port": 12221
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# OpenRouter API configuration
|
| 65 |
+
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
|
| 66 |
+
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
|
| 67 |
+
|
| 68 |
+
# Hugging Face model setup
|
| 69 |
+
SUMMARIZER_MODEL = "facebook/bart-large-cnn"
|
| 70 |
+
EXCERPT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
| 71 |
+
USE_INFERENCE_API = os.getenv("USE_HF_INFERENCE_API", "false").lower() == "true"
|
| 72 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
| 73 |
+
|
| 74 |
+
if USE_INFERENCE_API:
|
| 75 |
+
from huggingface_hub import InferenceClient
|
| 76 |
+
summarizer = InferenceClient(model=SUMMARIZER_MODEL, token=HF_API_TOKEN)
|
| 77 |
+
else:
|
| 78 |
+
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=-1) # CPU
|
| 79 |
+
|
| 80 |
+
excerpt_model = SentenceTransformer(EXCERPT_MODEL)
|
| 81 |
+
|
| 82 |
+
# Pydantic models for request validation
|
| 83 |
+
class ProseRequest(BaseModel):
|
| 84 |
+
model: str
|
| 85 |
+
manuscript: str = ""
|
| 86 |
+
outline: str = ""
|
| 87 |
+
characters: str = ""
|
| 88 |
+
prompt: str = ""
|
| 89 |
+
project_id: str = "default"
|
| 90 |
+
manuscript_mode: str = "summary" # Options: "full", "summary", "excerpts"
|
| 91 |
+
|
| 92 |
+
# Maximum characters for full manuscript mode (approx. 37,500 tokens)
|
| 93 |
+
MAX_MANUSCRIPT_CHARS = 50000
|
| 94 |
+
# Maximum characters for MySQL MEDIUMTEXT (safety threshold)
|
| 95 |
+
MAX_MEDIUMTEXT_CHARS = 1000000
|
| 96 |
+
|
| 97 |
+
# Initialize database tables with MEDIUMTEXT and ensure schema compatibility
|
| 98 |
+
def init_db():
|
| 99 |
+
try:
|
| 100 |
+
conn = mysql.connector.connect(**db_config)
|
| 101 |
+
cursor = conn.cursor()
|
| 102 |
+
|
| 103 |
+
# Create inputs table if it doesn't exist
|
| 104 |
+
cursor.execute("""
|
| 105 |
+
CREATE TABLE IF NOT EXISTS inputs (
|
| 106 |
+
id INT AUTO_INCREMENT PRIMARY KEY,
|
| 107 |
+
project_id VARCHAR(255),
|
| 108 |
+
manuscript MEDIUMTEXT,
|
| 109 |
+
outline MEDIUMTEXT,
|
| 110 |
+
characters MEDIUMTEXT,
|
| 111 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 112 |
+
)
|
| 113 |
+
""")
|
| 114 |
+
|
| 115 |
+
# Check and update inputs table schema
|
| 116 |
+
cursor.execute("SHOW COLUMNS FROM inputs")
|
| 117 |
+
columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
|
| 118 |
+
|
| 119 |
+
# Add manuscript_summary and manuscript_excerpts if missing
|
| 120 |
+
if 'manuscript_summary' not in columns:
|
| 121 |
+
cursor.execute("ALTER TABLE inputs ADD manuscript_summary MEDIUMTEXT")
|
| 122 |
+
if 'manuscript_excerpts' not in columns:
|
| 123 |
+
cursor.execute("ALTER TABLE inputs ADD manuscript_excerpts MEDIUMTEXT")
|
| 124 |
+
|
| 125 |
+
# Ensure existing columns are MEDIUMTEXT
|
| 126 |
+
text_columns = ['manuscript', 'outline', 'characters']
|
| 127 |
+
for col in text_columns:
|
| 128 |
+
if col in columns and 'mediumtext' not in columns[col]:
|
| 129 |
+
cursor.execute(f"ALTER TABLE inputs MODIFY {col} MEDIUMTEXT")
|
| 130 |
+
|
| 131 |
+
# Create prompt_history table if it doesn't exist
|
| 132 |
+
cursor.execute("""
|
| 133 |
+
CREATE TABLE IF NOT EXISTS prompt_history (
|
| 134 |
+
id INT AUTO_INCREMENT PRIMARY KEY,
|
| 135 |
+
project_id VARCHAR(255),
|
| 136 |
+
prompt MEDIUMTEXT,
|
| 137 |
+
response MEDIUMTEXT,
|
| 138 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 139 |
+
)
|
| 140 |
+
""")
|
| 141 |
+
|
| 142 |
+
# Check and update prompt_history table schema
|
| 143 |
+
cursor.execute("SHOW COLUMNS FROM prompt_history")
|
| 144 |
+
columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
|
| 145 |
+
for col in ['prompt', 'response']:
|
| 146 |
+
if col in columns and 'mediumtext' not in columns[col]:
|
| 147 |
+
cursor.execute(f"ALTER TABLE prompt_history MODIFY {col} MEDIUMTEXT")
|
| 148 |
+
|
| 149 |
+
conn.commit()
|
| 150 |
+
except mysql.connector.Error as e:
|
| 151 |
+
print(f"Database initialization error: {e}")
|
| 152 |
+
raise
|
| 153 |
+
finally:
|
| 154 |
+
cursor.close()
|
| 155 |
+
conn.close()
|
| 156 |
+
|
| 157 |
+
# Call init_db on startup
|
| 158 |
+
@app.on_event("startup")
|
| 159 |
+
def startup_event():
|
| 160 |
+
init_db()
|
| 161 |
+
|
| 162 |
+
# Generate manuscript summary using Hugging Face model
|
| 163 |
+
def generate_manuscript_summary(manuscript: str):
|
| 164 |
+
if not manuscript:
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
# Split manuscript into chunks to avoid memory issues
|
| 168 |
+
try:
|
| 169 |
+
sentences = nltk.sent_tokenize(manuscript)
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"Sentence tokenization error: {e}")
|
| 172 |
+
return ""
|
| 173 |
+
|
| 174 |
+
chunks = []
|
| 175 |
+
current_chunk = ""
|
| 176 |
+
for sentence in sentences:
|
| 177 |
+
if len(current_chunk) + len(sentence) < 4000: # BART max length ~1024 tokens
|
| 178 |
+
current_chunk += sentence + " "
|
| 179 |
+
else:
|
| 180 |
+
chunks.append(current_chunk.strip())
|
| 181 |
+
current_chunk = sentence + " "
|
| 182 |
+
if current_chunk:
|
| 183 |
+
chunks.append(current_chunk.strip())
|
| 184 |
+
|
| 185 |
+
# Summarize each chunk
|
| 186 |
+
summaries = []
|
| 187 |
+
for chunk in chunks:
|
| 188 |
+
try:
|
| 189 |
+
if USE_INFERENCE_API:
|
| 190 |
+
summary = summarizer.summarization(
|
| 191 |
+
chunk,
|
| 192 |
+
max_length=200,
|
| 193 |
+
min_length=50,
|
| 194 |
+
do_sample=False
|
| 195 |
+
)
|
| 196 |
+
summaries.append(summary[0]['summary_text'])
|
| 197 |
+
else:
|
| 198 |
+
summary = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
|
| 199 |
+
summaries.append(summary[0]['summary_text'])
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"Summary generation error: {e}")
|
| 202 |
+
summaries.append("")
|
| 203 |
+
|
| 204 |
+
# Combine and summarize again if too long
|
| 205 |
+
combined_summary = " ".join([s for s in summaries if s])
|
| 206 |
+
if len(combined_summary) > 1000:
|
| 207 |
+
try:
|
| 208 |
+
if USE_INFERENCE_API:
|
| 209 |
+
final_summary = summarizer.summarization(
|
| 210 |
+
combined_summary,
|
| 211 |
+
max_length=1000,
|
| 212 |
+
min_length=500,
|
| 213 |
+
do_sample=False
|
| 214 |
+
)[0]['summary_text']
|
| 215 |
+
else:
|
| 216 |
+
final_summary = summarizer(
|
| 217 |
+
combined_summary,
|
| 218 |
+
max_length=1000,
|
| 219 |
+
min_length=500,
|
| 220 |
+
do_sample=False
|
| 221 |
+
)[0]['summary_text']
|
| 222 |
+
return final_summary
|
| 223 |
+
except Exception as e:
|
| 224 |
+
print(f"Final summary error: {e}")
|
| 225 |
+
return combined_summary[:1000]
|
| 226 |
+
return combined_summary
|
| 227 |
+
|
| 228 |
+
# Generate manuscript excerpts using semantic similarity
|
| 229 |
+
def generate_manuscript_excerpts(manuscript: str, prompt: str):
|
| 230 |
+
if not manuscript or not prompt:
|
| 231 |
+
return ""
|
| 232 |
+
|
| 233 |
+
# Split manuscript into sections (e.g., paragraphs)
|
| 234 |
+
sections = [s.strip() for s in manuscript.split("\n\n") if s.strip()]
|
| 235 |
+
if not sections:
|
| 236 |
+
return ""
|
| 237 |
+
|
| 238 |
+
# Encode prompt and sections
|
| 239 |
+
try:
|
| 240 |
+
prompt_embedding = excerpt_model.encode(prompt)
|
| 241 |
+
section_embeddings = excerpt_model.encode(sections)
|
| 242 |
+
|
| 243 |
+
# Compute cosine similarities
|
| 244 |
+
similarities = np.dot(section_embeddings, prompt_embedding) / (
|
| 245 |
+
np.linalg.norm(section_embeddings, axis=1) * np.linalg.norm(prompt_embedding)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Select top 2 sections
|
| 249 |
+
top_indices = similarities.argsort()[-2:][::-1]
|
| 250 |
+
selected_sections = [sections[i] for i in top_indices]
|
| 251 |
+
excerpts = "\n\n".join(selected_sections)
|
| 252 |
+
|
| 253 |
+
# Truncate to ~2,000 words (~8,000 chars)
|
| 254 |
+
return excerpts[:8000]
|
| 255 |
+
except Exception as e:
|
| 256 |
+
print(f"Excerpt generation error: {e}")
|
| 257 |
+
return ""
|
| 258 |
+
|
| 259 |
+
# Save inputs to MySQL with summary/excerpts
|
| 260 |
+
def save_inputs(project_id: str, manuscript: str, outline: str, characters: str, prompt: str):
|
| 261 |
+
# Truncate inputs to avoid exceeding MEDIUMTEXT limits
|
| 262 |
+
if len(manuscript) > MAX_MEDIUMTEXT_CHARS:
|
| 263 |
+
manuscript = manuscript[:MAX_MEDIUMTEXT_CHARS]
|
| 264 |
+
print(f"Manuscript truncated to {MAX_MEDIUMTEXT_CHARS} characters")
|
| 265 |
+
if len(outline) > MAX_MEDIUMTEXT_CHARS:
|
| 266 |
+
outline = outline[:MAX_MEDIUMTEXT_CHARS]
|
| 267 |
+
print(f"Outline truncated to {MAX_MEDIUMTEXT_CHARS} characters")
|
| 268 |
+
if len(characters) > MAX_MEDIUMTEXT_CHARS:
|
| 269 |
+
characters = characters[:MAX_MEDIUMTEXT_CHARS]
|
| 270 |
+
print(f"Characters truncated to {MAX_MEDIUMTEXT_CHARS} characters")
|
| 271 |
+
|
| 272 |
+
# Check if manuscript is unchanged to avoid redundant processing
|
| 273 |
+
manuscript_hash = hashlib.md5(manuscript.encode()).hexdigest()
|
| 274 |
+
inputs = get_latest_inputs(project_id)
|
| 275 |
+
if inputs.get("manuscript") and hashlib.md5(inputs["manuscript"].encode()).hexdigest() == manuscript_hash:
|
| 276 |
+
summary = inputs["manuscript_summary"]
|
| 277 |
+
excerpts = inputs["manuscript_excerpts"]
|
| 278 |
+
else:
|
| 279 |
+
summary = generate_manuscript_summary(manuscript)
|
| 280 |
+
excerpts = generate_manuscript_excerpts(manuscript, prompt)
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
conn = mysql.connector.connect(**db_config)
|
| 284 |
+
cursor = conn.cursor()
|
| 285 |
+
cursor.execute(
|
| 286 |
+
"INSERT INTO inputs (project_id, manuscript, manuscript_summary, manuscript_excerpts, outline, characters) VALUES (%s, %s, %s, %s, %s, %s)",
|
| 287 |
+
(project_id, manuscript, summary, excerpts, outline, characters)
|
| 288 |
+
)
|
| 289 |
+
conn.commit()
|
| 290 |
+
except mysql.connector.Error as e:
|
| 291 |
+
print(f"Database save error: {e}")
|
| 292 |
+
raise HTTPException(status_code=500, detail=f"Failed to save inputs: {str(e)}")
|
| 293 |
+
finally:
|
| 294 |
+
cursor.close()
|
| 295 |
+
conn.close()
|
| 296 |
+
|
| 297 |
+
# Save prompt and response to history
|
| 298 |
+
def save_prompt_history(project_id: str, prompt: str, response: str):
|
| 299 |
+
if len(prompt) > MAX_MEDIUMTEXT_CHARS:
|
| 300 |
+
prompt = prompt[:MAX_MEDIUMTEXT_CHARS]
|
| 301 |
+
print(f"Prompt truncated to {MAX_MEDIUMTEXT_CHARS} characters")
|
| 302 |
+
if len(response) > MAX_MEDIUMTEXT_CHARS:
|
| 303 |
+
response = response[:MAX_MEDIUMTEXT_CHARS]
|
| 304 |
+
print(f"Response truncated to {MAX_MEDIUMTEXT_CHARS} characters")
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
conn = mysql.connector.connect(**db_config)
|
| 308 |
+
cursor = conn.cursor()
|
| 309 |
+
cursor.execute(
|
| 310 |
+
"INSERT INTO prompt_history (project_id, prompt, response) VALUES (%s, %s, %s)",
|
| 311 |
+
(project_id, prompt, response)
|
| 312 |
+
)
|
| 313 |
+
conn.commit()
|
| 314 |
+
except mysql.connector.Error as e:
|
| 315 |
+
print(f"Prompt history save error: {e}")
|
| 316 |
+
raise HTTPException(status_code=500, detail=f"Failed to save prompt history: {str(e)}")
|
| 317 |
+
finally:
|
| 318 |
+
cursor.close()
|
| 319 |
+
conn.close()
|
| 320 |
+
|
| 321 |
+
# Retrieve recent prompt history for context
|
| 322 |
+
def get_prompt_history(project_id: str, limit: int = 2):
|
| 323 |
+
try:
|
| 324 |
+
conn = mysql.connector.connect(**db_config)
|
| 325 |
+
cursor = conn.cursor()
|
| 326 |
+
cursor.execute(
|
| 327 |
+
"SELECT prompt, response FROM prompt_history WHERE project_id = %s ORDER BY created_at DESC LIMIT %s",
|
| 328 |
+
(project_id, limit)
|
| 329 |
+
)
|
| 330 |
+
history = cursor.fetchall()
|
| 331 |
+
cursor.close()
|
| 332 |
+
conn.close()
|
| 333 |
+
return [{"prompt": p, "response": r} for p, r in history]
|
| 334 |
+
except mysql.connector.Error as e:
|
| 335 |
+
print(f"Prompt history retrieval error: {e}")
|
| 336 |
+
return []
|
| 337 |
+
|
| 338 |
+
# Retrieve latest inputs for a project
|
| 339 |
+
def get_latest_inputs(project_id: str):
|
| 340 |
+
try:
|
| 341 |
+
conn = mysql.connector.connect(**db_config)
|
| 342 |
+
cursor = conn.cursor()
|
| 343 |
+
cursor.execute(
|
| 344 |
+
"SELECT manuscript, manuscript_summary, manuscript_excerpts, outline, characters FROM inputs WHERE project_id = %s ORDER BY created_at DESC LIMIT 1",
|
| 345 |
+
(project_id,)
|
| 346 |
+
)
|
| 347 |
+
result = cursor.fetchone()
|
| 348 |
+
cursor.close()
|
| 349 |
+
conn.close()
|
| 350 |
+
return {
|
| 351 |
+
"manuscript": result[0] if result else None,
|
| 352 |
+
"manuscript_summary": result[1] if result else None,
|
| 353 |
+
"manuscript_excerpts": result[2] if result else None,
|
| 354 |
+
"outline": result[3] if result else None,
|
| 355 |
+
"characters": result[4] if result else None
|
| 356 |
+
}
|
| 357 |
+
except mysql.connector.Error as e:
|
| 358 |
+
print(f"Inputs retrieval error: {e}")
|
| 359 |
+
return {}
|
| 360 |
+
|
| 361 |
+
# Generate prose with OpenRouter API
|
| 362 |
+
async def generate_prose_stream(request: ProseRequest):
|
| 363 |
+
# Save full inputs to MySQL with summary/excerpts
|
| 364 |
+
save_inputs(request.project_id, request.manuscript, request.outline, request.characters, request.prompt)
|
| 365 |
+
|
| 366 |
+
# Get latest inputs
|
| 367 |
+
inputs = get_latest_inputs(request.project_id)
|
| 368 |
+
|
| 369 |
+
# Select manuscript content based on mode
|
| 370 |
+
if request.manuscript_mode == "full":
|
| 371 |
+
manuscript_content = (inputs.get("manuscript") or "")[:MAX_MANUSCRIPT_CHARS]
|
| 372 |
+
manuscript_label = "Manuscript Pages (Truncated)"
|
| 373 |
+
elif request.manuscript_mode == "excerpts":
|
| 374 |
+
manuscript_content = inputs.get("manuscript_excerpts") or ""
|
| 375 |
+
manuscript_label = "Manuscript Excerpts"
|
| 376 |
+
else: # summary
|
| 377 |
+
manuscript_content = inputs.get("manuscript_summary") or ""
|
| 378 |
+
manuscript_label = "Manuscript Summary"
|
| 379 |
+
|
| 380 |
+
# Get recent prompt history for context
|
| 381 |
+
history = get_prompt_history(request.project_id)
|
| 382 |
+
history_context = "\n".join(
|
| 383 |
+
[f"Previous Prompt: {h['prompt']}\nPrevious Response: {h['response']}" for h in history]
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
# Construct system prompt
|
| 387 |
+
system_prompt = f"""
|
| 388 |
+
You are a creative writing assistant tasked with generating novel prose. Use the following inputs and prompt history to create coherent, engaging prose that aligns with the provided manuscript content, outline, and character descriptions. If the manuscript content is a summary or excerpts, rely on the outline and characters for additional context. Ensure the tone and style match the provided context and follow the specific request in the prompt.
|
| 389 |
+
|
| 390 |
+
**{manuscript_label}**:
|
| 391 |
+
{manuscript_content or "No manuscript content provided."}
|
| 392 |
+
|
| 393 |
+
**Outline**:
|
| 394 |
+
{request.outline or "No outline provided."}
|
| 395 |
+
|
| 396 |
+
**Character Descriptions**:
|
| 397 |
+
{request.characters or "No character descriptions provided."}
|
| 398 |
+
|
| 399 |
+
**Prompt History**:
|
| 400 |
+
{history_context or "No prompt history available."}
|
| 401 |
+
|
| 402 |
+
**Specific Request**:
|
| 403 |
+
{request.prompt or "Generate a continuation of the manuscript with creative prose."}
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
headers = {
|
| 407 |
+
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
|
| 408 |
+
"Content-Type": "application/json",
|
| 409 |
+
"HTTP-Referer": "https://huggingface.co/spaces/NoLev/NovelCrafter",
|
| 410 |
+
"X-Title": "Novel Prose Generator"
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
payload = {
|
| 414 |
+
"model": request.model,
|
| 415 |
+
"messages": [
|
| 416 |
+
{"role": "system", "content": system_prompt},
|
| 417 |
+
{"role": "user", "content": request.prompt or "Generate novel prose based on the provided inputs."}
|
| 418 |
+
],
|
| 419 |
+
"stream": True,
|
| 420 |
+
"temperature": 0.7,
|
| 421 |
+
"max_tokens": 1000
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
try:
|
| 425 |
+
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, stream=True)
|
| 426 |
+
if response.status_code != 200:
|
| 427 |
+
raise HTTPException(status_code=response.status_code, detail="Error from OpenRouter API")
|
| 428 |
+
except Exception as e:
|
| 429 |
+
raise HTTPException(status_code=500, detail=f"OpenRouter API request failed: {str(e)}")
|
| 430 |
+
|
| 431 |
+
# Stream response and collect full response for caching
|
| 432 |
+
full_response = ""
|
| 433 |
+
async def stream_response():
|
| 434 |
+
nonlocal full_response
|
| 435 |
+
for line in response.iter_lines():
|
| 436 |
+
if line:
|
| 437 |
+
decoded_line = line.decode('utf-8')
|
| 438 |
+
if decoded_line.startswith("data: "):
|
| 439 |
+
data = decoded_line[6:]
|
| 440 |
+
if data == "[DONE]":
|
| 441 |
+
continue
|
| 442 |
+
try:
|
| 443 |
+
json_data = json.loads(data)
|
| 444 |
+
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
| 445 |
+
if content:
|
| 446 |
+
full_response += content
|
| 447 |
+
yield content
|
| 448 |
+
except json.JSONDecodeError:
|
| 449 |
+
continue
|
| 450 |
+
# Save prompt and response to history after streaming
|
| 451 |
+
save_prompt_history(request.project_id, request.prompt, full_response)
|
| 452 |
+
|
| 453 |
+
return StreamingResponse(stream_response(), media_type="text/plain")
|
| 454 |
+
|
| 455 |
+
# API endpoint to generate prose
|
| 456 |
+
@app.post("/generate")
|
| 457 |
+
async def generate_prose(request: ProseRequest):
|
| 458 |
+
return await generate_prose_stream(request)
|
| 459 |
+
|
| 460 |
+
# API endpoint to retrieve latest inputs
|
| 461 |
+
@app.get("/inputs/{project_id}")
|
| 462 |
+
async def get_inputs(project_id: str):
|
| 463 |
+
inputs = get_latest_inputs(project_id)
|
| 464 |
+
return inputs
|
| 465 |
+
|
| 466 |
+
# Serve the frontend
|
| 467 |
+
@app.get("/")
|
| 468 |
+
async def serve_index():
|
| 469 |
+
with open("static/index.html", "r") as f:
|
| 470 |
+
return HTMLResponse(content=f.read())
|