new-api / service /file_decoder.go
liuzhao521
Deploy New API v0.9.25+ (commit b47cf4ef) to HuggingFace Spaces
4674012
package service
import (
"bytes"
"encoding/base64"
"fmt"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strings"
"github.com/QuantumNous/new-api/common"
"github.com/QuantumNous/new-api/constant"
"github.com/QuantumNous/new-api/logger"
"github.com/QuantumNous/new-api/types"
"github.com/gin-gonic/gin"
)
// GetFileTypeFromUrl 获取文件类型,返回 mime type, 例如 image/jpeg, image/png, image/gif, image/bmp, image/tiff, application/pdf
// 如果获取失败,返回 application/octet-stream
func GetFileTypeFromUrl(c *gin.Context, url string, reason ...string) (string, error) {
response, err := DoDownloadRequest(url, []string{"get_mime_type", strings.Join(reason, ", ")}...)
if err != nil {
common.SysLog(fmt.Sprintf("fail to get file type from url: %s, error: %s", url, err.Error()))
return "", err
}
defer response.Body.Close()
if response.StatusCode != 200 {
logger.LogError(c, fmt.Sprintf("failed to download file from %s, status code: %d", url, response.StatusCode))
return "", fmt.Errorf("failed to download file, status code: %d", response.StatusCode)
}
if headerType := strings.TrimSpace(response.Header.Get("Content-Type")); headerType != "" {
if i := strings.Index(headerType, ";"); i != -1 {
headerType = headerType[:i]
}
if headerType != "application/octet-stream" {
return headerType, nil
}
}
if cd := response.Header.Get("Content-Disposition"); cd != "" {
parts := strings.Split(cd, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(strings.ToLower(part), "filename=") {
name := strings.TrimSpace(strings.TrimPrefix(part, "filename="))
if len(name) > 2 && name[0] == '"' && name[len(name)-1] == '"' {
name = name[1 : len(name)-1]
}
if dot := strings.LastIndex(name, "."); dot != -1 && dot+1 < len(name) {
ext := strings.ToLower(name[dot+1:])
if ext != "" {
mt := GetMimeTypeByExtension(ext)
if mt != "application/octet-stream" {
return mt, nil
}
}
}
break
}
}
}
cleanedURL := url
if q := strings.Index(cleanedURL, "?"); q != -1 {
cleanedURL = cleanedURL[:q]
}
if slash := strings.LastIndex(cleanedURL, "/"); slash != -1 && slash+1 < len(cleanedURL) {
last := cleanedURL[slash+1:]
if dot := strings.LastIndex(last, "."); dot != -1 && dot+1 < len(last) {
ext := strings.ToLower(last[dot+1:])
if ext != "" {
mt := GetMimeTypeByExtension(ext)
if mt != "application/octet-stream" {
return mt, nil
}
}
}
}
var readData []byte
limits := []int{512, 8 * 1024, 24 * 1024, 64 * 1024}
for _, limit := range limits {
logger.LogDebug(c, fmt.Sprintf("Trying to read %d bytes to determine file type", limit))
if len(readData) < limit {
need := limit - len(readData)
tmp := make([]byte, need)
n, _ := io.ReadFull(response.Body, tmp)
if n > 0 {
readData = append(readData, tmp[:n]...)
}
}
if len(readData) == 0 {
continue
}
sniffed := http.DetectContentType(readData)
if sniffed != "" && sniffed != "application/octet-stream" {
return sniffed, nil
}
if _, format, err := image.DecodeConfig(bytes.NewReader(readData)); err == nil {
switch strings.ToLower(format) {
case "jpeg", "jpg":
return "image/jpeg", nil
case "png":
return "image/png", nil
case "gif":
return "image/gif", nil
case "bmp":
return "image/bmp", nil
case "tiff":
return "image/tiff", nil
default:
if format != "" {
return "image/" + strings.ToLower(format), nil
}
}
}
}
// Fallback
return "application/octet-stream", nil
}
func GetFileBase64FromUrl(c *gin.Context, url string, reason ...string) (*types.LocalFileData, error) {
contextKey := fmt.Sprintf("file_download_%s", common.GenerateHMAC(url))
// Check if the file has already been downloaded in this request
if cachedData, exists := c.Get(contextKey); exists {
if common.DebugEnabled {
logger.LogDebug(c, fmt.Sprintf("Using cached file data for URL: %s", url))
}
return cachedData.(*types.LocalFileData), nil
}
var maxFileSize = constant.MaxFileDownloadMB * 1024 * 1024
resp, err := DoDownloadRequest(url, reason...)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// Always use LimitReader to prevent oversized downloads
fileBytes, err := io.ReadAll(io.LimitReader(resp.Body, int64(maxFileSize+1)))
if err != nil {
return nil, err
}
// Check actual size after reading
if len(fileBytes) > maxFileSize {
return nil, fmt.Errorf("file size exceeds maximum allowed size: %dMB", constant.MaxFileDownloadMB)
}
// Convert to base64
base64Data := base64.StdEncoding.EncodeToString(fileBytes)
mimeType := resp.Header.Get("Content-Type")
if len(strings.Split(mimeType, ";")) > 1 {
// If Content-Type has parameters, take the first part
mimeType = strings.Split(mimeType, ";")[0]
}
if mimeType == "application/octet-stream" {
logger.LogDebug(c, fmt.Sprintf("MIME type is application/octet-stream for URL: %s", url))
// try to guess the MIME type from the url last segment
urlParts := strings.Split(url, "/")
if len(urlParts) > 0 {
lastSegment := urlParts[len(urlParts)-1]
if strings.Contains(lastSegment, ".") {
// Extract the file extension
filename := strings.Split(lastSegment, ".")
if len(filename) > 1 {
ext := strings.ToLower(filename[len(filename)-1])
// Guess MIME type based on file extension
mimeType = GetMimeTypeByExtension(ext)
}
}
} else {
// try to guess the MIME type from the file extension
fileName := resp.Header.Get("Content-Disposition")
if fileName != "" {
// Extract the filename from the Content-Disposition header
parts := strings.Split(fileName, ";")
for _, part := range parts {
if strings.HasPrefix(strings.TrimSpace(part), "filename=") {
fileName = strings.TrimSpace(strings.TrimPrefix(part, "filename="))
// Remove quotes if present
if len(fileName) > 2 && fileName[0] == '"' && fileName[len(fileName)-1] == '"' {
fileName = fileName[1 : len(fileName)-1]
}
// Guess MIME type based on file extension
if ext := strings.ToLower(strings.TrimPrefix(fileName, ".")); ext != "" {
mimeType = GetMimeTypeByExtension(ext)
}
break
}
}
}
}
}
data := &types.LocalFileData{
Base64Data: base64Data,
MimeType: mimeType,
Size: int64(len(fileBytes)),
}
// Store the file data in the context to avoid re-downloading
c.Set(contextKey, data)
return data, nil
}
func GetMimeTypeByExtension(ext string) string {
// Convert to lowercase for case-insensitive comparison
ext = strings.ToLower(ext)
switch ext {
// Text files
case "txt", "md", "markdown", "csv", "json", "xml", "html", "htm":
return "text/plain"
// Image files
case "jpg", "jpeg":
return "image/jpeg"
case "png":
return "image/png"
case "gif":
return "image/gif"
// Audio files
case "mp3":
return "audio/mp3"
case "wav":
return "audio/wav"
case "mpeg":
return "audio/mpeg"
// Video files
case "mp4":
return "video/mp4"
case "wmv":
return "video/wmv"
case "flv":
return "video/flv"
case "mov":
return "video/mov"
case "mpg":
return "video/mpg"
case "avi":
return "video/avi"
case "mpegps":
return "video/mpegps"
// Document files
case "pdf":
return "application/pdf"
default:
return "application/octet-stream" // Default for unknown types
}
}