NovelCrafter2 / main.py
NoLev's picture
Update main.py
59006f9 verified
from fastapi import FastAPI, HTTPException, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
import mysql.connector
import os
import requests
import json
from urllib.parse import urlparse
from fastapi.responses import StreamingResponse
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import nltk
import numpy as np
import hashlib
# Suppress FutureWarnings for cleaner logs
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
# Set NLTK data path to a writable directory
NLTK_DATA_PATH = os.getenv("NLTK_DATA", "/app/nltk_data")
os.makedirs(NLTK_DATA_PATH, exist_ok=True)
nltk.data.path.append(NLTK_DATA_PATH)
# Download NLTK punkt data if not already present
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
try:
nltk.download('punkt', download_dir=NLTK_DATA_PATH)
except Exception as e:
print(f"Failed to download NLTK punkt data: {e}")
# Set Hugging Face cache directory
os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "/app/hf_cache")
os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
# Set password from environment variable
PASSWORD_SECRET = os.getenv("PASSWORD_SECRET", "default_password")
app = FastAPI()
# Mount static files (frontend)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Database configuration from DATABASE_URL
database_url = os.getenv("DATABASE_URL")
if database_url:
parsed = urlparse(database_url)
db_config = {
"host": parsed.hostname,
"user": parsed.username,
"password": parsed.password,
"database": parsed.path.lstrip("/"),
"port": parsed.port or 3306
}
else:
db_config = {
"host": "novelcrafter-novelcrafter.g.aivencloud.com",
"user": "avnadmin",
"password": "AVNS_6uHxC3wASDZVJ_PxQXJ",
"database": "defaultdb",
"port": 12221
}
# OpenRouter API configuration
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
# Hugging Face model setup
SUMMARIZER_MODEL = "facebook/bart-large-cnn"
EXCERPT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
USE_INFERENCE_API = os.getenv("USE_HF_INFERENCE_API", "false").lower() == "true"
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
if USE_INFERENCE_API:
from huggingface_hub import InferenceClient
summarizer = InferenceClient(model=SUMMARIZER_MODEL, token=HF_API_TOKEN)
else:
summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=-1) # CPU
excerpt_model = SentenceTransformer(EXCERPT_MODEL)
# Pydantic models for request validation
class ProseRequest(BaseModel):
model: str
manuscript: str = ""
outline: str = ""
characters: str = ""
prompt: str = ""
project_id: str = "default"
manuscript_mode: str = "summary" # Options: "full", "summary", "excerpts"
class PromptRequest(BaseModel):
project_id: str = "default"
manuscript: str = ""
outline: str = ""
characters: str = ""
class PasswordRequest(BaseModel):
password: str
class ProseSaveRequest(BaseModel):
project_id: str
prose: str
# Maximum characters for full manuscript mode (approx. 37,500 tokens)
MAX_MANUSCRIPT_CHARS = 50000
# Maximum characters for MySQL MEDIUMTEXT (safety threshold)
MAX_MEDIUMTEXT_CHARS = 1000000
# Initialize database tables with MEDIUMTEXT and ensure schema compatibility
def init_db():
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
# Create inputs table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS inputs (
id INT AUTO_INCREMENT PRIMARY KEY,
project_id VARCHAR(255),
manuscript MEDIUMTEXT,
outline MEDIUMTEXT,
characters MEDIUMTEXT,
generated_prompts MEDIUMTEXT,
manuscript_summary MEDIUMTEXT,
manuscript_excerpts MEDIUMTEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Check and update inputs table schema
cursor.execute("SHOW COLUMNS FROM inputs")
columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
# Add columns if missing
if 'manuscript_summary' not in columns:
cursor.execute("ALTER TABLE inputs ADD manuscript_summary MEDIUMTEXT")
if 'manuscript_excerpts' not in columns:
cursor.execute("ALTER TABLE inputs ADD manuscript_excerpts MEDIUMTEXT")
if 'generated_prompts' not in columns:
cursor.execute("ALTER TABLE inputs ADD generated_prompts MEDIUMTEXT")
# Ensure existing columns are MEDIUMTEXT
text_columns = ['manuscript', 'outline', 'characters', 'generated_prompts', 'manuscript_summary', 'manuscript_excerpts']
for col in text_columns:
if col in columns and 'mediumtext' not in columns[col]:
cursor.execute(f"ALTER TABLE inputs MODIFY {col} MEDIUMTEXT")
# Create prompt_history table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS prompt_history (
id INT AUTO_INCREMENT PRIMARY KEY,
project_id VARCHAR(255),
prompt MEDIUMTEXT,
response MEDIUMTEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Check and update prompt_history table schema
cursor.execute("SHOW COLUMNS FROM prompt_history")
columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
for col in ['prompt', 'response']:
if col in columns and 'mediumtext' not in columns[col]:
cursor.execute(f"ALTER TABLE prompt_history MODIFY {col} MEDIUMTEXT")
# Create prose_history table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS prose_history (
id INT AUTO_INCREMENT PRIMARY KEY,
project_id VARCHAR(255),
prose MEDIUMTEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
# Check and update prose_history table schema
cursor.execute("SHOW COLUMNS FROM prose_history")
columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
if 'prose' in columns and 'mediumtext' not in columns['prose']:
cursor.execute("ALTER TABLE prose_history MODIFY prose MEDIUMTEXT")
conn.commit()
except mysql.connector.Error as e:
print(f"Database initialization error: {e}")
raise
finally:
cursor.close()
conn.close()
# Call init_db on startup
@app.on_event("startup")
def startup_event():
init_db()
# Generate manuscript summary using Hugging Face model
def generate_manuscript_summary(manuscript: str):
if not manuscript:
return ""
try:
sentences = nltk.sent_tokenize(manuscript)
except Exception as e:
print(f"Sentence tokenization error: {e}")
return ""
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < 4000:
current_chunk += sentence + " "
else:
chunks.append(current_chunk.strip())
current_chunk = sentence + " "
if current_chunk:
chunks.append(current_chunk.strip())
summaries = []
for chunk in chunks:
try:
input_length = len(chunk.split())
max_length = min(200, input_length // 2) if input_length > 0 else 50
min_length = max(10, max_length // 4)
if USE_INFERENCE_API:
summary = summarizer.summarization(
chunk,
max_length=max_length,
min_length=min_length,
do_sample=False
)
summaries.append(summary[0]['summary_text'])
else:
summary = summarizer(chunk, max_length=max_length, min_length=min_length, do_sample=False)
summaries.append(summary[0]['summary_text'])
except Exception as e:
print(f"Summary generation error: {e}")
summaries.append("")
combined_summary = " ".join([s for s in summaries if s])
if len(combined_summary) > 1000:
try:
input_length = len(combined_summary.split())
max_length = min(1000, input_length // 2) if input_length > 0 else 500
min_length = max(100, max_length // 2)
if USE_INFERENCE_API:
final_summary = summarizer.summarization(
combined_summary,
max_length=max_length,
min_length=min_length,
do_sample=False
)[0]['summary_text']
else:
final_summary = summarizer(
combined_summary,
max_length=max_length,
min_length=min_length,
do_sample=False
)[0]['summary_text']
return final_summary
except Exception as e:
print(f"Final summary error: {e}")
return combined_summary[:1000]
return combined_summary
# Generate manuscript excerpts using semantic similarity
def generate_manuscript_excerpts(manuscript: str, prompt: str):
if not manuscript or not prompt:
return ""
sections = [s.strip() for s in manuscript.split("\n\n") if s.strip()]
if not sections:
return ""
try:
prompt_embedding = excerpt_model.encode(prompt)
section_embeddings = excerpt_model.encode(sections)
similarities = np.dot(section_embeddings, prompt_embedding) / (
np.linalg.norm(section_embeddings, axis=1) * np.linalg.norm(prompt_embedding)
)
top_indices = similarities.argsort()[-2:][::-1]
selected_sections = [sections[i] for i in top_indices]
excerpts = "\n\n".join(selected_sections)
return excerpts[:8000]
except Exception as e:
print(f"Excerpt generation error: {e}")
return ""
# Generate prompts for continuing the novel
def generate_prompts(outline: str, characters: str, manuscript_summary: str):
if not outline and not characters:
return ["No outline or character descriptions provided to generate prompts."]
try:
sentences = nltk.sent_tokenize(outline)
except Exception as e:
print(f"Outline tokenization error: {e}")
sentences = [outline] if outline else []
prompts = []
character_lines = [line.strip() for line in characters.split("\n") if line.strip()]
characters_list = [line.split(":")[0].strip() for line in character_lines if ":" in line]
for i, section in enumerate(sentences[:5]):
if section:
section_embedding = excerpt_model.encode(section)
summary_embedding = excerpt_model.encode(manuscript_summary) if manuscript_summary else np.zeros_like(section_embedding)
similarity = np.dot(section_embedding, summary_embedding) / (
np.linalg.norm(section_embedding) * (np.linalg.norm(summary_embedding) if np.linalg.norm(summary_embedding) else 1)
)
character = characters_list[i % len(characters_list)] if characters_list else "a key character"
if similarity > 0.5 and manuscript_summary:
prompt = f"Continue the manuscript by writing a scene that advances the plot point: '{section}'. Focus on {character} and ensure the scene builds on the existing manuscript's tone and context."
else:
prompt = f"Write a new scene for the novel based on the outline point: '{section}'. Feature {character} prominently and maintain consistency with the provided character descriptions and outline."
prompts.append(prompt)
if not prompts and characters_list:
prompts.append(f"Write a scene that continues the novel, featuring {characters_list[0]} and advancing the story based on the provided outline and character descriptions.")
elif not prompts:
prompts.append("Continue the novel with a new scene that advances the story based on the provided outline.")
return prompts[:3]
# Save inputs to MySQL with summary/excerpts/prompts
def save_inputs(project_id: str, manuscript: str, outline: str, characters: str, prompt: str, generated_prompts: str = ""):
if len(manuscript) > MAX_MEDIUMTEXT_CHARS:
manuscript = manuscript[:MAX_MEDIUMTEXT_CHARS]
print(f"Manuscript truncated to {MAX_MEDIUMTEXT_CHARS} characters")
if len(outline) > MAX_MEDIUMTEXT_CHARS:
outline = outline[:MAX_MEDIUMTEXT_CHARS]
print(f"Outline truncated to {MAX_MEDIUMTEXT_CHARS} characters")
if len(characters) > MAX_MEDIUMTEXT_CHARS:
characters = characters[:MAX_MEDIUMTEXT_CHARS]
print(f"Characters truncated to {MAX_MEDIUMTEXT_CHARS} characters")
if len(generated_prompts) > MAX_MEDIUMTEXT_CHARS:
generated_prompts = generated_prompts[:MAX_MEDIUMTEXT_CHARS]
print(f"Generated prompts truncated to {MAX_MEDIUMTEXT_CHARS} characters")
manuscript_hash = hashlib.md5(manuscript.encode()).hexdigest()
inputs = get_latest_inputs(project_id)
if inputs.get("manuscript") and hashlib.md5(inputs["manuscript"].encode()).hexdigest() == manuscript_hash:
summary = inputs["manuscript_summary"]
excerpts = inputs["manuscript_excerpts"]
else:
summary = generate_manuscript_summary(manuscript)
excerpts = generate_manuscript_excerpts(manuscript, prompt)
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO inputs (project_id, manuscript, manuscript_summary, manuscript_excerpts, outline, characters, generated_prompts) VALUES (%s, %s, %s, %s, %s, %s, %s)",
(project_id, manuscript, summary, excerpts, outline, characters, generated_prompts)
)
conn.commit()
except mysql.connector.Error as e:
print(f"Database save error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save inputs: {str(e)}")
finally:
cursor.close()
conn.close()
# Save prompt and response to history
def save_prompt_history(project_id: str, prompt: str, response: str):
if len(prompt) > MAX_MEDIUMTEXT_CHARS:
prompt = prompt[:MAX_MEDIUMTEXT_CHARS]
print(f"Prompt truncated to {MAX_MEDIUMTEXT_CHARS} characters")
if len(response) > MAX_MEDIUMTEXT_CHARS:
response = response[:MAX_MEDIUMTEXT_CHARS]
print(f"Response truncated to {MAX_MEDIUMTEXT_CHARS} characters")
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO prompt_history (project_id, prompt, response) VALUES (%s, %s, %s)",
(project_id, prompt, response)
)
conn.commit()
except mysql.connector.Error as e:
print(f"Prompt history save error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save prompt history: {str(e)}")
finally:
cursor.close()
conn.close()
# Save generated prose to history
def save_prose(project_id: str, prose: str):
if len(prose) > MAX_MEDIUMTEXT_CHARS:
prose = prose[:MAX_MEDIUMTEXT_CHARS]
print(f"Prose truncated to {MAX_MEDIUMTEXT_CHARS} characters")
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
cursor.execute(
"INSERT INTO prose_history (project_id, prose) VALUES (%s, %s)",
(project_id, prose)
)
conn.commit()
except mysql.connector.Error as e:
print(f"Prose save error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to save prose: {str(e)}")
finally:
cursor.close()
conn.close()
# Retrieve recent prompt history for context
def get_prompt_history(project_id: str, limit: int = 2):
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
cursor.execute(
"SELECT prompt, response FROM prompt_history WHERE project_id = %s ORDER BY created_at DESC LIMIT %s",
(project_id, limit)
)
history = cursor.fetchall()
cursor.close()
conn.close()
return [{"prompt": p, "response": r} for p, r in history]
except mysql.connector.Error as e:
print(f"Prompt history retrieval error: {e}")
return []
# Retrieve latest inputs for a project
def get_latest_inputs(project_id: str):
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor()
cursor.execute(
"SELECT manuscript, manuscript_summary, manuscript_excerpts, outline, characters, generated_prompts FROM inputs WHERE project_id = %s ORDER BY created_at DESC LIMIT 1",
(project_id,)
)
result = cursor.fetchone()
cursor.close()
conn.close()
return {
"manuscript": result[0] if result else None,
"manuscript_summary": result[1] if result else None,
"manuscript_excerpts": result[2] if result else None,
"outline": result[3] if result else None,
"characters": result[4] if result else None,
"generated_prompts": result[5] if result else None
}
except mysql.connector.Error as e:
print(f"Inputs retrieval error: {e}")
return {}
# Check password
@app.post("/check_password")
async def check_password(request: PasswordRequest):
if request.password != PASSWORD_SECRET:
return {"success": False}
return {"success": True}
# Generate prose with OpenRouter API
async def generate_prose_stream(request: ProseRequest):
save_inputs(request.project_id, request.manuscript, request.outline, request.characters, request.prompt)
inputs = get_latest_inputs(request.project_id)
if request.manuscript_mode == "full":
manuscript_content = (inputs.get("manuscript") or "")[:MAX_MANUSCRIPT_CHARS]
manuscript_label = "Manuscript Pages (Truncated)"
elif request.manuscript_mode == "excerpts":
manuscript_content = inputs.get("manuscript_excerpts") or ""
manuscript_label = "Manuscript Excerpts"
else:
manuscript_content = inputs.get("manuscript_summary") or ""
manuscript_label = "Manuscript Summary"
history = get_prompt_history(request.project_id)
history_context = "\n".join(
[f"Previous Prompt: {h['prompt']}\nPrevious Response: {h['response']}" for h in history]
)
system_prompt = f"""
You are a creative writing assistant tasked with generating novel prose. Use the following inputs and prompt history to create coherent, engaging prose that aligns with the provided manuscript content, outline, and character descriptions. If the manuscript content is a summary or excerpts, rely on the outline and characters for additional context. Ensure the tone and style match the provided context and follow the specific request in the prompt.
**{manuscript_label}**:
{manuscript_content or "No manuscript content provided."}
**Outline**:
{request.outline or "No outline provided."}
**Character Descriptions**:
{request.characters or "No character descriptions provided."}
**Prompt History**:
{history_context or "No prompt history available."}
**Specific Request**:
{request.prompt or "Generate a continuation of the manuscript with creative prose."}
"""
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": "https://huggingface.co/spaces/NoLev/NovelCrafter",
"X-Title": "Novel Prose Generator"
}
payload = {
"model": request.model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request.prompt or "Generate novel prose based on the provided inputs."}
],
"stream": True,
"temperature": 0.7,
"max_tokens": 1000
}
try:
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, stream=True)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail="Error from OpenRouter API")
except Exception as e:
raise HTTPException(status_code=500, detail=f"OpenRouter API request failed: {str(e)}")
full_response = ""
async def stream_response():
nonlocal full_response
for line in response.iter_lines():
if line:
decoded_line = line.decode('utf-8')
if decoded_line.startswith("data: "):
data = decoded_line[6:]
if data == "[DONE]":
continue
try:
json_data = json.loads(data)
content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
if content:
full_response += content
yield content
except json.JSONDecodeError:
continue
save_prompt_history(request.project_id, request.prompt, full_response)
save_prose(request.project_id, full_response)
return StreamingResponse(stream_response(), media_type="text/plain")
# API endpoint to generate prompts
@app.post("/generate_prompts")
async def generate_prompts_endpoint(request: PromptRequest):
inputs = get_latest_inputs(request.project_id)
summary = generate_manuscript_summary(request.manuscript) if request.manuscript else inputs.get("manuscript_summary", "")
prompts = generate_prompts(request.outline or inputs.get("outline", ""),
request.characters or inputs.get("characters", ""),
summary)
generated_prompts = "\n\n".join(prompts)
save_inputs(request.project_id, request.manuscript, request.outline, request.characters, "", generated_prompts)
return prompts
# API endpoint to generate prose
@app.post("/generate")
async def generate_prose(request: ProseRequest):
return await generate_prose_stream(request)
# API endpoint to save prose
@app.post("/save_prose")
async def save_prose_endpoint(request: ProseSaveRequest):
save_prose(request.project_id, request.prose)
return {"status": "success"}
# API endpoint to retrieve latest inputs
@app.get("/inputs/{project_id}")
async def get_inputs(project_id: str):
inputs = get_latest_inputs(project_id)
return inputs
# Serve the frontend
@app.get("/")
async def serve_index():
with open("static/index.html", "r") as f:
return HTMLResponse(content=f.read())