Spaces:
Runtime error
Runtime error
| package main | |
| import ( | |
| "context" | |
| "fmt" | |
| "io" | |
| "log" | |
| "net/http" | |
| "net/url" | |
| "os" | |
| "path/filepath" | |
| "strings" | |
| got "github.com/joho/godotenv" | |
| "bytes" | |
| "encoding/base64" | |
| "encoding/json" | |
| "unicode/utf8" | |
| "github.com/gin-gonic/gin" | |
| "github.com/jackc/pgx/v5" | |
| "github.com/ledongthuc/pdf" | |
| "github.com/ollama/ollama/api" | |
| "github.com/pgvector/pgvector-go" | |
| ) | |
| // @title RAGTAG API | |
| // @version 1.0 | |
| // @description This is the API for the RAGTAG system. | |
| // @termsOfService http://swagger.io/terms/ | |
| // @contact.name James Campbell | |
| // @contact.email james@enigmalabs.com | |
| // @license.name Apache 2.0 | |
| // @license.url http://www.apache.org/licenses/LICENSE-2.0.html | |
| // @host localhost:8080 | |
| // @BasePath / | |
| type Session struct { | |
| Messages []api.Message | |
| TitleFilter string | |
| } | |
| var sessions = make(map[string]*Session) | |
| func generateEmbedding(input string) ([]float32, error) { | |
| ollamaHost := os.Getenv("OLLAMA_HOST") | |
| if ollamaHost == "" { | |
| ollamaHost = "localhost" // fallback to localhost if not set | |
| } | |
| ollamaURL, err := url.Parse(fmt.Sprintf("http://%s:11434", ollamaHost)) | |
| if err != nil { | |
| return nil, err | |
| } | |
| client := api.NewClient(ollamaURL, http.DefaultClient) | |
| // Create an embedding request | |
| req := &api.EmbedRequest{ | |
| Model: "llama3.1", // Ensure this is an embedding-capable model | |
| Input: input, | |
| } | |
| // Call the Embed function | |
| resp, err := client.Embed(context.Background(), req) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return resp.Embeddings[0], nil | |
| } | |
| func insertItem(conn *pgx.Conn, title string, docText string, embedding []float32) error { | |
| // Combine title and docText for embedding | |
| combinedText := title + " " + docText | |
| _, err := conn.Exec(context.Background(), | |
| "INSERT INTO items (title, doc, embedding) VALUES ($1, $2, $3)", | |
| title, combinedText, pgvector.NewVector(embedding)) | |
| return err | |
| } | |
| func queryEmbeddings(conn *pgx.Conn, query string, session *Session, c *gin.Context) error { | |
| // Generate embedding for the query | |
| queryEmbedding, err := generateEmbedding(query) | |
| if err != nil { | |
| return err | |
| } | |
| // Prepare the SQL query | |
| sqlQuery := "SELECT doc, COALESCE(title, 'Untitled') FROM items" | |
| if session.TitleFilter != "" { | |
| sqlQuery += fmt.Sprintf(" WHERE title LIKE '%%%s%%'", session.TitleFilter) | |
| } | |
| sqlQuery += " ORDER BY embedding <-> $1 LIMIT 5" | |
| // Query the database for similar documents | |
| rows, err := conn.Query(context.Background(), sqlQuery, pgvector.NewVector(queryEmbedding)) | |
| if err != nil { | |
| return err | |
| } | |
| defer rows.Close() | |
| var docs []string | |
| var sources []string | |
| for rows.Next() { | |
| var doc, title string | |
| if err := rows.Scan(&doc, &title); err != nil { | |
| return err | |
| } | |
| docs = append(docs, doc) | |
| sources = append(sources, fmt.Sprintf("Source: %s", title)) | |
| } | |
| // Combine the retrieved documents | |
| contextText := strings.Join(docs, "\n\n") | |
| // Create a chat request | |
| ollamaHost := os.Getenv("OLLAMA_HOST") | |
| if ollamaHost == "" { | |
| ollamaHost = "localhost" // fallback to localhost if not set | |
| } | |
| ollamaURL, err := url.Parse(fmt.Sprintf("http://%s:11434", ollamaHost)) | |
| if err != nil { | |
| return err | |
| } | |
| client := api.NewClient(ollamaURL, http.DefaultClient) | |
| // Add the new query to the session | |
| session.Messages = append(session.Messages, api.Message{Role: "user", Content: query}) | |
| // Prepare the messages for the chat request | |
| messages := []api.Message{ | |
| {Role: "system", Content: "You are an assistant that answers questions based on the given context."}, | |
| {Role: "user", Content: "Here's the context:\n" + contextText}, | |
| } | |
| messages = append(messages, session.Messages...) | |
| req := &api.ChatRequest{ | |
| Model: "llama3.1", | |
| Messages: messages, | |
| Stream: new(bool), // Use new(bool) to create a pointer to a boolean | |
| } | |
| *req.Stream = true // Set the value to true | |
| // Call the Chat function with streaming | |
| err = client.Chat(context.Background(), req, func(resp api.ChatResponse) error { | |
| // Send the raw content without any modifications | |
| if resp.Message.Content != "" { | |
| c.SSEvent("message", resp.Message.Content) | |
| c.Writer.Flush() // Ensure the content is sent immediately | |
| } | |
| return nil | |
| }) | |
| if err != nil { | |
| return err | |
| } | |
| // Add the AI response to the session | |
| session.Messages = append(session.Messages, api.Message{Role: "assistant", Content: "Response sent via streaming"}) | |
| return nil | |
| } | |
| func getDocuments(conn *pgx.Conn) ([]map[string]interface{}, error) { | |
| rows, err := conn.Query(context.Background(), "SELECT DISTINCT ON (SPLIT_PART(title, '_chunk_', 1)) SPLIT_PART(title, '_chunk_', 1) as title, COUNT(*) as count FROM items GROUP BY SPLIT_PART(title, '_chunk_', 1)") | |
| if err != nil { | |
| return nil, err | |
| } | |
| defer rows.Close() | |
| var documents []map[string]interface{} | |
| for rows.Next() { | |
| var title string | |
| var count int | |
| if err := rows.Scan(&title, &count); err != nil { | |
| return nil, err | |
| } | |
| documents = append(documents, map[string]interface{}{ | |
| "title": title, | |
| "count": count, | |
| }) | |
| } | |
| return documents, nil | |
| } | |
| func deleteDocument(conn *pgx.Conn, title string) error { | |
| _, err := conn.Exec(context.Background(), "DELETE FROM items WHERE title LIKE $1 || '%'", title) | |
| return err | |
| } | |
| func uploadDocument(c *gin.Context, conn *pgx.Conn) { | |
| title := c.PostForm("title") | |
| file, header, err := c.Request.FormFile("file") | |
| if err != nil { | |
| log.Printf("Error getting file: %v", err) | |
| c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| defer file.Close() | |
| // Create uploads directory if it doesn't exist | |
| uploadsDir := "uploads" | |
| if err := os.MkdirAll(uploadsDir, 0755); err != nil { | |
| log.Printf("Error creating uploads directory: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create uploads directory"}) | |
| return | |
| } | |
| filename := filepath.Join(uploadsDir, header.Filename) | |
| out, err := os.Create(filename) | |
| if err != nil { | |
| log.Printf("Error creating file: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| defer out.Close() | |
| _, err = io.Copy(out, file) | |
| if err != nil { | |
| log.Printf("Error copying file: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| // Debug: log file size and first 16 bytes | |
| stat, statErr := os.Stat(filename) | |
| if statErr == nil { | |
| log.Printf("Uploaded file size: %d bytes", stat.Size()) | |
| fcheck, ferr := os.Open(filename) | |
| if ferr == nil { | |
| buf := make([]byte, 16) | |
| n, _ := fcheck.Read(buf) | |
| log.Printf("First 16 bytes: % x", buf[:n]) | |
| fcheck.Close() | |
| } | |
| } | |
| var textContent string | |
| ext := strings.ToLower(filepath.Ext(filename)) | |
| if ext == ".jpg" || ext == ".jpeg" || ext == ".png" { | |
| // Generate image summary using the llava model | |
| summary, err := generateImageSummary(filename) | |
| if err != nil { | |
| log.Printf("Error generating image summary: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| textContent = summary | |
| } else if ext == ".pdf" { | |
| // Check PDF signature before parsing | |
| f, rErr := os.Open(filename) | |
| if rErr != nil { | |
| log.Printf("Error opening PDF: %v", rErr) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": rErr.Error()}) | |
| return | |
| } | |
| defer f.Close() | |
| buf := make([]byte, 5) | |
| _, err := f.Read(buf) | |
| if err != nil || string(buf) != "%PDF-" { | |
| log.Printf("Uploaded file is not a valid PDF (missing %PDF- header): %s", filename) | |
| c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF (missing %PDF- header)"}) | |
| return | |
| } | |
| stat, _ := f.Stat() | |
| // Loosen EOF check: search last 1KB for %%EOF | |
| eofCheckSize := int64(1024) | |
| if stat.Size() < eofCheckSize { | |
| eofCheckSize = stat.Size() | |
| } | |
| endBuf := make([]byte, eofCheckSize) | |
| _, err = f.ReadAt(endBuf, stat.Size()-eofCheckSize) | |
| if err != nil || !strings.Contains(string(endBuf), "%%EOF") { | |
| log.Printf("Uploaded file is not a valid PDF (missing %%EOF): %s", filename) | |
| c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF (missing %%EOF)"}) | |
| return | |
| } | |
| // Reset file pointer for pdf.NewReader | |
| f.Seek(0, 0) | |
| reader, pdfErr := pdf.NewReader(f, stat.Size()) | |
| if pdfErr != nil { | |
| log.Printf("Error reading PDF: %v", pdfErr) | |
| c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded file is not a valid PDF or is corrupted"}) | |
| return | |
| } | |
| var sb strings.Builder | |
| for i := 1; i <= reader.NumPage(); i++ { | |
| page := reader.Page(i) | |
| if page.V.IsNull() { | |
| continue | |
| } | |
| content, err := page.GetPlainText(nil) | |
| if err != nil { | |
| log.Printf("Error extracting text from page %d: %v", i, err) | |
| continue | |
| } | |
| sb.WriteString(content) | |
| } | |
| textContent = sb.String() | |
| log.Printf("Extracted text length: %d", len(textContent)) | |
| } else { | |
| content, err := os.ReadFile(filename) | |
| if err != nil { | |
| log.Printf("Error reading file: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| textContent = string(content) | |
| } | |
| // Remove null bytes (Postgres TEXT cannot contain 0x00) | |
| textContent = strings.ReplaceAll(textContent, "\x00", "") | |
| // Validate UTF-8 | |
| if !utf8.ValidString(textContent) { | |
| log.Printf("Invalid UTF-8 detected in document: %s", filename) | |
| c.JSON(http.StatusBadRequest, gin.H{"error": "Uploaded document is not valid UTF-8"}) | |
| return | |
| } | |
| // Generate embedding for the text content using llama3.1 | |
| embedding, err := generateEmbedding(textContent) | |
| if err != nil { | |
| log.Printf("Error generating embedding: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| // Insert the document into the database | |
| err = insertItem(conn, title, textContent, embedding) | |
| if err != nil { | |
| log.Printf("Error inserting item: %v", err) | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{"message": "Document uploaded and processed successfully"}) | |
| } | |
| func chunkText(text string, chunkSize int) []string { | |
| words := strings.Fields(text) | |
| var chunks []string | |
| for i := 0; i < len(words); i += chunkSize { | |
| end := i + chunkSize | |
| if end > len(words) { | |
| end = len(words) | |
| } | |
| chunks = append(chunks, strings.Join(words[i:end], " ")) | |
| } | |
| return chunks | |
| } | |
| func generateImageSummary(imagePath string) (string, error) { | |
| imageData, err := os.ReadFile(imagePath) | |
| if err != nil { | |
| return "", fmt.Errorf("failed to read image file: %w", err) | |
| } | |
| base64Image := base64.StdEncoding.EncodeToString(imageData) | |
| payload := map[string]interface{}{ | |
| "model": "llava", | |
| "prompt": "Describe this image in detail:", | |
| "images": []string{base64Image}, | |
| "stream": true, | |
| } | |
| jsonPayload, err := json.Marshal(payload) | |
| if err != nil { | |
| return "", fmt.Errorf("failed to marshal JSON payload: %w", err) | |
| } | |
| ollamaHost := os.Getenv("OLLAMA_HOST") | |
| if ollamaHost == "" { | |
| ollamaHost = "localhost" | |
| } | |
| url := fmt.Sprintf("http://%s:11434/api/generate", ollamaHost) | |
| resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) | |
| if err != nil { | |
| return "", fmt.Errorf("failed to send POST request: %w", err) | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| body, _ := io.ReadAll(resp.Body) | |
| return "", fmt.Errorf("unexpected response status: %d, body: %s", resp.StatusCode, string(body)) | |
| } | |
| var summary strings.Builder | |
| decoder := json.NewDecoder(resp.Body) | |
| for { | |
| var result struct { | |
| Response string `json:"response"` | |
| Done bool `json:"done"` | |
| } | |
| if err := decoder.Decode(&result); err != nil { | |
| if err == io.EOF { | |
| break | |
| } | |
| return "", fmt.Errorf("failed to decode JSON response: %w", err) | |
| } | |
| summary.WriteString(result.Response) | |
| if result.Done { | |
| break | |
| } | |
| } | |
| if summary.Len() == 0 { | |
| return "", fmt.Errorf("empty response from llava model") | |
| } | |
| fmt.Println("The summary of the image is: ", summary.String()) | |
| return summary.String(), nil | |
| } | |
| func main() { | |
| // Set up the database connection | |
| // load env variables | |
| got.Load() | |
| conn, err := pgx.Connect(context.Background(), os.Getenv("DB_URL")) | |
| if err != nil { | |
| log.Fatal("Unable to connect to database:", err) | |
| } | |
| defer conn.Close(context.Background()) | |
| // Set up the Gin router | |
| r := gin.Default() | |
| // Define the /add_document endpoint | |
| r.POST("/add_document", func(c *gin.Context) { | |
| var request struct { | |
| Title string `json:"title"` | |
| DocText string `json:"doc_text"` | |
| } | |
| if err := c.BindJSON(&request); err != nil { | |
| c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| // Generate the embedding | |
| embedding, err := generateEmbedding(request.DocText) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| // Insert the document and its embedding into the items table | |
| err = insertItem(conn, request.Title, request.DocText, embedding) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{"message": "Document chunk embedded and stored successfully!"}) | |
| }) | |
| // Add the new /query endpoint | |
| r.POST("/query", func(c *gin.Context) { | |
| var request struct { | |
| Query string `json:"query"` | |
| SessionID string `json:"sessionId"` | |
| } | |
| if err := c.BindJSON(&request); err != nil { | |
| c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| session, ok := sessions[request.SessionID] | |
| if !ok { | |
| session = &Session{ | |
| Messages: []api.Message{ | |
| {Role: "system", Content: "You are an assistant that answers questions based on the given context."}, | |
| }, | |
| TitleFilter: "", | |
| } | |
| sessions[request.SessionID] = session | |
| } | |
| // Check for @title in the query | |
| if strings.Contains(request.Query, "@") { | |
| parts := strings.Split(request.Query, "@") | |
| if len(parts) > 1 { | |
| session.TitleFilter = strings.Split(parts[1], " ")[0] | |
| request.Query = strings.Replace(request.Query, "@"+session.TitleFilter, "", 1) | |
| } | |
| } | |
| c.Header("Content-Type", "text/event-stream") | |
| c.Header("Cache-Control", "no-cache") | |
| c.Header("Connection", "keep-alive") | |
| c.Header("Access-Control-Allow-Origin", "*") | |
| c.Header("Access-Control-Allow-Credentials", "true") | |
| c.Header("Access-Control-Allow-Headers", "Content-Type") | |
| c.Header("Access-Control-Allow-Methods", "POST") | |
| c.Header("encoding", "chunked") | |
| err := queryEmbeddings(conn, request.Query, session, c) | |
| if err != nil { | |
| c.SSEvent("error", err.Error()) | |
| } | |
| c.SSEvent("done", "") | |
| }) | |
| // Serve the index.html file | |
| r.GET("/", func(c *gin.Context) { | |
| c.File("index.html") | |
| }) | |
| // Serve the docmanager.html file | |
| r.GET("/docmanager", func(c *gin.Context) { | |
| c.File("docmanager.html") | |
| }) | |
| // Add a new endpoint to fetch documents | |
| r.GET("/documents", func(c *gin.Context) { | |
| documents, err := getDocuments(conn) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| c.JSON(http.StatusOK, documents) | |
| }) | |
| // Add a new endpoint to delete documents | |
| r.POST("/delete_document", func(c *gin.Context) { | |
| var request struct { | |
| Title string `json:"title"` | |
| } | |
| if err := c.BindJSON(&request); err != nil { | |
| c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| err := deleteDocument(conn, request.Title) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) | |
| }) | |
| // Add a new endpoint to upload documents | |
| r.POST("/upload_document", func(c *gin.Context) { | |
| uploadDocument(c, conn) | |
| }) | |
| // Add a new endpoint to clear the chat session | |
| r.POST("/clear_session", func(c *gin.Context) { | |
| var request struct { | |
| SessionID string `json:"sessionId"` | |
| } | |
| if err := c.BindJSON(&request); err != nil { | |
| c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| delete(sessions, request.SessionID) | |
| c.JSON(http.StatusOK, gin.H{"message": "Chat session cleared successfully"}) | |
| }) | |
| // Add a new endpoint to check if Twitter data exists | |
| r.GET("/check_data", func(c *gin.Context) { | |
| rows, err := conn.Query(context.Background(), "SELECT DISTINCT title FROM items WHERE title LIKE '%twitter%'") | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| defer rows.Close() | |
| var titles []string | |
| for rows.Next() { | |
| var title string | |
| if err := rows.Scan(&title); err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| titles = append(titles, title) | |
| } | |
| c.JSON(http.StatusOK, gin.H{"twitter_titles": titles}) | |
| }) | |
| // Serve the describer.html file | |
| r.GET("/describer", func(c *gin.Context) { | |
| c.File("describer.html") | |
| }) | |
| // Handle image description | |
| r.POST("/describe_image", func(c *gin.Context) { | |
| file, _, err := c.Request.FormFile("file") | |
| if err != nil { | |
| c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| defer file.Close() | |
| // Create a temporary file to store the uploaded image | |
| tempFile, err := os.CreateTemp("", "uploaded-*.jpg") | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| defer os.Remove(tempFile.Name()) | |
| defer tempFile.Close() | |
| // Copy the uploaded file to the temporary file | |
| _, err = io.Copy(tempFile, file) | |
| if err != nil { | |
| c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | |
| return | |
| } | |
| c.Header("Content-Type", "text/event-stream") | |
| c.Header("Cache-Control", "no-cache") | |
| c.Header("Connection", "keep-alive") | |
| c.Header("Access-Control-Allow-Origin", "*") | |
| c.Header("Access-Control-Allow-Credentials", "true") | |
| c.Header("Access-Control-Allow-Headers", "Content-Type") | |
| c.Header("Access-Control-Allow-Methods", "POST") | |
| c.Header("encoding", "chunked") | |
| imageData, err := os.ReadFile(tempFile.Name()) | |
| if err != nil { | |
| c.SSEvent("error", err.Error()) | |
| return | |
| } | |
| base64Image := base64.StdEncoding.EncodeToString(imageData) | |
| payload := map[string]interface{}{ | |
| "model": "llava", | |
| "prompt": "Describe this image in detail:", | |
| "images": []string{base64Image}, | |
| "stream": true, | |
| } | |
| jsonPayload, err := json.Marshal(payload) | |
| if err != nil { | |
| c.SSEvent("error", err.Error()) | |
| return | |
| } | |
| ollamaHost := os.Getenv("OLLAMA_HOST") | |
| if ollamaHost == "" { | |
| ollamaHost = "localhost" | |
| } | |
| url := fmt.Sprintf("http://%s:11434/api/generate", ollamaHost) | |
| resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonPayload)) | |
| if err != nil { | |
| c.SSEvent("error", err.Error()) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| body, _ := io.ReadAll(resp.Body) | |
| c.SSEvent("error", fmt.Sprintf("Unexpected response status: %d, body: %s", resp.StatusCode, string(body))) | |
| return | |
| } | |
| decoder := json.NewDecoder(resp.Body) | |
| for { | |
| var result struct { | |
| Response string `json:"response"` | |
| Done bool `json:"done"` | |
| } | |
| if err := decoder.Decode(&result); err != nil { | |
| if err == io.EOF { | |
| break | |
| } | |
| c.SSEvent("error", err.Error()) | |
| return | |
| } | |
| if result.Response != "" { | |
| c.SSEvent("message", result.Response) | |
| c.Writer.Flush() // Ensure the content is sent immediately | |
| } | |
| if result.Done { | |
| break | |
| } | |
| } | |
| c.SSEvent("done", "") | |
| }) | |
| // Run the Gin server | |
| r.Run(":8080") | |
| } | |