Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import requests | |
| import uvicorn | |
| import threading | |
| import sqlite3 | |
| import os | |
| import bcrypt | |
| from fastapi import FastAPI, Request, Depends, HTTPException, BackgroundTasks, status | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from fastapi.concurrency import run_in_threadpool | |
| from datetime import datetime, timedelta | |
| import tempfile | |
| import asyncio | |
| import html | |
| import re | |
| from pathlib import Path | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| from fpdf import FPDF | |
| from contextlib import asynccontextmanager | |
| import time | |
| # JWT specific imports | |
| from jose import JWTError, jwt | |
| from fastapi.security import OAuth2PasswordBearer | |
| # Assuming other python files are in the same directory for local import | |
| # Make sure chatbot.py, db.py, and auth.py are in the same directory | |
| from chatbot import ask_groq | |
| from db import init_db, get_db | |
| from auth import create_user, verify_user | |
| # =================================================================== | |
| # JWT Configuration | |
| # =================================================================== | |
| SECRET_KEY = os.getenv("JWT_SECRET_KEY", "your-super-secret-jwt-key") # USE A STRONG RANDOM KEY IN PRODUCTION | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_MINUTES = 24 * 60 # 24 hours, matching remember_me idea | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") # This tokenUrl is for docs, actual login is handled below | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.utcnow() + expires_delta | |
| else: | |
| expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| return encoded_jwt | |
| # This function verifies the token and returns the username (subject) | |
| def verify_access_token(token: str, credentials_exception): | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| # You could also fetch user from DB here to ensure they are active | |
| return username | |
| except JWTError: | |
| raise credentials_exception | |
| # async def get_current_user_id(token: str = Depends(oauth2_scheme), db: sqlite3.Connection = Depends(get_db)): | |
| # credentials_exception = HTTPException( | |
| # status_code=status.HTTP_401_UNAUTHORIZED, | |
| # detail="Could not validate credentials", | |
| # headers={"WWW-Authenticate": "Bearer"}, | |
| # ) | |
| # username = verify_access_token(token, credentials_exception) | |
| # # Fetch user_id from database using username | |
| # # FIX: Correctly await run_in_threadpool and then call fetchone() | |
| # cursor = await run_in_threadpool(db.execute, "SELECT id FROM users WHERE username = ?", (username,)) | |
| # user_row = cursor.fetchone() | |
| # if user_row is None: | |
| # raise credentials_exception | |
| # return user_row["id"] # Return the user_id | |
| async def get_current_user_id(token: str = Depends(oauth2_scheme), db: sqlite3.Connection = Depends(get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| username = verify_access_token(token, credentials_exception) | |
| # All SQLite operations must be within the same thread. | |
| # Define a lambda to perform both execute and fetchone in the threadpool. | |
| user_row = await run_in_threadpool( | |
| lambda: db.execute("SELECT id FROM users WHERE username = ?", (username,)).fetchone() | |
| ) | |
| if user_row is None: | |
| raise credentials_exception | |
| return user_row["id"] # Return the user_id | |
| # =================================================================== | |
| # 1. GRADIO/UI CODE (with one key change) | |
| # =================================================================== | |
| # THIS IS THE CRITICAL CHANGE FOR DEPLOYMENT | |
| # API_URL now points to the root of the FastAPI app, | |
| # as Gradio will be mounted under a subpath. | |
| # The UI functions will construct the full URL using request.url.scheme and request.url.netloc | |
| # and then append the API_URL for backend calls. | |
| API_URL = "" | |
| # REMOVED: session = requests.Session() | |
| # Add this function somewhere in your app.py, e.g., above on_load | |
| def check_and_set_token(token_from_local_storage: Optional[str]): | |
| print(f"check_and_set_token called. Token from JS: {token_from_local_storage[:20] if token_from_local_storage else 'None'}") | |
| # This function is called by the JavaScript that retrieves the token | |
| # It simply takes the token from local storage and passes it to http_session | |
| # The actual validation happens in on_load | |
| return token_from_local_storage | |
| # Helper to convert backend history to Gradio's format | |
| def _format_history_for_chatbot(history_list): | |
| formatted = [] | |
| for user_msg, bot_reply in history_list: | |
| formatted.append([user_msg, bot_reply]) | |
| return formatted | |
| # Helper to convert Gradio's format back for backend use | |
| def _convert_chatbot_history_to_backend_format(chatbot_history): | |
| backend_history = [] | |
| for pair in chatbot_history: | |
| if isinstance(pair, (list, tuple)) and len(pair) == 2: | |
| user_msg, bot_reply = pair | |
| backend_history.append({'message': user_msg, 'response': bot_reply}) | |
| elif isinstance(pair, (list, tuple)) and len(pair) == 1: | |
| user_msg = pair[0] | |
| backend_history.append({'message': user_msg, 'response': None}) | |
| return backend_history | |
| # Helper function to format warning messages with HTML/CSS | |
| def format_warning(message): | |
| """Wraps a message in a styled HTML div for warnings.""" | |
| return f"<div class='warning-message'>{message}</div>" | |
| signup_success_flag = gr.State(value=False) | |
| # NOTE: http_session will now store the JWT token string, not a requests.Session object | |
| def sign_up(username, password, http_session_token: Optional[str], request: gr.Request): | |
| if not username or not password: | |
| return gr.update(), gr.update(), format_warning("⚠️ Username and password cannot be empty."), False, None | |
| try: | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| url = f"{base_internal_url}/signup" | |
| # Use a fresh requests session for signup, as we don't have a token yet | |
| r = requests.post(url, json={"username": username, "password": password}) | |
| r.raise_for_status() | |
| result = r.json() | |
| if result["success"]: | |
| return ( | |
| gr.update(value=""), | |
| gr.update(value=""), | |
| "✅ Signup successful! Switching to login...", | |
| True, | |
| None # No token to set yet | |
| ) | |
| else: | |
| return gr.update(), gr.update(), format_warning(f"⚠️ {result.get('message', 'Signup failed.')}"), False, None | |
| except requests.exceptions.HTTPError as e: | |
| try: | |
| message = e.response.json().get("message", "Signup failed.") | |
| except Exception: | |
| message = "Signup failed." | |
| return gr.update(), gr.update(), format_warning(f"⚠️ Signup failed: {message}"), False, None | |
| except requests.RequestException as e: | |
| print(f"Signup request exception: {e}") | |
| return gr.update(), gr.update(), format_warning(f"⚠️ Signup error: Unable to connect to the server. Please try again later."), False, None | |
| # NOTE: http_session_token will be updated with the JWT token string | |
| def log_in(username, password, remember_me, http_session_token: Optional[str], request: gr.Request): | |
| if not username or not password: | |
| return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), format_warning("⚠️ Username and password cannot be empty."), None, "") | |
| if not remember_me: | |
| return ( | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), gr.update(), gr.update(), gr.update(), | |
| gr.update(), format_warning("⚠️ You must agree to the 24-hour session to log in."), | |
| None, "" | |
| ) | |
| try: | |
| payload = {"username": username, "password": password} # remember_me is handled by JWT expiry | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| url = f"{base_internal_url}/login" | |
| # Use a fresh requests session for login to get the token | |
| r = requests.post(url, json=payload) | |
| r.raise_for_status() | |
| result = r.json() | |
| token = result.get("access_token") # Get the JWT token | |
| if token: | |
| print(f"Login successful. Received token: {token[:20]}...") | |
| http_session_token = token # Store the JWT in the Gradio state | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| chat_data_url = f"{base_internal_url}/get_current_chat_history" | |
| conv_data_url = f"{base_internal_url}/get_conversations" | |
| # Use requests with JWT headers | |
| chat_data = requests.get(chat_data_url, headers=headers).json() | |
| conv_data = requests.get(conv_data_url, headers=headers).json() | |
| conv_choices = [("🗁 New Chat", "EMPTY_CONVO")] + [(c['title'], c['id']) for c in conv_data.get('conversations', [])] | |
| selected_value = chat_data.get("current_conversation_id") | |
| dropdown_selected_value = selected_value if any(c_id == selected_value for _, c_id in conv_choices) else "EMPTY_CONVO" | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| _format_history_for_chatbot(chat_data.get("history", [])), | |
| dropdown_selected_value, | |
| gr.update(choices=conv_choices, value=dropdown_selected_value), | |
| gr.update(value=""), | |
| gr.update(value=""), | |
| gr.update(value=False), | |
| gr.update(visible=False), | |
| f"✅ Login successful! Welcome {username}.", | |
| http_session_token, # Return the updated token state | |
| f"<script>setTokenInLocalStorage('{token}');</script>" # MODIFIED: Call pre-defined JS function | |
| ) | |
| else: | |
| warning_message = f"⚠️ {result.get('message', 'Login failed. Please check your credentials.')}" | |
| return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), format_warning(warning_message), None, "") | |
| except requests.exceptions.HTTPError as e: | |
| try: | |
| message = e.response.json().get("detail", e.response.json().get("message", "Login failed.")) | |
| except Exception: | |
| message = "Login failed." | |
| error_message = f"⚠️ Login failed: {message}" | |
| return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), format_warning(error_message), None, "") | |
| except requests.RequestException as e: | |
| print(f"Login request exception: {e}") | |
| error_message = f"⚠️ Login error: Unable to connect to the server. Please try again later." | |
| return (gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), format_warning(error_message), None, "") | |
| # NOTE: http_session_token will be cleared | |
| def log_out(http_session_token: Optional[str], request: gr.Request): | |
| # For JWT, logging out client-side simply means discarding the token. | |
| # No backend call is strictly necessary unless you have a JWT blacklist (more complex). | |
| gr.Info("Logged out successfully!") | |
| return (gr.update(visible=True), gr.update(visible=False), [], gr.update(choices=[], value=None), gr.update(visible=True), None, f"<script>clearTokenFromLocalStorage();</script>") | |
| def sleep_for_a_moment(): | |
| time.sleep(2) | |
| return | |
| def switch_to_login_tab_if_success(success_flag): | |
| return gr.update(selected="Login") if success_flag else gr.update() | |
| # NOTE: http_session_token used for authorization | |
| def chat_with_bot(msg, history, http_session_token: Optional[str], request: gr.Request): | |
| if not msg: | |
| return "", history, http_session_token | |
| if not http_session_token: | |
| gr.Warning("Please log in to chat.") | |
| return msg, history, http_session_token | |
| try: | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| url = f"{base_internal_url}/chat" | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| r = requests.post(url, json={"message": msg}, headers=headers) | |
| r.raise_for_status() | |
| reply = r.json()["reply"] | |
| history.append((msg, reply)) | |
| return "", history, http_session_token | |
| except requests.exceptions.HTTPError as e: | |
| if e.response.status_code == 401: | |
| gr.Warning("Your session has expired. Please log in again.") | |
| return msg, history, None # Clear token state | |
| else: | |
| gr.Warning(f"Chat error: {e}") | |
| return msg, history, http_session_token | |
| except requests.RequestException as e: | |
| gr.Warning(f"Chat error: {e}") | |
| return msg, history, http_session_token | |
| # NOTE: http_session_token used for authorization | |
| def start_new_conversation(http_session_token: Optional[str], request: gr.Request): | |
| if not http_session_token: | |
| gr.Warning("Please log in to start a new conversation.") | |
| return [], None, None | |
| try: | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| url = f"{base_internal_url}/new_conversation" | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| r_new = requests.post(url, headers=headers) | |
| r_new.raise_for_status() | |
| new_conv_id = r_new.json().get("conversation_id") | |
| gr.Info("New conversation started!") | |
| return [], new_conv_id, http_session_token | |
| except requests.exceptions.HTTPError as e: | |
| if e.response.status_code == 401: | |
| gr.Warning("Your session has expired. Please log in again.") | |
| return [], None, None # Clear token state | |
| else: | |
| gr.Warning(f"Failed to start new conversation: {e}") | |
| return [], None, http_session_token | |
| except requests.RequestException as e: | |
| gr.Warning(f"Failed to start new conversation: {e}") | |
| return [], None, http_session_token | |
| # NOTE: http_session_token used for authorization | |
| def load_selected_conversation(conv_id, http_session_token: Optional[str], request: gr.Request): | |
| if not http_session_token: | |
| gr.Warning("Please log in to load conversations.") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), None | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| if not conv_id or conv_id == "EMPTY_CONVO": | |
| try: | |
| conversations_r = requests.get(f"{base_internal_url}/get_conversations", headers=headers) | |
| conversations_r.raise_for_status() | |
| conv_list_data = conversations_r.json().get("conversations", []) | |
| conv_dropdown_choices = [("🗁 New Chat", "EMPTY_CONVO")] + [(c['title'], c['id']) for c in conv_list_data] | |
| return [], gr.update(choices=conv_dropdown_choices, value="EMPTY_CONVO"), http_session_token | |
| except requests.exceptions.HTTPError as e: | |
| if e.response.status_code == 401: | |
| gr.Warning("Your session has expired. Please log in again.") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), None | |
| else: | |
| gr.Warning(f"Failed to load conversations: {e}") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), http_session_token | |
| except requests.RequestException as e: | |
| gr.Warning(f"Failed to load conversations: {e}") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), http_session_token | |
| try: | |
| r = requests.get(f"{base_internal_url}/load_conversation/{conv_id}", headers=headers) | |
| r.raise_for_status() | |
| data = r.json() | |
| formatted_history = _format_history_for_chatbot(data.get("history", [])) | |
| conversations_r = requests.get(f"{base_internal_url}/get_conversations", headers=headers) | |
| conversations_r.raise_for_status() | |
| conv_list_data = conversations_r.json().get("conversations", []) | |
| conv_dropdown_choices = [("🗁 New Chat", "EMPTY_CONVO")] + [(c['title'], c['id']) for c in conv_list_data] | |
| valid_ids = [c_id for _, c_id in conv_dropdown_choices if c_id != "EMPTY_CONVO"] | |
| dropdown_selected_value = conv_id if conv_id in valid_ids else "EMPTY_CONVO" | |
| return formatted_history, gr.update(choices=conv_dropdown_choices, value=dropdown_selected_value), http_session_token | |
| except requests.exceptions.HTTPError as e: | |
| if e.response.status_code == 401: | |
| gr.Warning("Your session has expired. Please log in again.") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), None | |
| else: | |
| gr.Warning(f"Failed to load conversation: {e}") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), http_session_token | |
| except requests.RequestException as e: | |
| gr.Warning(f"Failed to load conversation: {e}") | |
| return [], gr.update(choices=[("🗁 New Chat", "EMPTY_CONVO")], value="EMPTY_CONVO"), http_session_token | |
| # NOTE: http_session_token used for authorization | |
| def generate_summary(chat_history, http_session_token: Optional[str], request: gr.Request): | |
| if not http_session_token: | |
| gr.Warning("Please log in to generate summaries.") | |
| return (gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value="Please log in.", visible=True), None) | |
| if not chat_history: | |
| gr.Warning("Chat is empty, nothing to summarize.") | |
| return (gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value="Chat is empty.", visible=True), http_session_token) | |
| backend_history = _convert_chatbot_history_to_backend_format(chat_history) | |
| try: | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| r = requests.post(f"{base_internal_url}/summarize_chat", json={"history": backend_history}, headers=headers) | |
| r.raise_for_status() | |
| result = r.json() | |
| if result["success"]: | |
| file_path = result["file_path"] | |
| return ( | |
| gr.File.update(value=file_path, visible=True), | |
| gr.Markdown.update(value=f"Summary ready!)", visible=False), http_session_token | |
| ) | |
| else: | |
| return ( | |
| gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value=f"Error: {result.get('message', 'Unknown error during summarization.')}", visible=True), http_session_token | |
| ) | |
| except requests.exceptions.HTTPError as e: | |
| if e.response.status_code == 401: | |
| gr.Warning("Your session has expired. Please log in again.") | |
| return (gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value="Session expired. Please log in.", visible=True), None) | |
| else: | |
| return ( | |
| gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value=f"Error generating summary: {e}", visible=True), http_session_token | |
| ) | |
| except requests.RequestException as e: | |
| return ( | |
| gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value=f"Error generating summary: {e}", visible=True), http_session_token | |
| ) | |
| # NOTE: http_session_token used for authorization | |
| def generate_flashcards(file_format, chat_history, http_session_token: Optional[str], request: gr.Request): | |
| if not http_session_token: | |
| gr.Warning("Please log in to generate flashcards.") | |
| return (gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value="Please log in.", visible=True), None) | |
| if not chat_history: | |
| gr.Warning("Chat is empty, nothing to generate flashcards from.") | |
| return (gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value="Chat is empty.", visible=True), http_session_token) | |
| backend_history = _convert_chatbot_history_to_backend_format(chat_history) | |
| try: | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| r = requests.post(f"{base_internal_url}/generate_flashcards", json={"history": backend_history, "format": file_format.lower()}, headers=headers) | |
| r.raise_for_status() | |
| result = r.json() | |
| if result["success"]: | |
| file_path = result["file_path"] | |
| if "html" in file_format.lower(): | |
| return (gr.File.update(value=file_path, visible=True), | |
| gr.Markdown.update(value=f"Flashcards ready!", visible=True), http_session_token) | |
| else: | |
| return ( | |
| gr.File.update(value=file_path, visible=True), | |
| gr.Markdown.update(value=f"Flashcards ready!", visible=True), http_session_token | |
| ) | |
| else: | |
| return ( | |
| gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value=f"Error: {result.get('message', 'Unknown error during flashcard generation.')}", visible=True), http_session_token | |
| ) | |
| except requests.exceptions.HTTPError as e: | |
| if e.response.status_code == 401: | |
| gr.Warning("Your session has expired. Please log in again.") | |
| return (gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value="Session expired. Please log in.", visible=True), None) | |
| else: | |
| return ( | |
| gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value=f"Error generating flashcards: {e}", visible=True), http_session_token | |
| ) | |
| except requests.RequestException as e: | |
| return ( | |
| gr.File.update(value=None, visible=False), | |
| gr.Markdown.update(value=f"Error generating flashcards: {e}", visible=True), http_session_token | |
| ) | |
| # NOTE: http_session_state will be the JWT token string | |
| def on_load(http_session_token: Optional[str], request: gr.Request): | |
| # This function is called on page load. | |
| # We check if a token exists in the Gradio state. | |
| # If it does, we attempt to validate it with the backend. | |
| base_internal_url = f"{request.url.scheme}://localhost:7860" | |
| logged_out_ui_updates = ( | |
| gr.update(visible=True), # Auth UI visible | |
| gr.update(visible=False), # Chat UI hidden | |
| [], # Chatbot empty | |
| None, # Current conversation ID | |
| gr.update(choices=[], value=None), # Conversation dropdown empty | |
| gr.update(visible=True), # About image visible | |
| None # Clear token state | |
| ) | |
| if not http_session_token: | |
| print("on_load: No token found in Gradio state. User is logged out.") | |
| return logged_out_ui_updates | |
| try: | |
| headers = {"Authorization": f"Bearer {http_session_token}"} | |
| r = requests.get(f"{base_internal_url}/check_login_status", headers=headers) | |
| print(f"on_load: check_login_status response: {r.status_code}, {r.text}") | |
| if r.ok and r.json().get("logged_in"): | |
| # If logged in, fetch chat data and conversations using the authenticated token | |
| chat_data = requests.get(f"{base_internal_url}/get_current_chat_history", headers=headers).json() | |
| conv_data = requests.get(f"{base_internal_url}/get_conversations", headers=headers).json() | |
| conv_choices = [("🗁 New Chat", "EMPTY_CONVO")] + [(c['title'], c['id']) for c in conv_data.get('conversations', [])] | |
| selected_value = chat_data.get("current_conversation_id") | |
| dropdown_selected_value = selected_value if any(c_id == selected_value for _, c_id in conv_choices) else "EMPTY_CONVO" | |
| return ( | |
| gr.update(visible=False), # Auth UI hidden | |
| gr.update(visible=True), # Chat UI visible | |
| _format_history_for_chatbot(chat_data.get("history", [])), | |
| selected_value, | |
| gr.update(choices=conv_choices, value=dropdown_selected_value), | |
| gr.update(visible=False), # About image hidden | |
| http_session_token # Keep the token in state | |
| ) | |
| else: | |
| # Token might be invalid or expired, or backend says not logged in | |
| print("on_load: Backend reports not logged in or token invalid. Clearing state.") | |
| return logged_out_ui_updates | |
| except requests.ConnectionError: | |
| gr.Warning("Could not connect to the backend.") | |
| return logged_out_ui_updates | |
| except Exception as e: | |
| print(f"Error on load: {e}") | |
| gr.Warning(f"An error occurred during loading: {e}") | |
| return logged_out_ui_updates | |
| def show_generating_summary(): | |
| return gr.update(visible=True) | |
| def hide_generating_summary(): | |
| return gr.update(visible=False) | |
| def show_generating_flashcards(): | |
| return gr.update(visible=True) | |
| def hide_generating_flashcards(): | |
| return gr.update(visible=False) | |
| custom_css = Path("style.css").read_text() if Path("style.css").exists() else "" | |
| with gr.Blocks(theme=None, elem_id="flashcard_block", css=custom_css) as flashcard_ui: | |
| flashcard_format = gr.Radio(["PDF", "HTML (Interactive)"], show_label=False, value="PDF") | |
| flashcard_btn = gr.Button("Generate Flashcards", elem_id="submit_buttons") | |
| generating_flashcards_msg = gr.Markdown("Generating flashcards, please wait...", visible=False) | |
| flashcard_file = gr.File(label="Download Flashcards", visible=False, interactive=False) | |
| flashcard_output = gr.Markdown() | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Query Foxy Chat", css=custom_css) as demo: | |
| # Initialize gr.State to store the JWT token string (Optional[str]) | |
| http_session = gr.State(None) | |
| # NEW: HTML component to inject JavaScript for setting/clearing local storage | |
| # js_executor = gr.HTML(value="", visible=False) # Start with empty value and hidden | |
| # Keep http_session and local_storage_token_input as they are. | |
| # MODIFIED: js_executor now holds global JS functions for setting/clearing localStorage | |
| js_executor = gr.HTML( | |
| value=""" | |
| <script> | |
| // Define a global function to set the token in localStorage | |
| function setTokenInLocalStorage(token) { | |
| console.log('JS function: Attempting to set token in localStorage...'); | |
| localStorage.setItem('gradio_jwt_token', token); | |
| console.log('JS function: Token set successfully (or attempted).'); | |
| } | |
| // Define a global function to clear the token from localStorage | |
| function clearTokenFromLocalStorage() { | |
| console.log('JS function: Attempting to clear token from localStorage...'); | |
| localStorage.removeItem('gradio_jwt_token'); | |
| console.log('JS function: Token cleared successfully (or attempted).'); | |
| } | |
| </script> | |
| """, | |
| visible=False # Keep it hidden | |
| ) | |
| # # New: Component to trigger JavaScript for setting/clearing local storage | |
| # token_js_trigger = gr.JSON( | |
| # value={"action": "none", "token": None}, | |
| # visible=False, # This component is hidden | |
| # _js=r""" | |
| # (value) => { | |
| # if (value && value.action === 'set' && value.token) { | |
| # localStorage.setItem('gradio_jwt_token', value.token); | |
| # console.log('Token set in localStorage:', value.token.substring(0, 20) + '...'); | |
| # } else if (value && value.action === 'clear') { | |
| # localStorage.removeItem('gradio_jwt_token'); | |
| # console.log('Token cleared from localStorage'); | |
| # } | |
| # return value; // Return the value unchanged | |
| # } | |
| # """ | |
| # ) | |
| # New: Hidden Textbox to receive token from localStorage on load | |
| local_storage_token_input = gr.Textbox(visible=False) | |
| gr.Image( | |
| value="assets/logo.png", | |
| elem_id="logo-img", | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| width=250, | |
| height=105 | |
| ) | |
| with gr.Column(visible=False) as chat_ui: | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_id="chatbot-cont"): | |
| gr.Markdown("### Past Chats:", elem_classes="pastconvos-label") | |
| conversation_dd = gr.Dropdown( | |
| show_label=False, | |
| interactive=True, | |
| allow_custom_value=False, | |
| ) | |
| new_chat_btn = gr.Button("New Chat", elem_id="submit_buttons") | |
| logout_btn = gr.Button("Logout", elem_id="submit_buttons") | |
| with gr.Accordion("Export Tools", open=False): | |
| gr.Markdown("#### Generate Summary") | |
| summary_btn = gr.Button("Generate PDF Summary", elem_id="submit_buttons") | |
| generating_summary_msg = gr.Markdown("Generating Summary, please wait...", visible=False) | |
| summary_file = gr.File(label="Download Summary", visible=False, interactive=False) | |
| summary_output = gr.Markdown() | |
| gr.Markdown("---") | |
| gr.Markdown("#### Generate Flashcards") | |
| gr.Markdown("### Format", elem_classes="custom-label") | |
| flashcard_ui.render() | |
| with gr.Column(scale=3, elem_id="chatbot-cont"): | |
| chatbot = gr.Chatbot( | |
| label="Query Foxy", height=500, | |
| avatar_images=(None, "https://github.com/MahekTrivedi44/logo/blob/main/download%20(13).jpg?raw=true") | |
| ) | |
| msg_txt = gr.Textbox(label="Your Message", placeholder="Type here...", show_label=False, lines=1) | |
| send_btn = gr.Button("Send", elem_id="send_buttons") | |
| with gr.Row(): | |
| with gr.Column(scale=2, visible=True, elem_id="about_img_col") as about_img_col: | |
| gr.Image( | |
| value="assets/about.png", | |
| elem_id="about-img", | |
| show_label=False, | |
| interactive=False, | |
| show_download_button=False, | |
| ) | |
| with gr.Column(scale=1) as auth_ui: | |
| with gr.Column(visible=True, elem_id="auth_container"): | |
| with gr.Tabs(selected="Login", elem_id="auth_tabs") as auth_tabs: | |
| with gr.Tab("Log In", id="Login"): | |
| gr.Markdown("### Username", elem_classes="custom-label") | |
| login_user = gr.Textbox(show_label=False, placeholder="Enter your username", elem_id="login-user") | |
| gr.Markdown("### Password", elem_classes="custom-label") | |
| login_pass = gr.Textbox(type="password", show_label=False, placeholder="Enter your password", elem_id="login-pass") | |
| remember_chk = gr.Checkbox(label="I understand this session lasts 24 hrs unless I log out.") | |
| login_btn = gr.Button("Log In", elem_id="submit_buttons") | |
| with gr.Tab("Sign Up", id="SignUp"): | |
| gr.Markdown("### Username", elem_classes="custom-label") | |
| signup_user = gr.Textbox(show_label=False, placeholder="Create your username", elem_id="login-user") | |
| gr.Markdown("### Password", elem_classes="custom-label") | |
| signup_pass = gr.Textbox(type="password", show_label=False, placeholder="Create your password", elem_id="login-pass") | |
| signup_btn = gr.Button("Sign Up", elem_id="submit_buttons") | |
| signup_success_flag = gr.State(value=False) | |
| status_output = gr.Markdown() | |
| with gr.Column(elem_id="creator_info_container"): | |
| gr.Markdown("### Connect with the Creator", elem_classes="custom-label") | |
| with gr.Row(elem_id="creator_button_row"): | |
| gr.Button("LinkedIn", link="https://www.linkedin.com/in/mahek-devang-trivedi-511a1b29a/", elem_id="btn_linkedin") | |
| gr.Button("GitHub", link="https://github.com/MahekTrivedi44", elem_id="btn_github") | |
| gr.Button("Mail", link="mailto:mahektrivedi2006@gmail.com", elem_id="btn_mail") | |
| current_conversation_id_state = gr.State(None) | |
| def clear_status_on_tab_change(): | |
| return "" | |
| # @login_btn.click( | |
| # inputs=[login_user, login_pass, remember_chk, http_session], | |
| # outputs=[auth_ui, chat_ui, chatbot, current_conversation_id_state, conversation_dd, login_user, login_pass, remember_chk, about_img_col, status_output, http_session] | |
| # ) | |
| # def login_handler(username, password, remember_me, http_session_val, request: gr.Request): | |
| # return log_in(username, password, remember_me, http_session_val, request=request) | |
| def login_handler(username, password, remember_me, http_session_val, request: gr.Request): | |
| return log_in(username, password, remember_me, http_session_val, request=request) | |
| # logout_btn.click(log_out, inputs=[http_session], outputs=[auth_ui, chat_ui, chatbot, conversation_dd, about_img_col, http_session]) | |
| logout_btn.click( | |
| log_out, | |
| inputs=[http_session], | |
| outputs=[auth_ui, chat_ui, chatbot, conversation_dd, about_img_col, http_session, js_executor] # NEW: Add js_executor here | |
| ) | |
| signup_btn.click( | |
| fn=sign_up, | |
| inputs=[signup_user, signup_pass, http_session], # http_session is passed but not used for actual signup in new JWT flow | |
| outputs=[signup_user, signup_pass, status_output, signup_success_flag, http_session] | |
| ).then( | |
| fn=sleep_for_a_moment, | |
| inputs=None, | |
| outputs=None | |
| ).then( | |
| fn=switch_to_login_tab_if_success, | |
| inputs=signup_success_flag, | |
| outputs=[auth_tabs] | |
| ) | |
| msg_txt.submit(chat_with_bot, inputs=[msg_txt, chatbot, http_session], outputs=[msg_txt, chatbot, http_session]) | |
| send_btn.click(chat_with_bot, inputs=[msg_txt, chatbot, http_session], outputs=[msg_txt, chatbot, http_session]) | |
| new_chat_btn.click(start_new_conversation, inputs=[http_session], outputs=[chatbot, current_conversation_id_state, http_session]).then(fn=load_selected_conversation, inputs=[current_conversation_id_state, http_session], outputs=[chatbot, conversation_dd, http_session]) | |
| conversation_dd.change(load_selected_conversation, inputs=[conversation_dd, http_session], outputs=[chatbot, conversation_dd, http_session]) | |
| summary_btn.click( | |
| show_generating_summary, inputs=[], outputs=generating_summary_msg | |
| ).then( | |
| fn=generate_summary, inputs=[chatbot, http_session], outputs=[summary_file, summary_output, http_session] | |
| ).then( | |
| fn=hide_generating_summary, inputs=[], outputs=generating_summary_msg | |
| ) | |
| flashcard_btn.click( | |
| show_generating_flashcards, inputs=[], outputs=generating_flashcards_msg | |
| ).then( | |
| fn=generate_flashcards, inputs=[flashcard_format, chatbot, http_session], outputs=[flashcard_file, flashcard_output, http_session] | |
| ).then( | |
| fn=hide_generating_flashcards, inputs=[], outputs=generating_flashcards_msg | |
| ) | |
| # demo.load(on_load, inputs=[http_session], outputs=[auth_ui, chat_ui, chatbot, current_conversation_id_state, conversation_dd, about_img_col, http_session]) | |
| # demo.load( | |
| # # First, run a JS function to get token from localStorage and update local_storage_token_input | |
| # None, # No Python inputs for this initial JS call | |
| # outputs=[local_storage_token_input], # Output to the hidden textbox | |
| # _js="() => localStorage.getItem('gradio_jwt_token')" # The JavaScript to execute | |
| # ).then( | |
| # # Then, take the token from local_storage_token_input and update http_session via check_and_set_token | |
| # fn=check_and_set_token, | |
| # inputs=[local_storage_token_input], | |
| # outputs=[http_session] # Update the http_session state | |
| # ).then( | |
| # # Finally, run your original on_load function using the (now potentially populated) http_session | |
| # on_load, | |
| # inputs=[http_session], | |
| # outputs=[auth_ui, chat_ui, chatbot, current_conversation_id_state, conversation_dd, about_img_col, http_session] | |
| # ) | |
| # demo.load( | |
| # # First, run a JS function to get token from localStorage and update local_storage_token_input | |
| # # The _js for demo.load is expected to return a value that updates its first output. | |
| # None, # No Python inputs for this initial JS call | |
| # outputs=[local_storage_token_input], # Output to the hidden textbox | |
| # _js="() => localStorage.getItem('gradio_jwt_token')" # The JavaScript to execute on load | |
| # ).then( | |
| # # Then, take the token from local_storage_token_input and update http_session via check_and_set_token | |
| # fn=check_and_set_token, | |
| # inputs=[local_storage_token_input], | |
| # outputs=[http_session] # Update the http_session state | |
| # ).then( | |
| # # Finally, run your original on_load function using the (now potentially populated) http_session | |
| # on_load, | |
| # inputs=[http_session], | |
| # outputs=[auth_ui, chat_ui, chatbot, current_conversation_id_state, conversation_dd, about_img_col, http_session] | |
| # ) | |
| demo.load( | |
| None, # No Python inputs for this initial JS call | |
| outputs=[local_storage_token_input], # The JS return value populates this Textbox | |
| _js="() => { console.log('Attempting to get token from localStorage...'); const token = localStorage.getItem('gradio_jwt_token'); console.log('Token from localStorage:', token ? token.substring(0, 20) + '...' : 'None'); return token; }" | |
| ) | |
| # Step 2: When `local_storage_token_input`'s value changes (i.e., when it gets the token from JS), | |
| # pass that token to `check_and_set_token` to update `http_session`. | |
| local_storage_token_input.change( | |
| fn=check_and_set_token, | |
| inputs=[local_storage_token_input], | |
| outputs=[http_session] # `check_and_set_token` returns the token, which updates http_session | |
| ).then( | |
| # Step 3: Once `http_session` is updated, call the `on_load` function to set up the UI. | |
| fn=on_load, | |
| inputs=[http_session], | |
| outputs=[auth_ui, chat_ui, chatbot, current_conversation_id_state, conversation_dd, about_img_col, http_session] | |
| ) | |
| # =================================================================== | |
| # 2. BACKEND/API CODE (now mounted on the Gradio app) | |
| # =================================================================== | |
| # --- Pydantic Models for Request Validation --- | |
| class UserCredentials(BaseModel): | |
| username: str | |
| password: str | |
| # remember_me is no longer directly used in backend, JWT expiry handles it | |
| # remember_me: Optional[bool] = False | |
| class ChatMessage(BaseModel): | |
| message: str | |
| class HistoryPayload(BaseModel): | |
| history: List[dict] | |
| class FlashcardPayload(BaseModel): | |
| history: List[dict] | |
| format: str | |
| # --- Helper Functions (from main.py) --- | |
| # Modified to not use session, current_conversation_id will be managed by client or query param | |
| def get_or_create_default_conversation(user_id: int, db: sqlite3.Connection): | |
| # This will now always create a new conversation for simplicity on login | |
| # If you want to load a user's *last* conversation, you'd need to store that in the DB | |
| # or pass it from the frontend. For this JWT conversion, we'll simplify. | |
| cursor = db.execute("INSERT INTO conversations (user_id, title) VALUES (?, ?)", | |
| (user_id, f"Chat {datetime.now().strftime('%Y-%m-%d %H:%M')}")) | |
| new_conv_id = cursor.lastrowid | |
| db.commit() | |
| return new_conv_id | |
| async def delete_file_later(path: str, delay: int = 300): | |
| """Deletes a file after a specified delay.""" | |
| await asyncio.sleep(delay) | |
| try: | |
| if os.path.exists(path): | |
| os.remove(path) | |
| print(f"🧹 Deleted temp file: {path}") | |
| except Exception as e: | |
| print(f"Error deleting file {path}: {e}") | |
| # --- PDF Generation Classes and Helpers --- | |
| def safe_multicell(pdf_obj, line): | |
| """Safely add a multi-line cell to a PDF, handling potential encoding errors.""" | |
| try: | |
| cleaned = re.sub(r'[^\x20-\x7E\n\r]', '', line) # Keep basic printable chars and newlines | |
| page_width = pdf_obj.w - 2 * pdf_obj.l_margin | |
| pdf_obj.multi_cell(page_width, 6, cleaned) | |
| except Exception as e: | |
| print(f"⚠️ PDF error in safe_multicell: {e} for line: {line[:50]}...") | |
| truncated = cleaned[:200] + "..." if len(cleaned) > 200 else cleaned | |
| try: | |
| pdf_obj.multi_cell(page_width, 6, truncated) | |
| except Exception as e_fallback: | |
| print(f"⚠️ PDF fallback multi_cell also failed: {e_fallback}") | |
| class CustomPDF(FPDF): | |
| """A custom PDF class to handle headers and Unicode fonts.""" | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| try: | |
| self.add_font('DejaVuSans', '', 'DejaVuSans.ttf', uni=True) | |
| self.add_font('DejaVuSans', 'B', 'DejaVuSans-Bold.ttf', uni=True) | |
| self.set_font('DejaVuSans', '', 10) | |
| except RuntimeError: | |
| print("Warning: DejaVu fonts not found. Falling back to Arial.") | |
| self.set_font('Arial', '', 10) | |
| def ensure_space(self, min_height=15): | |
| """Start a new page if there's not enough vertical space left.""" | |
| if self.get_y() + min_height > self.page_break_trigger: | |
| self.add_page() | |
| def header(self): | |
| try: | |
| self.set_font('DejaVuSans', 'B', 15) | |
| except RuntimeError: | |
| self.set_font('Arial', 'B', 15) | |
| safe_multicell(self, "💖 Query Foxy Learning Material ") | |
| self.ln(10) | |
| def chapter_title(self, title): | |
| try: | |
| self.set_font('DejaVuSans', 'B', 12) | |
| except RuntimeError: | |
| self.set_font('Arial', 'B', 12) | |
| self.set_fill_color(200, 220, 255) | |
| from fpdf.enums import XPos, YPos | |
| self.cell(0, 10, title, new_x=XPos.LMARGIN, new_y=YPos.NEXT, align='L', fill=True) | |
| self.ln(4) | |
| def chapter_body(self, body): | |
| try: | |
| self.set_font('DejaVuSans', '', 10) | |
| except RuntimeError: | |
| self.set_font('Arial', '', 10) | |
| safe_multicell(self, body) | |
| self.ln(6) | |
| # --- HTML Flashcard Generation --- | |
| def generate_flashcards_html(flashcards_text): | |
| """Generates an HTML string for interactive flashcards.""" | |
| cards_html = [] | |
| current_topic = "" | |
| current_question_text = None | |
| lines = flashcards_text.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| question_match = re.match(r'Q:\s*(.*)', line, re.IGNORECASE) | |
| answer_match = re.match(r'A:\s*(.*)', line, re.IGNORECASE) | |
| if line.startswith("=== ") and line.endswith(" ==="): | |
| current_topic = html.escape(line.replace("===", "").strip()) | |
| elif question_match: | |
| current_question_text = question_match.group(1).strip() | |
| elif answer_match and current_question_text is not None: | |
| escaped_question = html.escape(current_question_text) | |
| escaped_answer = html.escape(answer_match.group(1).strip()) | |
| cards_html.append(f""" | |
| <div class="flashcard-container"> | |
| <div class="flashcard" onclick="this.classList.toggle('flipped');"> | |
| <div class="flashcard-front"><p class="card-question">{escaped_question}</p></div> | |
| <div class="flashcard-back"><p class="card-answer">{escaped_answer}</p></div> | |
| </div> | |
| <div class="flashcard-topic">{current_topic}</div> | |
| </div> | |
| """) | |
| current_question_text = None | |
| return f""" | |
| <!DOCTYPE html> | |
| <html lang="en"> | |
| <head> | |
| <meta charset="UTF-8"> | |
| <title>Query Foxy Flashcards</title> | |
| <link href="https://fonts.googleapis.com/css2?family=Love+Ya+Like+A+Sister&display=swap" rel="stylesheet"> | |
| <style> | |
| body {{ font-family: 'Arial', sans-serif; background-color: #FFFFFF; display: flex; flex-direction: column; align-items: center; padding: 20px; }} | |
| h1 {{ color: #6c493b; font-family: 'Love Ya Like A Sister', cursive; }} | |
| .logo-container {{ text-align: center; margin-bottom: 20px; }} | |
| .logo-container img {{ max-width: 200px; height: auto; }} | |
| .flashcards-grid {{ display: flex; flex-wrap: wrap; gap: 20px; justify-content: center; }} | |
| .flashcard-container {{ perspective: 1000px; width: 300px; height: 200px; margin-bottom: 20px; }} | |
| .flashcard {{ width: 100%; height: 100%; position: absolute; transform-style: preserve-3d; transition: transform 0.6s; border-radius: 15px; box-shadow: 0 4px 8px rgba(0,0,0,0.2); cursor: pointer; }} | |
| .flashcard.flipped {{ transform: rotateY(180deg); }} | |
| .flashcard-front, .flashcard-back {{ position: absolute; width: 100%; height: 100%; backface-visibility: hidden; display: flex; justify-content: center; align-items: center; padding: 15px; box-sizing: border-box; box-shadow: 0 0 15px 5px rgba(108, 73, 59, 0.5) !important; border-radius: 20px; text-align: center; font-family: 'Love Ya Like A Sister', cursive; font-size: 1.2em; color: #ffffff; }} | |
| .flashcard-front {{ background-color: #b77a5a; }} | |
| .flashcard-back {{ background-color: #6c493b; transform: rotateY(180deg); }} | |
| .flashcard-topic {{ position: absolute; bottom: -25px; left: 0; right: 0; text-align: center; font-family: 'Love Ya Like A Sister', cursive; font-size: 0.9em; color: #6c493b; font-weight: bold; }} | |
| </style> | |
| </head> | |
| <body> | |
| <div class="logo-container"> | |
| <img src="https://github.com/MahekTrivedi44/logo/blob/main/logo.png?raw=true" alt="Query Foxy Logo"> | |
| </div> | |
| <div class="flashcards-grid">{''.join(cards_html)}</div> | |
| </body> | |
| </html> | |
| """ | |
| # --- FastAPI App Initialization and Mounting --- | |
| # Create a FastAPI app instance | |
| fastapi_app = FastAPI() | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from gradio import mount_gradio_app | |
| # REMOVED: fastapi_app.add_middleware(SessionMiddleware,...) # No longer using SessionMiddleware | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Adjust in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define the lifespan for the FastAPI app to initialize the database | |
| async def lifespan(app: FastAPI): | |
| print("FastAPI app starting up: Initializing database...") | |
| init_db() | |
| yield | |
| print("FastAPI app shutting down.") | |
| # Apply the lifespan to the FastAPI app | |
| fastapi_app.router.lifespan_context = lifespan | |
| from fastapi.responses import JSONResponse, FileResponse, RedirectResponse | |
| # --- FastAPI Routes (Now using fastapi_app instead of demo.app) --- | |
| async def signup(credentials: UserCredentials, db: sqlite3.Connection = Depends(get_db)): | |
| success, message = await run_in_threadpool(create_user, credentials.username, credentials.password, db) | |
| if not success: | |
| return JSONResponse(status_code=400, content={"success": False, "message": message}) | |
| return JSONResponse(content={"success": True, "message": "Signup successful! You can now log in."}) | |
| async def login(credentials: UserCredentials, db: sqlite3.Connection = Depends(get_db)): | |
| print("Received login request for user:", credentials.username) | |
| uid = await run_in_threadpool(verify_user, credentials.username, credentials.password, db) | |
| if uid: | |
| # Create an access token | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| access_token = create_access_token( | |
| data={"sub": str(credentials.username)}, expires_delta=access_token_expires | |
| ) | |
| return JSONResponse(content={"success": True, "message": "Login successful!", "access_token": access_token, "token_type": "bearer"}) | |
| raise HTTPException(status_code=401, detail="Invalid credentials.") | |
| async def logout(): | |
| # For JWT, logout is primarily client-side (discarding the token). | |
| # If you implement a JWT blacklist, you would add the token to it here. | |
| return JSONResponse(content={"success": True, "message": "Logged out successfully!"}) | |
| async def check_login_status(user_id: int = Depends(get_current_user_id)): | |
| # If get_current_user_id succeeds, it means the token is valid and user is logged in | |
| return JSONResponse(content={"logged_in": True}) | |
| async def new_conversation(user_id: int = Depends(get_current_user_id), db: sqlite3.Connection = Depends(get_db)): | |
| timestamp_str = datetime.now().strftime('%Y-%m-%d %H:%M') | |
| new_title = f"New Chat {timestamp_str}" | |
| def create(): | |
| cursor = db.execute("INSERT INTO conversations (user_id, title) VALUES (?, ?)", (user_id, new_title)) | |
| new_conv_id = cursor.lastrowid | |
| db.commit() | |
| print(f"Backend /new_conversation: Created new conversation with ID {new_conv_id} for user {user_id}") | |
| return new_conv_id | |
| conv_id = await run_in_threadpool(create) | |
| return JSONResponse(content={"success": True, "conversation_id": conv_id}) | |
| async def get_current_chat_history(user_id: int = Depends(get_current_user_id), db: sqlite3.Connection = Depends(get_db)): | |
| # This endpoint needs to know *which* conversation to load. | |
| # For simplicity, we'll return the history of the *most recent* conversation | |
| # for the user. A real app might pass conversation_id as a query param. | |
| def get_history_for_user(): | |
| most_recent_conv = db.execute( | |
| "SELECT id FROM conversations WHERE user_id = ? ORDER BY id DESC LIMIT 1", | |
| (user_id,) | |
| ).fetchone() | |
| conv_id = None | |
| if most_recent_conv: | |
| conv_id = most_recent_conv["id"] | |
| else: | |
| # If no conversations exist, create a default one | |
| conv_id = get_or_create_default_conversation(user_id, db) | |
| messages = db.execute( | |
| "SELECT message, response, timestamp FROM messages WHERE conversation_id = ? ORDER BY timestamp ASC", | |
| (conv_id,) | |
| ).fetchall() | |
| history = [[m["message"], m["response"]] for m in messages] | |
| return { | |
| "success": True, | |
| "history": history, | |
| "current_conversation_id": conv_id | |
| } | |
| result = await run_in_threadpool(get_history_for_user) | |
| return JSONResponse(content=result) | |
| async def summarize_chat(payload: HistoryPayload, request: Request, background_tasks: BackgroundTasks, user_id: int = Depends(get_current_user_id)): | |
| conversation_history = payload.history | |
| messages_for_groq = [] | |
| for h_msg in conversation_history: | |
| messages_for_groq.append({"role": "user", "content": h_msg["message"]}) | |
| messages_for_groq.append({"role": "assistant", "content": h_msg["response"]}) | |
| summarize_prompt = ( | |
| "You are an academic tutor and curriculum writer tasked with generating a detailed, structured learning report from the following conversation. " | |
| "Your objective is to extract all educational content, group it by topic, and provide an in-depth explanation of each topic as if teaching it to a student. " | |
| "Do not summarize the conversation or reference specific dialogue. Instead, reconstruct the content into a clear, well-organized report that fully explains each subject discussed. " | |
| "Include additional context, definitions, and examples where needed. Fill in any gaps where a concept was mentioned but not thoroughly explained. " | |
| "If practical examples, case studies, **code**, logic, syntax, functions, methods, pseudocode, scenarios, or analogies were discussed in the conversation, include them in the relevant sections. " | |
| "If such examples were not provided, **GENERATE appropriate examples**, illustrations, or simplified explanations to help reinforce understanding. These can be from real-world situations, sample problems, or thought experiments. " | |
| "Where helpful, include memory techniques, mnemonics, diagrams (as descriptions), or analogies to enhance understanding and retention.\n\n" | |
| "For formatting: " | |
| "Use plain text only, EXCEPT for the subheadings (Explanation, Examples / Applications, Tips / Mnemonics) which MUST be bolded as shown in the structure below. Do not use other markdown like asterisks (*), backticks (`), or other symbols for emphasis. " | |
| "For lists, use numbered bullets like '1.', '2.', '3.' instead of asterisks or dashes. " | |
| "Ignore small talk, greetings, or tool usage unless directly relevant to the learning content.\n\n" | |
| "Important: Structure the report, exactly as below, and ensure EVERY topic (including any introductory sections) contains ALL three subsections. If content is not directly available from the conversation for 'Examples / Applications' or 'Tips / Mnemonics', you MUST generate relevant content for those sections:\n\n" | |
| "=== [Topic Title] ===\n" | |
| "**Explanation:**\nFull teaching-style explanation here.\n\n" | |
| "**Examples / Applications:**\nReal-world or code examples (if relevant). If no direct examples from the conversation, generate new ones.\n\n" | |
| "**Tips / Mnemonics:**\nUseful memory aids or tricks. If no direct tips/mnemonics from the conversation, generate new ones.\n\n" | |
| "Conversation:\n" + | |
| "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_groq]) | |
| ) | |
| messages_for_groq.append({"role": "user", "content": summarize_prompt}) | |
| summary_text = await ask_groq(messages_for_groq) | |
| summary_text = re.sub(r'\*\*(.*?)\*\*', r'\1', summary_text) | |
| summary_text = re.sub(r'\_(.*?)\_', r'\1', summary_text) | |
| summary_text = re.sub(r'\`(.*?)\`', r'\1', summary_text) | |
| summary_text = re.sub(r'^\s+', '', summary_text, flags=re.MULTILINE) | |
| print(f"Generated summary:\n{summary_text}") | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp: | |
| pdf = CustomPDF() | |
| pdf.add_page() | |
| section_pattern = re.compile(r"(Explanation|Examples / Applications|Tips / Mnemonics)[::]?\s*(.*)", re.IGNORECASE) | |
| last_section = None | |
| seen_lines = set() | |
| for line in summary_text.split('\n'): | |
| line = line.strip() | |
| if not line or line in seen_lines: | |
| continue | |
| seen_lines.add(line) | |
| if line.startswith("=== ") and line.endswith(" ==="): | |
| pdf.ensure_space(20) | |
| pdf.chapter_title(line.replace("===", "").strip()) | |
| pdf.ln(4) | |
| continue | |
| match = section_pattern.match(line) | |
| if match: | |
| label = match.group(1).strip() | |
| content = match.group(2).strip() | |
| if last_section == label: | |
| continue | |
| last_section = label | |
| pdf.ensure_space(15) | |
| pdf.set_font('', 'B') | |
| safe_multicell(pdf, label + ":") | |
| pdf.set_font('', '') | |
| if content: | |
| safe_multicell(pdf, content) | |
| pdf.ln(3) | |
| continue | |
| line = re.sub(r'^\*+\s*', '• ', line) | |
| line = re.sub(r'\s{2,}', ' ', line) | |
| pdf.set_font('', '') | |
| pdf.ensure_space(10) | |
| safe_multicell(pdf, line) | |
| pdf.ln(2) | |
| pdf.output(temp.name) | |
| file_path = temp.name | |
| background_tasks.add_task(delete_file_later, file_path) | |
| return JSONResponse(content={"success": True, "file_path": file_path}) | |
| except Exception as e: | |
| print(f"Error creating summary file: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error creating summary file: {e}") | |
| async def generate_flashcards(payload: FlashcardPayload, request: Request, background_tasks: BackgroundTasks, user_id: int = Depends(get_current_user_id)): | |
| conversation_history = payload.history | |
| file_format = payload.format | |
| messages_for_groq = [] | |
| for h_msg in conversation_history: | |
| messages_for_groq.append({"role": "user", "content": h_msg["message"]}) | |
| messages_for_groq.append({"role": "assistant", "content": h_msg["response"]}) | |
| flashcard_prompt = ( | |
| "You are an instructional designer and subject matter expert. Your task is to generate high-quality educational flashcards from the following conversation. " | |
| "Ignore greetings, social chat, and tool-related comments. Focus solely on extracting learning content from the conversation, even if it spans multiple topics. " | |
| "Group flashcards by topic, and ensure each card tests important concepts, definitions, processes, or problem-solving methods discussed. " | |
| "Where relevant, include flashcards for concepts that were only briefly mentioned or implied but are necessary for complete understanding. " | |
| "\n\nFlashcards must include a **mix** of question types depending on the subject and content:\n" | |
| "- Conceptual: definitions, distinctions, 'what' and 'why'\n" | |
| "- Applied: case studies, real-world examples, diagnosis-based, analysis questions\n" | |
| "- Practical: code snippets, pseudo-scenarios, data interpretation, step-by-step problems\n" | |
| "- Process-oriented: questions about sequences, protocols, workflows\n" | |
| "- Mnemonics & memory hacks: where helpful, embed memory aids or analogies\n" | |
| "Use simple yet precise language for both questions and answers. Provide mnemonics, analogies, or real-world examples where they can enhance understanding or retention. " | |
| "Use plain text only. Do not use markdown (e.g., no **bold**, *, or backticks). " | |
| "For lists, use numbered bullets like '1.', '2.', '3.' instead of asterisks or dashes. " | |
| "Do not reference specific user messages — focus on converting the knowledge into effective active recall material.\n\n" | |
| "Format the output as follows:\n\n" | |
| "=== [Topic Name] ===\n" | |
| "Q: ...\n" | |
| "A: ... [Answer in no more than 190 characters total. If the full explanation is longer, split it into multiple Q&A pairs to keep each answer within the limit.]\n\n" | |
| "Conversation:\n" + | |
| "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages_for_groq]) | |
| ) | |
| messages_for_groq.append({"role": "user", "content": flashcard_prompt}) | |
| flashcards_text = await ask_groq(messages_for_groq) | |
| file_path = None | |
| try: | |
| if file_format == "pdf": | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp: | |
| pdf = CustomPDF() | |
| pdf.add_page() | |
| for line in flashcards_text.split('\n'): | |
| if line.startswith("=== ") and line.endswith(" ==="): | |
| pdf.chapter_title(line.replace("===", "").strip()) | |
| elif line.startswith("Q:") or line.startswith("A:"): | |
| pdf.set_font('', 'B' if line.startswith("Q:") else '') | |
| safe_multicell(pdf, line) | |
| pdf.ln(2) | |
| pdf.output(temp.name) | |
| file_path = temp.name | |
| elif file_format.lower() in ["html", "html (interactive)"]: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".html", mode="w", encoding="utf-8") as temp: | |
| html_content = generate_flashcards_html(flashcards_text) | |
| temp.write(html_content) | |
| file_path = temp.name | |
| if file_path: | |
| background_tasks.add_task(delete_file_later, file_path) | |
| return JSONResponse(content={"success": True, "file_path": file_path}) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file format.") | |
| except Exception as e: | |
| print(f"Error creating flashcard file: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error creating flashcard file: {e}") | |
| async def serve_file(filename: str): | |
| file_path = os.path.join(tempfile.gettempdir(), filename) | |
| if os.path.exists(file_path): | |
| media_type = "application/octet-stream" | |
| if filename.endswith(".pdf"): | |
| media_type = "application/pdf" | |
| elif filename.endswith(".html"): | |
| media_type = "text/html" | |
| return FileResponse( | |
| path=file_path, | |
| filename=filename, | |
| media_type=media_type | |
| ) | |
| raise HTTPException(status_code=404, detail="File not found") | |
| async def get_conversations(user_id: int = Depends(get_current_user_id), db: sqlite3.Connection = Depends(get_db)): | |
| def fetch_conversations(): | |
| convs = db.execute( | |
| """ | |
| SELECT c.id, c.title, (SELECT message FROM messages WHERE conversation_id = c.id ORDER BY timestamp ASC LIMIT 1) as preview | |
| FROM conversations c | |
| WHERE c.user_id = ? AND (SELECT COUNT(*) FROM messages WHERE conversation_id = c.id) > 0 | |
| ORDER BY (SELECT MAX(timestamp) FROM messages WHERE conversation_id = c.id) DESC | |
| """, (user_id,) | |
| ).fetchall() | |
| conv_list = [] | |
| for c in convs: | |
| display_title = c["title"] | |
| if c["preview"] and c["preview"] != display_title: | |
| display_title = f"{display_title} - {c['preview'][:30]}..." if len(c['preview']) > 30 else f"{display_title} - {c['preview']}" | |
| conv_list.append({"id": c["id"], "title": display_title}) | |
| return {"success": True, "conversations": conv_list} | |
| return JSONResponse(content=await run_in_threadpool(fetch_conversations)) | |
| async def load_conversation(conversation_id: int, user_id: int = Depends(get_current_user_id), db: sqlite3.Connection = Depends(get_db)): | |
| def load(): | |
| conv = db.execute("SELECT id FROM conversations WHERE id = ? AND user_id = ?", | |
| (conversation_id, user_id)).fetchone() | |
| if not conv: | |
| raise HTTPException(status_code=404, detail="Conversation not found.") | |
| messages = db.execute("SELECT message, response FROM messages WHERE conversation_id = ? ORDER BY timestamp ASC", | |
| (conversation_id,)).fetchall() | |
| history = [[m["message"], m["response"]] for m in messages] | |
| return {"success": True, "history": history, "conversation_id": conversation_id} | |
| return JSONResponse(content=await run_in_threadpool(load)) | |
| async def chat(chat_message: ChatMessage, user_id: int = Depends(get_current_user_id), db: sqlite3.Connection = Depends(get_db)): | |
| user_msg = chat_message.message | |
| if not user_msg: | |
| raise HTTPException(status_code=400, detail="Empty message.") | |
| # For chat, we need to know the current conversation. | |
| # A robust solution would pass current_conversation_id from frontend in payload. | |
| # For this conversion, we'll assume the most recent conversation is active. | |
| def get_current_conversation_for_chat(): | |
| conv_row = db.execute( | |
| "SELECT id FROM conversations WHERE user_id = ? ORDER BY id DESC LIMIT 1", | |
| (user_id,) | |
| ).fetchone() | |
| conv_id = None | |
| if conv_row: | |
| conv_id = conv_row["id"] | |
| else: | |
| # If no conversations, create one | |
| conv_id = get_or_create_default_conversation(user_id, db) | |
| return conv_id | |
| conv_id = await run_in_threadpool(get_current_conversation_for_chat) | |
| def chat_logic_db_access(): | |
| historical_msgs = db.execute("SELECT message, response FROM messages WHERE conversation_id = ? ORDER BY timestamp ASC", | |
| (conv_id,)).fetchall() | |
| messages_for_groq = [] | |
| for h_msg in historical_msgs: | |
| messages_for_groq.append({"role": "user", "content": h_msg["message"]}) | |
| messages_for_groq.append({"role": "assistant", "content": h_msg["response"]}) | |
| messages_for_groq.append({"role": "user", "content": user_msg}) | |
| return messages_for_groq | |
| messages_for_groq = await run_in_threadpool(chat_logic_db_access) | |
| reply = await ask_groq(messages_for_groq) | |
| def save(): | |
| db.execute("INSERT INTO messages (conversation_id, user_id, message, response) VALUES (?, ?, ?, ?)", | |
| (conv_id, user_id, user_msg, reply)) | |
| db.commit() | |
| await run_in_threadpool(save) | |
| return JSONResponse(content={"success": True, "reply": reply}) | |
| # =================================================================== | |
| # 3. LAUNCH THE APP | |
| # =================================================================== | |
| from gradio import mount_gradio_app | |
| mount_gradio_app(fastapi_app, demo, path="/") | |
| if __name__ == "__main__": | |
| host = "0.0.0.0" | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"Gradio app is running at: http://{host}:{port}/") # Changed to root path | |
| uvicorn.run(fastapi_app, host=host, port=port, reload=False) | |