NoLev commited on
Commit
c9c801f
·
verified ·
1 Parent(s): fbdfe3a

Create static/main.py

Browse files
Files changed (1) hide show
  1. static/main.py +470 -0
static/main.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.staticfiles import StaticFiles
3
+ from fastapi.responses import HTMLResponse
4
+ from pydantic import BaseModel
5
+ import mysql.connector
6
+ import os
7
+ import requests
8
+ import json
9
+ from urllib.parse import urlparse
10
+ from fastapi.responses import StreamingResponse
11
+ from transformers import pipeline
12
+ from sentence_transformers import SentenceTransformer
13
+ import nltk
14
+ import numpy as np
15
+ import hashlib
16
+
17
+ # Suppress FutureWarnings for cleaner logs
18
+ import warnings
19
+ warnings.filterwarnings("ignore", category=FutureWarning)
20
+
21
+ # Set NLTK data path to a writable directory
22
+ NLTK_DATA_PATH = os.getenv("NLTK_DATA", "/app/nltk_data")
23
+ os.makedirs(NLTK_DATA_PATH, exist_ok=True)
24
+ nltk.data.path.append(NLTK_DATA_PATH)
25
+
26
+ # Download NLTK punkt data if not already present
27
+ try:
28
+ nltk.data.find('tokenizers/punkt')
29
+ except LookupError:
30
+ try:
31
+ nltk.download('punkt', download_dir=NLTK_DATA_PATH)
32
+ except Exception as e:
33
+ print(f"Failed to download NLTK punkt data: {e}")
34
+
35
+ # Set Hugging Face cache directory
36
+ os.environ["TRANSFORMERS_CACHE"] = os.getenv("TRANSFORMERS_CACHE", "/app/hf_cache")
37
+ os.makedirs(os.environ["TRANSFORMERS_CACHE"], exist_ok=True)
38
+
39
+ app = FastAPI()
40
+
41
+ # Mount static files (frontend)
42
+ app.mount("/static", StaticFiles(directory="static"), name="static")
43
+
44
+ # Database configuration from DATABASE_URL
45
+ database_url = os.getenv("DATABASE_URL")
46
+ if database_url:
47
+ parsed = urlparse(database_url)
48
+ db_config = {
49
+ "host": parsed.hostname,
50
+ "user": parsed.username,
51
+ "password": parsed.password,
52
+ "database": parsed.path.lstrip("/"),
53
+ "port": parsed.port or 3306
54
+ }
55
+ else:
56
+ db_config = {
57
+ "host": "novelcrafter-novelcrafter.g.aivencloud.com",
58
+ "user": "avnadmin",
59
+ "password": "AVNS_6uHxC3wASDZVJ_PxQXJ",
60
+ "database": "defaultdb",
61
+ "port": 12221
62
+ }
63
+
64
+ # OpenRouter API configuration
65
+ OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
66
+ OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
67
+
68
+ # Hugging Face model setup
69
+ SUMMARIZER_MODEL = "facebook/bart-large-cnn"
70
+ EXCERPT_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
71
+ USE_INFERENCE_API = os.getenv("USE_HF_INFERENCE_API", "false").lower() == "true"
72
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
73
+
74
+ if USE_INFERENCE_API:
75
+ from huggingface_hub import InferenceClient
76
+ summarizer = InferenceClient(model=SUMMARIZER_MODEL, token=HF_API_TOKEN)
77
+ else:
78
+ summarizer = pipeline("summarization", model=SUMMARIZER_MODEL, device=-1) # CPU
79
+
80
+ excerpt_model = SentenceTransformer(EXCERPT_MODEL)
81
+
82
+ # Pydantic models for request validation
83
+ class ProseRequest(BaseModel):
84
+ model: str
85
+ manuscript: str = ""
86
+ outline: str = ""
87
+ characters: str = ""
88
+ prompt: str = ""
89
+ project_id: str = "default"
90
+ manuscript_mode: str = "summary" # Options: "full", "summary", "excerpts"
91
+
92
+ # Maximum characters for full manuscript mode (approx. 37,500 tokens)
93
+ MAX_MANUSCRIPT_CHARS = 50000
94
+ # Maximum characters for MySQL MEDIUMTEXT (safety threshold)
95
+ MAX_MEDIUMTEXT_CHARS = 1000000
96
+
97
+ # Initialize database tables with MEDIUMTEXT and ensure schema compatibility
98
+ def init_db():
99
+ try:
100
+ conn = mysql.connector.connect(**db_config)
101
+ cursor = conn.cursor()
102
+
103
+ # Create inputs table if it doesn't exist
104
+ cursor.execute("""
105
+ CREATE TABLE IF NOT EXISTS inputs (
106
+ id INT AUTO_INCREMENT PRIMARY KEY,
107
+ project_id VARCHAR(255),
108
+ manuscript MEDIUMTEXT,
109
+ outline MEDIUMTEXT,
110
+ characters MEDIUMTEXT,
111
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
112
+ )
113
+ """)
114
+
115
+ # Check and update inputs table schema
116
+ cursor.execute("SHOW COLUMNS FROM inputs")
117
+ columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
118
+
119
+ # Add manuscript_summary and manuscript_excerpts if missing
120
+ if 'manuscript_summary' not in columns:
121
+ cursor.execute("ALTER TABLE inputs ADD manuscript_summary MEDIUMTEXT")
122
+ if 'manuscript_excerpts' not in columns:
123
+ cursor.execute("ALTER TABLE inputs ADD manuscript_excerpts MEDIUMTEXT")
124
+
125
+ # Ensure existing columns are MEDIUMTEXT
126
+ text_columns = ['manuscript', 'outline', 'characters']
127
+ for col in text_columns:
128
+ if col in columns and 'mediumtext' not in columns[col]:
129
+ cursor.execute(f"ALTER TABLE inputs MODIFY {col} MEDIUMTEXT")
130
+
131
+ # Create prompt_history table if it doesn't exist
132
+ cursor.execute("""
133
+ CREATE TABLE IF NOT EXISTS prompt_history (
134
+ id INT AUTO_INCREMENT PRIMARY KEY,
135
+ project_id VARCHAR(255),
136
+ prompt MEDIUMTEXT,
137
+ response MEDIUMTEXT,
138
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
139
+ )
140
+ """)
141
+
142
+ # Check and update prompt_history table schema
143
+ cursor.execute("SHOW COLUMNS FROM prompt_history")
144
+ columns = {row[0]: row[1].lower() for row in cursor.fetchall()}
145
+ for col in ['prompt', 'response']:
146
+ if col in columns and 'mediumtext' not in columns[col]:
147
+ cursor.execute(f"ALTER TABLE prompt_history MODIFY {col} MEDIUMTEXT")
148
+
149
+ conn.commit()
150
+ except mysql.connector.Error as e:
151
+ print(f"Database initialization error: {e}")
152
+ raise
153
+ finally:
154
+ cursor.close()
155
+ conn.close()
156
+
157
+ # Call init_db on startup
158
+ @app.on_event("startup")
159
+ def startup_event():
160
+ init_db()
161
+
162
+ # Generate manuscript summary using Hugging Face model
163
+ def generate_manuscript_summary(manuscript: str):
164
+ if not manuscript:
165
+ return ""
166
+
167
+ # Split manuscript into chunks to avoid memory issues
168
+ try:
169
+ sentences = nltk.sent_tokenize(manuscript)
170
+ except Exception as e:
171
+ print(f"Sentence tokenization error: {e}")
172
+ return ""
173
+
174
+ chunks = []
175
+ current_chunk = ""
176
+ for sentence in sentences:
177
+ if len(current_chunk) + len(sentence) < 4000: # BART max length ~1024 tokens
178
+ current_chunk += sentence + " "
179
+ else:
180
+ chunks.append(current_chunk.strip())
181
+ current_chunk = sentence + " "
182
+ if current_chunk:
183
+ chunks.append(current_chunk.strip())
184
+
185
+ # Summarize each chunk
186
+ summaries = []
187
+ for chunk in chunks:
188
+ try:
189
+ if USE_INFERENCE_API:
190
+ summary = summarizer.summarization(
191
+ chunk,
192
+ max_length=200,
193
+ min_length=50,
194
+ do_sample=False
195
+ )
196
+ summaries.append(summary[0]['summary_text'])
197
+ else:
198
+ summary = summarizer(chunk, max_length=200, min_length=50, do_sample=False)
199
+ summaries.append(summary[0]['summary_text'])
200
+ except Exception as e:
201
+ print(f"Summary generation error: {e}")
202
+ summaries.append("")
203
+
204
+ # Combine and summarize again if too long
205
+ combined_summary = " ".join([s for s in summaries if s])
206
+ if len(combined_summary) > 1000:
207
+ try:
208
+ if USE_INFERENCE_API:
209
+ final_summary = summarizer.summarization(
210
+ combined_summary,
211
+ max_length=1000,
212
+ min_length=500,
213
+ do_sample=False
214
+ )[0]['summary_text']
215
+ else:
216
+ final_summary = summarizer(
217
+ combined_summary,
218
+ max_length=1000,
219
+ min_length=500,
220
+ do_sample=False
221
+ )[0]['summary_text']
222
+ return final_summary
223
+ except Exception as e:
224
+ print(f"Final summary error: {e}")
225
+ return combined_summary[:1000]
226
+ return combined_summary
227
+
228
+ # Generate manuscript excerpts using semantic similarity
229
+ def generate_manuscript_excerpts(manuscript: str, prompt: str):
230
+ if not manuscript or not prompt:
231
+ return ""
232
+
233
+ # Split manuscript into sections (e.g., paragraphs)
234
+ sections = [s.strip() for s in manuscript.split("\n\n") if s.strip()]
235
+ if not sections:
236
+ return ""
237
+
238
+ # Encode prompt and sections
239
+ try:
240
+ prompt_embedding = excerpt_model.encode(prompt)
241
+ section_embeddings = excerpt_model.encode(sections)
242
+
243
+ # Compute cosine similarities
244
+ similarities = np.dot(section_embeddings, prompt_embedding) / (
245
+ np.linalg.norm(section_embeddings, axis=1) * np.linalg.norm(prompt_embedding)
246
+ )
247
+
248
+ # Select top 2 sections
249
+ top_indices = similarities.argsort()[-2:][::-1]
250
+ selected_sections = [sections[i] for i in top_indices]
251
+ excerpts = "\n\n".join(selected_sections)
252
+
253
+ # Truncate to ~2,000 words (~8,000 chars)
254
+ return excerpts[:8000]
255
+ except Exception as e:
256
+ print(f"Excerpt generation error: {e}")
257
+ return ""
258
+
259
+ # Save inputs to MySQL with summary/excerpts
260
+ def save_inputs(project_id: str, manuscript: str, outline: str, characters: str, prompt: str):
261
+ # Truncate inputs to avoid exceeding MEDIUMTEXT limits
262
+ if len(manuscript) > MAX_MEDIUMTEXT_CHARS:
263
+ manuscript = manuscript[:MAX_MEDIUMTEXT_CHARS]
264
+ print(f"Manuscript truncated to {MAX_MEDIUMTEXT_CHARS} characters")
265
+ if len(outline) > MAX_MEDIUMTEXT_CHARS:
266
+ outline = outline[:MAX_MEDIUMTEXT_CHARS]
267
+ print(f"Outline truncated to {MAX_MEDIUMTEXT_CHARS} characters")
268
+ if len(characters) > MAX_MEDIUMTEXT_CHARS:
269
+ characters = characters[:MAX_MEDIUMTEXT_CHARS]
270
+ print(f"Characters truncated to {MAX_MEDIUMTEXT_CHARS} characters")
271
+
272
+ # Check if manuscript is unchanged to avoid redundant processing
273
+ manuscript_hash = hashlib.md5(manuscript.encode()).hexdigest()
274
+ inputs = get_latest_inputs(project_id)
275
+ if inputs.get("manuscript") and hashlib.md5(inputs["manuscript"].encode()).hexdigest() == manuscript_hash:
276
+ summary = inputs["manuscript_summary"]
277
+ excerpts = inputs["manuscript_excerpts"]
278
+ else:
279
+ summary = generate_manuscript_summary(manuscript)
280
+ excerpts = generate_manuscript_excerpts(manuscript, prompt)
281
+
282
+ try:
283
+ conn = mysql.connector.connect(**db_config)
284
+ cursor = conn.cursor()
285
+ cursor.execute(
286
+ "INSERT INTO inputs (project_id, manuscript, manuscript_summary, manuscript_excerpts, outline, characters) VALUES (%s, %s, %s, %s, %s, %s)",
287
+ (project_id, manuscript, summary, excerpts, outline, characters)
288
+ )
289
+ conn.commit()
290
+ except mysql.connector.Error as e:
291
+ print(f"Database save error: {e}")
292
+ raise HTTPException(status_code=500, detail=f"Failed to save inputs: {str(e)}")
293
+ finally:
294
+ cursor.close()
295
+ conn.close()
296
+
297
+ # Save prompt and response to history
298
+ def save_prompt_history(project_id: str, prompt: str, response: str):
299
+ if len(prompt) > MAX_MEDIUMTEXT_CHARS:
300
+ prompt = prompt[:MAX_MEDIUMTEXT_CHARS]
301
+ print(f"Prompt truncated to {MAX_MEDIUMTEXT_CHARS} characters")
302
+ if len(response) > MAX_MEDIUMTEXT_CHARS:
303
+ response = response[:MAX_MEDIUMTEXT_CHARS]
304
+ print(f"Response truncated to {MAX_MEDIUMTEXT_CHARS} characters")
305
+
306
+ try:
307
+ conn = mysql.connector.connect(**db_config)
308
+ cursor = conn.cursor()
309
+ cursor.execute(
310
+ "INSERT INTO prompt_history (project_id, prompt, response) VALUES (%s, %s, %s)",
311
+ (project_id, prompt, response)
312
+ )
313
+ conn.commit()
314
+ except mysql.connector.Error as e:
315
+ print(f"Prompt history save error: {e}")
316
+ raise HTTPException(status_code=500, detail=f"Failed to save prompt history: {str(e)}")
317
+ finally:
318
+ cursor.close()
319
+ conn.close()
320
+
321
+ # Retrieve recent prompt history for context
322
+ def get_prompt_history(project_id: str, limit: int = 2):
323
+ try:
324
+ conn = mysql.connector.connect(**db_config)
325
+ cursor = conn.cursor()
326
+ cursor.execute(
327
+ "SELECT prompt, response FROM prompt_history WHERE project_id = %s ORDER BY created_at DESC LIMIT %s",
328
+ (project_id, limit)
329
+ )
330
+ history = cursor.fetchall()
331
+ cursor.close()
332
+ conn.close()
333
+ return [{"prompt": p, "response": r} for p, r in history]
334
+ except mysql.connector.Error as e:
335
+ print(f"Prompt history retrieval error: {e}")
336
+ return []
337
+
338
+ # Retrieve latest inputs for a project
339
+ def get_latest_inputs(project_id: str):
340
+ try:
341
+ conn = mysql.connector.connect(**db_config)
342
+ cursor = conn.cursor()
343
+ cursor.execute(
344
+ "SELECT manuscript, manuscript_summary, manuscript_excerpts, outline, characters FROM inputs WHERE project_id = %s ORDER BY created_at DESC LIMIT 1",
345
+ (project_id,)
346
+ )
347
+ result = cursor.fetchone()
348
+ cursor.close()
349
+ conn.close()
350
+ return {
351
+ "manuscript": result[0] if result else None,
352
+ "manuscript_summary": result[1] if result else None,
353
+ "manuscript_excerpts": result[2] if result else None,
354
+ "outline": result[3] if result else None,
355
+ "characters": result[4] if result else None
356
+ }
357
+ except mysql.connector.Error as e:
358
+ print(f"Inputs retrieval error: {e}")
359
+ return {}
360
+
361
+ # Generate prose with OpenRouter API
362
+ async def generate_prose_stream(request: ProseRequest):
363
+ # Save full inputs to MySQL with summary/excerpts
364
+ save_inputs(request.project_id, request.manuscript, request.outline, request.characters, request.prompt)
365
+
366
+ # Get latest inputs
367
+ inputs = get_latest_inputs(request.project_id)
368
+
369
+ # Select manuscript content based on mode
370
+ if request.manuscript_mode == "full":
371
+ manuscript_content = (inputs.get("manuscript") or "")[:MAX_MANUSCRIPT_CHARS]
372
+ manuscript_label = "Manuscript Pages (Truncated)"
373
+ elif request.manuscript_mode == "excerpts":
374
+ manuscript_content = inputs.get("manuscript_excerpts") or ""
375
+ manuscript_label = "Manuscript Excerpts"
376
+ else: # summary
377
+ manuscript_content = inputs.get("manuscript_summary") or ""
378
+ manuscript_label = "Manuscript Summary"
379
+
380
+ # Get recent prompt history for context
381
+ history = get_prompt_history(request.project_id)
382
+ history_context = "\n".join(
383
+ [f"Previous Prompt: {h['prompt']}\nPrevious Response: {h['response']}" for h in history]
384
+ )
385
+
386
+ # Construct system prompt
387
+ system_prompt = f"""
388
+ You are a creative writing assistant tasked with generating novel prose. Use the following inputs and prompt history to create coherent, engaging prose that aligns with the provided manuscript content, outline, and character descriptions. If the manuscript content is a summary or excerpts, rely on the outline and characters for additional context. Ensure the tone and style match the provided context and follow the specific request in the prompt.
389
+
390
+ **{manuscript_label}**:
391
+ {manuscript_content or "No manuscript content provided."}
392
+
393
+ **Outline**:
394
+ {request.outline or "No outline provided."}
395
+
396
+ **Character Descriptions**:
397
+ {request.characters or "No character descriptions provided."}
398
+
399
+ **Prompt History**:
400
+ {history_context or "No prompt history available."}
401
+
402
+ **Specific Request**:
403
+ {request.prompt or "Generate a continuation of the manuscript with creative prose."}
404
+ """
405
+
406
+ headers = {
407
+ "Authorization": f"Bearer {OPENROUTER_API_KEY}",
408
+ "Content-Type": "application/json",
409
+ "HTTP-Referer": "https://huggingface.co/spaces/NoLev/NovelCrafter",
410
+ "X-Title": "Novel Prose Generator"
411
+ }
412
+
413
+ payload = {
414
+ "model": request.model,
415
+ "messages": [
416
+ {"role": "system", "content": system_prompt},
417
+ {"role": "user", "content": request.prompt or "Generate novel prose based on the provided inputs."}
418
+ ],
419
+ "stream": True,
420
+ "temperature": 0.7,
421
+ "max_tokens": 1000
422
+ }
423
+
424
+ try:
425
+ response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, stream=True)
426
+ if response.status_code != 200:
427
+ raise HTTPException(status_code=response.status_code, detail="Error from OpenRouter API")
428
+ except Exception as e:
429
+ raise HTTPException(status_code=500, detail=f"OpenRouter API request failed: {str(e)}")
430
+
431
+ # Stream response and collect full response for caching
432
+ full_response = ""
433
+ async def stream_response():
434
+ nonlocal full_response
435
+ for line in response.iter_lines():
436
+ if line:
437
+ decoded_line = line.decode('utf-8')
438
+ if decoded_line.startswith("data: "):
439
+ data = decoded_line[6:]
440
+ if data == "[DONE]":
441
+ continue
442
+ try:
443
+ json_data = json.loads(data)
444
+ content = json_data.get("choices", [{}])[0].get("delta", {}).get("content", "")
445
+ if content:
446
+ full_response += content
447
+ yield content
448
+ except json.JSONDecodeError:
449
+ continue
450
+ # Save prompt and response to history after streaming
451
+ save_prompt_history(request.project_id, request.prompt, full_response)
452
+
453
+ return StreamingResponse(stream_response(), media_type="text/plain")
454
+
455
+ # API endpoint to generate prose
456
+ @app.post("/generate")
457
+ async def generate_prose(request: ProseRequest):
458
+ return await generate_prose_stream(request)
459
+
460
+ # API endpoint to retrieve latest inputs
461
+ @app.get("/inputs/{project_id}")
462
+ async def get_inputs(project_id: str):
463
+ inputs = get_latest_inputs(project_id)
464
+ return inputs
465
+
466
+ # Serve the frontend
467
+ @app.get("/")
468
+ async def serve_index():
469
+ with open("static/index.html", "r") as f:
470
+ return HTMLResponse(content=f.read())