tgf / internal /utils /reader.go
bonesmasher's picture
changed the chunk size to 1 mb
32197b4 verified
package utils
import (
"context"
"fmt"
"io"
"time"
"github.com/celestix/gotgproto"
"github.com/gotd/td/tg"
"github.com/gotd/td/tgerr"
"go.uber.org/zap"
)
type telegramReader struct {
messageID int
ctx context.Context
log *zap.Logger
client *gotgproto.Client
location tg.InputFileLocationClass
start int64
end int64
next func() ([]byte, error)
buffer []byte
bytesread int64
chunkSize int64
i int64
contentLength int64
maxRetries int // New field for retry logic
}
func (*telegramReader) Close() error {
return nil
}
func NewTelegramReader(
ctx context.Context,
client *gotgproto.Client,
location tg.InputFileLocationClass,
start int64,
end int64,
contentLength int64,
messageID int,
) (io.ReadCloser, error) {
r := &telegramReader{
messageID: messageID,
ctx: ctx,
log: Logger.Named("telegramReader"),
location: location,
client: client,
start: start,
end: end,
chunkSize: int64(1024 * 1024), // 4 MB chunk size
contentLength: contentLength,
maxRetries: 5, // Allow up to 5 retries for any chunk failure
}
r.log.Sugar().Debug("Start")
r.next = r.partStream()
return r, nil
}
// It relies on the underlying `next()` call being resilient.
func (r *telegramReader) Read(p []byte) (n int, err error) {
if r.bytesread == r.contentLength {
return 0, io.EOF
}
if r.i >= int64(len(r.buffer)) {
r.buffer, err = r.next() // This `next` call now has retry logic inside it.
if err != nil {
// If `next()` fails after all retries, propagate the fatal error.
r.log.Error("Failed to read next buffer after all retries", zap.Error(err))
return 0, err
}
if len(r.buffer) == 0 {
// This is the correct way to signal the end of the stream.
return 0, io.EOF
}
r.i = 0
}
n = copy(p, r.buffer[r.i:])
r.i += int64(n)
r.bytesread += int64(n)
return n, nil
}
func (r *telegramReader) chunk(offset int64, limit int64) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < r.maxRetries; attempt++ {
// Prepare the request in every loop, as the location might change after a refresh.
req := &tg.UploadGetFileRequest{
Offset: offset,
Limit: int(limit),
Location: r.location,
}
res, err := r.client.API().UploadGetFile(r.ctx, req)
// --- Success Path ---
if err == nil {
switch result := res.(type) {
case *tg.UploadFile:
return result.Bytes, nil
default:
// This should not happen in a successful case, but handle it defensively.
return nil, fmt.Errorf("unexpected success type %T", result)
}
}
// --- Error Handling & Retry Path ---
lastErr = err // Store the error in case we exhaust all retries.
r.log.Warn("Failed to download chunk, will retry",
zap.Int("attempt", attempt+1),
zap.Int("max_retries", r.maxRetries),
zap.Int64("offset", offset),
zap.Error(err),
)
if tgerr.Is(err, "FILE_REFERENCE_EXPIRED") {
r.log.Info("File reference expired. Attempting to refresh.", zap.Int("messageID", r.messageID))
newFile, refreshErr := RefreshFileReference(r.ctx, r.client, r.messageID)
if refreshErr != nil {
r.log.Error("Failed to refresh file reference, cannot recover chunk.", zap.Error(refreshErr))
// If refresh fails, the error is fatal. Break the loop and return the error.
lastErr = fmt.Errorf("could not refresh file reference after it expired: %w", refreshErr)
break
}
// Refresh was successful! Update the reader's location for the next attempt.
r.location = newFile.Location
r.log.Info("File reference refreshed successfully. Retrying download immediately.")
continue // Immediately retry the loop with the new, valid reference.
}
// For any other temporary error, wait a bit before retrying.
// This implements a simple "backoff" strategy.
time.Sleep(time.Duration(attempt+1) * 500 * time.Millisecond)
}
// If we've finished the loop, it means we've exhausted all retries.
r.log.Error("Exhausted all retries for chunk download", zap.Error(lastErr))
return nil, fmt.Errorf("failed to download chunk at offset %d after %d retries: %w", offset, r.maxRetries, lastErr)
}
func (r *telegramReader) partStream() func() ([]byte, error) {
start := r.start
end := r.end
offset := start - (start % r.chunkSize)
firstPartCut := start - offset
lastPartCut := (end % r.chunkSize) + 1
partCount := int((end - offset + r.chunkSize) / r.chunkSize)
currentPart := 1
readData := func() ([]byte, error) {
if currentPart > partCount {
return make([]byte, 0), nil
}
res, err := r.chunk(offset, r.chunkSize)
if err != nil {
return nil, err
}
if len(res) == 0 {
return res, nil
} else if partCount == 1 {
res = res[firstPartCut:lastPartCut]
} else if currentPart == 1 {
res = res[firstPartCut:]
} else if currentPart == partCount {
res = res[:lastPartCut]
}
currentPart++
offset += r.chunkSize
r.log.Sugar().Debugf("Part %d/%d", currentPart, partCount)
return res, nil
}
return readData
}