diff --git a/.venv/lib/python3.13/site-packages/instructor-1.14.3.dist-info/licenses/LICENSE b/.venv/lib/python3.13/site-packages/instructor-1.14.3.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f3325f8da4271c8e711369623d138881737177cf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/instructor-1.14.3.dist-info/licenses/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Jason Liu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb3108206bbaa6bcdc3926156830a49e49d58236 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/core.cpython-313.pyc b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/core.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc675503667cb36d7c882d0ebc53ae864d0fc342 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/core.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/defaults.cpython-313.pyc b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/defaults.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17a52ddf6ed4e2fba2a53fd301a471fa55388bfe Binary files /dev/null and b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/defaults.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/exception.cpython-313.pyc b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/exception.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5264d4e04a09a885915ab6ad025a67f93452cee8 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/exception.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/json.cpython-313.pyc b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/json.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca701d687a6abef2789ec86839d7ae89d51c2ece Binary files /dev/null and b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/json.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/utils.cpython-313.pyc b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23e9a0caf99d04fb4c193ac526663b720d780744 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/pythonjsonlogger/__pycache__/utils.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/charts/dotbox.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/charts/dotbox.py new file mode 100644 index 0000000000000000000000000000000000000000..e81404533efaea7c177343040b2e3b1952332e67 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/charts/dotbox.py @@ -0,0 +1,165 @@ +from reportlab.lib.colors import _PCMYK_black +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.shapes import Circle, Drawing, Group, Line, Rect, String +from reportlab.graphics.widgetbase import Widget +from reportlab.lib.attrmap import * +from reportlab.lib.validators import * +from reportlab.lib.units import cm +from reportlab.pdfbase.pdfmetrics import getFont +from reportlab.graphics.charts.lineplots import _maxWidth + +class DotBox(Widget): + """Returns a dotbox widget.""" + + #Doesn't use TypedPropertyCollection for labels - this can be a later improvement + _attrMap = AttrMap( + xlabels = AttrMapValue(isNoneOrListOfNoneOrStrings, + desc="List of text labels for boxes on left hand side"), + ylabels = AttrMapValue(isNoneOrListOfNoneOrStrings, + desc="Text label for second box on left hand side"), + labelFontName = AttrMapValue(isString, + desc="Name of font used for the labels"), + labelFontSize = AttrMapValue(isNumber, + desc="Size of font used for the labels"), + labelOffset = AttrMapValue(isNumber, + desc="Space between label text and grid edge"), + strokeWidth = AttrMapValue(isNumber, + desc='Width of the grid and dot outline'), + gridDivWidth = AttrMapValue(isNumber, + desc="Width of each 'box'"), + gridColor = AttrMapValue(isColor, + desc='Colour for the box and gridding'), + dotDiameter = AttrMapValue(isNumber, + desc="Diameter of the circle used for the 'dot'"), + dotColor = AttrMapValue(isColor, + desc='Colour of the circle on the box'), + dotXPosition = AttrMapValue(isNumber, + desc='X Position of the circle'), + dotYPosition = AttrMapValue(isNumber, + desc='X Position of the circle'), + x = AttrMapValue(isNumber, + desc='X Position of dotbox'), + y = AttrMapValue(isNumber, + desc='Y Position of dotbox'), + ) + + def __init__(self): + self.xlabels=["Value", "Blend", "Growth"] + self.ylabels=["Small", "Medium", "Large"] + self.labelFontName = "Helvetica" + self.labelFontSize = 6 + self.labelOffset = 5 + self.strokeWidth = 0.5 + self.gridDivWidth=0.5*cm + self.gridColor=colors.Color(25/255.0,77/255.0,135/255.0) + self.dotDiameter=0.4*cm + self.dotColor=colors.Color(232/255.0,224/255.0,119/255.0) + self.dotXPosition = 1 + self.dotYPosition = 1 + self.x = 30 + self.y = 5 + + + def _getDrawingDimensions(self): + leftPadding=rightPadding=topPadding=bottomPadding=5 + #find width of grid + tx=len(self.xlabels)*self.gridDivWidth + #add padding (and offset) + tx=tx+leftPadding+rightPadding+self.labelOffset + #add in maximum width of text + tx=tx+_maxWidth(self.xlabels, self.labelFontName, self.labelFontSize) + #find height of grid + ty=len(self.ylabels)*self.gridDivWidth + #add padding (and offset) + ty=ty+topPadding+bottomPadding+self.labelOffset + #add in maximum width of text + ty=ty+_maxWidth(self.ylabels, self.labelFontName, self.labelFontSize) + #print (tx, ty) + return (tx,ty) + + def demo(self,drawing=None): + if not drawing: + tx,ty=self._getDrawingDimensions() + drawing = Drawing(tx,ty) + drawing.add(self.draw()) + return drawing + + def draw(self): + g = Group() + + #box + g.add(Rect(self.x,self.y,len(self.xlabels)*self.gridDivWidth,len(self.ylabels)*self.gridDivWidth, + strokeColor=self.gridColor, + strokeWidth=self.strokeWidth, + fillColor=None)) + + #internal gridding + for f in range (1,len(self.ylabels)): + #horizontal + g.add(Line(strokeColor=self.gridColor, + strokeWidth=self.strokeWidth, + x1 = self.x, + y1 = self.y+f*self.gridDivWidth, + x2 = self.x+len(self.xlabels)*self.gridDivWidth, + y2 = self.y+f*self.gridDivWidth)) + for f in range (1,len(self.xlabels)): + #vertical + g.add(Line(strokeColor=self.gridColor, + strokeWidth=self.strokeWidth, + x1 = self.x+f*self.gridDivWidth, + y1 = self.y, + x2 = self.x+f*self.gridDivWidth, + y2 = self.y+len(self.ylabels)*self.gridDivWidth)) + + # draw the 'dot' + g.add(Circle(strokeColor=self.gridColor, + strokeWidth=self.strokeWidth, + fillColor=self.dotColor, + cx = self.x+(self.dotXPosition*self.gridDivWidth), + cy = self.y+(self.dotYPosition*self.gridDivWidth), + r = self.dotDiameter/2.0)) + + #used for centering y-labels (below) + ascent=getFont(self.labelFontName).face.ascent + if ascent==0: + ascent=0.718 # default (from helvetica) + ascent=ascent*self.labelFontSize # normalize + + #do y-labels + if self.ylabels != None: + for f in range (len(self.ylabels)-1,-1,-1): + if self.ylabels[f]!= None: + g.add(String(strokeColor=self.gridColor, + text = self.ylabels[f], + fontName = self.labelFontName, + fontSize = self.labelFontSize, + fillColor=_PCMYK_black, + x = self.x-self.labelOffset, + y = self.y+(f*self.gridDivWidth+(self.gridDivWidth-ascent)/2.0), + textAnchor = 'end')) + + #do x-labels + if self.xlabels != None: + for f in range (0,len(self.xlabels)): + if self.xlabels[f]!= None: + l=Label() + l.x=self.x+(f*self.gridDivWidth)+(self.gridDivWidth+ascent)/2.0 + l.y=self.y+(len(self.ylabels)*self.gridDivWidth)+self.labelOffset + l.angle=90 + l.textAnchor='start' + l.fontName = self.labelFontName + l.fontSize = self.labelFontSize + l.fillColor = _PCMYK_black + l.setText(self.xlabels[f]) + l.boxAnchor = 'sw' + l.draw() + g.add(l) + + return g + + + + +if __name__ == "__main__": + d = DotBox() + d.demo().save(fnRoot="dotbox") \ No newline at end of file diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/charts/linecharts.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/charts/linecharts.py new file mode 100644 index 0000000000000000000000000000000000000000..4973a5a45b8c44bcd88c07951f26326bfaeb3c02 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/charts/linecharts.py @@ -0,0 +1,801 @@ +#Copyright ReportLab Europe Ltd. 2000-2017 +#see license.txt for license details +#history https://hg.reportlab.com/hg-public/reportlab/log/tip/src/reportlab/graphics/charts/linecharts.py + +__version__='3.3.0' +__doc__="""This modules defines a very preliminary Line Chart example.""" + +from reportlab.lib import colors +from reportlab.lib.validators import isNumber, isNumberOrNone, isColorOrNone, \ + isListOfStringsOrNone, isBoolean, NoneOr, \ + isListOfNumbersOrNone, isStringOrNone, OneOf, Percentage +from reportlab.lib.attrmap import * +from reportlab.lib.utils import flatten +from reportlab.graphics.widgetbase import TypedPropertyCollection, PropHolder, tpcGetItem +from reportlab.graphics.shapes import Line, Rect, Group, Drawing, Polygon, PolyLine +from reportlab.graphics.widgets.signsandsymbols import NoEntry +from reportlab.graphics.charts.axes import XCategoryAxis, YValueAxis, YCategoryAxis, XValueAxis +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.widgets.markers import uSymbol2Symbol, isSymbol, makeMarker +from reportlab.graphics.charts.areas import PlotArea +from reportlab.graphics.charts.legends import _objStr +from .utils import FillPairedData + +class LineChartProperties(PropHolder): + _attrMap = AttrMap( + strokeWidth = AttrMapValue(isNumber, desc='Width of a line.'), + strokeColor = AttrMapValue(isColorOrNone, desc='Color of a line or border.'), + fillColor = AttrMapValue(isColorOrNone, desc='fill color of a bar.'), + strokeDashArray = AttrMapValue(isListOfNumbersOrNone, desc='Dash array of a line.'), + symbol = AttrMapValue(NoneOr(isSymbol), desc='Widget placed at data points.',advancedUsage=1), + shader = AttrMapValue(None, desc='Shader Class.',advancedUsage=1), + filler = AttrMapValue(None, desc='Filler Class.',advancedUsage=1), + name = AttrMapValue(isStringOrNone, desc='Name of the line.'), + lineStyle = AttrMapValue(NoneOr(OneOf('line','joinedLine','bar')), desc="What kind of plot this line is",advancedUsage=1), + barWidth = AttrMapValue(isNumberOrNone,desc="Percentage of available width to be used for a bar",advancedUsage=1), + inFill = AttrMapValue(isBoolean, desc='If true flood fill to x axis',advancedUsage=1), + ) + +class AbstractLineChart(PlotArea): + + def makeSwatchSample(self,rowNo, x, y, width, height): + baseStyle = self.lines + styleIdx = rowNo % len(baseStyle) + style = baseStyle[styleIdx] + color = style.strokeColor + yh2 = y+height/2. + lineStyle = getattr(style,'lineStyle',None) + if lineStyle=='bar': + dash = getattr(style, 'strokeDashArray', getattr(baseStyle,'strokeDashArray',None)) + strokeWidth= getattr(style, 'strokeWidth', getattr(style, 'strokeWidth',None)) + L = Rect(x,y,width,height,strokeWidth=strokeWidth,strokeColor=color,strokeLineCap=0,strokeDashArray=dash,fillColor=getattr(style,'fillColor',color)) + elif self.joinedLines or lineStyle=='joinedLine': + dash = getattr(style, 'strokeDashArray', getattr(baseStyle,'strokeDashArray',None)) + strokeWidth= getattr(style, 'strokeWidth', getattr(style, 'strokeWidth',None)) + L = Line(x,yh2,x+width,yh2,strokeColor=color,strokeLineCap=0) + if strokeWidth: L.strokeWidth = strokeWidth + if dash: L.strokeDashArray = dash + else: + L = None + + if hasattr(style, 'symbol'): + S = style.symbol + elif hasattr(baseStyle, 'symbol'): + S = baseStyle.symbol + else: + S = None + + if S: S = uSymbol2Symbol(S,x+width/2.,yh2,color) + if S and L: + g = Group() + g.add(L) + g.add(S) + return g + return S or L + + def getSeriesName(self,i,default=None): + '''return series name i or default''' + return _objStr(getattr(self.lines[i],'name',default)) + +class LineChart(AbstractLineChart): + pass + +# This is conceptually similar to the VerticalBarChart. +# Still it is better named HorizontalLineChart... :-/ + +class HorizontalLineChart(LineChart): + """Line chart with multiple lines. + + A line chart is assumed to have one category and one value axis. + Despite its generic name this particular line chart class has + a vertical value axis and a horizontal category one. It may + evolve into individual horizontal and vertical variants (like + with the existing bar charts). + + Available attributes are: + + x: x-position of lower-left chart origin + y: y-position of lower-left chart origin + width: chart width + height: chart height + + useAbsolute: disables auto-scaling of chart elements (?) + lineLabelNudge: distance of data labels to data points + lineLabels: labels associated with data values + lineLabelFormat: format string or callback function + groupSpacing: space between categories + + joinedLines: enables drawing of lines + + strokeColor: color of chart lines (?) + fillColor: color for chart background (?) + lines: style list, used cyclically for data series + + valueAxis: value axis object + categoryAxis: category axis object + categoryNames: category names + + data: chart data, a list of data series of equal length + """ + _flipXY = 0 + + _attrMap = AttrMap(BASE=LineChart, + useAbsolute = AttrMapValue(isNumber, desc='Flag to use absolute spacing values.',advancedUsage=1), + lineLabelNudge = AttrMapValue(isNumber, desc='Distance between a data point and its label.',advancedUsage=1), + lineLabels = AttrMapValue(None, desc='Handle to the list of data point labels.'), + lineLabelFormat = AttrMapValue(None, desc='Formatting string or function used for data point labels.'), + lineLabelArray = AttrMapValue(None, desc='explicit array of line label values, must match size of data if present.'), + groupSpacing = AttrMapValue(isNumber, desc='? - Likely to disappear.'), + joinedLines = AttrMapValue(isNumber, desc='Display data points joined with lines if true.'), + lines = AttrMapValue(None, desc='Handle of the lines.'), + valueAxis = AttrMapValue(None, desc='Handle of the value axis.'), + categoryAxis = AttrMapValue(None, desc='Handle of the category axis.'), + categoryNames = AttrMapValue(isListOfStringsOrNone, desc='List of category names.'), + data = AttrMapValue(None, desc='Data to be plotted, list of (lists of) numbers.'), + inFill = AttrMapValue(isBoolean, desc='Whether infilling should be done.',advancedUsage=1), + reversePlotOrder = AttrMapValue(isBoolean, desc='If true reverse plot order.',advancedUsage=1), + annotations = AttrMapValue(None, desc='list of callables, will be called with self, xscale, yscale.',advancedUsage=1), + ) + + def __init__(self): + LineChart.__init__(self) + + # Allow for a bounding rectangle. + self.strokeColor = None + self.fillColor = None + + # Named so we have less recoding for the horizontal one :-) + if self._flipXY: + self.categoryAxis = YCategoryAxis() + self.valueAxis = XValueAxis() + else: + self.categoryAxis = XCategoryAxis() + self.valueAxis = YValueAxis() + + # This defines two series of 3 points. Just an example. + self.data = [(100,110,120,130), + (70, 80, 80, 90)] + self.categoryNames = ('North','South','East','West') + + self.lines = TypedPropertyCollection(LineChartProperties) + self.lines.strokeWidth = 1 + self.lines[0].strokeColor = colors.red + self.lines[1].strokeColor = colors.green + self.lines[2].strokeColor = colors.blue + + # control spacing. if useAbsolute = 1 then + # the next parameters are in points; otherwise + # they are 'proportions' and are normalized to + # fit the available space. + self.useAbsolute = 0 #- not done yet + self.groupSpacing = 1 #5 + + self.lineLabels = TypedPropertyCollection(Label) + self.lineLabelFormat = None + self.lineLabelArray = None + + # This says whether the origin is above or below + # the data point. +10 means put the origin ten points + # above the data point if value > 0, or ten + # points below if data value < 0. This is different + # to label dx/dy which are not dependent on the + # sign of the data. + self.lineLabelNudge = 10 + # If you have multiple series, by default they butt + # together. + + # New line chart attributes. + self.joinedLines = 1 # Connect items with straight lines. + self.inFill = 0 + self.reversePlotOrder = 0 + + def demo(self): + """Shows basic use of a line chart.""" + + drawing = Drawing(200, 100) + + data = [ + (13, 5, 20, 22, 37, 45, 19, 4), + (14, 10, 21, 28, 38, 46, 25, 5) + ] + + lc = HorizontalLineChart() + + lc.x = 20 + lc.y = 10 + lc.height = 85 + lc.width = 170 + lc.data = data + lc.lines.symbol = makeMarker('Circle') + + drawing.add(lc) + + return drawing + + def calcPositions(self): + """Works out where they go. + + Sets an attribute _positions which is a list of + lists of (x, y) matching the data. + """ + + self._seriesCount = len(self.data) + self._rowLength = max(list(map(len,self.data))) + + if self.useAbsolute: + # Dimensions are absolute. + normFactor = 1.0 + else: + # Dimensions are normalized to fit. + normWidth = self.groupSpacing + availWidth = self.categoryAxis.scale(0)[1] + normFactor = availWidth / normWidth + self._normFactor = normFactor + self._vzero = vzero = self.valueAxis.scale(0) + self._hngs = hngs = 0.5 * self.groupSpacing * normFactor + + pairs = set() + P = [].append + cscale = self.categoryAxis.scale + vscale = self.valueAxis.scale + data = self.data + flipXY = self._flipXY + n = len(data) + for rowNo,row in enumerate(data): + if isinstance(row, FillPairedData): + other = row.other + if 0<=other 0: + label.setOrigin(x, y + self.lineLabelNudge) + else: + label.setOrigin(x, y - self.lineLabelNudge) + label.setText(labelText) + else: + label = None + return label + + def drawLabel(self, G, rowNo, colNo, x, y): + '''Draw a label for a given item in the list. + G must have an add method''' + G.add(self._innerDrawLabel(rowNo,colNo,x,y)) + + def makeLines(self): + g = Group() + + labelFmt = self.lineLabelFormat + P = self._positions + if self.reversePlotOrder: P.reverse() + lines = self.lines + styleCount = len(lines) + flipXY = self._flipXY + cA = self.categoryAxis + vA = self.valueAxis + _inFill = self.inFill + if (_inFill or self._pairInFills or + [rowNo for rowNo in range(len(P)) + if getattr(lines[rowNo%styleCount],'inFill',False)] + ): + if flipXY: + infillC = cA._x + infillV0 = vA._y + infillV1 = infillV0 + cA._length + else: + infillC = cA._y + infillV0 = vA._x + infillV1 = infillV0 + cA._length + inFillG = getattr(self,'_inFillG',g) + vzero = self._vzero + bypos = None + + # Iterate over data rows. + for rowNo, row in enumerate(reversed(P) if self.reversePlotOrder else P): + styleIdx = rowNo % styleCount + rowStyle = lines[styleIdx] + strokeColor = rowStyle.strokeColor + fillColor = getattr(rowStyle,'fillColor',strokeColor) + inFill = getattr(rowStyle,'inFill',_inFill) + dash = getattr(rowStyle, 'strokeDashArray', None) + lineStyle = getattr(rowStyle,'lineStyle',None) + + if hasattr(rowStyle, 'strokeWidth'): + strokeWidth = rowStyle.strokeWidth + elif hasattr(lines, 'strokeWidth'): + strokeWidth = lines.strokeWidth + else: + strokeWidth = None + + # Iterate over data columns. + if lineStyle=='bar': + if bypos is None: + if flipXY: + bypos = max(vA._x,vzero) + byneg = min(vA._x+vA._length,vzero) + else: + bypos = max(vA._y,vzero) + byneg = min(vA._y+vA._length,vzero) + barWidth = getattr(rowStyle,'barWidth',Percentage(50)) + if isinstance(barWidth,Percentage): + hbw = self._hngs*barWidth*0.01 + else: + hbw = barWidth*0.5 + for x, y in row: + if flipXY: + v0 = byneg if x x + self.width) or (crossesAt < x)): + x = crossesAt + cA.setPosition(x, y, self.height) + else: + # If zero is in chart, put x axis there, otherwise + # use bottom. + crossesAt = vA.scale(0) + if not ((crossesAt > y + self.height) or (crossesAt < y)): + y = crossesAt + cA.setPosition(x, y, self.width) + cA.configure(self.data) + + self.calcPositions() + + g = Group() + g.add(self.makeBackground()) + if self.inFill: + self._inFillG = Group() + g.add(self._inFillG) + + g.add(cA) + g.add(vA) + cAdgl = getattr(cA,'drawGridLast',False) + vAdgl = getattr(vA,'drawGridLast',False) + if not cAdgl: cA.makeGrid(g,parent=self,dim=vA.getGridDims) + if not vAdgl: vA.makeGrid(g,parent=self,dim=cA.getGridDims) + g.add(self.makeLines()) + if cAdgl: cA.makeGrid(g,parent=self,dim=vA.getGridDims) + if vAdgl: vA.makeGrid(g,parent=self,dim=cA.getGridDims) + for a in getattr(self,'annotations',()): g.add(a(self,cA.scale,vA.scale)) + return g + +def _fakeItemKey(a): + '''t, z0, z1, x, y = a[:5]''' + return (-a[1],a[3],a[0],-a[4]) + +class _FakeGroup: + def __init__(self): + self._data = [] + + def add(self,what): + if what: self._data.append(what) + + def value(self): + return self._data + + def sort(self): + self._data.sort(key=_fakeItemKey) + #for t in self._data: print t + +class HorizontalLineChart3D(HorizontalLineChart): + _attrMap = AttrMap(BASE=HorizontalLineChart, + theta_x = AttrMapValue(isNumber, desc='dx/dz'), + theta_y = AttrMapValue(isNumber, desc='dy/dz'), + zDepth = AttrMapValue(isNumber, desc='depth of an individual series'), + zSpace = AttrMapValue(isNumber, desc='z gap around series'), + ) + theta_x = .5 + theta_y = .5 + zDepth = 10 + zSpace = 3 + + def calcPositions(self): + HorizontalLineChart.calcPositions(self) + nSeries = self._seriesCount + zSpace = self.zSpace + zDepth = self.zDepth + if self.categoryAxis.style=='parallel_3d': + _3d_depth = nSeries*zDepth+(nSeries+1)*zSpace + else: + _3d_depth = zDepth + 2*zSpace + self._3d_dx = self.theta_x*_3d_depth + self._3d_dy = self.theta_y*_3d_depth + + def _calc_z0(self,rowNo): + zSpace = self.zSpace + if self.categoryAxis.style=='parallel_3d': + z0 = rowNo*(self.zDepth+zSpace)+zSpace + else: + z0 = zSpace + return z0 + + def _zadjust(self,x,y,z): + return x+z*self.theta_x, y+z*self.theta_y + + def makeLines(self): + labelFmt = self.lineLabelFormat + P = list(range(len(self._positions))) + if self.reversePlotOrder: P.reverse() + inFill = self.inFill + assert not inFill, "inFill not supported for 3d yet" + #if inFill: + #inFillY = self.categoryAxis._y + #inFillX0 = self.valueAxis._x + #inFillX1 = inFillX0 + self.categoryAxis._length + #inFillG = getattr(self,'_inFillG',g) + zDepth = self.zDepth + _zadjust = self._zadjust + theta_x = self.theta_x + theta_y = self.theta_y + F = _FakeGroup() + from reportlab.graphics.charts.utils3d import _make_3d_line_info + tileWidth = getattr(self,'_3d_tilewidth',None) + if not tileWidth and self.categoryAxis.style!='parallel_3d': tileWidth = 1 + + # Iterate over data rows. + for rowNo in P: + row = self._positions[rowNo] + n = len(row) + styleCount = len(self.lines) + styleIdx = rowNo % styleCount + rowStyle = self.lines[styleIdx] + rowColor = rowStyle.strokeColor + dash = getattr(rowStyle, 'strokeDashArray', None) + z0 = self._calc_z0(rowNo) + z1 = z0 + zDepth + + if hasattr(self.lines[styleIdx], 'strokeWidth'): + strokeWidth = self.lines[styleIdx].strokeWidth + elif hasattr(self.lines, 'strokeWidth'): + strokeWidth = self.lines.strokeWidth + else: + strokeWidth = None + + # Iterate over data columns. + if self.joinedLines: + if n: + x0, y0 = row[0] + for colNo in range(1,n): + x1, y1 = row[colNo] + _make_3d_line_info( F, x0, x1, y0, y1, z0, z1, + theta_x, theta_y, + rowColor, fillColorShaded=None, tileWidth=tileWidth, + strokeColor=None, strokeWidth=None, strokeDashArray=None, + shading=0.1) + x0, y0 = x1, y1 + + if hasattr(self.lines[styleIdx], 'symbol'): + uSymbol = self.lines[styleIdx].symbol + elif hasattr(self.lines, 'symbol'): + uSymbol = self.lines.symbol + else: + uSymbol = None + + if uSymbol: + for colNo in range(n): + x1, y1 = row[colNo] + x1, y1 = _zadjust(x1,y1,z0) + symbol = uSymbol2Symbol(uSymbol,x1,y1,rowColor) + if symbol: F.add((2,z0,z0,x1,y1,symbol)) + + # Draw item labels. + for colNo in range(n): + x1, y1 = row[colNo] + x1, y1 = _zadjust(x1,y1,z0) + L = self._innerDrawLabel(rowNo, colNo, x1, y1) + if L: F.add((2,z0,z0,x1,y1,L)) + + F.sort() + g = Group() + for v in F.value(): g.add(v[-1]) + return g + +class VerticalLineChart(HorizontalLineChart): + _flipXY = 1 + +def sample1(): + drawing = Drawing(400, 200) + + data = [ + (13, 5, 20, 22, 37, 45, 19, 4), + (5, 20, 46, 38, 23, 21, 6, 14) + ] + + lc = HorizontalLineChart() + + lc.x = 50 + lc.y = 50 + lc.height = 125 + lc.width = 300 + lc.data = data + lc.joinedLines = 1 + lc.lines.symbol = makeMarker('FilledDiamond') + lc.lineLabelFormat = '%2.0f' + + catNames = 'Jan Feb Mar Apr May Jun Jul Aug'.split(' ') + lc.categoryAxis.categoryNames = catNames + lc.categoryAxis.labels.boxAnchor = 'n' + + lc.valueAxis.valueMin = 0 + lc.valueAxis.valueMax = 60 + lc.valueAxis.valueStep = 15 + + drawing.add(lc) + + return drawing + +class SampleHorizontalLineChart(HorizontalLineChart): + "Sample class overwriting one method to draw additional horizontal lines." + + def demo(self): + """Shows basic use of a line chart.""" + + drawing = Drawing(200, 100) + + data = [ + (13, 5, 20, 22, 37, 45, 19, 4), + (14, 10, 21, 28, 38, 46, 25, 5) + ] + + lc = SampleHorizontalLineChart() + + lc.x = 20 + lc.y = 10 + lc.height = 85 + lc.width = 170 + lc.data = data + lc.strokeColor = colors.white + lc.fillColor = colors.HexColor(0xCCCCCC) + + drawing.add(lc) + + return drawing + + def makeBackground(self): + g = Group() + + g.add(HorizontalLineChart.makeBackground(self)) + + valAxis = self.valueAxis + valTickPositions = valAxis._tickValues + + for y in valTickPositions: + y = valAxis.scale(y) + g.add(Line(self.x, y, self.x+self.width, y, + strokeColor = self.strokeColor)) + + return g + +def sample1a(): + drawing = Drawing(400, 200) + + data = [ + (13, 5, 20, 22, 37, 45, 19, 4), + (5, 20, 46, 38, 23, 21, 6, 14) + ] + + lc = SampleHorizontalLineChart() + + lc.x = 50 + lc.y = 50 + lc.height = 125 + lc.width = 300 + lc.data = data + lc.joinedLines = 1 + lc.strokeColor = colors.white + lc.fillColor = colors.HexColor(0xCCCCCC) + lc.lines.symbol = makeMarker('FilledDiamond') + lc.lineLabelFormat = '%2.0f' + + catNames = 'Jan Feb Mar Apr May Jun Jul Aug'.split(' ') + lc.categoryAxis.categoryNames = catNames + lc.categoryAxis.labels.boxAnchor = 'n' + + lc.valueAxis.valueMin = 0 + lc.valueAxis.valueMax = 60 + lc.valueAxis.valueStep = 15 + + drawing.add(lc) + + return drawing + +def sample2(): + drawing = Drawing(400, 200) + + data = [ + (13, 5, 20, 22, 37, 45, 19, 4), + (5, 20, 46, 38, 23, 21, 6, 14) + ] + + lc = HorizontalLineChart() + + lc.x = 50 + lc.y = 50 + lc.height = 125 + lc.width = 300 + lc.data = data + lc.joinedLines = 1 + lc.lines.symbol = makeMarker('Smiley') + lc.lineLabelFormat = '%2.0f' + lc.strokeColor = colors.black + lc.fillColor = colors.lightblue + + catNames = 'Jan Feb Mar Apr May Jun Jul Aug'.split(' ') + lc.categoryAxis.categoryNames = catNames + lc.categoryAxis.labels.boxAnchor = 'n' + + lc.valueAxis.valueMin = 0 + lc.valueAxis.valueMax = 60 + lc.valueAxis.valueStep = 15 + + drawing.add(lc) + + return drawing + +def sample3(): + drawing = Drawing(400, 200) + + data = [ + (13, 5, 20, 22, 37, 45, 19, 4), + (5, 20, 46, 38, 23, 21, 6, 14) + ] + + lc = HorizontalLineChart() + + lc.x = 50 + lc.y = 50 + lc.height = 125 + lc.width = 300 + lc.data = data + lc.joinedLines = 1 + lc.lineLabelFormat = '%2.0f' + lc.strokeColor = colors.black + + lc.lines[0].symbol = makeMarker('Smiley') + lc.lines[1].symbol = NoEntry + lc.lines[0].strokeWidth = 2 + lc.lines[1].strokeWidth = 4 + + catNames = 'Jan Feb Mar Apr May Jun Jul Aug'.split(' ') + lc.categoryAxis.categoryNames = catNames + lc.categoryAxis.labels.boxAnchor = 'n' + + lc.valueAxis.valueMin = 0 + lc.valueAxis.valueMax = 60 + lc.valueAxis.valueStep = 15 + + drawing.add(lc) + + return drawing + +def sampleCandleStick(): + from reportlab.graphics.widgetbase import CandleSticks + d = Drawing(400, 200) + chart = HorizontalLineChart() + d.add(chart) + chart.y = 20 + boxMid = (100, 110, 120, 130) + hi = [m+10 for m in boxMid] + lo = [m-10 for m in boxMid] + boxHi = [m+6 for m in boxMid] + boxLo = [m-4 for m in boxMid] + boxFillColor = colors.pink + boxWidth = 20 + crossWidth = 10 + candleStrokeWidth = 0.5 + candleStrokeColor = colors.black + chart.valueAxis.avoidBoundSpace = 5 + + chart.valueAxis.valueMin = min(min(boxMid),min(hi),min(lo),min(boxLo),min(boxHi)) + chart.valueAxis.valueMax = max(max(boxMid),max(hi),max(lo),max(boxLo),max(boxHi)) + lines = chart.lines + lines[0].strokeColor = None + I = range(len(boxMid)) + chart.data = [boxMid] + lines[0].symbol = candles = CandleSticks(chart=chart, boxFillColor=boxFillColor, boxWidth=boxWidth, crossWidth=crossWidth, strokeWidth=candleStrokeWidth, strokeColor=candleStrokeColor) + for i in I: candles[i].setProperties(dict(position=i,boxMid=boxMid[i],crossLo=lo[i],crossHi=hi[i],boxLo=boxLo[i],boxHi=boxHi[i])) + return d diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/__init__.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c16b3b27f8b6bc0483be20423de078e44ca1b219 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/__init__.py @@ -0,0 +1 @@ +__doc__="""Example drawings to review, used in autogenerated docs""" diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/bubble.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/bubble.py new file mode 100644 index 0000000000000000000000000000000000000000..57bc012a8960ea4fae0240145f80f5cc65671b1a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/bubble.py @@ -0,0 +1,73 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.lineplots import ScatterPlot +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class Bubble(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,ScatterPlot(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.lines[0].strokeColor = color01 + self.chart.lines[1].strokeColor = color02 + self.chart.lines[2].strokeColor = color03 + self.chart.lines[3].strokeColor = color04 + self.chart.lines[4].strokeColor = color05 + self.chart.lines[5].strokeColor = color06 + self.chart.lines[6].strokeColor = color07 + self.chart.lines[7].strokeColor = color08 + self.chart.lines[8].strokeColor = color09 + self.chart.lines[9].strokeColor = color10 + self.chart.lines.symbol.kind ='Circle' + self.chart.lines.symbol.size = 15 + self.chart.fillColor = backgroundGrey + self.chart.lineLabels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontSize = 7 + self.chart.xValueAxis.forceZero = 0 + self.chart.data = [((100,100), (200,200), (250,210), (300,300), (350,450))] + self.chart.xValueAxis.avoidBoundFrac = 1 + self.chart.xValueAxis.gridEnd = 115 + self.chart.xValueAxis.tickDown = 3 + self.chart.xValueAxis.visibleGrid = 1 + self.chart.yValueAxis.tickLeft = 3 + self.chart.yValueAxis.labels.fontName = 'Helvetica' + self.chart.yValueAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.chart.lineLabelFormat = None + self.chart.xLabel = 'X Axis' + self.chart.y = 30 + self.chart.yLabel = 'Y Axis' + self.chart.yValueAxis.labelTextFormat = '%d' + self.chart.yValueAxis.forceZero = 1 + self.chart.xValueAxis.forceZero = 1 + + + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + Bubble().save(formats=['pdf'],outDir=None,fnRoot='bubble') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/clustered_bar.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/clustered_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..e57836134f7f02ea0c32af0e03b313a59ac1c41d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/clustered_bar.py @@ -0,0 +1,84 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.samples.excelcolors import * +from reportlab.graphics.charts.barcharts import HorizontalBarChart +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label + +class ClusteredBar(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,HorizontalBarChart(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.bars[0].fillColor = color01 + self.chart.bars[1].fillColor = color02 + self.chart.bars[2].fillColor = color03 + self.chart.bars[3].fillColor = color04 + self.chart.bars[4].fillColor = color05 + self.chart.bars[5].fillColor = color06 + self.chart.bars[6].fillColor = color07 + self.chart.bars[7].fillColor = color08 + self.chart.bars[8].fillColor = color09 + self.chart.bars[9].fillColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.barLabels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontSize = 6 + self.chart.valueAxis.forceZero = 1 + self.chart.data = [(100, 150, 180), (125, 180, 200)] + self.chart.groupSpacing = 15 + self.chart.valueAxis.avoidBoundFrac = 1 + self.chart.valueAxis.gridEnd = 80 + self.chart.valueAxis.tickDown = 3 + self.chart.valueAxis.visibleGrid = 1 + self.chart.categoryAxis.categoryNames = ['North', 'South', 'Central'] + self.chart.categoryAxis.tickLeft = 3 + self.chart.categoryAxis.labels.fontName = 'Helvetica' + self.chart.categoryAxis.labels.fontSize = 6 + self.chart.categoryAxis.labels.dx = -3 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,Label(),name='XLabel',validate=None,desc="The label on the horizontal axis") + self.XLabel.fontName = 'Helvetica' + self.XLabel.fontSize = 7 + self.XLabel.x = 85 + self.XLabel.y = 10 + self.XLabel.textAnchor ='middle' + self.XLabel.maxWidth = 100 + self.XLabel.height = 20 + self.XLabel._text = "X Axis" + self._add(self,Label(),name='YLabel',validate=None,desc="The label on the vertical axis") + self.YLabel.fontName = 'Helvetica' + self.YLabel.fontSize = 7 + self.YLabel.x = 12 + self.YLabel.y = 80 + self.YLabel.angle = 90 + self.YLabel.textAnchor ='middle' + self.YLabel.maxWidth = 100 + self.YLabel.height = 20 + self.YLabel._text = "Y Axis" + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + ClusteredBar().save(formats=['pdf'],outDir=None,fnRoot='clustered_bar') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/clustered_column.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/clustered_column.py new file mode 100644 index 0000000000000000000000000000000000000000..055f1ea7754f978c55b33195dab184650e015eda --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/clustered_column.py @@ -0,0 +1,83 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.samples.excelcolors import * +from reportlab.graphics.charts.barcharts import VerticalBarChart +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label + +class ClusteredColumn(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,VerticalBarChart(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.bars[0].fillColor = color01 + self.chart.bars[1].fillColor = color02 + self.chart.bars[2].fillColor = color03 + self.chart.bars[3].fillColor = color04 + self.chart.bars[4].fillColor = color05 + self.chart.bars[5].fillColor = color06 + self.chart.bars[6].fillColor = color07 + self.chart.bars[7].fillColor = color08 + self.chart.bars[8].fillColor = color09 + self.chart.bars[9].fillColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.barLabels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontSize = 7 + self.chart.valueAxis.forceZero = 1 + self.chart.data = [(100, 150, 180), (125, 180, 200)] + self.chart.groupSpacing = 15 + self.chart.valueAxis.avoidBoundFrac = 1 + self.chart.valueAxis.gridEnd = 115 + self.chart.valueAxis.tickLeft = 3 + self.chart.valueAxis.visibleGrid = 1 + self.chart.categoryAxis.categoryNames = ['North', 'South', 'Central'] + self.chart.categoryAxis.tickDown = 3 + self.chart.categoryAxis.labels.fontName = 'Helvetica' + self.chart.categoryAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,Label(),name='XLabel',validate=None,desc="The label on the horizontal axis") + self.XLabel.fontName = 'Helvetica' + self.XLabel.fontSize = 7 + self.XLabel.x = 85 + self.XLabel.y = 10 + self.XLabel.textAnchor ='middle' + self.XLabel.maxWidth = 100 + self.XLabel.height = 20 + self.XLabel._text = "X Axis" + self._add(self,Label(),name='YLabel',validate=None,desc="The label on the vertical axis") + self.YLabel.fontName = 'Helvetica' + self.YLabel.fontSize = 7 + self.YLabel.x = 12 + self.YLabel.y = 80 + self.YLabel.angle = 90 + self.YLabel.textAnchor ='middle' + self.YLabel.maxWidth = 100 + self.YLabel.height = 20 + self.YLabel._text = "Y Axis" + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + ClusteredColumn().save(formats=['pdf'],outDir=None,fnRoot='clustered_column') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/excelcolors.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/excelcolors.py new file mode 100644 index 0000000000000000000000000000000000000000..23b0213cd0ca3f3909182991e47139a6f440e6ff --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/excelcolors.py @@ -0,0 +1,45 @@ +# define standard colors to mimic those used by Microsoft Excel +from reportlab.lib.colors import PCMYKColor + +#colour names as comments at the end of each line are as a memory jogger ONLY +#NOT HTML named colours! + +#Main colours as used for bars etc +color01 = PCMYKColor(40,40,0,0) # Lavender +color02 = PCMYKColor(0,66,33,39) # Maroon +color03 = PCMYKColor(0,0,20,0) # Yellow +color04 = PCMYKColor(20,0,0,0) # Cyan +color05 = PCMYKColor(0,100,0,59) # Purple +color06 = PCMYKColor(0,49,49,0) # Salmon +color07 = PCMYKColor(100,49,0,19) # Blue +color08 = PCMYKColor(20,20,0,0) # PaleLavender +color09 = PCMYKColor(100,100,0,49) # NavyBlue +color10 = PCMYKColor(0,100,0,0) # Purple + +#Highlight colors - eg for the tops of bars +color01Light = PCMYKColor(39,39,0,25) # Light Lavender +color02Light = PCMYKColor(0,66,33,54) # Light Maroon +color03Light = PCMYKColor(0,0,19,25) # Light Yellow +color04Light = PCMYKColor(19,0,0,25) # Light Cyan +color05Light = PCMYKColor(0,100,0,69) # Light Purple +color06Light = PCMYKColor(0,49,49,25) # Light Salmon +color07Light = PCMYKColor(100,49,0,39) # Light Blue +color08Light = PCMYKColor(19,19,0,25) # Light PaleLavender +color09Light = PCMYKColor(100,100,0,62) # Light NavyBlue +color10Light = PCMYKColor(0,100,0,25) # Light Purple + +#Lowlight colors - eg for the sides of bars +color01Dark = PCMYKColor(39,39,0,49) # Dark Lavender +color02Dark = PCMYKColor(0,66,33,69) # Dark Maroon +color03Dark = PCMYKColor(0,0,20,49) # Dark Yellow +color04Dark = PCMYKColor(20,0,0,49) # Dark Cyan +color05Dark = PCMYKColor(0,100,0,80) # Dark Purple +color06Dark = PCMYKColor(0,50,50,49) # Dark Salmon +color07Dark = PCMYKColor(100,50,0,59) # Dark Blue +color08Dark = PCMYKColor(20,20,0,49) # Dark PaleLavender +color09Dark = PCMYKColor(100,100,0,79) # Dark NavyBlue +color10Dark = PCMYKColor(0,100,0,49) # Dark Purple + +#for standard grey backgrounds +backgroundGrey = PCMYKColor(0,0,0,24) + diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/exploded_pie.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/exploded_pie.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d99bc7b582bee4feea55449cd87759b2c9e93b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/exploded_pie.py @@ -0,0 +1,65 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.piecharts import Pie +from reportlab.graphics.samples.excelcolors import * +from reportlab.graphics.widgets.grids import ShadedRect +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label + +class ExplodedPie(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,Pie(),name='chart',validate=None,desc="The main chart") + self.chart.width = 100 + self.chart.height = 100 + self.chart.x = 25 + self.chart.y = 25 + self.chart.slices[0].fillColor = color01 + self.chart.slices[1].fillColor = color02 + self.chart.slices[2].fillColor = color03 + self.chart.slices[3].fillColor = color04 + self.chart.slices[4].fillColor = color05 + self.chart.slices[5].fillColor = color06 + self.chart.slices[6].fillColor = color07 + self.chart.slices[7].fillColor = color08 + self.chart.slices[8].fillColor = color09 + self.chart.slices[9].fillColor = color10 + self.chart.data = (100, 150, 180) + self.chart.startAngle = -90 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'North'), (color02, 'South'), (color03, 'Central')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 160 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.Legend.columnMaximum = 10 + self.chart.slices.strokeWidth = 1 + self.chart.slices.fontName = 'Helvetica' + self.background = ShadedRect() + self.background.fillColorStart = backgroundGrey + self.background.fillColorEnd = backgroundGrey + self.background.numShades = 1 + self.background.strokeWidth = 0.5 + self.background.x = 20 + self.background.y = 20 + self.chart.slices.popout = 5 + self.background.height = 110 + self.background.width = 110 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + ExplodedPie().save(formats=['pdf'],outDir=None,fnRoot='exploded_pie') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/filled_radar.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/filled_radar.py new file mode 100644 index 0000000000000000000000000000000000000000..7761a4a55a36a7d58eb4aab667ce8f262944d313 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/filled_radar.py @@ -0,0 +1,54 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.spider import SpiderChart +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class FilledRadarChart(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,SpiderChart(),name='chart',validate=None,desc="The main chart") + self.chart.width = 90 + self.chart.height = 90 + self.chart.x = 45 + self.chart.y = 25 + self.chart.strands[0].fillColor = color01 + self.chart.strands[1].fillColor = color02 + self.chart.strands[2].fillColor = color03 + self.chart.strands[3].fillColor = color04 + self.chart.strands[4].fillColor = color05 + self.chart.strands[5].fillColor = color06 + self.chart.strands[6].fillColor = color07 + self.chart.strands[7].fillColor = color08 + self.chart.strands[8].fillColor = color09 + self.chart.strands[9].fillColor = color10 + self.chart.strandLabels.fontName = 'Helvetica' + self.chart.strandLabels.fontSize = 6 + self.chart.fillColor = backgroundGrey + self.chart.data = [(125, 180, 200), (100, 150, 180)] + self.chart.labels = ['North', 'South', 'Central'] + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + FilledRadarChart().save(formats=['pdf'],outDir=None,fnRoot='filled_radar') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/line_chart.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/line_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6bcab005c53378021677396817a432c107a14c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/line_chart.py @@ -0,0 +1,83 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.lineplots import LinePlot +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class LineChart(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,LinePlot(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.lines[0].strokeColor = color01 + self.chart.lines[1].strokeColor = color02 + self.chart.lines[2].strokeColor = color03 + self.chart.lines[3].strokeColor = color04 + self.chart.lines[4].strokeColor = color05 + self.chart.lines[5].strokeColor = color06 + self.chart.lines[6].strokeColor = color07 + self.chart.lines[7].strokeColor = color08 + self.chart.lines[8].strokeColor = color09 + self.chart.lines[9].strokeColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.lineLabels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontSize = 7 + self.chart.xValueAxis.forceZero = 0 + self.chart.data = [((0, 50), (100,100), (200,200), (250,210), (300,300), (400,500)), ((0, 150), (100,200), (200,300), (250,200), (300,400), (400, 600))] + self.chart.xValueAxis.avoidBoundFrac = 1 + self.chart.xValueAxis.gridEnd = 115 + self.chart.xValueAxis.tickDown = 3 + self.chart.xValueAxis.visibleGrid = 1 + self.chart.yValueAxis.tickLeft = 3 + self.chart.yValueAxis.labels.fontName = 'Helvetica' + self.chart.yValueAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,Label(),name='XLabel',validate=None,desc="The label on the horizontal axis") + self.XLabel.fontName = 'Helvetica' + self.XLabel.fontSize = 7 + self.XLabel.x = 85 + self.XLabel.y = 10 + self.XLabel.textAnchor ='middle' + self.XLabel.maxWidth = 100 + self.XLabel.height = 20 + self.XLabel._text = "X Axis" + self._add(self,Label(),name='YLabel',validate=None,desc="The label on the vertical axis") + self.YLabel.fontName = 'Helvetica' + self.YLabel.fontSize = 7 + self.YLabel.x = 12 + self.YLabel.y = 80 + self.YLabel.angle = 90 + self.YLabel.textAnchor ='middle' + self.YLabel.maxWidth = 100 + self.YLabel.height = 20 + self.YLabel._text = "Y Axis" + self.chart.yValueAxis.forceZero = 1 + self.chart.xValueAxis.forceZero = 1 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + LineChart().save(formats=['pdf'],outDir=None,fnRoot='line_chart') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/linechart_with_markers.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/linechart_with_markers.py new file mode 100644 index 0000000000000000000000000000000000000000..981e69ac47a323d4a56c669008b859403a590985 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/linechart_with_markers.py @@ -0,0 +1,94 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.lineplots import LinePlot +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.widgets.markers import makeMarker +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class LineChartWithMarkers(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,LinePlot(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.lines[0].strokeColor = color01 + self.chart.lines[1].strokeColor = color02 + self.chart.lines[2].strokeColor = color03 + self.chart.lines[3].strokeColor = color04 + self.chart.lines[4].strokeColor = color05 + self.chart.lines[5].strokeColor = color06 + self.chart.lines[6].strokeColor = color07 + self.chart.lines[7].strokeColor = color08 + self.chart.lines[8].strokeColor = color09 + self.chart.lines[9].strokeColor = color10 + self.chart.lines[0].symbol = makeMarker('FilledSquare') + self.chart.lines[1].symbol = makeMarker('FilledDiamond') + self.chart.lines[2].symbol = makeMarker('FilledStarFive') + self.chart.lines[3].symbol = makeMarker('FilledTriangle') + self.chart.lines[4].symbol = makeMarker('FilledCircle') + self.chart.lines[5].symbol = makeMarker('FilledPentagon') + self.chart.lines[6].symbol = makeMarker('FilledStarSix') + self.chart.lines[7].symbol = makeMarker('FilledHeptagon') + self.chart.lines[8].symbol = makeMarker('FilledOctagon') + self.chart.lines[9].symbol = makeMarker('FilledCross') + self.chart.fillColor = backgroundGrey + self.chart.lineLabels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontSize = 7 + self.chart.xValueAxis.forceZero = 0 + self.chart.data = [((0, 50), (100,100), (200,200), (250,210), (300,300), (400,500)), ((0, 150), (100,200), (200,300), (250,200), (300,400), (400, 600))] + self.chart.xValueAxis.avoidBoundFrac = 1 + self.chart.xValueAxis.gridEnd = 115 + self.chart.xValueAxis.tickDown = 3 + self.chart.xValueAxis.visibleGrid = 1 + self.chart.yValueAxis.tickLeft = 3 + self.chart.yValueAxis.labels.fontName = 'Helvetica' + self.chart.yValueAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,Label(),name='XLabel',validate=None,desc="The label on the horizontal axis") + self.XLabel.fontName = 'Helvetica' + self.XLabel.fontSize = 7 + self.XLabel.x = 85 + self.XLabel.y = 10 + self.XLabel.textAnchor ='middle' + self.XLabel.maxWidth = 100 + self.XLabel.height = 20 + self.XLabel._text = "X Axis" + self._add(self,Label(),name='YLabel',validate=None,desc="The label on the vertical axis") + self.YLabel.fontName = 'Helvetica' + self.YLabel.fontSize = 7 + self.YLabel.x = 12 + self.YLabel.y = 80 + self.YLabel.angle = 90 + self.YLabel.textAnchor ='middle' + self.YLabel.maxWidth = 100 + self.YLabel.height = 20 + self.YLabel._text = "Y Axis" + self.chart.yValueAxis.forceZero = 1 + self.chart.xValueAxis.forceZero = 1 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + LineChartWithMarkers().save(formats=['pdf'],outDir=None,fnRoot='linechart_with_markers') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/radar.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/radar.py new file mode 100644 index 0000000000000000000000000000000000000000..b7480546ec963bc43dfdc64d4c278ff2b049da31 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/radar.py @@ -0,0 +1,66 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.samples.excelcolors import * +from reportlab.graphics.charts.spider import SpiderChart +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label + +class RadarChart(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,SpiderChart(),name='chart',validate=None,desc="The main chart") + self.chart.width = 90 + self.chart.height = 90 + self.chart.x = 45 + self.chart.y = 25 + self.chart.strands[0].strokeColor= color01 + self.chart.strands[1].strokeColor= color02 + self.chart.strands[2].strokeColor= color03 + self.chart.strands[3].strokeColor= color04 + self.chart.strands[4].strokeColor= color05 + self.chart.strands[5].strokeColor= color06 + self.chart.strands[6].strokeColor= color07 + self.chart.strands[7].strokeColor= color08 + self.chart.strands[8].strokeColor= color09 + self.chart.strands[9].strokeColor= color10 + self.chart.strands[0].fillColor = None + self.chart.strands[1].fillColor = None + self.chart.strands[2].fillColor = None + self.chart.strands[3].fillColor = None + self.chart.strands[4].fillColor = None + self.chart.strands[5].fillColor = None + self.chart.strands[6].fillColor = None + self.chart.strands[7].fillColor = None + self.chart.strands[8].fillColor = None + self.chart.strands[9].fillColor = None + self.chart.strands.strokeWidth = 1 + self.chart.strandLabels.fontName = 'Helvetica' + self.chart.strandLabels.fontSize = 6 + self.chart.fillColor = backgroundGrey + self.chart.data = [(125, 180, 200), (100, 150, 180)] + self.chart.labels = ['North', 'South', 'Central'] + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.chart.strands.strokeWidth = 1 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + RadarChart().save(formats=['pdf'],outDir=None,fnRoot='radar') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/runall.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/runall.py new file mode 100644 index 0000000000000000000000000000000000000000..9d6400a5af216801ad24f427a29e9a12651b1368 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/runall.py @@ -0,0 +1,57 @@ +# runs all the GUIedit charts in this directory - +# makes a PDF sample for eaxh existing chart type +import sys +import glob +import inspect + +def moduleClasses(mod): + def P(obj, m=mod.__name__, CT=type): + return (type(obj)==CT and obj.__module__==m) + try: + return inspect.getmembers(mod, P)[0][1] + except: + return None + +def getclass(f): + return moduleClasses(__import__(f)) + +def run(format, VERBOSE=0): + formats = format.split( ',') + for i in range(0, len(formats)): + formats[i] == formats[i].strip().lower() + allfiles = glob.glob('*.py') + allfiles.sort() + for fn in allfiles: + f = fn.split('.')[0] + c = getclass(f) + if c != None: + print(c.__name__) + try: + for fmt in formats: + if fmt: + c().save(formats=[fmt],outDir='.',fnRoot=c.__name__) + if VERBOSE: + print(" %s.%s" % (c.__name__, fmt)) + except: + print(" COULDN'T CREATE '%s.%s'!" % (c.__name__, format)) + +if __name__ == "__main__": + if len(sys.argv) == 1: + run('pdf,pict,png') + else: + try: + if sys.argv[1] == "-h": + print('usage: runall.py [FORMAT] [-h]') + print(' if format is supplied is should be one or more of pdf,gif,eps,png etc') + print(' if format is missing the following formats are assumed: pdf,pict,png') + print(' -h prints this message') + else: + t = sys.argv[1:] + for f in t: + run(f) + except: + print('usage: runall.py [FORMAT][-h]') + print(' if format is supplied is should be one or more of pdf,gif,eps,png etc') + print(' if format is missing the following formats are assumed: pdf,pict,png') + print(' -h prints this message') + raise diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..a5a17f0231bba4ae548cc36ab0ef56167558a87a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter.py @@ -0,0 +1,71 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.lineplots import ScatterPlot +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class Scatter(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,ScatterPlot(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.lines[0].strokeColor = color01 + self.chart.lines[1].strokeColor = color02 + self.chart.lines[2].strokeColor = color03 + self.chart.lines[3].strokeColor = color04 + self.chart.lines[4].strokeColor = color05 + self.chart.lines[5].strokeColor = color06 + self.chart.lines[6].strokeColor = color07 + self.chart.lines[7].strokeColor = color08 + self.chart.lines[8].strokeColor = color09 + self.chart.lines[9].strokeColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.lineLabels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontSize = 7 + self.chart.xValueAxis.forceZero = 0 + self.chart.data = [((100,100), (200,200), (250,210), (300,300), (400,500)), ((100,200), (200,300), (250,200), (300,400), (400, 600))] + self.chart.xValueAxis.avoidBoundFrac = 1 + self.chart.xValueAxis.gridEnd = 115 + self.chart.xValueAxis.tickDown = 3 + self.chart.xValueAxis.visibleGrid = 1 + self.chart.yValueAxis.tickLeft = 3 + self.chart.yValueAxis.labels.fontName = 'Helvetica' + self.chart.yValueAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.chart.lineLabelFormat = None + self.chart.xLabel = 'X Axis' + self.chart.y = 30 + self.chart.yLabel = 'Y Axis' + self.chart.yValueAxis.labelTextFormat = '%d' + self.chart.yValueAxis.forceZero = 1 + self.chart.xValueAxis.forceZero = 1 + + + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + Scatter().save(formats=['pdf'],outDir=None,fnRoot='scatter') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter_lines.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter_lines.py new file mode 100644 index 0000000000000000000000000000000000000000..b435625beb7d6795e14909d87132c9a85dc92b29 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter_lines.py @@ -0,0 +1,82 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.lineplots import ScatterPlot +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class ScatterLines(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,ScatterPlot(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.lines[0].strokeColor = color01 + self.chart.lines[1].strokeColor = color02 + self.chart.lines[2].strokeColor = color03 + self.chart.lines[3].strokeColor = color04 + self.chart.lines[4].strokeColor = color05 + self.chart.lines[5].strokeColor = color06 + self.chart.lines[6].strokeColor = color07 + self.chart.lines[7].strokeColor = color08 + self.chart.lines[8].strokeColor = color09 + self.chart.lines[9].strokeColor = color10 + self.chart.lines[0].symbol = None + self.chart.lines[1].symbol = None + self.chart.lines[2].symbol = None + self.chart.lines[3].symbol = None + self.chart.lines[4].symbol = None + self.chart.lines[5].symbol = None + self.chart.lines[6].symbol = None + self.chart.lines[7].symbol = None + self.chart.lines[8].symbol = None + self.chart.lines[9].symbol = None + self.chart.fillColor = backgroundGrey + self.chart.lineLabels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontSize = 7 + self.chart.xValueAxis.forceZero = 0 + self.chart.data = [((100,100), (200,200), (250,210), (300,300), (400,500)), ((100,200), (200,300), (250,200), (300,400), (400, 600))] + self.chart.xValueAxis.avoidBoundFrac = 1 + self.chart.xValueAxis.gridEnd = 115 + self.chart.xValueAxis.tickDown = 3 + self.chart.xValueAxis.visibleGrid = 1 + self.chart.yValueAxis.tickLeft = 3 + self.chart.yValueAxis.labels.fontName = 'Helvetica' + self.chart.yValueAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.chart.lineLabelFormat = None + self.chart.xLabel = 'X Axis' + self.chart.y = 30 + self.chart.yLabel = 'Y Axis' + self.chart.yValueAxis.gridEnd = 115 + self.chart.yValueAxis.visibleGrid = 1 + self.chart.yValueAxis.labelTextFormat = '%d' + self.chart.yValueAxis.forceZero = 1 + self.chart.xValueAxis.forceZero = 1 + self.chart.joinedLines = 1 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + ScatterLines().save(formats=['pdf'],outDir=None,fnRoot='scatter_lines') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter_lines_markers.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter_lines_markers.py new file mode 100644 index 0000000000000000000000000000000000000000..53406056bd28b296c854c4d0c1489af2142df62d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/scatter_lines_markers.py @@ -0,0 +1,72 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.lineplots import ScatterPlot +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class ScatterLinesMarkers(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,ScatterPlot(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.lines[0].strokeColor = color01 + self.chart.lines[1].strokeColor = color02 + self.chart.lines[2].strokeColor = color03 + self.chart.lines[3].strokeColor = color04 + self.chart.lines[4].strokeColor = color05 + self.chart.lines[5].strokeColor = color06 + self.chart.lines[6].strokeColor = color07 + self.chart.lines[7].strokeColor = color08 + self.chart.lines[8].strokeColor = color09 + self.chart.lines[9].strokeColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.lineLabels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontName = 'Helvetica' + self.chart.xValueAxis.labels.fontSize = 7 + self.chart.xValueAxis.forceZero = 0 + self.chart.data = [((100,100), (200,200), (250,210), (300,300), (400,500)), ((100,200), (200,300), (250,200), (300,400), (400, 600))] + self.chart.xValueAxis.avoidBoundFrac = 1 + self.chart.xValueAxis.gridEnd = 115 + self.chart.xValueAxis.tickDown = 3 + self.chart.xValueAxis.visibleGrid = 1 + self.chart.yValueAxis.tickLeft = 3 + self.chart.yValueAxis.labels.fontName = 'Helvetica' + self.chart.yValueAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.chart.lineLabelFormat = None + self.chart.xLabel = 'X Axis' + self.chart.y = 30 + self.chart.yLabel = 'Y Axis' + self.chart.yValueAxis.gridEnd = 115 + self.chart.yValueAxis.visibleGrid = 1 + self.chart.yValueAxis.labelTextFormat = '%d' + self.chart.yValueAxis.forceZero = 1 + self.chart.xValueAxis.forceZero = 1 + self.chart.joinedLines = 1 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + ScatterLinesMarkers().save(formats=['pdf'],outDir=None,fnRoot='scatter_lines_markers') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/simple_pie.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/simple_pie.py new file mode 100644 index 0000000000000000000000000000000000000000..670beba5984fd0c6b32a61727bad73cc69fe23f3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/simple_pie.py @@ -0,0 +1,61 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.piecharts import Pie +from reportlab.graphics.widgets.grids import ShadedRect +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class SimplePie(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,Pie(),name='chart',validate=None,desc="The main chart") + self.chart.width = 100 + self.chart.height = 100 + self.chart.x = 25 + self.chart.y = 25 + self.chart.slices[0].fillColor = color01 + self.chart.slices[1].fillColor = color02 + self.chart.slices[2].fillColor = color03 + self.chart.slices[3].fillColor = color04 + self.chart.slices[4].fillColor = color05 + self.chart.slices[5].fillColor = color06 + self.chart.slices[6].fillColor = color07 + self.chart.slices[7].fillColor = color08 + self.chart.slices[8].fillColor = color09 + self.chart.slices[9].fillColor = color10 + self.chart.data = (100, 150, 180) + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'North'), (color02, 'South'),(color03, 'Central')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 160 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self.chart.slices.strokeWidth = 1 + self.chart.slices.fontName = 'Helvetica' + self.background = ShadedRect() + self.background.fillColorStart = backgroundGrey + self.background.fillColorEnd = backgroundGrey + self.background.numShades = 1 + self.background.strokeWidth = 0.5 + self.background.x = 25 + self.background.y = 25 + self.Legend.columnMaximum = 10 + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + SimplePie().save(formats=['pdf'],outDir=None,fnRoot=None) diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/stacked_bar.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/stacked_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..d175651807c43f3126ec60f78b06b2fedb993db7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/stacked_bar.py @@ -0,0 +1,85 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.barcharts import HorizontalBarChart +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class StackedBar(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,HorizontalBarChart(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.bars[0].fillColor = color01 + self.chart.bars[1].fillColor = color02 + self.chart.bars[2].fillColor = color03 + self.chart.bars[3].fillColor = color04 + self.chart.bars[4].fillColor = color05 + self.chart.bars[5].fillColor = color06 + self.chart.bars[6].fillColor = color07 + self.chart.bars[7].fillColor = color08 + self.chart.bars[8].fillColor = color09 + self.chart.bars[9].fillColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.barLabels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontSize = 6 + self.chart.valueAxis.forceZero = 1 + self.chart.data = [(100, 150, 180), (125, 180, 200)] + self.chart.groupSpacing = 15 + self.chart.valueAxis.avoidBoundFrac = 1 + self.chart.valueAxis.gridEnd = 80 + self.chart.valueAxis.tickDown = 3 + self.chart.valueAxis.visibleGrid = 1 + self.chart.categoryAxis.categoryNames = ['North', 'South', 'Central'] + self.chart.categoryAxis.tickLeft = 3 + self.chart.categoryAxis.labels.fontName = 'Helvetica' + self.chart.categoryAxis.labels.fontSize = 6 + self.chart.categoryAxis.labels.dx = -3 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,Label(),name='XLabel',validate=None,desc="The label on the horizontal axis") + self.XLabel.fontName = 'Helvetica' + self.XLabel.fontSize = 7 + self.XLabel.x = 85 + self.XLabel.y = 10 + self.XLabel.textAnchor ='middle' + self.XLabel.maxWidth = 100 + self.XLabel.height = 20 + self.XLabel._text = "X Axis" + self._add(self,Label(),name='YLabel',validate=None,desc="The label on the vertical axis") + self.YLabel.fontName = 'Helvetica' + self.YLabel.fontSize = 7 + self.YLabel.x = 12 + self.YLabel.y = 80 + self.YLabel.angle = 90 + self.YLabel.textAnchor ='middle' + self.YLabel.maxWidth = 100 + self.YLabel.height = 20 + self.YLabel._text = "Y Axis" + self.chart.categoryAxis.style='stacked' + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + StackedBar().save(formats=['pdf'],outDir=None,fnRoot='stacked_bar') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/stacked_column.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/stacked_column.py new file mode 100644 index 0000000000000000000000000000000000000000..70246c665179a41c12b44ae63008c43ad9549193 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/samples/stacked_column.py @@ -0,0 +1,84 @@ +#Autogenerated by ReportLab guiedit do not edit +from reportlab.graphics.charts.legends import Legend +from reportlab.graphics.charts.barcharts import VerticalBarChart +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.samples.excelcolors import * + +class StackedColumn(_DrawingEditorMixin,Drawing): + def __init__(self,width=200,height=150,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,VerticalBarChart(),name='chart',validate=None,desc="The main chart") + self.chart.width = 115 + self.chart.height = 80 + self.chart.x = 30 + self.chart.y = 40 + self.chart.bars[0].fillColor = color01 + self.chart.bars[1].fillColor = color02 + self.chart.bars[2].fillColor = color03 + self.chart.bars[3].fillColor = color04 + self.chart.bars[4].fillColor = color05 + self.chart.bars[5].fillColor = color06 + self.chart.bars[6].fillColor = color07 + self.chart.bars[7].fillColor = color08 + self.chart.bars[8].fillColor = color09 + self.chart.bars[9].fillColor = color10 + self.chart.fillColor = backgroundGrey + self.chart.barLabels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontName = 'Helvetica' + self.chart.valueAxis.labels.fontSize = 7 + self.chart.valueAxis.forceZero = 1 + self.chart.data = [(100, 150, 180), (125, 180, 200)] + self.chart.groupSpacing = 15 + self.chart.valueAxis.avoidBoundFrac = 1 + self.chart.valueAxis.gridEnd = 115 + self.chart.valueAxis.tickLeft = 3 + self.chart.valueAxis.visibleGrid = 1 + self.chart.categoryAxis.categoryNames = ['North', 'South', 'Central'] + self.chart.categoryAxis.tickDown = 3 + self.chart.categoryAxis.labels.fontName = 'Helvetica' + self.chart.categoryAxis.labels.fontSize = 7 + self._add(self,Label(),name='Title',validate=None,desc="The title at the top of the chart") + self.Title.fontName = 'Helvetica-Bold' + self.Title.fontSize = 7 + self.Title.x = 100 + self.Title.y = 135 + self.Title._text = 'Chart Title' + self.Title.maxWidth = 180 + self.Title.height = 20 + self.Title.textAnchor ='middle' + self._add(self,Legend(),name='Legend',validate=None,desc="The legend or key for the chart") + self.Legend.colorNamePairs = [(color01, 'Widgets'), (color02, 'Sprockets')] + self.Legend.fontName = 'Helvetica' + self.Legend.fontSize = 7 + self.Legend.x = 153 + self.Legend.y = 85 + self.Legend.dxTextSpace = 5 + self.Legend.dy = 5 + self.Legend.dx = 5 + self.Legend.deltay = 5 + self.Legend.alignment ='right' + self._add(self,Label(),name='XLabel',validate=None,desc="The label on the horizontal axis") + self.XLabel.fontName = 'Helvetica' + self.XLabel.fontSize = 7 + self.XLabel.x = 85 + self.XLabel.y = 10 + self.XLabel.textAnchor ='middle' + self.XLabel.maxWidth = 100 + self.XLabel.height = 20 + self.XLabel._text = "X Axis" + self._add(self,Label(),name='YLabel',validate=None,desc="The label on the vertical axis") + self.YLabel.fontName = 'Helvetica' + self.YLabel.fontSize = 7 + self.YLabel.x = 12 + self.YLabel.y = 80 + self.YLabel.angle = 90 + self.YLabel.textAnchor ='middle' + self.YLabel.maxWidth = 100 + self.YLabel.height = 20 + self.YLabel._text = "Y Axis" + self.chart.categoryAxis.style='stacked' + self._add(self,0,name='preview',validate=None,desc=None) + +if __name__=="__main__": #NORUNTESTS + StackedColumn().save(formats=['pdf'],outDir=None,fnRoot='stacked_column') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/__init__.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8f6f784ba813e7d96b25d07885a95b2b812928 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/__init__.py @@ -0,0 +1,5 @@ +#Copyright ReportLab Europe Ltd. 2000-2017 +#see license.txt for license details +#history https://hg.reportlab.com/hg-public/reportlab/log/tip/src/reportlab/graphics/widgets/__init__.py +__version__='3.3.0' +__doc__='''Some non-chart widgets''' diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/adjustableArrow.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/adjustableArrow.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd18c9f1af4cb421e1ad3f24a434a8b665d6f7c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/adjustableArrow.py @@ -0,0 +1,126 @@ +from reportlab.lib import colors +from reportlab.lib.validators import * +from reportlab.lib.attrmap import * +from reportlab.graphics.shapes import Drawing, _DrawingEditorMixin, Group, Polygon +from reportlab.graphics.widgetbase import Widget + +class AdjustableArrow(Widget): + """This widget draws an arrow (style one). + + possible attributes: + 'x', 'y', 'size', 'fillColor' + + """ + _attrMap = AttrMap( + x = AttrMapValue(isNumber,desc='symbol x coordinate'), + y = AttrMapValue(isNumber,desc='symbol y coordinate'), + dx = AttrMapValue(isNumber,desc='symbol x coordinate adjustment'), + dy = AttrMapValue(isNumber,desc='symbol x coordinate adjustment'), + stemThickness = AttrMapValue(isNumber, 'width of the stem'), + stemLength = AttrMapValue(isNumber, 'length of the stem'), + headProjection = AttrMapValue(isNumber, 'how much the head projects from the stem'), + headLength = AttrMapValue(isNumber, 'length of the head'), + headSweep = AttrMapValue(isNumber, 'howmuch the head sweeps back (-ve) or forwards (+ve)'), + scale = AttrMapValue(isNumber, 'scaling factor'), + fillColor = AttrMapValue(isColorOrNone), + strokeColor = AttrMapValue(isColorOrNone), + strokeWidth = AttrMapValue(isNumber), + boxAnchor = AttrMapValue(isBoxAnchor,desc='anchoring point of the label'), + right =AttrMapValue(isBoolean,desc='If True (default) the arrow is horizontal pointing right\nFalse means it points up'), + angle = AttrMapValue(isNumber, desc='angle of arrow default (0), right True 0 is horizontal to right else vertical up'), + ) + def __init__(self,**kwds): + self._setKeywords(**kwds) + self._setKeywords(**dict( + x = 0, + y = 0, + fillColor = colors.red, + strokeWidth = 0, + strokeColor = None, + boxAnchor = 'c', + angle = 0, + stemThickness = 33, + stemLength = 50, + headProjection = 15, + headLength = 50, + headSweep = 0, + scale = 1., + right=True, + )) + + def draw(self): + # general widget bits + g = Group() + + x = self.x + y = self.y + scale = self.scale + stemThickness = self.stemThickness*scale + stemLength = self.stemLength*scale + headProjection = self.headProjection*scale + headLength = self.headLength*scale + headSweep = self.headSweep*scale + w = stemLength+headLength + h = 2*headProjection+stemThickness + # shift to the boxAnchor + boxAnchor = self.boxAnchor + if self.right: + if boxAnchor in ('sw','w','nw'): + dy = -h + elif boxAnchor in ('s','c','n'): + dy = -h*0.5 + else: + dy = 0 + if boxAnchor in ('w','c','e'): + dx = -w*0.5 + elif boxAnchor in ('nw','n','ne'): + dx = -w + else: + dx = 0 + points = [ + dx, dy+headProjection+stemThickness, + dx+stemLength, dy+headProjection+stemThickness, + dx+stemLength+headSweep, dy+2*headProjection+stemThickness, + dx+stemLength+headLength, dy+0.5*stemThickness+headProjection, + dx+stemLength+headSweep, dy, + dx+stemLength, dy+headProjection, + dx, dy+headProjection, + ] + else: + w,h = h,w + if boxAnchor in ('nw','n','ne'): + dy = -h + elif boxAnchor in ('w','c','e'): + dy = -h*0.5 + else: + dy = 0 + if boxAnchor in ('ne','e','se'): + dx = -w + elif boxAnchor in ('n','c','s'): + dx = -w*0.5 + else: + dx = 0 + points = [ + dx+headProjection, dy, #sw + dx+headProjection+stemThickness, dy, #se + dx+headProjection+stemThickness, dy+stemLength, + dx+w, dy+stemLength+headSweep, + dx+headProjection+0.5*stemThickness, dy+h, + dx, dy+stemLength+headSweep, + dx+headProjection, dy+stemLength, + ] + + g.add(Polygon( + points = points, + fillColor = self.fillColor, + strokeColor = self.strokeColor, + strokeWidth = self.strokeWidth, + )) + g.translate(x,y) + g.rotate(self.angle) + return g + +class AdjustableArrowDrawing(_DrawingEditorMixin,Drawing): + def __init__(self,width=100,height=63,*args,**kw): + Drawing.__init__(self,width,height,*args,**kw) + self._add(self,AdjustableArrow(),name='adjustableArrow',validate=None,desc=None) diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/eventcal.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/eventcal.py new file mode 100644 index 0000000000000000000000000000000000000000..55758b1cbd4e38bf7a4b13ef0f088b9ce537a366 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/eventcal.py @@ -0,0 +1,299 @@ +#see license.txt for license details +#history https://hg.reportlab.com/hg-public/reportlab/log/tip/src/reportlab/graphics/widgets/eventcal.py +# Event Calendar widget +# author: Andy Robinson + +__version__='3.3.0' +__doc__="""This file is a +""" + +from reportlab.lib import colors +from reportlab.graphics.shapes import Rect, Drawing, Group, String +from reportlab.graphics.charts.textlabels import Label +from reportlab.graphics.widgetbase import Widget + + +class EventCalendar(Widget): + def __init__(self): + self.x = 0 + self.y = 0 + self.width = 300 + self.height = 150 + self.timeColWidth = None # if declared, use it; otherwise auto-size. + self.trackRowHeight = 20 + self.data = [] # list of Event objects + self.trackNames = None + + self.startTime = None #displays ALL data on day if not set + self.endTime = None # displays ALL data on day if not set + self.day = 0 + + + # we will keep any internal geometry variables + # here. These are computed by computeSize(), + # which is the first thing done when drawing. + self._talksVisible = [] # subset of data which will get plotted, cache + self._startTime = None + self._endTime = None + self._trackCount = 0 + self._colWidths = [] + self._colLeftEdges = [] # left edge of each column + + def computeSize(self): + "Called at start of draw. Sets various column widths" + self._talksVisible = self.getRelevantTalks(self.data) + self._trackCount = len(self.getAllTracks()) + self.computeStartAndEndTimes() + self._colLeftEdges = [self.x] + if self.timeColWidth is None: + w = self.width / (1 + self._trackCount) + self._colWidths = [w] * (1+ self._trackCount) + for i in range(self._trackCount): + self._colLeftEdges.append(self._colLeftEdges[-1] + w) + else: + self._colWidths = [self.timeColWidth] + w = (self.width - self.timeColWidth) / self._trackCount + for i in range(self._trackCount): + self._colWidths.append(w) + self._colLeftEdges.append(self._colLeftEdges[-1] + w) + + + + def computeStartAndEndTimes(self): + "Work out first and last times to display" + if self.startTime: + self._startTime = self.startTime + else: + for (title, speaker, trackId, day, start, duration) in self._talksVisible: + + if self._startTime is None: #first one + self._startTime = start + else: + if start < self._startTime: + self._startTime = start + + if self.endTime: + self._endTime = self.endTime + else: + for (title, speaker, trackId, day, start, duration) in self._talksVisible: + if self._endTime is None: #first one + self._endTime = start + duration + else: + if start + duration > self._endTime: + self._endTime = start + duration + + + + + def getAllTracks(self): + tracks = [] + for (title, speaker, trackId, day, hours, duration) in self.data: + if trackId is not None: + if trackId not in tracks: + tracks.append(trackId) + tracks.sort() + return tracks + + def getRelevantTalks(self, talkList): + "Scans for tracks actually used" + used = [] + for talk in talkList: + (title, speaker, trackId, day, hours, duration) = talk + assert trackId != 0, "trackId must be None or 1,2,3... zero not allowed!" + if day == self.day: + if (((self.startTime is None) or ((hours + duration) >= self.startTime)) + and ((self.endTime is None) or (hours <= self.endTime))): + used.append(talk) + return used + + def scaleTime(self, theTime): + "Return y-value corresponding to times given" + axisHeight = self.height - self.trackRowHeight + # compute fraction between 0 and 1, 0 is at start of period + proportionUp = ((theTime - self._startTime) / (self._endTime - self._startTime)) + y = self.y + axisHeight - (axisHeight * proportionUp) + return y + + + def getTalkRect(self, startTime, duration, trackId, text): + "Return shapes for a specific talk" + g = Group() + y_bottom = self.scaleTime(startTime + duration) + y_top = self.scaleTime(startTime) + y_height = y_top - y_bottom + + if trackId is None: + #spans all columns + x = self._colLeftEdges[1] + width = self.width - self._colWidths[0] + else: + #trackId is 1-based and these arrays have the margin info in column + #zero, so no need to add 1 + x = self._colLeftEdges[trackId] + width = self._colWidths[trackId] + + lab = Label() + lab.setText(text) + lab.setOrigin(x + 0.5*width, y_bottom+0.5*y_height) + lab.boxAnchor = 'c' + lab.width = width + lab.height = y_height + lab.fontSize = 6 + + r = Rect(x, y_bottom, width, y_height, fillColor=colors.cyan) + g.add(r) + g.add(lab) + + #now for a label + # would expect to color-code and add text + return g + + def draw(self): + self.computeSize() + g = Group() + + # time column + g.add(Rect(self.x, self.y, self._colWidths[0], self.height - self.trackRowHeight, fillColor=colors.cornsilk)) + + # track headers + x = self.x + self._colWidths[0] + y = self.y + self.height - self.trackRowHeight + for trk in range(self._trackCount): + wid = self._colWidths[trk+1] + r = Rect(x, y, wid, self.trackRowHeight, fillColor=colors.yellow) + s = String(x + 0.5*wid, y, 'Track %d' % trk, align='middle') + g.add(r) + g.add(s) + x = x + wid + + for talk in self._talksVisible: + (title, speaker, trackId, day, start, duration) = talk + r = self.getTalkRect(start, duration, trackId, title + '\n' + speaker) + g.add(r) + + + return g + + + + +def test(): + "Make a conference event for day 1 of UP Python 2003" + + + d = Drawing(400,200) + + cal = EventCalendar() + cal.x = 50 + cal.y = 25 + cal.data = [ + # these might be better as objects instead of tuples, since I + # predict a large number of "optionsl" variables to affect + # formatting in future. + + #title, speaker, track id, day, start time (hrs), duration (hrs) + # track ID is 1-based not zero-based! + ('Keynote: Why design another programming language?', 'Guido van Rossum', None, 1, 9.0, 1.0), + + ('Siena Web Service Architecture', 'Marc-Andre Lemburg', 1, 1, 10.5, 1.5), + ('Extreme Programming in Python', 'Chris Withers', 2, 1, 10.5, 1.5), + ('Pattern Experiences in C++', 'Mark Radford', 3, 1, 10.5, 1.5), + ('What is the Type of std::toupper()', 'Gabriel Dos Reis', 4, 1, 10.5, 1.5), + ('Linguistic Variables: Clear Thinking with Fuzzy Logic ', 'Walter Banks', 5, 1, 10.5, 1.5), + + ('lunch, short presentations, vendor presentations', '', None, 1, 12.0, 2.0), + + ("CORBA? Isn't that obsolete", 'Duncan Grisby', 1, 1, 14.0, 1.5), + ("Python Design Patterns", 'Duncan Booth', 2, 1, 14.0, 1.5), + ("Inside Security Checks and Safe Exceptions", 'Brandon Bray', 3, 1, 14.0, 1.5), + ("Studying at a Distance", 'Panel Discussion, Panel to include Alan Lenton & Francis Glassborow', 4, 1, 14.0, 1.5), + ("Coding Standards - Given the ANSI C Standard why do I still need a coding Standard", 'Randy Marques', 5, 1, 14.0, 1.5), + + ("RESTful Python", 'Hamish Lawson', 1, 1, 16.0, 1.5), + ("Parsing made easier - a radical old idea", 'Andrew Koenig', 2, 1, 16.0, 1.5), + ("C++ & Multimethods", 'Julian Smith', 3, 1, 16.0, 1.5), + ("C++ Threading", 'Kevlin Henney', 4, 1, 16.0, 1.5), + ("The Organisation Strikes Back", 'Alan Griffiths & Sarah Lees', 5, 1, 16.0, 1.5), + + ('Birds of a Feather meeting', '', None, 1, 17.5, 2.0), + + ('Keynote: In the Spirit of C', 'Greg Colvin', None, 2, 9.0, 1.0), + + ('The Infinite Filing Cabinet - object storage in Python', 'Jacob Hallen', 1, 2, 10.5, 1.5), + ('Introduction to Python and Jython for C++ and Java Programmers', 'Alex Martelli', 2, 2, 10.5, 1.5), + ('Template metaprogramming in Haskell', 'Simon Peyton Jones', 3, 2, 10.5, 1.5), + ('Plenty People Programming: C++ Programming in a Group, Workshop with a difference', 'Nico Josuttis', 4, 2, 10.5, 1.5), + ('Design and Implementation of the Boost Graph Library', 'Jeremy Siek', 5, 2, 10.5, 1.5), + + ('lunch, short presentations, vendor presentations', '', None, 2, 12.0, 2.0), + + ("Building GUI Applications with PythonCard and PyCrust", 'Andy Todd', 1, 2, 14.0, 1.5), + ("Integrating Python, C and C++", 'Duncan Booth', 2, 2, 14.0, 1.5), + ("Secrets and Pitfalls of Templates", 'Nicolai Josuttis & David Vandevoorde', 3, 2, 14.0, 1.5), + ("Being a Mentor", 'Panel Discussion, Panel to include Alan Lenton & Francis Glassborow', 4, 2, 14.0, 1.5), + ("The Embedded C Extensions to C", 'Willem Wakker', 5, 2, 14.0, 1.5), + + ("Lightning Talks", 'Paul Brian', 1, 2, 16.0, 1.5), + ("Scripting Java Applications with Jython", 'Anthony Eden', 2, 2, 16.0, 1.5), + ("Metaprogramming and the Boost Metaprogramming Library", 'David Abrahams', 3, 2, 16.0, 1.5), + ("A Common Vendor ABI for C++ -- GCC's why, what and not", 'Nathan Sidwell & Gabriel Dos Reis', 4, 2, 16.0, 1.5), + ("The Timing and Cost of Choices", 'Hubert Matthews', 5, 2, 16.0, 1.5), + + ('Birds of a Feather meeting', '', None, 2, 17.5, 2.0), + + ('Keynote: The Cost of C & C++ Compatibility', 'Andy Koenig', None, 3, 9.0, 1.0), + + ('Prying Eyes: Generic Observer Implementations in C++', 'Andrei Alexandrescu', 1, 2, 10.5, 1.5), + ('The Roadmap to Generative Programming With C++', 'Ulrich Eisenecker', 2, 2, 10.5, 1.5), + ('Design Patterns in C++ and C# for the Common Language Runtime', 'Brandon Bray', 3, 2, 10.5, 1.5), + ('Extreme Hour (XH): (workshop) - Jutta Eckstein and Nico Josuttis', 'Jutta Ecstein', 4, 2, 10.5, 1.5), + ('The Lambda Library : Unnamed Functions for C++', 'Jaako Jarvi', 5, 2, 10.5, 1.5), + + ('lunch, short presentations, vendor presentations', '', None, 3, 12.0, 2.0), + + ('Reflective Metaprogramming', 'Daveed Vandevoorde', 1, 3, 14.0, 1.5), + ('Advanced Template Issues and Solutions (double session)', 'Herb Sutter',2, 3, 14.0, 3), + ('Concurrent Programming in Java (double session)', 'Angelika Langer', 3, 3, 14.0, 3), + ('What can MISRA-C (2nd Edition) do for us?', 'Chris Hills', 4, 3, 14.0, 1.5), + ('C++ Metaprogramming Concepts and Results', 'Walter E Brown', 5, 3, 14.0, 1.5), + + ('Binding C++ to Python with the Boost Python Library', 'David Abrahams', 1, 3, 16.0, 1.5), + ('Using Aspect Oriented Programming for Enterprise Application Integration', 'Arno Schmidmeier', 4, 3, 16.0, 1.5), + ('Defective C++', 'Marc Paterno', 5, 3, 16.0, 1.5), + + ("Speakers' Banquet & Birds of a Feather meeting", '', None, 3, 17.5, 2.0), + + ('Keynote: The Internet, Software and Computers - A Report Card', 'Alan Lenton', None, 4, 9.0, 1.0), + + ('Multi-Platform Software Development; Lessons from the Boost libraries', 'Beman Dawes', 1, 5, 10.5, 1.5), + ('The Stability of the C++ ABI', 'Steve Clamage', 2, 5, 10.5, 1.5), + ('Generic Build Support - A Pragmatic Approach to the Software Build Process', 'Randy Marques', 3, 5, 10.5, 1.5), + ('How to Handle Project Managers: a survival guide', 'Barb Byro', 4, 5, 10.5, 1.5), + + ('lunch, ACCU AGM', '', None, 5, 12.0, 2.0), + + ('Sauce: An OO recursive descent parser; its design and implementation.', 'Jon Jagger', 1, 5, 14.0, 1.5), + ('GNIRTS ESAC REWOL - Bringing the UNIX filters to the C++ iostream library.', 'JC van Winkel', 2, 5, 14.0, 1.5), + ('Pattern Writing: Live and Direct', 'Frank Buschmann & Kevlin Henney', 3, 5, 14.0, 3.0), + ('The Future of Programming Languages - A Goldfish Bowl', 'Francis Glassborow and friends', 3, 5, 14.0, 1.5), + + ('Honey, I Shrunk the Threads: Compile-time checked multithreaded transactions in C++', 'Andrei Alexandrescu', 1, 5, 16.0, 1.5), + ('Fun and Functionality with Functors', 'Lois Goldthwaite', 2, 5, 16.0, 1.5), + ('Agile Enough?', 'Alan Griffiths', 4, 5, 16.0, 1.5), + ("Conference Closure: A brief plenary session", '', None, 5, 17.5, 0.5), + + ] + + #return cal + cal.day = 1 + + d.add(cal) + + + for format in ['pdf']:#,'gif','png']: + out = d.asString(format) + open('eventcal.%s' % format, 'wb').write(out) + print('saved eventcal.%s' % format) + +if __name__=='__main__': + test() diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/flags.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/flags.py new file mode 100644 index 0000000000000000000000000000000000000000..7afd6b5a1318dc581bd190b128a96bbc2c720a1c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/flags.py @@ -0,0 +1,879 @@ +#see license.txt for license details +#history https://hg.reportlab.com/hg-public/reportlab/log/tip/src/reportlab/graphics/widgets/flags.py +# Flag Widgets - a collection of flags as widgets +# author: John Precedo (johnp@reportlab.com) + +__version__='3.3.0' +__doc__="""This file is a collection of flag graphics as widgets. + +All flags are represented at the ratio of 1:2, even where the official ratio for the flag is something else +(such as 3:5 for the German national flag). The only exceptions are for where this would look _very_ wrong, +such as the Danish flag whose (ratio is 28:37), or the Swiss flag (which is square). + +Unless otherwise stated, these flags are all the 'national flags' of the countries, rather than their +state flags, naval flags, ensigns or any other variants. (National flags are the flag flown by civilians +of a country and the ones usually used to represent a country abroad. State flags are the variants used by +the government and by diplomatic missions overseas). + +To check on how close these are to the 'official' representations of flags, check the World Flag Database at +http://www.flags.ndirect.co.uk/ + +The flags this file contains are: + +EU Members: +United Kingdom, Austria, Belgium, Denmark, Finland, France, Germany, Greece, Ireland, Italy, Luxembourg, +Holland (The Netherlands), Spain, Sweden + +Others: +USA, Czech Republic, European Union, Switzerland, Turkey, Brazil + +(Brazilian flag contributed by Publio da Costa Melo [publio@planetarium.com.br]). +""" + +from reportlab.lib import colors +from reportlab.lib.validators import * +from reportlab.lib.attrmap import * +from reportlab.graphics.shapes import Line, Rect, Polygon, Drawing, Group, String, Circle, Wedge +from reportlab.graphics import renderPDF +from reportlab.graphics.widgets.signsandsymbols import _Symbol +import copy +from math import sin, cos, pi + +validFlag=OneOf(None, + 'UK', + 'USA', + 'Afghanistan', + 'Austria', + 'Belgium', + 'China', + 'Cuba', + 'Denmark', + 'Finland', + 'France', + 'Germany', + 'Greece', + 'Ireland', + 'Italy', + 'Japan', + 'Luxembourg', + 'Holland', + 'Palestine', + 'Portugal', + 'Russia', + 'Spain', + 'Sweden', + 'Norway', + 'CzechRepublic', + 'Turkey', + 'Switzerland', + 'EU', + 'Brazil' + ) + +_size = 100. + +class Star(_Symbol): + """This draws a 5-pointed star. + + possible attributes: + 'x', 'y', 'size', 'fillColor', 'strokeColor' + + """ + _attrMap = AttrMap(BASE=_Symbol, + angle = AttrMapValue(isNumber, desc='angle in degrees'), + ) + _size = 100. + + def __init__(self): + _Symbol.__init__(self) + self.size = 100 + self.fillColor = colors.yellow + self.strokeColor = None + self.angle = 0 + + def demo(self): + D = Drawing(200, 100) + et = Star() + et.x=50 + et.y=0 + D.add(et) + labelFontSize = 10 + D.add(String(et.x+(et.size/2.0),(et.y-(1.2*labelFontSize)), + et.__class__.__name__, fillColor=colors.black, textAnchor='middle', + fontSize=labelFontSize)) + return D + + def draw(self): + s = float(self.size) #abbreviate as we will use this a lot + g = Group() + + # new algorithm from markers.StarFive + R = float(self.size)/2 + r = R*sin(18*(pi/180.0))/cos(36*(pi/180.0)) + P = [] + angle = 90 + for i in range(5): + for radius in R, r: + theta = angle*(pi/180.0) + P.append(radius*cos(theta)) + P.append(radius*sin(theta)) + angle = angle + 36 + # star specific bits + star = Polygon(P, + fillColor = self.fillColor, + strokeColor = self.strokeColor, + strokeWidth=s/50) + g.rotate(self.angle) + g.shift(self.x+self.dx,self.y+self.dy) + g.add(star) + + return g + +class Flag(_Symbol): + """This is a generic flag class that all the flags in this file use as a basis. + + This class basically provides edges and a tidy-up routine to hide any bits of + line that overlap the 'outside' of the flag + + possible attributes: + 'x', 'y', 'size', 'fillColor' + """ + + _attrMap = AttrMap(BASE=_Symbol, + fillColor = AttrMapValue(isColor, desc='Background color'), + border = AttrMapValue(isBoolean, 'Whether a background is drawn'), + kind = AttrMapValue(validFlag, desc='Which flag'), + ) + + _cache = {} + + def __init__(self,**kw): + _Symbol.__init__(self) + self.kind = None + self.size = 100 + self.fillColor = colors.white + self.border=1 + self.setProperties(kw) + + def availableFlagNames(self): + '''return a list of the things we can display''' + return [x for x in self._attrMap['kind'].validate._enum if x is not None] + + def _Flag_None(self): + s = _size # abbreviate as we will use this a lot + g = Group() + g.add(Rect(0, 0, s*2, s, fillColor = colors.purple, strokeColor = colors.black, strokeWidth=0)) + return g + + def _borderDraw(self,f): + s = self.size # abbreviate as we will use this a lot + g = Group() + g.add(f) + x, y, sW = self.x+self.dx, self.y+self.dy, self.strokeWidth/2. + g.insert(0,Rect(-sW, -sW, width=getattr(self,'_width',2*s)+3*sW, height=getattr(self,'_height',s)+2*sW, + fillColor = None, strokeColor = self.strokeColor, strokeWidth=sW*2)) + g.shift(x,y) + g.scale(s/_size, s/_size) + return g + + def draw(self): + kind = self.kind or 'None' + f = self._cache.get(kind) + if not f: + f = getattr(self,'_Flag_'+kind)() + self._cache[kind] = f._explode() + return self._borderDraw(f) + + def clone(self): + return copy.copy(self) + + def demo(self): + D = Drawing(200, 100) + name = self.availableFlagNames() + import time + name = name[int(time.time()) % len(name)] + fx = Flag() + fx.kind = name + fx.x = 0 + fx.y = 0 + D.add(fx) + labelFontSize = 10 + D.add(String(fx.x+(fx.size/2.0),(fx.y-(1.2*labelFontSize)), + name, fillColor=colors.black, textAnchor='middle', + fontSize=labelFontSize)) + labelFontSize = int(fx.size/4.0) + D.add(String(fx.x+(fx.size),(fx.y+((fx.size/2.0))), + "SAMPLE", fillColor=colors.gold, textAnchor='middle', + fontSize=labelFontSize, fontName="Helvetica-Bold")) + return D + + def _Flag_UK(self): + s = _size + g = Group() + w = s*2 + g.add(Rect(0, 0, w, s, fillColor = colors.navy, strokeColor = colors.black, strokeWidth=0)) + g.add(Polygon([0,0, s*.225,0, w,s*(1-.1125), w,s, w-s*.225,s, 0, s*.1125], fillColor = colors.mintcream, strokeColor=None, strokeWidth=0)) + g.add(Polygon([0,s*(1-.1125), 0, s, s*.225,s, w, s*.1125, w,0, w-s*.225,0], fillColor = colors.mintcream, strokeColor=None, strokeWidth=0)) + g.add(Polygon([0, s-(s/15.0), (s-((s/10.0)*4)), (s*0.65), (s-(s/10.0)*3), (s*0.65), 0, s], fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Polygon([0, 0, (s-((s/10.0)*3)), (s*0.35), (s-((s/10.0)*2)), (s*0.35), (s/10.0), 0], fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Polygon([w, s, (s+((s/10.0)*3)), (s*0.65), (s+((s/10.0)*2)), (s*0.65), w-(s/10.0), s], fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Polygon([w, (s/15.0), (s+((s/10.0)*4)), (s*0.35), (s+((s/10.0)*3)), (s*0.35), w, 0], fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Rect(((s*0.42)*2), 0, width=(0.16*s)*2, height=s, fillColor = colors.mintcream, strokeColor = None, strokeWidth=0)) + g.add(Rect(0, (s*0.35), width=w, height=s*0.3, fillColor = colors.mintcream, strokeColor = None, strokeWidth=0)) + g.add(Rect(((s*0.45)*2), 0, width=(0.1*s)*2, height=s, fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Rect(0, (s*0.4), width=w, height=s*0.2, fillColor = colors.red, strokeColor = None, strokeWidth=0)) + return g + + def _Flag_USA(self): + s = _size # abbreviate as we will use this a lot + g = Group() + + box = Rect(0, 0, s*2, s, fillColor = colors.mintcream, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + for stripecounter in range (13,0, -1): + stripeheight = s/13.0 + if not (stripecounter%2 == 0): + stripecolor = colors.red + else: + stripecolor = colors.mintcream + redorwhiteline = Rect(0, (s-(stripeheight*stripecounter)), width=s*2, height=stripeheight, + fillColor = stripecolor, strokeColor = None, strokeWidth=20) + g.add(redorwhiteline) + + bluebox = Rect(0, (s-(stripeheight*7)), width=0.8*s, height=stripeheight*7, + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(bluebox) + + lss = s*0.045 + lss2 = lss/2.0 + s9 = s/9.0 + s7 = s/7.0 + for starxcounter in range(5): + for starycounter in range(4): + ls = Star() + ls.size = lss + ls.x = 0-s/22.0+lss/2.0+s7+starxcounter*s7 + ls.fillColor = colors.mintcream + ls.y = s-(starycounter+1)*s9+lss2 + g.add(ls) + + for starxcounter in range(6): + for starycounter in range(5): + ls = Star() + ls.size = lss + ls.x = 0-(s/22.0)+lss/2.0+s/14.0+starxcounter*s7 + ls.fillColor = colors.mintcream + ls.y = s-(starycounter+1)*s9+(s/18.0)+lss2 + g.add(ls) + return g + + def _Flag_Afghanistan(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.mintcream, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + greenbox = Rect(0, ((s/3.0)*2.0), width=s*2.0, height=s/3.0, + fillColor = colors.limegreen, strokeColor = None, strokeWidth=0) + g.add(greenbox) + + blackbox = Rect(0, 0, width=s*2.0, height=s/3.0, + fillColor = colors.black, strokeColor = None, strokeWidth=0) + g.add(blackbox) + return g + + def _Flag_Austria(self): + s = _size # abbreviate as we will use this a lot + g = Group() + + box = Rect(0, 0, s*2, s, fillColor = colors.mintcream, + strokeColor = colors.black, strokeWidth=0) + g.add(box) + + + redbox1 = Rect(0, 0, width=s*2.0, height=s/3.0, + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(redbox1) + + redbox2 = Rect(0, ((s/3.0)*2.0), width=s*2.0, height=s/3.0, + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(redbox2) + return g + + def _Flag_Belgium(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.black, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + + box1 = Rect(0, 0, width=(s/3.0)*2.0, height=s, + fillColor = colors.black, strokeColor = None, strokeWidth=0) + g.add(box1) + + box2 = Rect(((s/3.0)*2.0), 0, width=(s/3.0)*2.0, height=s, + fillColor = colors.gold, strokeColor = None, strokeWidth=0) + g.add(box2) + + box3 = Rect(((s/3.0)*4.0), 0, width=(s/3.0)*2.0, height=s, + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(box3) + return g + + def _Flag_China(self): + s = _size + g = Group() + self._width = w = s*1.5 + g.add(Rect(0, 0, w, s, fillColor=colors.red, strokeColor=None, strokeWidth=0)) + + def addStar(x,y,size,angle,g=g,w=s/20.0,x0=0,y0=s/2.0): + s = Star() + s.fillColor=colors.yellow + s.angle = angle + s.size = size*w*2 + s.x = x*w+x0 + s.y = y*w+y0 + g.add(s) + + addStar(5,5,3, 0) + addStar(10,1,1,36.86989765) + addStar(12,3,1,8.213210702) + addStar(12,6,1,16.60154960) + addStar(10,8,1,53.13010235) + return g + + def _Flag_Cuba(self): + s = _size + g = Group() + + for i in range(5): + stripe = Rect(0, i*s/5.0, width=s*2, height=s/5.0, + fillColor = [colors.darkblue, colors.mintcream][i%2], + strokeColor = None, + strokeWidth=0) + g.add(stripe) + + redwedge = Polygon(points = [ 0, 0, 4*s/5.0, (s/2.0), 0, s], + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(redwedge) + + star = Star() + star.x = 2.5*s/10.0 + star.y = s/2.0 + star.size = 3*s/10.0 + star.fillColor = colors.white + g.add(star) + + box = Rect(0, 0, s*2, s, + fillColor = None, + strokeColor = colors.black, + strokeWidth=0) + g.add(box) + + return g + + def _Flag_Denmark(self): + s = _size + g = Group() + self._width = w = s*1.4 + + box = Rect(0, 0, w, s, + fillColor = colors.red, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + whitebox1 = Rect(((s/5.0)*2), 0, width=s/6.0, height=s, + fillColor = colors.mintcream, strokeColor = None, strokeWidth=0) + g.add(whitebox1) + + whitebox2 = Rect(0, ((s/2.0)-(s/12.0)), width=w, height=s/6.0, + fillColor = colors.mintcream, strokeColor = None, strokeWidth=0) + g.add(whitebox2) + return g + + def _Flag_Finland(self): + s = _size + g = Group() + + # crossbox specific bits + box = Rect(0, 0, s*2, s, + fillColor = colors.ghostwhite, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + blueline1 = Rect((s*0.6), 0, width=0.3*s, height=s, + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(blueline1) + + blueline2 = Rect(0, (s*0.4), width=s*2, height=s*0.3, + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(blueline2) + return g + + def _Flag_France(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, fillColor = colors.navy, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + bluebox = Rect(0, 0, width=((s/3.0)*2.0), height=s, + fillColor = colors.blue, strokeColor = None, strokeWidth=0) + g.add(bluebox) + + whitebox = Rect(((s/3.0)*2.0), 0, width=((s/3.0)*2.0), height=s, + fillColor = colors.mintcream, strokeColor = None, strokeWidth=0) + g.add(whitebox) + + redbox = Rect(((s/3.0)*4.0), 0, width=((s/3.0)*2.0), height=s, + fillColor = colors.red, + strokeColor = None, + strokeWidth=0) + g.add(redbox) + return g + + def _Flag_Germany(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.gold, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + blackbox1 = Rect(0, ((s/3.0)*2.0), width=s*2.0, height=s/3.0, + fillColor = colors.black, strokeColor = None, strokeWidth=0) + g.add(blackbox1) + + redbox1 = Rect(0, (s/3.0), width=s*2.0, height=s/3.0, + fillColor = colors.orangered, strokeColor = None, strokeWidth=0) + g.add(redbox1) + return g + + def _Flag_Greece(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, fillColor = colors.gold, + strokeColor = colors.black, strokeWidth=0) + g.add(box) + + for stripecounter in range (9,0, -1): + stripeheight = s/9.0 + if not (stripecounter%2 == 0): + stripecolor = colors.deepskyblue + else: + stripecolor = colors.mintcream + + blueorwhiteline = Rect(0, (s-(stripeheight*stripecounter)), width=s*2, height=stripeheight, + fillColor = stripecolor, strokeColor = None, strokeWidth=20) + g.add(blueorwhiteline) + + bluebox1 = Rect(0, ((s)-stripeheight*5), width=(stripeheight*5), height=stripeheight*5, + fillColor = colors.deepskyblue, strokeColor = None, strokeWidth=0) + g.add(bluebox1) + + whiteline1 = Rect(0, ((s)-stripeheight*3), width=stripeheight*5, height=stripeheight, + fillColor = colors.mintcream, strokeColor = None, strokeWidth=0) + g.add(whiteline1) + + whiteline2 = Rect((stripeheight*2), ((s)-stripeheight*5), width=stripeheight, height=stripeheight*5, + fillColor = colors.mintcream, strokeColor = None, strokeWidth=0) + g.add(whiteline2) + + return g + + def _Flag_Ireland(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.forestgreen, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + whitebox = Rect(((s*2.0)/3.0), 0, width=(2.0*(s*2.0)/3.0), height=s, + fillColor = colors.mintcream, strokeColor = None, strokeWidth=0) + g.add(whitebox) + + orangebox = Rect(((2.0*(s*2.0)/3.0)), 0, width=(s*2.0)/3.0, height=s, + fillColor = colors.darkorange, strokeColor = None, strokeWidth=0) + g.add(orangebox) + return g + + def _Flag_Italy(self): + s = _size + g = Group() + g.add(Rect(0,0,s*2,s,fillColor=colors.forestgreen,strokeColor=None, strokeWidth=0)) + g.add(Rect((2*s)/3.0, 0, width=(s*4)/3.0, height=s, fillColor = colors.mintcream, strokeColor = None, strokeWidth=0)) + g.add(Rect((4*s)/3.0, 0, width=(s*2)/3.0, height=s, fillColor = colors.red, strokeColor = None, strokeWidth=0)) + return g + + def _Flag_Japan(self): + s = _size + g = Group() + w = self._width = s*1.5 + g.add(Rect(0,0,w,s,fillColor=colors.mintcream,strokeColor=None, strokeWidth=0)) + g.add(Circle(cx=w/2.0,cy=s/2.0,r=0.3*w,fillColor=colors.red,strokeColor=None, strokeWidth=0)) + return g + + def _Flag_Luxembourg(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.mintcream, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + redbox = Rect(0, ((s/3.0)*2.0), width=s*2.0, height=s/3.0, + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(redbox) + + bluebox = Rect(0, 0, width=s*2.0, height=s/3.0, + fillColor = colors.dodgerblue, strokeColor = None, strokeWidth=0) + g.add(bluebox) + return g + + def _Flag_Holland(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.mintcream, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + redbox = Rect(0, ((s/3.0)*2.0), width=s*2.0, height=s/3.0, + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(redbox) + + bluebox = Rect(0, 0, width=s*2.0, height=s/3.0, + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(bluebox) + return g + + def _Flag_Portugal(self): + return Group() + + def _Flag_Russia(self): + s = _size + g = Group() + w = self._width = s*1.5 + t = s/3.0 + g.add(Rect(0, 0, width=w, height=t, fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Rect(0, t, width=w, height=t, fillColor = colors.blue, strokeColor = None, strokeWidth=0)) + g.add(Rect(0, 2*t, width=w, height=t, fillColor = colors.mintcream, strokeColor = None, strokeWidth=0)) + return g + + def _Flag_Spain(self): + s = _size + g = Group() + w = self._width = s*1.5 + g.add(Rect(0, 0, width=w, height=s, fillColor = colors.red, strokeColor = None, strokeWidth=0)) + g.add(Rect(0, (s/4.0), width=w, height=s/2.0, fillColor = colors.yellow, strokeColor = None, strokeWidth=0)) + return g + + def _Flag_Sweden(self): + s = _size + g = Group() + self._width = s*1.4 + box = Rect(0, 0, self._width, s, + fillColor = colors.dodgerblue, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + box1 = Rect(((s/5.0)*2), 0, width=s/6.0, height=s, + fillColor = colors.gold, strokeColor = None, strokeWidth=0) + g.add(box1) + + box2 = Rect(0, ((s/2.0)-(s/12.0)), width=self._width, height=s/6.0, + fillColor = colors.gold, + strokeColor = None, + strokeWidth=0) + g.add(box2) + return g + + def _Flag_Norway(self): + s = _size + g = Group() + self._width = s*1.4 + + box = Rect(0, 0, self._width, s, + fillColor = colors.red, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + box = Rect(0, 0, self._width, s, + fillColor = colors.red, strokeColor = colors.black, strokeWidth=0) + g.add(box) + + whiteline1 = Rect(((s*0.2)*2), 0, width=s*0.2, height=s, + fillColor = colors.ghostwhite, strokeColor = None, strokeWidth=0) + g.add(whiteline1) + + whiteline2 = Rect(0, (s*0.4), width=self._width, height=s*0.2, + fillColor = colors.ghostwhite, strokeColor = None, strokeWidth=0) + g.add(whiteline2) + + blueline1 = Rect(((s*0.225)*2), 0, width=0.1*s, height=s, + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(blueline1) + + blueline2 = Rect(0, (s*0.45), width=self._width, height=s*0.1, + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(blueline2) + return g + + def _Flag_CzechRepublic(self): + s = _size + g = Group() + box = Rect(0, 0, s*2, s, + fillColor = colors.mintcream, + strokeColor = colors.black, + strokeWidth=0) + g.add(box) + + redbox = Rect(0, 0, width=s*2, height=s/2.0, + fillColor = colors.red, + strokeColor = None, + strokeWidth=0) + g.add(redbox) + + bluewedge = Polygon(points = [ 0, 0, s, (s/2.0), 0, s], + fillColor = colors.darkblue, strokeColor = None, strokeWidth=0) + g.add(bluewedge) + return g + + def _Flag_Palestine(self): + s = _size + g = Group() + box = Rect(0, s/3.0, s*2, s/3.0, + fillColor = colors.mintcream, + strokeColor = None, + strokeWidth=0) + g.add(box) + + greenbox = Rect(0, 0, width=s*2, height=s/3.0, + fillColor = colors.limegreen, + strokeColor = None, + strokeWidth=0) + g.add(greenbox) + + blackbox = Rect(0, 2*s/3.0, width=s*2, height=s/3.0, + fillColor = colors.black, + strokeColor = None, + strokeWidth=0) + g.add(blackbox) + + redwedge = Polygon(points = [ 0, 0, 2*s/3.0, (s/2.0), 0, s], + fillColor = colors.red, strokeColor = None, strokeWidth=0) + g.add(redwedge) + return g + + def _Flag_Turkey(self): + s = _size + g = Group() + + box = Rect(0, 0, s*2, s, + fillColor = colors.red, + strokeColor = colors.black, + strokeWidth=0) + g.add(box) + + whitecircle = Circle(cx=((s*0.35)*2), cy=s/2.0, r=s*0.3, + fillColor = colors.mintcream, + strokeColor = None, + strokeWidth=0) + g.add(whitecircle) + + redcircle = Circle(cx=((s*0.39)*2), cy=s/2.0, r=s*0.24, + fillColor = colors.red, + strokeColor = None, + strokeWidth=0) + g.add(redcircle) + + ws = Star() + ws.angle = 15 + ws.size = s/5.0 + ws.x = (s*0.5)*2+ws.size/2.0 + ws.y = (s*0.5) + ws.fillColor = colors.mintcream + ws.strokeColor = None + g.add(ws) + return g + + def _Flag_Switzerland(self): + s = _size + g = Group() + self._width = s + + g.add(Rect(0, 0, s, s, fillColor = colors.red, strokeColor = colors.black, strokeWidth=0)) + g.add(Line((s/2.0), (s/5.5), (s/2), (s-(s/5.5)), + fillColor = colors.mintcream, strokeColor = colors.mintcream, strokeWidth=(s/5.0))) + g.add(Line((s/5.5), (s/2.0), (s-(s/5.5)), (s/2.0), + fillColor = colors.mintcream, strokeColor = colors.mintcream, strokeWidth=s/5.0)) + return g + + def _Flag_EU(self): + s = _size + g = Group() + w = self._width = 1.5*s + + g.add(Rect(0, 0, w, s, fillColor = colors.darkblue, strokeColor = None, strokeWidth=0)) + centerx=w/2.0 + centery=s/2.0 + radius=s/3.0 + yradius = radius + xradius = radius + nStars = 12 + delta = 2*pi/nStars + for i in range(nStars): + rad = i*delta + gs = Star() + gs.x=cos(rad)*radius+centerx + gs.y=sin(rad)*radius+centery + gs.size=s/10.0 + gs.fillColor=colors.gold + g.add(gs) + return g + + def _Flag_Brazil(self): + s = _size # abbreviate as we will use this a lot + g = Group() + + m = s/14.0 + self._width = w = (m * 20) + + def addStar(x,y,size, g=g, w=w, s=s, m=m): + st = Star() + st.fillColor=colors.mintcream + st.size = size*m + st.x = (w/2.0) + (x * (0.35 * m)) + st.y = (s/2.0) + (y * (0.35 * m)) + g.add(st) + + g.add(Rect(0, 0, w, s, fillColor = colors.green, strokeColor = None, strokeWidth=0)) + g.add(Polygon(points = [ 1.7*m, (s/2.0), (w/2.0), s-(1.7*m), w-(1.7*m),(s/2.0),(w/2.0), 1.7*m], + fillColor = colors.yellow, strokeColor = None, strokeWidth=0)) + g.add(Circle(cx=w/2.0, cy=s/2.0, r=3.5*m, + fillColor=colors.blue,strokeColor=None, strokeWidth=0)) + g.add(Wedge((w/2.0)-(2*m), 0, 8.5*m, 50, 98.1, 8.5*m, + fillColor=colors.mintcream,strokeColor=None, strokeWidth=0)) + g.add(Wedge((w/2.0), (s/2.0), 3.501*m, 156, 352, 3.501*m, + fillColor=colors.mintcream,strokeColor=None, strokeWidth=0)) + g.add(Wedge((w/2.0)-(2*m), 0, 8*m, 48.1, 100, 8*m, + fillColor=colors.blue,strokeColor=None, strokeWidth=0)) + g.add(Rect(0, 0, w, (s/4.0) + 1.7*m, + fillColor = colors.green, strokeColor = None, strokeWidth=0)) + g.add(Polygon(points = [ 1.7*m,(s/2.0), (w/2.0),s/2.0 - 2*m, w-(1.7*m),(s/2.0) , (w/2.0),1.7*m], + fillColor = colors.yellow, strokeColor = None, strokeWidth=0)) + g.add(Wedge(w/2.0, s/2.0, 3.502*m, 166, 342.1, 3.502*m, + fillColor=colors.blue,strokeColor=None, strokeWidth=0)) + + addStar(3.2,3.5,0.3) + addStar(-8.5,1.5,0.3) + addStar(-7.5,-3,0.3) + addStar(-4,-5.5,0.3) + addStar(0,-4.5,0.3) + addStar(7,-3.5,0.3) + addStar(-3.5,-0.5,0.25) + addStar(0,-1.5,0.25) + addStar(1,-2.5,0.25) + addStar(3,-7,0.25) + addStar(5,-6.5,0.25) + addStar(6.5,-5,0.25) + addStar(7,-4.5,0.25) + addStar(-5.5,-3.2,0.25) + addStar(-6,-4.2,0.25) + addStar(-1,-2.75,0.2) + addStar(2,-5.5,0.2) + addStar(4,-5.5,0.2) + addStar(5,-7.5,0.2) + addStar(5,-5.5,0.2) + addStar(6,-5.5,0.2) + addStar(-8.8,-3.2,0.2) + addStar(2.5,0.5,0.2) + addStar(-0.2,-3.2,0.14) + addStar(-7.2,-2,0.14) + addStar(0,-8,0.1) + + sTmp = "ORDEM E PROGRESSO" + nTmp = len(sTmp) + delta = 0.850848010347/nTmp + radius = 7.9 *m + centerx = (w/2.0)-(2*m) + centery = 0 + for i in range(nTmp): + rad = 2*pi - i*delta -4.60766922527 + x=cos(rad)*radius+centerx + y=sin(rad)*radius+centery + if i == 6: + z = 0.35*m + else: + z= 0.45*m + g2 = Group(String(x, y, sTmp[i], fontName='Helvetica-Bold', + fontSize = z,strokeColor=None,fillColor=colors.green)) + g2.rotate(rad) + g.add(g2) + return g + +def makeFlag(name): + flag = Flag() + flag.kind = name + return flag + +def test(): + """This function produces three pdf files with examples of all the signs and symbols from this file. + """ +# page 1 + + labelFontSize = 10 + + X = (20,245) + + flags = [ + 'UK', + 'USA', + 'Afghanistan', + 'Austria', + 'Belgium', + 'Denmark', + 'Cuba', + 'Finland', + 'France', + 'Germany', + 'Greece', + 'Ireland', + 'Italy', + 'Luxembourg', + 'Holland', + 'Palestine', + 'Portugal', + 'Spain', + 'Sweden', + 'Norway', + 'CzechRepublic', + 'Turkey', + 'Switzerland', + 'EU', + 'Brazil', + ] + y = Y0 = 530 + f = 0 + D = None + for name in flags: + if not D: D = Drawing(450,650) + flag = makeFlag(name) + i = flags.index(name) + flag.x = X[i%2] + flag.y = y + D.add(flag) + D.add(String(flag.x+(flag.size/2.0),(flag.y-(1.2*labelFontSize)), + name, fillColor=colors.black, textAnchor='middle', fontSize=labelFontSize)) + if i%2: y = y - 125 + if (i%2 and y<0) or name==flags[-1]: + renderPDF.drawToFile(D, 'flags%02d.pdf'%f, 'flags.py - Page #%d'%(f+1)) + y = Y0 + f = f+1 + D = None + +if __name__=='__main__': + test() diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/grids.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/grids.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bfa6336f00667bc5c66d4b6406ef8a240b45b3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/grids.py @@ -0,0 +1,542 @@ +#Copyright ReportLab Europe Ltd. 2000-2017 +#see license.txt for license details +#history https://hg.reportlab.com/hg-public/reportlab/log/tip/src/reportlab/graphics/widgets/grids.py +__version__='3.3.0' + +from reportlab.lib import colors +from reportlab.lib.validators import isNumber, isColorOrNone, isBoolean, isListOfNumbers, OneOf, isListOfColors, isNumberOrNone +from reportlab.lib.attrmap import AttrMap, AttrMapValue +from reportlab.graphics.shapes import Drawing, Group, Line, Rect, LineShape, definePath, EmptyClipPath +from reportlab.graphics.widgetbase import Widget +from math import radians +from reportlab.graphics.transform import translate, rotate, mmult, transformPoints, inverse +from reportlab.lib.utils import flatten + +def frange(start, end=None, inc=None): + "A range function, that does accept float increments..." + + if end == None: + end = start + 0.0 + start = 0.0 + + if inc == None: + inc = 1.0 + + L = [] + end = end - inc*0.0001 #to avoid numrical problems + while 1: + next = start + len(L) * inc + if inc > 0 and next >= end: + break + elif inc < 0 and next <= end: + break + L.append(next) + + return L + + +def makeDistancesList(list): + """Returns a list of distances between adjacent numbers in some input list. + + E.g. [1, 1, 2, 3, 5, 7] -> [0, 1, 1, 2, 2] + """ + + d = [] + for i in range(len(list[:-1])): + d.append(list[i+1] - list[i]) + + return d + + +class Grid(Widget): + """This makes a rectangular grid of equidistant stripes. + + The grid contains an outer border rectangle, and stripes + inside which can be drawn with lines and/or as solid tiles. + The drawing order is: outer rectangle, then lines and tiles. + + The stripes' width is indicated as 'delta'. The sequence of + stripes can have an offset named 'delta0'. Both values need + to be positive! + """ + + _attrMap = AttrMap( + x = AttrMapValue(isNumber, desc="The grid's lower-left x position."), + y = AttrMapValue(isNumber, desc="The grid's lower-left y position."), + width = AttrMapValue(isNumber, desc="The grid's width."), + height = AttrMapValue(isNumber, desc="The grid's height."), + orientation = AttrMapValue(OneOf(('vertical', 'horizontal')), + desc='Determines if stripes are vertical or horizontal.'), + useLines = AttrMapValue(OneOf((0, 1)), + desc='Determines if stripes are drawn with lines.'), + useRects = AttrMapValue(OneOf((0, 1)), + desc='Determines if stripes are drawn with solid rectangles.'), + delta = AttrMapValue(isNumber, + desc='Determines the width/height of the stripes.'), + delta0 = AttrMapValue(isNumber, + desc='Determines the stripes initial width/height offset.'), + deltaSteps = AttrMapValue(isListOfNumbers, + desc='List of deltas to be used cyclically.'), + stripeColors = AttrMapValue(isListOfColors, + desc='Colors applied cyclically in the right or upper direction.'), + fillColor = AttrMapValue(isColorOrNone, + desc='Background color for entire rectangle.'), + strokeColor = AttrMapValue(isColorOrNone, + desc='Color used for lines.'), + strokeWidth = AttrMapValue(isNumber, + desc='Width used for lines.'), + rectStrokeColor = AttrMapValue(isColorOrNone, desc='Color for outer rect stroke.'), + rectStrokeWidth = AttrMapValue(isNumberOrNone, desc='Width for outer rect stroke.'), + ) + + def __init__(self): + self.x = 0 + self.y = 0 + self.width = 100 + self.height = 100 + self.orientation = 'vertical' + self.useLines = 0 + self.useRects = 1 + self.delta = 20 + self.delta0 = 0 + self.deltaSteps = [] + self.fillColor = colors.white + self.stripeColors = [colors.red, colors.green, colors.blue] + self.strokeColor = colors.black + self.strokeWidth = 2 + + + def demo(self): + D = Drawing(100, 100) + + g = Grid() + D.add(g) + + return D + + def makeOuterRect(self): + strokeColor = getattr(self,'rectStrokeColor',self.strokeColor) + strokeWidth = getattr(self,'rectStrokeWidth',self.strokeWidth) + if self.fillColor or (strokeColor and strokeWidth): + rect = Rect(self.x, self.y, self.width, self.height) + rect.fillColor = self.fillColor + rect.strokeColor = strokeColor + rect.strokeWidth = strokeWidth + return rect + else: + return None + + def makeLinePosList(self, start, isX=0): + "Returns a list of positions where to place lines." + + w, h = self.width, self.height + if isX: + length = w + else: + length = h + if self.deltaSteps: + r = [start + self.delta0] + i = 0 + while 1: + if r[-1] > start + length: + del r[-1] + break + r.append(r[-1] + self.deltaSteps[i % len(self.deltaSteps)]) + i = i + 1 + else: + r = frange(start + self.delta0, start + length, self.delta) + + r.append(start + length) + if self.delta0 != 0: + r.insert(0, start) + #print 'Grid.makeLinePosList() -> %s' % r + return r + + + def makeInnerLines(self): + # inner grid lines + group = Group() + + w, h = self.width, self.height + + if self.useLines == 1: + if self.orientation == 'vertical': + r = self.makeLinePosList(self.x, isX=1) + for x in r: + line = Line(x, self.y, x, self.y + h) + line.strokeColor = self.strokeColor + line.strokeWidth = self.strokeWidth + group.add(line) + elif self.orientation == 'horizontal': + r = self.makeLinePosList(self.y, isX=0) + for y in r: + line = Line(self.x, y, self.x + w, y) + line.strokeColor = self.strokeColor + line.strokeWidth = self.strokeWidth + group.add(line) + + return group + + + def makeInnerTiles(self): + # inner grid lines + group = Group() + + w, h = self.width, self.height + + # inner grid stripes (solid rectangles) + if self.useRects == 1: + cols = self.stripeColors + + if self.orientation == 'vertical': + r = self.makeLinePosList(self.x, isX=1) + elif self.orientation == 'horizontal': + r = self.makeLinePosList(self.y, isX=0) + + dist = makeDistancesList(r) + + i = 0 + for j in range(len(dist)): + if self.orientation == 'vertical': + x = r[j] + stripe = Rect(x, self.y, dist[j], h) + elif self.orientation == 'horizontal': + y = r[j] + stripe = Rect(self.x, y, w, dist[j]) + stripe.fillColor = cols[i % len(cols)] + stripe.strokeColor = None + group.add(stripe) + i = i + 1 + + return group + + + def draw(self): + # general widget bits + group = Group() + + group.add(self.makeOuterRect()) + group.add(self.makeInnerTiles()) + group.add(self.makeInnerLines(),name='_gridLines') + + return group + + +class DoubleGrid(Widget): + """This combines two ordinary Grid objects orthogonal to each other. + """ + + _attrMap = AttrMap( + x = AttrMapValue(isNumber, desc="The grid's lower-left x position."), + y = AttrMapValue(isNumber, desc="The grid's lower-left y position."), + width = AttrMapValue(isNumber, desc="The grid's width."), + height = AttrMapValue(isNumber, desc="The grid's height."), + grid0 = AttrMapValue(None, desc="The first grid component."), + grid1 = AttrMapValue(None, desc="The second grid component."), + ) + + def __init__(self): + self.x = 0 + self.y = 0 + self.width = 100 + self.height = 100 + + g0 = Grid() + g0.x = self.x + g0.y = self.y + g0.width = self.width + g0.height = self.height + g0.orientation = 'vertical' + g0.useLines = 1 + g0.useRects = 0 + g0.delta = 20 + g0.delta0 = 0 + g0.deltaSteps = [] + g0.fillColor = colors.white + g0.stripeColors = [colors.red, colors.green, colors.blue] + g0.strokeColor = colors.black + g0.strokeWidth = 1 + + g1 = Grid() + g1.x = self.x + g1.y = self.y + g1.width = self.width + g1.height = self.height + g1.orientation = 'horizontal' + g1.useLines = 1 + g1.useRects = 0 + g1.delta = 20 + g1.delta0 = 0 + g1.deltaSteps = [] + g1.fillColor = colors.white + g1.stripeColors = [colors.red, colors.green, colors.blue] + g1.strokeColor = colors.black + g1.strokeWidth = 1 + + self.grid0 = g0 + self.grid1 = g1 + + +## # This gives an AttributeError: +## # DoubleGrid instance has no attribute 'grid0' +## def __setattr__(self, name, value): +## if name in ('x', 'y', 'width', 'height'): +## setattr(self.grid0, name, value) +## setattr(self.grid1, name, value) + + + def demo(self): + D = Drawing(100, 100) + g = DoubleGrid() + D.add(g) + return D + + + def draw(self): + group = Group() + g0, g1 = self.grid0, self.grid1 + # Order groups to make sure both v and h lines + # are visible (works only when there is only + # one kind of stripes, v or h). + G = g0.useRects == 1 and g1.useRects == 0 and (g0,g1) or (g1,g0) + for g in G: + group.add(g.makeOuterRect()) + for g in G: + group.add(g.makeInnerTiles()) + group.add(g.makeInnerLines(),name='_gridLines') + + return group + + +class ShadedRect(Widget): + """This makes a rectangle with shaded colors between two colors. + + Colors are interpolated linearly between 'fillColorStart' + and 'fillColorEnd', both of which appear at the margins. + If 'numShades' is set to one, though, only 'fillColorStart' + is used. + """ + + _attrMap = AttrMap( + x = AttrMapValue(isNumber, desc="The grid's lower-left x position."), + y = AttrMapValue(isNumber, desc="The grid's lower-left y position."), + width = AttrMapValue(isNumber, desc="The grid's width."), + height = AttrMapValue(isNumber, desc="The grid's height."), + orientation = AttrMapValue(OneOf(('vertical', 'horizontal')), desc='Determines if stripes are vertical or horizontal.'), + numShades = AttrMapValue(isNumber, desc='The number of interpolating colors.'), + fillColorStart = AttrMapValue(isColorOrNone, desc='Start value of the color shade.'), + fillColorEnd = AttrMapValue(isColorOrNone, desc='End value of the color shade.'), + strokeColor = AttrMapValue(isColorOrNone, desc='Color used for border line.'), + strokeWidth = AttrMapValue(isNumber, desc='Width used for lines.'), + cylinderMode = AttrMapValue(isBoolean, desc='True if shading reverses in middle.'), + ) + + def __init__(self,**kw): + self.x = 0 + self.y = 0 + self.width = 100 + self.height = 100 + self.orientation = 'vertical' + self.numShades = 20 + self.fillColorStart = colors.pink + self.fillColorEnd = colors.black + self.strokeColor = colors.black + self.strokeWidth = 2 + self.cylinderMode = 0 + self.setProperties(kw) + + def demo(self): + D = Drawing(100, 100) + g = ShadedRect() + D.add(g) + + return D + + def _flipRectCorners(self): + "Flip rectangle's corners if width or height is negative." + x, y, width, height, fillColorStart, fillColorEnd = self.x, self.y, self.width, self.height, self.fillColorStart, self.fillColorEnd + if width < 0 and height > 0: + x = x + width + width = -width + if self.orientation=='vertical': fillColorStart, fillColorEnd = fillColorEnd, fillColorStart + elif height<0 and width>0: + y = y + height + height = -height + if self.orientation=='horizontal': fillColorStart, fillColorEnd = fillColorEnd, fillColorStart + elif height < 0 and height < 0: + x = x + width + width = -width + y = y + height + height = -height + return x, y, width, height, fillColorStart, fillColorEnd + + def draw(self): + # general widget bits + group = Group() + x, y, w, h, c0, c1 = self._flipRectCorners() + vertical = self.orientation == 'vertical' + cylinderMode = self.cylinderMode + linG = getattr(getattr(self,'_canvas',None),'linearGradient',None) + if linG: + canv = linG.__self__ + canv.saveState() + p = canv.beginPath() + p.rect(x, y, w, h) + canv.clipPath(p, stroke=0) + if cylinderMode: + if vertical: + linG(x, y, x+w/2, y, (c0,c1), extend=False) + linG(x+w/2, y, x+w, y, (c1,c0), extend=False) + else: + linG(x, y, x, y+h/2, (c0,c1), extend=False) + linG(x, y+h/2, x, y+h, (c1,c0), extend=False) + else: + if vertical: + linG(x, y, x+w, y, (c0,c1), extend=False) + else: + linG(x, y, x, y+h, (c0,c1), extend=False) + canv.restoreState() + else: + numShades = self.numShades + if cylinderMode: + if not numShades%2: numShades = numShades+1 + halfNumShades = int((numShades-1)/2) + 1 + num = float(numShades) # must make it float! + if vertical: + if numShades == 1: + V = [x] + else: + V = frange(x, x + w, w/num) + else: + if numShades == 1: + V = [y] + else: + V = frange(y, y + h, h/num) + + for v in V: + stripe = vertical and Rect(v, y, w/num, h) or Rect(x, v, w, h/num) + if cylinderMode: + if V.index(v)>=halfNumShades: + col = colors.linearlyInterpolatedColor(c1,c0,V[halfNumShades],V[-1], v) + else: + col = colors.linearlyInterpolatedColor(c0,c1,V[0],V[halfNumShades], v) + else: + col = colors.linearlyInterpolatedColor(c0,c1,V[0],V[-1], v) + stripe.fillColor = col + stripe.strokeColor = col + stripe.strokeWidth = 1 + group.add(stripe) + if self.strokeColor and self.strokeWidth>=0: + rect = Rect(x, y, w, h) + rect.strokeColor = self.strokeColor + rect.strokeWidth = self.strokeWidth + rect.fillColor = None + group.add(rect) + return group + + +def colorRange(c0, c1, n): + "Return a range of intermediate colors between c0 and c1" + if n==1: return [c0] + + C = [] + if n>1: + lim = n-1 + for i in range(n): + C.append(colors.linearlyInterpolatedColor(c0,c1,0,lim, i)) + return C + + +def centroid(P): + '''compute average point of a set of points''' + cx = 0 + cy = 0 + for x,y in P: + cx+=x + cy+=y + n = len(P) + return cx/n, cy/n + +def rotatedEnclosingRect(P, angle, rect): + ''' + given P a sequence P of x,y coordinate pairs and an angle in degrees + find the centroid of P and the axis at angle theta through it + find the extreme points of P wrt axis parallel distance and axis + orthogonal distance. Then compute the least rectangle that will still + enclose P when rotated by angle. Positive angles correspond to clockwise + rotation of the enclosing rect. + ''' + x0, y0 = centroid(P) + theta = radians(angle) + #translate to the centroid and rotate + mx = mmult(translate(x0,y0),rotate(angle)) + + #compute min and max of x and y of the rotated points + tp = flatten(transformPoints(mx,P)) + xx = tp[::2] + yx = tp[1::2] + xn = min(xx) + xx = max(xx) + yn = min(yx) + yx = max(yx) + + #make the enclosing rect and invert the original transform + rect.x = xn + rect.width = xx-xn + rect.y = yn + rect.height = yx-yn + g = Group(transform=inverse(mx)) + g.add(rect) + return g + +class ShadedPolygon(Widget,LineShape): + '''given a list of points [(x0,y0),....] we construct an enclosing + shaded rectangle and mask using the polygon points. + At angle 0 the shading fillColorStart left --> fillColorEnd right. + positive angles rotate the shading clockwise. + ''' + _attrMap = AttrMap(BASE=LineShape, + angle = AttrMapValue(isNumber,desc="Shading angle"), + fillColorStart = AttrMapValue(isColorOrNone), + fillColorEnd = AttrMapValue(isColorOrNone), + numShades = AttrMapValue(isNumber, desc='The number of interpolating colors.'), + cylinderMode = AttrMapValue(isBoolean, desc='True if shading reverses in middle.'), + points = AttrMapValue(isListOfNumbers), + ) + + def __init__(self,**kw): + self.angle = 90 + self.fillColorStart = colors.red + self.fillColorEnd = colors.green + self.cylinderMode = 0 + self.numShades = 50 + self.points = [-1,-1,2,2,3,-1] + LineShape.__init__(self,kw) + + def draw(self): + P = self.points + P = list(zip(P[::2],P[1::2])) + path = definePath([('moveTo',)+P[0]]+[('lineTo',)+x for x in P[1:]]+['closePath'], + fillColor=None, strokeColor=None) + path.isClipPath = 1 + g = Group() + g.add(path) + angle = self.angle % 360 + orientation = 'horizontal' if 0<=angle<=45 or 315<=angle<=360 or 135<=angle<=225 else 'vertical' + rect = ShadedRect(strokeWidth=0,strokeColor=None,orientation=orientation) + for k in 'fillColorStart', 'fillColorEnd', 'numShades', 'cylinderMode': + setattr(rect,k,getattr(self,k)) + g.add(rotatedEnclosingRect(P, angle, rect)) + g.add(EmptyClipPath) + path = path.copy() + path.isClipPath = 0 + path.strokeColor = self.strokeColor + path.strokeWidth = self.strokeWidth + g.add(path) + return g + +if __name__=='__main__': #noruntests + angle=45 + D = Drawing(120,120) + D.add(ShadedPolygon(points=(10,10,60,60,110,10),strokeColor=None,strokeWidth=1,angle=90,numShades=50,cylinderMode=0)) + D.save(formats=['pdf','gif'],fnRoot='shobj',outDir='/tmp') diff --git a/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/signsandsymbols.py b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/signsandsymbols.py new file mode 100644 index 0000000000000000000000000000000000000000..8b75b5bb61b146edd1596a5da3303750fed0606d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/reportlab/graphics/widgets/signsandsymbols.py @@ -0,0 +1,977 @@ +#Copyright ReportLab Europe Ltd. 2000-2017 +#see license.txt for license details +#history https://hg.reportlab.com/hg-public/reportlab/log/tip/src/reportlab/graphics/widgets/signsandsymbols.py +# signsandsymbols.py +# A collection of new widgets +# author: John Precedo (johnp@reportlab.com) + +__version__='3.3.0' +__doc__="""This file is a collection of widgets to produce some common signs and symbols. + +Widgets include: + +- ETriangle (an equilateral triangle), +- RTriangle (a right angled triangle), +- Octagon, +- Crossbox, +- Tickbox, +- SmileyFace, +- StopSign, +- NoEntry, +- NotAllowed (the red roundel from 'no smoking' signs), +- NoSmoking, +- DangerSign (a black exclamation point in a yellow triangle), +- YesNo (returns a tickbox or a crossbox depending on a testvalue), +- FloppyDisk, +- ArrowOne, and +- ArrowTwo +- CrossHair +""" + +from reportlab.lib import colors +from reportlab.lib.validators import * +from reportlab.lib.attrmap import * +from reportlab.lib.utils import isStr, asUnicode +from reportlab.graphics import shapes +from reportlab.graphics.widgetbase import Widget +from reportlab.graphics import renderPDF + + +class _Symbol(Widget): + """Abstract base widget + possible attributes: + 'x', 'y', 'size', 'fillColor', 'strokeColor' + """ + _nodoc = 1 + _attrMap = AttrMap( + x = AttrMapValue(isNumber,desc='symbol x coordinate'), + y = AttrMapValue(isNumber,desc='symbol y coordinate'), + dx = AttrMapValue(isNumber,desc='symbol x coordinate adjustment'), + dy = AttrMapValue(isNumber,desc='symbol x coordinate adjustment'), + size = AttrMapValue(isNumber), + fillColor = AttrMapValue(isColorOrNone), + strokeColor = AttrMapValue(isColorOrNone), + strokeWidth = AttrMapValue(isNumber), + ) + def __init__(self): + assert self.__class__.__name__!='_Symbol', 'Abstract class _Symbol instantiated' + self.x = self.y = self.dx = self.dy = 0 + self.size = 100 + self.fillColor = colors.red + self.strokeColor = None + self.strokeWidth = 0.1 + + def demo(self): + D = shapes.Drawing(200, 100) + s = float(self.size) + ob = self.__class__() + ob.x=50 + ob.y=0 + ob.draw() + D.add(ob) + D.add(shapes.String(ob.x+(s/2),(ob.y-12), + ob.__class__.__name__, fillColor=colors.black, textAnchor='middle', + fontSize=10)) + return D + +class ETriangle(_Symbol): + """This draws an equilateral triangle.""" + + def __init__(self): + _Symbol.__init__(self) + + def draw(self): + # general widget bits + s = float(self.size) # abbreviate as we will use this a lot + g = shapes.Group() + + # Triangle specific bits + ae = s*0.125 #(ae = 'an eighth') + triangle = shapes.Polygon(points = [ + self.x, self.y, + self.x+s, self.y, + self.x+(s/2),self.y+s], + fillColor = self.fillColor, + strokeColor = self.strokeColor, + strokeWidth=s/50.) + g.add(triangle) + return g + +class RTriangle(_Symbol): + """This draws a right-angled triangle. + + possible attributes: + 'x', 'y', 'size', 'fillColor', 'strokeColor' + + """ + + def __init__(self): + self.x = 0 + self.y = 0 + self.size = 100 + self.fillColor = colors.green + self.strokeColor = None + + def draw(self): + # general widget bits + s = float(self.size) # abbreviate as we will use this a lot + g = shapes.Group() + + # Triangle specific bits + ae = s*0.125 #(ae = 'an eighth') + triangle = shapes.Polygon(points = [ + self.x, self.y, + self.x+s, self.y, + self.x,self.y+s], + fillColor = self.fillColor, + strokeColor = self.strokeColor, + strokeWidth=s/50.) + g.add(triangle) + return g + +class Octagon(_Symbol): + """This widget draws an Octagon. + + possible attributes: + 'x', 'y', 'size', 'fillColor', 'strokeColor' + + """ + + def __init__(self): + self.x = 0 + self.y = 0 + self.size = 100 + self.fillColor = colors.yellow + self.strokeColor = None + + def draw(self): + # general widget bits + s = float(self.size) # abbreviate as we will use this a lot + g = shapes.Group() + + # Octagon specific bits + athird=s/3 + + octagon = shapes.Polygon(points=[self.x+athird, self.y, + self.x, self.y+athird, + self.x, self.y+(athird*2), + self.x+athird, self.y+s, + self.x+(athird*2), self.y+s, + self.x+s, self.y+(athird*2), + self.x+s, self.y+athird, + self.x+(athird*2), self.y], + strokeColor = self.strokeColor, + fillColor = self.fillColor, + strokeWidth=10) + g.add(octagon) + return g + +class Crossbox(_Symbol): + """This draws a black box with a red cross in it - a 'checkbox'. + + possible attributes: + 'x', 'y', 'size', 'crossColor', 'strokeColor', 'crosswidth' + + """ + + _attrMap = AttrMap(BASE=_Symbol, + crossColor = AttrMapValue(isColorOrNone), + crosswidth = AttrMapValue(isNumber), + ) + + def __init__(self): + self.x = 0 + self.y = 0 + self.size = 100 + self.fillColor = colors.white + self.crossColor = colors.red + self.strokeColor = colors.black + self.crosswidth = 10 + + def draw(self): + # general widget bits + s = float(self.size) # abbreviate as we will use this a lot + g = shapes.Group() + + # crossbox specific bits + box = shapes.Rect(self.x+1, self.y+1, s-2, s-2, + fillColor = self.fillColor, + strokeColor = self.strokeColor, + strokeWidth=2) + g.add(box) + + crossLine1 = shapes.Line(self.x+(s*0.15), self.y+(s*0.15), self.x+(s*0.85), self.y+(s*0.85), + fillColor = self.crossColor, + strokeColor = self.crossColor, + strokeWidth = self.crosswidth) + g.add(crossLine1) + + crossLine2 = shapes.Line(self.x+(s*0.15), self.y+(s*0.85), self.x+(s*0.85) ,self.y+(s*0.15), + fillColor = self.crossColor, + strokeColor = self.crossColor, + strokeWidth = self.crosswidth) + g.add(crossLine2) + + return g + + +class Tickbox(_Symbol): + """This draws a black box with a red tick in it - another 'checkbox'. + + possible attributes: + 'x', 'y', 'size', 'tickColor', 'strokeColor', 'tickwidth' + +""" + + _attrMap = AttrMap(BASE=_Symbol, + tickColor = AttrMapValue(isColorOrNone), + tickwidth = AttrMapValue(isNumber), + ) + + def __init__(self): + self.x = 0 + self.y = 0 + self.size = 100 + self.tickColor = colors.red + self.strokeColor = colors.black + self.fillColor = colors.white + self.tickwidth = 10 + + def draw(self): + # general widget bits + s = float(self.size) # abbreviate as we will use this a lot + g = shapes.Group() + + # tickbox specific bits + box = shapes.Rect(self.x+1, self.y+1, s-2, s-2, + fillColor = self.fillColor, + strokeColor = self.strokeColor, + strokeWidth=2) + g.add(box) + + tickLine = shapes.PolyLine(points = [self.x+(s*0.15), self.y+(s*0.35), self.x+(s*0.35), self.y+(s*0.15), + self.x+(s*0.35), self.y+(s*0.15), self.x+(s*0.85) ,self.y+(s*0.85)], + fillColor = self.tickColor, + strokeColor = self.tickColor, + strokeWidth = self.tickwidth) + g.add(tickLine) + + return g + +class SmileyFace(_Symbol): + """This draws a classic smiley face. + + possible attributes: + 'x', 'y', 'size', 'fillColor' + + """ + + def __init__(self): + _Symbol.__init__(self) + self.x = 0 + self.y = 0 + self.size = 100 + self.fillColor = colors.yellow + self.strokeColor = colors.black + + def draw(self): + # general widget bits + s = float(self.size) # abbreviate as we will use this a lot + g = shapes.Group() + + # SmileyFace specific bits + g.add(shapes.Circle(cx=self.x+(s/2), cy=self.y+(s/2), r=s/2, + fillColor=self.fillColor, strokeColor=self.strokeColor, + strokeWidth=max(s/38.,self.strokeWidth))) + + for i in (1,2): + g.add(shapes.Ellipse(self.x+(s/3)*i,self.y+(s/3)*2, s/30, s/10, + fillColor=self.strokeColor, strokeColor = self.strokeColor, + strokeWidth=max(s/38.,self.strokeWidth))) + + # calculate a pointslist for the mouth + # THIS IS A HACK! - don't use if there is a 'shapes.Arc' + centerx=self.x+(s/2) + centery=self.y+(s/2) + radius=s/3 + yradius = radius + xradius = radius + startangledegrees=200 + endangledegrees=340 + degreedelta = 1 + pointslist = [] + a = pointslist.append + from math import sin, cos, pi + degreestoradians = pi/180.0 + radiansdelta = degreedelta*degreestoradians + startangle = startangledegrees*degreestoradians + endangle = endangledegrees*degreestoradians + while endangle None: + # override connection classes to use emscripten specific classes + # n.b. mypy complains about the overriding of classes below + # if it isn't ignored + HTTPConnectionPool.ConnectionCls = EmscriptenHTTPConnection + HTTPSConnectionPool.ConnectionCls = EmscriptenHTTPSConnection + urllib3.connection.HTTPConnection = EmscriptenHTTPConnection # type: ignore[misc,assignment] + urllib3.connection.HTTPSConnection = EmscriptenHTTPSConnection # type: ignore[misc,assignment] + urllib3.connection.VerifiedHTTPSConnection = EmscriptenHTTPSConnection # type: ignore[assignment] diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/connection.py b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..63f79dd3be803db09671c909f79316c3f65d6916 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/connection.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import os +import typing + +# use http.client.HTTPException for consistency with non-emscripten +from http.client import HTTPException as HTTPException # noqa: F401 +from http.client import ResponseNotReady + +from ..._base_connection import _TYPE_BODY +from ...connection import HTTPConnection, ProxyConfig, port_by_scheme +from ...exceptions import TimeoutError +from ...response import BaseHTTPResponse +from ...util.connection import _TYPE_SOCKET_OPTIONS +from ...util.timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT +from ...util.url import Url +from .fetch import _RequestError, _TimeoutError, send_request, send_streaming_request +from .request import EmscriptenRequest +from .response import EmscriptenHttpResponseWrapper, EmscriptenResponse + +if typing.TYPE_CHECKING: + from ..._base_connection import BaseHTTPConnection, BaseHTTPSConnection + + +class EmscriptenHTTPConnection: + default_port: typing.ClassVar[int] = port_by_scheme["http"] + default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS] + + timeout: None | (float) + + host: str + port: int + blocksize: int + source_address: tuple[str, int] | None + socket_options: _TYPE_SOCKET_OPTIONS | None + + proxy: Url | None + proxy_config: ProxyConfig | None + + is_verified: bool = False + proxy_is_verified: bool | None = None + + response_class: type[BaseHTTPResponse] = EmscriptenHttpResponseWrapper + _response: EmscriptenResponse | None + + def __init__( + self, + host: str, + port: int = 0, + *, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + blocksize: int = 8192, + socket_options: _TYPE_SOCKET_OPTIONS | None = None, + proxy: Url | None = None, + proxy_config: ProxyConfig | None = None, + ) -> None: + self.host = host + self.port = port + self.timeout = timeout if isinstance(timeout, float) else 0.0 + self.scheme = "http" + self._closed = True + self._response = None + # ignore these things because we don't + # have control over that stuff + self.proxy = None + self.proxy_config = None + self.blocksize = blocksize + self.source_address = None + self.socket_options = None + self.is_verified = False + + def set_tunnel( + self, + host: str, + port: int | None = 0, + headers: typing.Mapping[str, str] | None = None, + scheme: str = "http", + ) -> None: + pass + + def connect(self) -> None: + pass + + def request( + self, + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + # We know *at least* botocore is depending on the order of the + # first 3 parameters so to be safe we only mark the later ones + # as keyword-only to ensure we have space to extend. + *, + chunked: bool = False, + preload_content: bool = True, + decode_content: bool = True, + enforce_content_length: bool = True, + ) -> None: + self._closed = False + if url.startswith("/"): + if self.port is not None: + port = f":{self.port}" + else: + port = "" + # no scheme / host / port included, make a full url + url = f"{self.scheme}://{self.host}{port}{url}" + request = EmscriptenRequest( + url=url, + method=method, + timeout=self.timeout if self.timeout else 0, + decode_content=decode_content, + ) + request.set_body(body) + if headers: + for k, v in headers.items(): + request.set_header(k, v) + self._response = None + try: + if not preload_content: + self._response = send_streaming_request(request) + if self._response is None: + self._response = send_request(request) + except _TimeoutError as e: + raise TimeoutError(e.message) from e + except _RequestError as e: + raise HTTPException(e.message) from e + + def getresponse(self) -> BaseHTTPResponse: + if self._response is not None: + return EmscriptenHttpResponseWrapper( + internal_response=self._response, + url=self._response.request.url, + connection=self, + ) + else: + raise ResponseNotReady() + + def close(self) -> None: + self._closed = True + self._response = None + + @property + def is_closed(self) -> bool: + """Whether the connection either is brand new or has been previously closed. + If this property is True then both ``is_connected`` and ``has_connected_to_proxy`` + properties must be False. + """ + return self._closed + + @property + def is_connected(self) -> bool: + """Whether the connection is actively connected to any origin (proxy or target)""" + return True + + @property + def has_connected_to_proxy(self) -> bool: + """Whether the connection has successfully connected to its proxy. + This returns False if no proxy is in use. Used to determine whether + errors are coming from the proxy layer or from tunnelling to the target origin. + """ + return False + + +class EmscriptenHTTPSConnection(EmscriptenHTTPConnection): + default_port = port_by_scheme["https"] + # all this is basically ignored, as browser handles https + cert_reqs: int | str | None = None + ca_certs: str | None = None + ca_cert_dir: str | None = None + ca_cert_data: None | str | bytes = None + cert_file: str | None + key_file: str | None + key_password: str | None + ssl_context: typing.Any | None + ssl_version: int | str | None = None + ssl_minimum_version: int | None = None + ssl_maximum_version: int | None = None + assert_hostname: None | str | typing.Literal[False] + assert_fingerprint: str | None = None + + def __init__( + self, + host: str, + port: int = 0, + *, + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + blocksize: int = 16384, + socket_options: ( + None | _TYPE_SOCKET_OPTIONS + ) = HTTPConnection.default_socket_options, + proxy: Url | None = None, + proxy_config: ProxyConfig | None = None, + cert_reqs: int | str | None = None, + assert_hostname: None | str | typing.Literal[False] = None, + assert_fingerprint: str | None = None, + server_hostname: str | None = None, + ssl_context: typing.Any | None = None, + ca_certs: str | None = None, + ca_cert_dir: str | None = None, + ca_cert_data: None | str | bytes = None, + ssl_minimum_version: int | None = None, + ssl_maximum_version: int | None = None, + ssl_version: int | str | None = None, # Deprecated + cert_file: str | None = None, + key_file: str | None = None, + key_password: str | None = None, + ) -> None: + super().__init__( + host, + port=port, + timeout=timeout, + source_address=source_address, + blocksize=blocksize, + socket_options=socket_options, + proxy=proxy, + proxy_config=proxy_config, + ) + self.scheme = "https" + + self.key_file = key_file + self.cert_file = cert_file + self.key_password = key_password + self.ssl_context = ssl_context + self.server_hostname = server_hostname + self.assert_hostname = assert_hostname + self.assert_fingerprint = assert_fingerprint + self.ssl_version = ssl_version + self.ssl_minimum_version = ssl_minimum_version + self.ssl_maximum_version = ssl_maximum_version + self.ca_certs = ca_certs and os.path.expanduser(ca_certs) + self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir) + self.ca_cert_data = ca_cert_data + + self.cert_reqs = None + + # The browser will automatically verify all requests. + # We have no control over that setting. + self.is_verified = True + + def set_cert( + self, + key_file: str | None = None, + cert_file: str | None = None, + cert_reqs: int | str | None = None, + key_password: str | None = None, + ca_certs: str | None = None, + assert_hostname: None | str | typing.Literal[False] = None, + assert_fingerprint: str | None = None, + ca_cert_dir: str | None = None, + ca_cert_data: None | str | bytes = None, + ) -> None: + pass + + +# verify that this class implements BaseHTTP(s) connection correctly +if typing.TYPE_CHECKING: + _supports_http_protocol: BaseHTTPConnection = EmscriptenHTTPConnection("", 0) + _supports_https_protocol: BaseHTTPSConnection = EmscriptenHTTPSConnection("", 0) diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/emscripten_fetch_worker.js b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/emscripten_fetch_worker.js new file mode 100644 index 0000000000000000000000000000000000000000..faf141e1fa4113a0c14480d1681ddecb9678ced4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/emscripten_fetch_worker.js @@ -0,0 +1,110 @@ +let Status = { + SUCCESS_HEADER: -1, + SUCCESS_EOF: -2, + ERROR_TIMEOUT: -3, + ERROR_EXCEPTION: -4, +}; + +let connections = new Map(); +let nextConnectionID = 1; +const encoder = new TextEncoder(); + +self.addEventListener("message", async function (event) { + if (event.data.close) { + let connectionID = event.data.close; + connections.delete(connectionID); + return; + } else if (event.data.getMore) { + let connectionID = event.data.getMore; + let { curOffset, value, reader, intBuffer, byteBuffer } = + connections.get(connectionID); + // if we still have some in buffer, then just send it back straight away + if (!value || curOffset >= value.length) { + // read another buffer if required + try { + let readResponse = await reader.read(); + + if (readResponse.done) { + // read everything - clear connection and return + connections.delete(connectionID); + Atomics.store(intBuffer, 0, Status.SUCCESS_EOF); + Atomics.notify(intBuffer, 0); + // finished reading successfully + // return from event handler + return; + } + curOffset = 0; + connections.get(connectionID).value = readResponse.value; + value = readResponse.value; + } catch (error) { + console.log("Request exception:", error); + let errorBytes = encoder.encode(error.message); + let written = errorBytes.length; + byteBuffer.set(errorBytes); + intBuffer[1] = written; + Atomics.store(intBuffer, 0, Status.ERROR_EXCEPTION); + Atomics.notify(intBuffer, 0); + } + } + + // send as much buffer as we can + let curLen = value.length - curOffset; + if (curLen > byteBuffer.length) { + curLen = byteBuffer.length; + } + byteBuffer.set(value.subarray(curOffset, curOffset + curLen), 0); + + Atomics.store(intBuffer, 0, curLen); // store current length in bytes + Atomics.notify(intBuffer, 0); + curOffset += curLen; + connections.get(connectionID).curOffset = curOffset; + + return; + } else { + // start fetch + let connectionID = nextConnectionID; + nextConnectionID += 1; + const intBuffer = new Int32Array(event.data.buffer); + const byteBuffer = new Uint8Array(event.data.buffer, 8); + try { + const response = await fetch(event.data.url, event.data.fetchParams); + // return the headers first via textencoder + var headers = []; + for (const pair of response.headers.entries()) { + headers.push([pair[0], pair[1]]); + } + let headerObj = { + headers: headers, + status: response.status, + connectionID, + }; + const headerText = JSON.stringify(headerObj); + let headerBytes = encoder.encode(headerText); + let written = headerBytes.length; + byteBuffer.set(headerBytes); + intBuffer[1] = written; + // make a connection + connections.set(connectionID, { + reader: response.body.getReader(), + intBuffer: intBuffer, + byteBuffer: byteBuffer, + value: undefined, + curOffset: 0, + }); + // set header ready + Atomics.store(intBuffer, 0, Status.SUCCESS_HEADER); + Atomics.notify(intBuffer, 0); + // all fetching after this goes through a new postmessage call with getMore + // this allows for parallel requests + } catch (error) { + console.log("Request exception:", error); + let errorBytes = encoder.encode(error.message); + let written = errorBytes.length; + byteBuffer.set(errorBytes); + intBuffer[1] = written; + Atomics.store(intBuffer, 0, Status.ERROR_EXCEPTION); + Atomics.notify(intBuffer, 0); + } + } +}); +self.postMessage({ inited: true }); diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/fetch.py b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..612cfddc4c28d2f0edf47522278fa6d9b7906623 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/fetch.py @@ -0,0 +1,726 @@ +""" +Support for streaming http requests in emscripten. + +A few caveats - + +If your browser (or Node.js) has WebAssembly JavaScript Promise Integration enabled +https://github.com/WebAssembly/js-promise-integration/blob/main/proposals/js-promise-integration/Overview.md +*and* you launch pyodide using `pyodide.runPythonAsync`, this will fetch data using the +JavaScript asynchronous fetch api (wrapped via `pyodide.ffi.call_sync`). In this case +timeouts and streaming should just work. + +Otherwise, it uses a combination of XMLHttpRequest and a web-worker for streaming. + +This approach has several caveats: + +Firstly, you can't do streaming http in the main UI thread, because atomics.wait isn't allowed. +Streaming only works if you're running pyodide in a web worker. + +Secondly, this uses an extra web worker and SharedArrayBuffer to do the asynchronous fetch +operation, so it requires that you have crossOriginIsolation enabled, by serving over https +(or from localhost) with the two headers below set: + + Cross-Origin-Opener-Policy: same-origin + Cross-Origin-Embedder-Policy: require-corp + +You can tell if cross origin isolation is successfully enabled by looking at the global crossOriginIsolated variable in +JavaScript console. If it isn't, streaming requests will fallback to XMLHttpRequest, i.e. getting the whole +request into a buffer and then returning it. it shows a warning in the JavaScript console in this case. + +Finally, the webworker which does the streaming fetch is created on initial import, but will only be started once +control is returned to javascript. Call `await wait_for_streaming_ready()` to wait for streaming fetch. + +NB: in this code, there are a lot of JavaScript objects. They are named js_* +to make it clear what type of object they are. +""" + +from __future__ import annotations + +import io +import json +from email.parser import Parser +from importlib.resources import files +from typing import TYPE_CHECKING, Any + +import js # type: ignore[import-not-found] +from pyodide.ffi import ( # type: ignore[import-not-found] + JsArray, + JsException, + JsProxy, + to_js, +) + +if TYPE_CHECKING: + from typing_extensions import Buffer + +from .request import EmscriptenRequest +from .response import EmscriptenResponse + +""" +There are some headers that trigger unintended CORS preflight requests. +See also https://github.com/koenvo/pyodide-http/issues/22 +""" +HEADERS_TO_IGNORE = ("user-agent",) + +SUCCESS_HEADER = -1 +SUCCESS_EOF = -2 +ERROR_TIMEOUT = -3 +ERROR_EXCEPTION = -4 + + +class _RequestError(Exception): + def __init__( + self, + message: str | None = None, + *, + request: EmscriptenRequest | None = None, + response: EmscriptenResponse | None = None, + ): + self.request = request + self.response = response + self.message = message + super().__init__(self.message) + + +class _StreamingError(_RequestError): + pass + + +class _TimeoutError(_RequestError): + pass + + +def _obj_from_dict(dict_val: dict[str, Any]) -> JsProxy: + return to_js(dict_val, dict_converter=js.Object.fromEntries) + + +class _ReadStream(io.RawIOBase): + def __init__( + self, + int_buffer: JsArray, + byte_buffer: JsArray, + timeout: float, + worker: JsProxy, + connection_id: int, + request: EmscriptenRequest, + ): + self.int_buffer = int_buffer + self.byte_buffer = byte_buffer + self.read_pos = 0 + self.read_len = 0 + self.connection_id = connection_id + self.worker = worker + self.timeout = int(1000 * timeout) if timeout > 0 else None + self.is_live = True + self._is_closed = False + self.request: EmscriptenRequest | None = request + + def __del__(self) -> None: + self.close() + + # this is compatible with _base_connection + def is_closed(self) -> bool: + return self._is_closed + + # for compatibility with RawIOBase + @property + def closed(self) -> bool: + return self.is_closed() + + def close(self) -> None: + if self.is_closed(): + return + self.read_len = 0 + self.read_pos = 0 + self.int_buffer = None + self.byte_buffer = None + self._is_closed = True + self.request = None + if self.is_live: + self.worker.postMessage(_obj_from_dict({"close": self.connection_id})) + self.is_live = False + super().close() + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return False + + def seekable(self) -> bool: + return False + + def readinto(self, byte_obj: Buffer) -> int: + if not self.int_buffer: + raise _StreamingError( + "No buffer for stream in _ReadStream.readinto", + request=self.request, + response=None, + ) + if self.read_len == 0: + # wait for the worker to send something + js.Atomics.store(self.int_buffer, 0, ERROR_TIMEOUT) + self.worker.postMessage(_obj_from_dict({"getMore": self.connection_id})) + if ( + js.Atomics.wait(self.int_buffer, 0, ERROR_TIMEOUT, self.timeout) + == "timed-out" + ): + raise _TimeoutError + data_len = self.int_buffer[0] + if data_len > 0: + self.read_len = data_len + self.read_pos = 0 + elif data_len == ERROR_EXCEPTION: + string_len = self.int_buffer[1] + # decode the error string + js_decoder = js.TextDecoder.new() + json_str = js_decoder.decode(self.byte_buffer.slice(0, string_len)) + raise _StreamingError( + f"Exception thrown in fetch: {json_str}", + request=self.request, + response=None, + ) + else: + # EOF, free the buffers and return zero + # and free the request + self.is_live = False + self.close() + return 0 + # copy from int32array to python bytes + ret_length = min(self.read_len, len(memoryview(byte_obj))) + subarray = self.byte_buffer.subarray( + self.read_pos, self.read_pos + ret_length + ).to_py() + memoryview(byte_obj)[0:ret_length] = subarray + self.read_len -= ret_length + self.read_pos += ret_length + return ret_length + + +class _StreamingFetcher: + def __init__(self) -> None: + # make web-worker and data buffer on startup + self.streaming_ready = False + streaming_worker_code = ( + files(__package__) + .joinpath("emscripten_fetch_worker.js") + .read_text(encoding="utf-8") + ) + js_data_blob = js.Blob.new( + to_js([streaming_worker_code], create_pyproxies=False), + _obj_from_dict({"type": "application/javascript"}), + ) + + def promise_resolver(js_resolve_fn: JsProxy, js_reject_fn: JsProxy) -> None: + def onMsg(e: JsProxy) -> None: + self.streaming_ready = True + js_resolve_fn(e) + + def onErr(e: JsProxy) -> None: + js_reject_fn(e) # Defensive: never happens in ci + + self.js_worker.onmessage = onMsg + self.js_worker.onerror = onErr + + js_data_url = js.URL.createObjectURL(js_data_blob) + self.js_worker = js.globalThis.Worker.new(js_data_url) + self.js_worker_ready_promise = js.globalThis.Promise.new(promise_resolver) + + def send(self, request: EmscriptenRequest) -> EmscriptenResponse: + headers = { + k: v for k, v in request.headers.items() if k not in HEADERS_TO_IGNORE + } + + body = request.body + fetch_data = {"headers": headers, "body": to_js(body), "method": request.method} + # start the request off in the worker + timeout = int(1000 * request.timeout) if request.timeout > 0 else None + js_shared_buffer = js.SharedArrayBuffer.new(1048576) + js_int_buffer = js.Int32Array.new(js_shared_buffer) + js_byte_buffer = js.Uint8Array.new(js_shared_buffer, 8) + + js.Atomics.store(js_int_buffer, 0, ERROR_TIMEOUT) + js.Atomics.notify(js_int_buffer, 0) + js_absolute_url = js.URL.new(request.url, js.location).href + self.js_worker.postMessage( + _obj_from_dict( + { + "buffer": js_shared_buffer, + "url": js_absolute_url, + "fetchParams": fetch_data, + } + ) + ) + # wait for the worker to send something + js.Atomics.wait(js_int_buffer, 0, ERROR_TIMEOUT, timeout) + if js_int_buffer[0] == ERROR_TIMEOUT: + raise _TimeoutError( + "Timeout connecting to streaming request", + request=request, + response=None, + ) + elif js_int_buffer[0] == SUCCESS_HEADER: + # got response + # header length is in second int of intBuffer + string_len = js_int_buffer[1] + # decode the rest to a JSON string + js_decoder = js.TextDecoder.new() + # this does a copy (the slice) because decode can't work on shared array + # for some silly reason + json_str = js_decoder.decode(js_byte_buffer.slice(0, string_len)) + # get it as an object + response_obj = json.loads(json_str) + return EmscriptenResponse( + request=request, + status_code=response_obj["status"], + headers=response_obj["headers"], + body=_ReadStream( + js_int_buffer, + js_byte_buffer, + request.timeout, + self.js_worker, + response_obj["connectionID"], + request, + ), + ) + elif js_int_buffer[0] == ERROR_EXCEPTION: + string_len = js_int_buffer[1] + # decode the error string + js_decoder = js.TextDecoder.new() + json_str = js_decoder.decode(js_byte_buffer.slice(0, string_len)) + raise _StreamingError( + f"Exception thrown in fetch: {json_str}", request=request, response=None + ) + else: + raise _StreamingError( + f"Unknown status from worker in fetch: {js_int_buffer[0]}", + request=request, + response=None, + ) + + +class _JSPIReadStream(io.RawIOBase): + """ + A read stream that uses pyodide.ffi.run_sync to read from a JavaScript fetch + response. This requires support for WebAssembly JavaScript Promise Integration + in the containing browser, and for pyodide to be launched via runPythonAsync. + + :param js_read_stream: + The JavaScript stream reader + + :param timeout: + Timeout in seconds + + :param request: + The request we're handling + + :param response: + The response this stream relates to + + :param js_abort_controller: + A JavaScript AbortController object, used for timeouts + """ + + def __init__( + self, + js_read_stream: Any, + timeout: float, + request: EmscriptenRequest, + response: EmscriptenResponse, + js_abort_controller: Any, # JavaScript AbortController for timeouts + ): + self.js_read_stream = js_read_stream + self.timeout = timeout + self._is_closed = False + self._is_done = False + self.request: EmscriptenRequest | None = request + self.response: EmscriptenResponse | None = response + self.current_buffer = None + self.current_buffer_pos = 0 + self.js_abort_controller = js_abort_controller + + def __del__(self) -> None: + self.close() + + # this is compatible with _base_connection + def is_closed(self) -> bool: + return self._is_closed + + # for compatibility with RawIOBase + @property + def closed(self) -> bool: + return self.is_closed() + + def close(self) -> None: + if self.is_closed(): + return + self.read_len = 0 + self.read_pos = 0 + self.js_read_stream.cancel() + self.js_read_stream = None + self._is_closed = True + self._is_done = True + self.request = None + self.response = None + super().close() + + def readable(self) -> bool: + return True + + def writable(self) -> bool: + return False + + def seekable(self) -> bool: + return False + + def _get_next_buffer(self) -> bool: + result_js = _run_sync_with_timeout( + self.js_read_stream.read(), + self.timeout, + self.js_abort_controller, + request=self.request, + response=self.response, + ) + if result_js.done: + self._is_done = True + return False + else: + self.current_buffer = result_js.value.to_py() + self.current_buffer_pos = 0 + return True + + def readinto(self, byte_obj: Buffer) -> int: + if self.current_buffer is None: + if not self._get_next_buffer() or self.current_buffer is None: + self.close() + return 0 + ret_length = min( + len(byte_obj), len(self.current_buffer) - self.current_buffer_pos + ) + byte_obj[0:ret_length] = self.current_buffer[ + self.current_buffer_pos : self.current_buffer_pos + ret_length + ] + self.current_buffer_pos += ret_length + if self.current_buffer_pos == len(self.current_buffer): + self.current_buffer = None + return ret_length + + +# check if we are in a worker or not +def is_in_browser_main_thread() -> bool: + return hasattr(js, "window") and hasattr(js, "self") and js.self == js.window + + +def is_cross_origin_isolated() -> bool: + return hasattr(js, "crossOriginIsolated") and js.crossOriginIsolated + + +def is_in_node() -> bool: + return ( + hasattr(js, "process") + and hasattr(js.process, "release") + and hasattr(js.process.release, "name") + and js.process.release.name == "node" + ) + + +def is_worker_available() -> bool: + return hasattr(js, "Worker") and hasattr(js, "Blob") + + +_fetcher: _StreamingFetcher | None = None + +if is_worker_available() and ( + (is_cross_origin_isolated() and not is_in_browser_main_thread()) + and (not is_in_node()) +): + _fetcher = _StreamingFetcher() +else: + _fetcher = None + + +NODE_JSPI_ERROR = ( + "urllib3 only works in Node.js with pyodide.runPythonAsync" + " and requires the flag --experimental-wasm-stack-switching in " + " versions of node <24." +) + + +def send_streaming_request(request: EmscriptenRequest) -> EmscriptenResponse | None: + if has_jspi(): + return send_jspi_request(request, True) + elif is_in_node(): + raise _RequestError( + message=NODE_JSPI_ERROR, + request=request, + response=None, + ) + + if _fetcher and streaming_ready(): + return _fetcher.send(request) + else: + _show_streaming_warning() + return None + + +_SHOWN_TIMEOUT_WARNING = False + + +def _show_timeout_warning() -> None: + global _SHOWN_TIMEOUT_WARNING + if not _SHOWN_TIMEOUT_WARNING: + _SHOWN_TIMEOUT_WARNING = True + message = "Warning: Timeout is not available on main browser thread" + js.console.warn(message) + + +_SHOWN_STREAMING_WARNING = False + + +def _show_streaming_warning() -> None: + global _SHOWN_STREAMING_WARNING + if not _SHOWN_STREAMING_WARNING: + _SHOWN_STREAMING_WARNING = True + message = "Can't stream HTTP requests because: \n" + if not is_cross_origin_isolated(): + message += " Page is not cross-origin isolated\n" + if is_in_browser_main_thread(): + message += " Python is running in main browser thread\n" + if not is_worker_available(): + message += " Worker or Blob classes are not available in this environment." # Defensive: this is always False in browsers that we test in + if streaming_ready() is False: + message += """ Streaming fetch worker isn't ready. If you want to be sure that streaming fetch +is working, you need to call: 'await urllib3.contrib.emscripten.fetch.wait_for_streaming_ready()`""" + from js import console + + console.warn(message) + + +def send_request(request: EmscriptenRequest) -> EmscriptenResponse: + if has_jspi(): + return send_jspi_request(request, False) + elif is_in_node(): + raise _RequestError( + message=NODE_JSPI_ERROR, + request=request, + response=None, + ) + try: + js_xhr = js.XMLHttpRequest.new() + + if not is_in_browser_main_thread(): + js_xhr.responseType = "arraybuffer" + if request.timeout: + js_xhr.timeout = int(request.timeout * 1000) + else: + js_xhr.overrideMimeType("text/plain; charset=ISO-8859-15") + if request.timeout: + # timeout isn't available on the main thread - show a warning in console + # if it is set + _show_timeout_warning() + + js_xhr.open(request.method, request.url, False) + for name, value in request.headers.items(): + if name.lower() not in HEADERS_TO_IGNORE: + js_xhr.setRequestHeader(name, value) + + js_xhr.send(to_js(request.body)) + + headers = dict(Parser().parsestr(js_xhr.getAllResponseHeaders())) + + if not is_in_browser_main_thread(): + body = js_xhr.response.to_py().tobytes() + else: + body = js_xhr.response.encode("ISO-8859-15") + return EmscriptenResponse( + status_code=js_xhr.status, headers=headers, body=body, request=request + ) + except JsException as err: + if err.name == "TimeoutError": + raise _TimeoutError(err.message, request=request) + elif err.name == "NetworkError": + raise _RequestError(err.message, request=request) + else: + # general http error + raise _RequestError(err.message, request=request) + + +def send_jspi_request( + request: EmscriptenRequest, streaming: bool +) -> EmscriptenResponse: + """ + Send a request using WebAssembly JavaScript Promise Integration + to wrap the asynchronous JavaScript fetch api (experimental). + + :param request: + Request to send + + :param streaming: + Whether to stream the response + + :return: The response object + :rtype: EmscriptenResponse + """ + timeout = request.timeout + js_abort_controller = js.AbortController.new() + headers = {k: v for k, v in request.headers.items() if k not in HEADERS_TO_IGNORE} + req_body = request.body + fetch_data = { + "headers": headers, + "body": to_js(req_body), + "method": request.method, + "signal": js_abort_controller.signal, + } + # Node.js returns the whole response (unlike opaqueredirect in browsers), + # so urllib3 can set `redirect: manual` to control redirects itself. + # https://stackoverflow.com/a/78524615 + if _is_node_js(): + fetch_data["redirect"] = "manual" + # Call JavaScript fetch (async api, returns a promise) + fetcher_promise_js = js.fetch(request.url, _obj_from_dict(fetch_data)) + # Now suspend WebAssembly until we resolve that promise + # or time out. + response_js = _run_sync_with_timeout( + fetcher_promise_js, + timeout, + js_abort_controller, + request=request, + response=None, + ) + headers = {} + header_iter = response_js.headers.entries() + while True: + iter_value_js = header_iter.next() + if getattr(iter_value_js, "done", False): + break + else: + headers[str(iter_value_js.value[0])] = str(iter_value_js.value[1]) + status_code = response_js.status + body: bytes | io.RawIOBase = b"" + + response = EmscriptenResponse( + status_code=status_code, headers=headers, body=b"", request=request + ) + if streaming: + # get via inputstream + if response_js.body is not None: + # get a reader from the fetch response + body_stream_js = response_js.body.getReader() + body = _JSPIReadStream( + body_stream_js, timeout, request, response, js_abort_controller + ) + else: + # get directly via arraybuffer + # n.b. this is another async JavaScript call. + body = _run_sync_with_timeout( + response_js.arrayBuffer(), + timeout, + js_abort_controller, + request=request, + response=response, + ).to_py() + response.body = body + return response + + +def _run_sync_with_timeout( + promise: Any, + timeout: float, + js_abort_controller: Any, + request: EmscriptenRequest | None, + response: EmscriptenResponse | None, +) -> Any: + """ + Await a JavaScript promise synchronously with a timeout which is implemented + via the AbortController + + :param promise: + Javascript promise to await + + :param timeout: + Timeout in seconds + + :param js_abort_controller: + A JavaScript AbortController object, used on timeout + + :param request: + The request being handled + + :param response: + The response being handled (if it exists yet) + + :raises _TimeoutError: If the request times out + :raises _RequestError: If the request raises a JavaScript exception + + :return: The result of awaiting the promise. + """ + timer_id = None + if timeout > 0: + timer_id = js.setTimeout( + js_abort_controller.abort.bind(js_abort_controller), int(timeout * 1000) + ) + try: + from pyodide.ffi import run_sync + + # run_sync here uses WebAssembly JavaScript Promise Integration to + # suspend python until the JavaScript promise resolves. + return run_sync(promise) + except JsException as err: + if err.name == "AbortError": + raise _TimeoutError( + message="Request timed out", request=request, response=response + ) + else: + raise _RequestError(message=err.message, request=request, response=response) + finally: + if timer_id is not None: + js.clearTimeout(timer_id) + + +def has_jspi() -> bool: + """ + Return true if jspi can be used. + + This requires both browser support and also WebAssembly + to be in the correct state - i.e. that the javascript + call into python was async not sync. + + :return: True if jspi can be used. + :rtype: bool + """ + try: + from pyodide.ffi import can_run_sync, run_sync # noqa: F401 + + return bool(can_run_sync()) + except ImportError: + return False + + +def _is_node_js() -> bool: + """ + Check if we are in Node.js. + + :return: True if we are in Node.js. + :rtype: bool + """ + return ( + hasattr(js, "process") + and hasattr(js.process, "release") + # According to the Node.js documentation, the release name is always "node". + and js.process.release.name == "node" + ) + + +def streaming_ready() -> bool | None: + if _fetcher: + return _fetcher.streaming_ready + else: + return None # no fetcher, return None to signify that + + +async def wait_for_streaming_ready() -> bool: + if _fetcher: + await _fetcher.js_worker_ready_promise + return True + else: + return False diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/request.py b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/request.py new file mode 100644 index 0000000000000000000000000000000000000000..e692e692bd0d38f6a0677992a6993fc68050dff3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/request.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + +from ..._base_connection import _TYPE_BODY + + +@dataclass +class EmscriptenRequest: + method: str + url: str + params: dict[str, str] | None = None + body: _TYPE_BODY | None = None + headers: dict[str, str] = field(default_factory=dict) + timeout: float = 0 + decode_content: bool = True + + def set_header(self, name: str, value: str) -> None: + self.headers[name.capitalize()] = value + + def set_body(self, body: _TYPE_BODY | None) -> None: + self.body = body diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/response.py b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/response.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1088a1826d089e1b603c51e85560b8583a3e3d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/emscripten/response.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +import json as _json +import logging +import typing +from contextlib import contextmanager +from dataclasses import dataclass +from http.client import HTTPException as HTTPException +from io import BytesIO, IOBase + +from ...exceptions import InvalidHeader, TimeoutError +from ...response import BaseHTTPResponse +from ...util.retry import Retry +from .request import EmscriptenRequest + +if typing.TYPE_CHECKING: + from ..._base_connection import BaseHTTPConnection, BaseHTTPSConnection + +log = logging.getLogger(__name__) + + +@dataclass +class EmscriptenResponse: + status_code: int + headers: dict[str, str] + body: IOBase | bytes + request: EmscriptenRequest + + +class EmscriptenHttpResponseWrapper(BaseHTTPResponse): + def __init__( + self, + internal_response: EmscriptenResponse, + url: str | None = None, + connection: BaseHTTPConnection | BaseHTTPSConnection | None = None, + ): + self._pool = None # set by pool class + self._body = None + self._response = internal_response + self._url = url + self._connection = connection + self._closed = False + super().__init__( + headers=internal_response.headers, + status=internal_response.status_code, + request_url=url, + version=0, + version_string="HTTP/?", + reason="", + decode_content=True, + ) + self.length_remaining = self._init_length(self._response.request.method) + self.length_is_certain = False + + @property + def url(self) -> str | None: + return self._url + + @url.setter + def url(self, url: str | None) -> None: + self._url = url + + @property + def connection(self) -> BaseHTTPConnection | BaseHTTPSConnection | None: + return self._connection + + @property + def retries(self) -> Retry | None: + return self._retries + + @retries.setter + def retries(self, retries: Retry | None) -> None: + # Override the request_url if retries has a redirect location. + self._retries = retries + + def stream( + self, amt: int | None = 2**16, decode_content: bool | None = None + ) -> typing.Generator[bytes]: + """ + A generator wrapper for the read() method. A call will block until + ``amt`` bytes have been read from the connection or until the + connection is closed. + + :param amt: + How much of the content to read. The generator will return up to + much data per iteration, but may return less. This is particularly + likely when using compressed data. However, the empty string will + never be returned. + + :param decode_content: + If True, will attempt to decode the body based on the + 'content-encoding' header. + """ + while True: + data = self.read(amt=amt, decode_content=decode_content) + + if data: + yield data + else: + break + + def _init_length(self, request_method: str | None) -> int | None: + length: int | None + content_length: str | None = self.headers.get("content-length") + + if content_length is not None: + try: + # RFC 7230 section 3.3.2 specifies multiple content lengths can + # be sent in a single Content-Length header + # (e.g. Content-Length: 42, 42). This line ensures the values + # are all valid ints and that as long as the `set` length is 1, + # all values are the same. Otherwise, the header is invalid. + lengths = {int(val) for val in content_length.split(",")} + if len(lengths) > 1: + raise InvalidHeader( + "Content-Length contained multiple " + "unmatching values (%s)" % content_length + ) + length = lengths.pop() + except ValueError: + length = None + else: + if length < 0: + length = None + + else: # if content_length is None + length = None + + # Check for responses that shouldn't include a body + if ( + self.status in (204, 304) + or 100 <= self.status < 200 + or request_method == "HEAD" + ): + length = 0 + + return length + + def read( + self, + amt: int | None = None, + decode_content: bool | None = None, # ignored because browser decodes always + cache_content: bool = False, + ) -> bytes: + if ( + self._closed + or self._response is None + or (isinstance(self._response.body, IOBase) and self._response.body.closed) + ): + return b"" + + with self._error_catcher(): + # body has been preloaded as a string by XmlHttpRequest + if not isinstance(self._response.body, IOBase): + self.length_remaining = len(self._response.body) + self.length_is_certain = True + # wrap body in IOStream + self._response.body = BytesIO(self._response.body) + if amt is not None and amt >= 0: + # don't cache partial content + cache_content = False + data = self._response.body.read(amt) + else: # read all we can (and cache it) + data = self._response.body.read() + if cache_content: + self._body = data + if self.length_remaining is not None: + self.length_remaining = max(self.length_remaining - len(data), 0) + if len(data) == 0 or ( + self.length_is_certain and self.length_remaining == 0 + ): + # definitely finished reading, close response stream + self._response.body.close() + return typing.cast(bytes, data) + + def read_chunked( + self, + amt: int | None = None, + decode_content: bool | None = None, + ) -> typing.Generator[bytes]: + # chunked is handled by browser + while True: + bytes = self.read(amt, decode_content) + if not bytes: + break + yield bytes + + def release_conn(self) -> None: + if not self._pool or not self._connection: + return None + + self._pool._put_conn(self._connection) + self._connection = None + + def drain_conn(self) -> None: + self.close() + + @property + def data(self) -> bytes: + if self._body: + return self._body + else: + return self.read(cache_content=True) + + def json(self) -> typing.Any: + """ + Deserializes the body of the HTTP response as a Python object. + + The body of the HTTP response must be encoded using UTF-8, as per + `RFC 8529 Section 8.1 `_. + + To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to + your custom decoder instead. + + If the body of the HTTP response is not decodable to UTF-8, a + `UnicodeDecodeError` will be raised. If the body of the HTTP response is not a + valid JSON document, a `json.JSONDecodeError` will be raised. + + Read more :ref:`here `. + + :returns: The body of the HTTP response as a Python object. + """ + data = self.data.decode("utf-8") + return _json.loads(data) + + def close(self) -> None: + if not self._closed: + if isinstance(self._response.body, IOBase): + self._response.body.close() + if self._connection: + self._connection.close() + self._connection = None + self._closed = True + + @contextmanager + def _error_catcher(self) -> typing.Generator[None]: + """ + Catch Emscripten specific exceptions thrown by fetch.py, + instead re-raising urllib3 variants, so that low-level exceptions + are not leaked in the high-level api. + + On exit, release the connection back to the pool. + """ + from .fetch import _RequestError, _TimeoutError # avoid circular import + + clean_exit = False + + try: + yield + # If no exception is thrown, we should avoid cleaning up + # unnecessarily. + clean_exit = True + except _TimeoutError as e: + raise TimeoutError(str(e)) + except _RequestError as e: + raise HTTPException(str(e)) + finally: + # If we didn't terminate cleanly, we need to throw away our + # connection. + if not clean_exit: + # The response may not be closed but we're not going to use it + # anymore so close it now + if ( + isinstance(self._response.body, IOBase) + and not self._response.body.closed + ): + self._response.body.close() + # release the connection back to the pool + self.release_conn() + else: + # If we have read everything from the response stream, + # return the connection back to the pool. + if ( + isinstance(self._response.body, IOBase) + and self._response.body.closed + ): + self.release_conn() diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/pyopenssl.py b/.venv/lib/python3.13/site-packages/urllib3/contrib/pyopenssl.py new file mode 100644 index 0000000000000000000000000000000000000000..8e05d3d785d53021a97a713cbdbb1f43708c9150 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/pyopenssl.py @@ -0,0 +1,564 @@ +""" +Module for using pyOpenSSL as a TLS backend. This module was relevant before +the standard library ``ssl`` module supported SNI, but now that we've dropped +support for Python 2.7 all relevant Python versions support SNI so +**this module is no longer recommended**. + +This needs the following packages installed: + +* `pyOpenSSL`_ (tested with 16.0.0) +* `cryptography`_ (minimum 1.3.4, from pyopenssl) +* `idna`_ (minimum 2.0) + +However, pyOpenSSL depends on cryptography, so while we use all three directly here we +end up having relatively few packages required. + +You can install them with the following command: + +.. code-block:: bash + + $ python -m pip install pyopenssl cryptography idna + +To activate certificate checking, call +:func:`~urllib3.contrib.pyopenssl.inject_into_urllib3` from your Python code +before you begin making HTTP requests. This can be done in a ``sitecustomize`` +module, or at any other time before your application begins using ``urllib3``, +like this: + +.. code-block:: python + + try: + import urllib3.contrib.pyopenssl + urllib3.contrib.pyopenssl.inject_into_urllib3() + except ImportError: + pass + +.. _pyopenssl: https://www.pyopenssl.org +.. _cryptography: https://cryptography.io +.. _idna: https://github.com/kjd/idna +""" + +from __future__ import annotations + +import OpenSSL.SSL # type: ignore[import-not-found] +from cryptography import x509 + +try: + from cryptography.x509 import UnsupportedExtension # type: ignore[attr-defined] +except ImportError: + # UnsupportedExtension is gone in cryptography >= 2.1.0 + class UnsupportedExtension(Exception): # type: ignore[no-redef] + pass + + +import logging +import ssl +import typing +from io import BytesIO +from socket import socket as socket_cls +from socket import timeout + +from .. import util + +if typing.TYPE_CHECKING: + from OpenSSL.crypto import X509 # type: ignore[import-not-found] + + +__all__ = ["inject_into_urllib3", "extract_from_urllib3"] + +# Map from urllib3 to PyOpenSSL compatible parameter-values. +_openssl_versions: dict[int, int] = { + util.ssl_.PROTOCOL_TLS: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] + util.ssl_.PROTOCOL_TLS_CLIENT: OpenSSL.SSL.SSLv23_METHOD, # type: ignore[attr-defined] + ssl.PROTOCOL_TLSv1: OpenSSL.SSL.TLSv1_METHOD, +} + +if hasattr(ssl, "PROTOCOL_TLSv1_1") and hasattr(OpenSSL.SSL, "TLSv1_1_METHOD"): + _openssl_versions[ssl.PROTOCOL_TLSv1_1] = OpenSSL.SSL.TLSv1_1_METHOD + +if hasattr(ssl, "PROTOCOL_TLSv1_2") and hasattr(OpenSSL.SSL, "TLSv1_2_METHOD"): + _openssl_versions[ssl.PROTOCOL_TLSv1_2] = OpenSSL.SSL.TLSv1_2_METHOD + + +_stdlib_to_openssl_verify = { + ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, + ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, + ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER + + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, +} +_openssl_to_stdlib_verify = {v: k for k, v in _stdlib_to_openssl_verify.items()} + +# The SSLvX values are the most likely to be missing in the future +# but we check them all just to be sure. +_OP_NO_SSLv2_OR_SSLv3: int = getattr(OpenSSL.SSL, "OP_NO_SSLv2", 0) | getattr( + OpenSSL.SSL, "OP_NO_SSLv3", 0 +) +_OP_NO_TLSv1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1", 0) +_OP_NO_TLSv1_1: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_1", 0) +_OP_NO_TLSv1_2: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_2", 0) +_OP_NO_TLSv1_3: int = getattr(OpenSSL.SSL, "OP_NO_TLSv1_3", 0) + +_openssl_to_ssl_minimum_version: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.TLSv1: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1, + ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1, + ssl.TLSVersion.TLSv1_3: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 + ), + ssl.TLSVersion.MAXIMUM_SUPPORTED: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 + ), +} +_openssl_to_ssl_maximum_version: dict[int, int] = { + ssl.TLSVersion.MINIMUM_SUPPORTED: ( + _OP_NO_SSLv2_OR_SSLv3 + | _OP_NO_TLSv1 + | _OP_NO_TLSv1_1 + | _OP_NO_TLSv1_2 + | _OP_NO_TLSv1_3 + ), + ssl.TLSVersion.TLSv1: ( + _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_1 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3 + ), + ssl.TLSVersion.TLSv1_1: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_2 | _OP_NO_TLSv1_3, + ssl.TLSVersion.TLSv1_2: _OP_NO_SSLv2_OR_SSLv3 | _OP_NO_TLSv1_3, + ssl.TLSVersion.TLSv1_3: _OP_NO_SSLv2_OR_SSLv3, + ssl.TLSVersion.MAXIMUM_SUPPORTED: _OP_NO_SSLv2_OR_SSLv3, +} + +# OpenSSL will only write 16K at a time +SSL_WRITE_BLOCKSIZE = 16384 + +orig_util_SSLContext = util.ssl_.SSLContext + + +log = logging.getLogger(__name__) + + +def inject_into_urllib3() -> None: + "Monkey-patch urllib3 with PyOpenSSL-backed SSL-support." + + _validate_dependencies_met() + + util.SSLContext = PyOpenSSLContext # type: ignore[assignment] + util.ssl_.SSLContext = PyOpenSSLContext # type: ignore[assignment] + util.IS_PYOPENSSL = True + util.ssl_.IS_PYOPENSSL = True + + +def extract_from_urllib3() -> None: + "Undo monkey-patching by :func:`inject_into_urllib3`." + + util.SSLContext = orig_util_SSLContext + util.ssl_.SSLContext = orig_util_SSLContext + util.IS_PYOPENSSL = False + util.ssl_.IS_PYOPENSSL = False + + +def _validate_dependencies_met() -> None: + """ + Verifies that PyOpenSSL's package-level dependencies have been met. + Throws `ImportError` if they are not met. + """ + # Method added in `cryptography==1.1`; not available in older versions + from cryptography.x509.extensions import Extensions + + if getattr(Extensions, "get_extension_for_class", None) is None: + raise ImportError( + "'cryptography' module missing required functionality. " + "Try upgrading to v1.3.4 or newer." + ) + + # pyOpenSSL 0.14 and above use cryptography for OpenSSL bindings. The _x509 + # attribute is only present on those versions. + from OpenSSL.crypto import X509 + + x509 = X509() + if getattr(x509, "_x509", None) is None: + raise ImportError( + "'pyOpenSSL' module missing required functionality. " + "Try upgrading to v0.14 or newer." + ) + + +def _dnsname_to_stdlib(name: str) -> str | None: + """ + Converts a dNSName SubjectAlternativeName field to the form used by the + standard library on the given Python version. + + Cryptography produces a dNSName as a unicode string that was idna-decoded + from ASCII bytes. We need to idna-encode that string to get it back, and + then on Python 3 we also need to convert to unicode via UTF-8 (the stdlib + uses PyUnicode_FromStringAndSize on it, which decodes via UTF-8). + + If the name cannot be idna-encoded then we return None signalling that + the name given should be skipped. + """ + + def idna_encode(name: str) -> bytes | None: + """ + Borrowed wholesale from the Python Cryptography Project. It turns out + that we can't just safely call `idna.encode`: it can explode for + wildcard names. This avoids that problem. + """ + import idna + + try: + for prefix in ["*.", "."]: + if name.startswith(prefix): + name = name[len(prefix) :] + return prefix.encode("ascii") + idna.encode(name) + return idna.encode(name) + except idna.core.IDNAError: + return None + + # Don't send IPv6 addresses through the IDNA encoder. + if ":" in name: + return name + + encoded_name = idna_encode(name) + if encoded_name is None: + return None + return encoded_name.decode("utf-8") + + +def get_subj_alt_name(peer_cert: X509) -> list[tuple[str, str]]: + """ + Given an PyOpenSSL certificate, provides all the subject alternative names. + """ + cert = peer_cert.to_cryptography() + + # We want to find the SAN extension. Ask Cryptography to locate it (it's + # faster than looping in Python) + try: + ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName).value + except x509.ExtensionNotFound: + # No such extension, return the empty list. + return [] + except ( + x509.DuplicateExtension, + UnsupportedExtension, + x509.UnsupportedGeneralNameType, + UnicodeError, + ) as e: + # A problem has been found with the quality of the certificate. Assume + # no SAN field is present. + log.warning( + "A problem was encountered with the certificate that prevented " + "urllib3 from finding the SubjectAlternativeName field. This can " + "affect certificate validation. The error was %s", + e, + ) + return [] + + # We want to return dNSName and iPAddress fields. We need to cast the IPs + # back to strings because the match_hostname function wants them as + # strings. + # Sadly the DNS names need to be idna encoded and then, on Python 3, UTF-8 + # decoded. This is pretty frustrating, but that's what the standard library + # does with certificates, and so we need to attempt to do the same. + # We also want to skip over names which cannot be idna encoded. + names = [ + ("DNS", name) + for name in map(_dnsname_to_stdlib, ext.get_values_for_type(x509.DNSName)) + if name is not None + ] + names.extend( + ("IP Address", str(name)) for name in ext.get_values_for_type(x509.IPAddress) + ) + + return names + + +class WrappedSocket: + """API-compatibility wrapper for Python OpenSSL's Connection-class.""" + + def __init__( + self, + connection: OpenSSL.SSL.Connection, + socket: socket_cls, + suppress_ragged_eofs: bool = True, + ) -> None: + self.connection = connection + self.socket = socket + self.suppress_ragged_eofs = suppress_ragged_eofs + self._io_refs = 0 + self._closed = False + + def fileno(self) -> int: + return self.socket.fileno() + + # Copy-pasted from Python 3.5 source code + def _decref_socketios(self) -> None: + if self._io_refs > 0: + self._io_refs -= 1 + if self._closed: + self.close() + + def recv(self, *args: typing.Any, **kwargs: typing.Any) -> bytes: + try: + data = self.connection.recv(*args, **kwargs) + except OpenSSL.SSL.SysCallError as e: + if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): + return b"" + else: + raise OSError(e.args[0], str(e)) from e + except OpenSSL.SSL.ZeroReturnError: + if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: + return b"" + else: + raise + except OpenSSL.SSL.WantReadError as e: + if not util.wait_for_read(self.socket, self.socket.gettimeout()): + raise timeout("The read operation timed out") from e + else: + return self.recv(*args, **kwargs) + + # TLS 1.3 post-handshake authentication + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"read error: {e!r}") from e + else: + return data # type: ignore[no-any-return] + + def recv_into(self, *args: typing.Any, **kwargs: typing.Any) -> int: + try: + return self.connection.recv_into(*args, **kwargs) # type: ignore[no-any-return] + except OpenSSL.SSL.SysCallError as e: + if self.suppress_ragged_eofs and e.args == (-1, "Unexpected EOF"): + return 0 + else: + raise OSError(e.args[0], str(e)) from e + except OpenSSL.SSL.ZeroReturnError: + if self.connection.get_shutdown() == OpenSSL.SSL.RECEIVED_SHUTDOWN: + return 0 + else: + raise + except OpenSSL.SSL.WantReadError as e: + if not util.wait_for_read(self.socket, self.socket.gettimeout()): + raise timeout("The read operation timed out") from e + else: + return self.recv_into(*args, **kwargs) + + # TLS 1.3 post-handshake authentication + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"read error: {e!r}") from e + + def settimeout(self, timeout: float) -> None: + return self.socket.settimeout(timeout) + + def _send_until_done(self, data: bytes) -> int: + while True: + try: + return self.connection.send(data) # type: ignore[no-any-return] + except OpenSSL.SSL.WantWriteError as e: + if not util.wait_for_write(self.socket, self.socket.gettimeout()): + raise timeout() from e + continue + except OpenSSL.SSL.SysCallError as e: + raise OSError(e.args[0], str(e)) from e + + def sendall(self, data: bytes) -> None: + total_sent = 0 + while total_sent < len(data): + sent = self._send_until_done( + data[total_sent : total_sent + SSL_WRITE_BLOCKSIZE] + ) + total_sent += sent + + def shutdown(self, how: int) -> None: + try: + self.connection.shutdown() + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"shutdown error: {e!r}") from e + + def close(self) -> None: + self._closed = True + if self._io_refs <= 0: + self._real_close() + + def _real_close(self) -> None: + try: + return self.connection.close() # type: ignore[no-any-return] + except OpenSSL.SSL.Error: + return + + def getpeercert( + self, binary_form: bool = False + ) -> dict[str, list[typing.Any]] | None: + x509 = self.connection.get_peer_certificate() + + if not x509: + return x509 # type: ignore[no-any-return] + + if binary_form: + return OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_ASN1, x509) # type: ignore[no-any-return] + + return { + "subject": ((("commonName", x509.get_subject().CN),),), # type: ignore[dict-item] + "subjectAltName": get_subj_alt_name(x509), + } + + def version(self) -> str: + return self.connection.get_protocol_version_name() # type: ignore[no-any-return] + + def selected_alpn_protocol(self) -> str | None: + alpn_proto = self.connection.get_alpn_proto_negotiated() + return alpn_proto.decode() if alpn_proto else None + + +WrappedSocket.makefile = socket_cls.makefile # type: ignore[attr-defined] + + +class PyOpenSSLContext: + """ + I am a wrapper class for the PyOpenSSL ``Context`` object. I am responsible + for translating the interface of the standard library ``SSLContext`` object + to calls into PyOpenSSL. + """ + + def __init__(self, protocol: int) -> None: + self.protocol = _openssl_versions[protocol] + self._ctx = OpenSSL.SSL.Context(self.protocol) + self._options = 0 + self.check_hostname = False + self._minimum_version: int = ssl.TLSVersion.MINIMUM_SUPPORTED + self._maximum_version: int = ssl.TLSVersion.MAXIMUM_SUPPORTED + self._verify_flags: int = ssl.VERIFY_X509_TRUSTED_FIRST + + @property + def options(self) -> int: + return self._options + + @options.setter + def options(self, value: int) -> None: + self._options = value + self._set_ctx_options() + + @property + def verify_flags(self) -> int: + return self._verify_flags + + @verify_flags.setter + def verify_flags(self, value: int) -> None: + self._verify_flags = value + self._ctx.get_cert_store().set_flags(self._verify_flags) + + @property + def verify_mode(self) -> int: + return _openssl_to_stdlib_verify[self._ctx.get_verify_mode()] + + @verify_mode.setter + def verify_mode(self, value: ssl.VerifyMode) -> None: + self._ctx.set_verify(_stdlib_to_openssl_verify[value], _verify_callback) + + def set_default_verify_paths(self) -> None: + self._ctx.set_default_verify_paths() + + def set_ciphers(self, ciphers: bytes | str) -> None: + if isinstance(ciphers, str): + ciphers = ciphers.encode("utf-8") + self._ctx.set_cipher_list(ciphers) + + def load_verify_locations( + self, + cafile: str | None = None, + capath: str | None = None, + cadata: bytes | None = None, + ) -> None: + if cafile is not None: + cafile = cafile.encode("utf-8") # type: ignore[assignment] + if capath is not None: + capath = capath.encode("utf-8") # type: ignore[assignment] + try: + self._ctx.load_verify_locations(cafile, capath) + if cadata is not None: + self._ctx.load_verify_locations(BytesIO(cadata)) + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"unable to load trusted certificates: {e!r}") from e + + def load_cert_chain( + self, + certfile: str, + keyfile: str | None = None, + password: str | None = None, + ) -> None: + try: + self._ctx.use_certificate_chain_file(certfile) + if password is not None: + if not isinstance(password, bytes): + password = password.encode("utf-8") # type: ignore[assignment] + self._ctx.set_passwd_cb(lambda *_: password) + self._ctx.use_privatekey_file(keyfile or certfile) + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"Unable to load certificate chain: {e!r}") from e + + def set_alpn_protocols(self, protocols: list[bytes | str]) -> None: + protocols = [util.util.to_bytes(p, "ascii") for p in protocols] + return self._ctx.set_alpn_protos(protocols) # type: ignore[no-any-return] + + def wrap_socket( + self, + sock: socket_cls, + server_side: bool = False, + do_handshake_on_connect: bool = True, + suppress_ragged_eofs: bool = True, + server_hostname: bytes | str | None = None, + ) -> WrappedSocket: + cnx = OpenSSL.SSL.Connection(self._ctx, sock) + + # If server_hostname is an IP, don't use it for SNI, per RFC6066 Section 3 + if server_hostname and not util.ssl_.is_ipaddress(server_hostname): + if isinstance(server_hostname, str): + server_hostname = server_hostname.encode("utf-8") + cnx.set_tlsext_host_name(server_hostname) + + cnx.set_connect_state() + + while True: + try: + cnx.do_handshake() + except OpenSSL.SSL.WantReadError as e: + if not util.wait_for_read(sock, sock.gettimeout()): + raise timeout("select timed out") from e + continue + except OpenSSL.SSL.Error as e: + raise ssl.SSLError(f"bad handshake: {e!r}") from e + break + + return WrappedSocket(cnx, sock) + + def _set_ctx_options(self) -> None: + self._ctx.set_options( + self._options + | _openssl_to_ssl_minimum_version[self._minimum_version] + | _openssl_to_ssl_maximum_version[self._maximum_version] + ) + + @property + def minimum_version(self) -> int: + return self._minimum_version + + @minimum_version.setter + def minimum_version(self, minimum_version: int) -> None: + self._minimum_version = minimum_version + self._set_ctx_options() + + @property + def maximum_version(self) -> int: + return self._maximum_version + + @maximum_version.setter + def maximum_version(self, maximum_version: int) -> None: + self._maximum_version = maximum_version + self._set_ctx_options() + + +def _verify_callback( + cnx: OpenSSL.SSL.Connection, + x509: X509, + err_no: int, + err_depth: int, + return_code: int, +) -> bool: + return err_no == 0 diff --git a/.venv/lib/python3.13/site-packages/urllib3/contrib/socks.py b/.venv/lib/python3.13/site-packages/urllib3/contrib/socks.py new file mode 100644 index 0000000000000000000000000000000000000000..e3239b569d93c6139f9c6a86118a5884daf1dabd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/contrib/socks.py @@ -0,0 +1,228 @@ +""" +This module contains provisional support for SOCKS proxies from within +urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and +SOCKS5. To enable its functionality, either install PySocks or install this +module with the ``socks`` extra. + +The SOCKS implementation supports the full range of urllib3 features. It also +supports the following SOCKS features: + +- SOCKS4A (``proxy_url='socks4a://...``) +- SOCKS4 (``proxy_url='socks4://...``) +- SOCKS5 with remote DNS (``proxy_url='socks5h://...``) +- SOCKS5 with local DNS (``proxy_url='socks5://...``) +- Usernames and passwords for the SOCKS proxy + +.. note:: + It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in + your ``proxy_url`` to ensure that DNS resolution is done from the remote + server instead of client-side when connecting to a domain name. + +SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5 +supports IPv4, IPv6, and domain names. + +When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url`` +will be sent as the ``userid`` section of the SOCKS request: + +.. code-block:: python + + proxy_url="socks4a://@proxy-host" + +When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion +of the ``proxy_url`` will be sent as the username/password to authenticate +with the proxy: + +.. code-block:: python + + proxy_url="socks5h://:@proxy-host" + +""" + +from __future__ import annotations + +try: + import socks # type: ignore[import-untyped] +except ImportError: + import warnings + + from ..exceptions import DependencyWarning + + warnings.warn( + ( + "SOCKS support in urllib3 requires the installation of optional " + "dependencies: specifically, PySocks. For more information, see " + "https://urllib3.readthedocs.io/en/latest/advanced-usage.html#socks-proxies" + ), + DependencyWarning, + ) + raise + +import typing +from socket import timeout as SocketTimeout + +from ..connection import HTTPConnection, HTTPSConnection +from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool +from ..exceptions import ConnectTimeoutError, NewConnectionError +from ..poolmanager import PoolManager +from ..util.url import parse_url + +try: + import ssl +except ImportError: + ssl = None # type: ignore[assignment] + + +class _TYPE_SOCKS_OPTIONS(typing.TypedDict): + socks_version: int + proxy_host: str | None + proxy_port: str | None + username: str | None + password: str | None + rdns: bool + + +class SOCKSConnection(HTTPConnection): + """ + A plain-text HTTP connection that connects via a SOCKS proxy. + """ + + def __init__( + self, + _socks_options: _TYPE_SOCKS_OPTIONS, + *args: typing.Any, + **kwargs: typing.Any, + ) -> None: + self._socks_options = _socks_options + super().__init__(*args, **kwargs) + + def _new_conn(self) -> socks.socksocket: + """ + Establish a new connection via the SOCKS proxy. + """ + extra_kw: dict[str, typing.Any] = {} + if self.source_address: + extra_kw["source_address"] = self.source_address + + if self.socket_options: + extra_kw["socket_options"] = self.socket_options + + try: + conn = socks.create_connection( + (self.host, self.port), + proxy_type=self._socks_options["socks_version"], + proxy_addr=self._socks_options["proxy_host"], + proxy_port=self._socks_options["proxy_port"], + proxy_username=self._socks_options["username"], + proxy_password=self._socks_options["password"], + proxy_rdns=self._socks_options["rdns"], + timeout=self.timeout, + **extra_kw, + ) + + except SocketTimeout as e: + raise ConnectTimeoutError( + self, + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e + + except socks.ProxyError as e: + # This is fragile as hell, but it seems to be the only way to raise + # useful errors here. + if e.socket_err: + error = e.socket_err + if isinstance(error, SocketTimeout): + raise ConnectTimeoutError( + self, + f"Connection to {self.host} timed out. (connect timeout={self.timeout})", + ) from e + else: + # Adding `from e` messes with coverage somehow, so it's omitted. + # See #2386. + raise NewConnectionError( + self, f"Failed to establish a new connection: {error}" + ) + else: + raise NewConnectionError( + self, f"Failed to establish a new connection: {e}" + ) from e + + except OSError as e: # Defensive: PySocks should catch all these. + raise NewConnectionError( + self, f"Failed to establish a new connection: {e}" + ) from e + + return conn + + +# We don't need to duplicate the Verified/Unverified distinction from +# urllib3/connection.py here because the HTTPSConnection will already have been +# correctly set to either the Verified or Unverified form by that module. This +# means the SOCKSHTTPSConnection will automatically be the correct type. +class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection): + pass + + +class SOCKSHTTPConnectionPool(HTTPConnectionPool): + ConnectionCls = SOCKSConnection + + +class SOCKSHTTPSConnectionPool(HTTPSConnectionPool): + ConnectionCls = SOCKSHTTPSConnection + + +class SOCKSProxyManager(PoolManager): + """ + A version of the urllib3 ProxyManager that routes connections via the + defined SOCKS proxy. + """ + + pool_classes_by_scheme = { + "http": SOCKSHTTPConnectionPool, + "https": SOCKSHTTPSConnectionPool, + } + + def __init__( + self, + proxy_url: str, + username: str | None = None, + password: str | None = None, + num_pools: int = 10, + headers: typing.Mapping[str, str] | None = None, + **connection_pool_kw: typing.Any, + ): + parsed = parse_url(proxy_url) + + if username is None and password is None and parsed.auth is not None: + split = parsed.auth.split(":") + if len(split) == 2: + username, password = split + if parsed.scheme == "socks5": + socks_version = socks.PROXY_TYPE_SOCKS5 + rdns = False + elif parsed.scheme == "socks5h": + socks_version = socks.PROXY_TYPE_SOCKS5 + rdns = True + elif parsed.scheme == "socks4": + socks_version = socks.PROXY_TYPE_SOCKS4 + rdns = False + elif parsed.scheme == "socks4a": + socks_version = socks.PROXY_TYPE_SOCKS4 + rdns = True + else: + raise ValueError(f"Unable to determine SOCKS version from {proxy_url}") + + self.proxy_url = proxy_url + + socks_options = { + "socks_version": socks_version, + "proxy_host": parsed.host, + "proxy_port": parsed.port, + "username": username, + "password": password, + "rdns": rdns, + } + connection_pool_kw["_socks_options"] = socks_options + + super().__init__(num_pools, headers, **connection_pool_kw) + + self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme diff --git a/.venv/lib/python3.13/site-packages/urllib3/http2/__init__.py b/.venv/lib/python3.13/site-packages/urllib3/http2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..133e1d8f237f6fddd557ae1c0e0cf738f7cc2748 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/http2/__init__.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from importlib.metadata import version + +__all__ = [ + "inject_into_urllib3", + "extract_from_urllib3", +] + +import typing + +orig_HTTPSConnection: typing.Any = None + + +def inject_into_urllib3() -> None: + # First check if h2 version is valid + h2_version = version("h2") + if not h2_version.startswith("4."): + raise ImportError( + "urllib3 v2 supports h2 version 4.x.x, currently " + f"the 'h2' module is compiled with {h2_version!r}. " + "See: https://github.com/urllib3/urllib3/issues/3290" + ) + + # Import here to avoid circular dependencies. + from .. import connection as urllib3_connection + from .. import util as urllib3_util + from ..connectionpool import HTTPSConnectionPool + from ..util import ssl_ as urllib3_util_ssl + from .connection import HTTP2Connection + + global orig_HTTPSConnection + orig_HTTPSConnection = urllib3_connection.HTTPSConnection + + HTTPSConnectionPool.ConnectionCls = HTTP2Connection + urllib3_connection.HTTPSConnection = HTTP2Connection # type: ignore[misc] + + # TODO: Offer 'http/1.1' as well, but for testing purposes this is handy. + urllib3_util.ALPN_PROTOCOLS = ["h2"] + urllib3_util_ssl.ALPN_PROTOCOLS = ["h2"] + + +def extract_from_urllib3() -> None: + from .. import connection as urllib3_connection + from .. import util as urllib3_util + from ..connectionpool import HTTPSConnectionPool + from ..util import ssl_ as urllib3_util_ssl + + HTTPSConnectionPool.ConnectionCls = orig_HTTPSConnection + urllib3_connection.HTTPSConnection = orig_HTTPSConnection # type: ignore[misc] + + urllib3_util.ALPN_PROTOCOLS = ["http/1.1"] + urllib3_util_ssl.ALPN_PROTOCOLS = ["http/1.1"] diff --git a/.venv/lib/python3.13/site-packages/urllib3/http2/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/http2/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a110ad27c1bc1078c07335e997af830fa1c2ec6b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/http2/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/http2/__pycache__/probe.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/http2/__pycache__/probe.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37881df8e122f39f89f7bed89f1167b362eb15ec Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/http2/__pycache__/probe.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/http2/connection.py b/.venv/lib/python3.13/site-packages/urllib3/http2/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..0a026da0a8357e324ded47b82b24042713b9bf06 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/http2/connection.py @@ -0,0 +1,356 @@ +from __future__ import annotations + +import logging +import re +import threading +import types +import typing + +import h2.config +import h2.connection +import h2.events + +from .._base_connection import _TYPE_BODY +from .._collections import HTTPHeaderDict +from ..connection import HTTPSConnection, _get_default_user_agent +from ..exceptions import ConnectionError +from ..response import BaseHTTPResponse + +orig_HTTPSConnection = HTTPSConnection + +T = typing.TypeVar("T") + +log = logging.getLogger(__name__) + +RE_IS_LEGAL_HEADER_NAME = re.compile(rb"^[!#$%&'*+\-.^_`|~0-9a-z]+$") +RE_IS_ILLEGAL_HEADER_VALUE = re.compile(rb"[\0\x00\x0a\x0d\r\n]|^[ \r\n\t]|[ \r\n\t]$") + + +def _is_legal_header_name(name: bytes) -> bool: + """ + "An implementation that validates fields according to the definitions in Sections + 5.1 and 5.5 of [HTTP] only needs an additional check that field names do not + include uppercase characters." (https://httpwg.org/specs/rfc9113.html#n-field-validity) + + `http.client._is_legal_header_name` does not validate the field name according to the + HTTP 1.1 spec, so we do that here, in addition to checking for uppercase characters. + + This does not allow for the `:` character in the header name, so should not + be used to validate pseudo-headers. + """ + return bool(RE_IS_LEGAL_HEADER_NAME.match(name)) + + +def _is_illegal_header_value(value: bytes) -> bool: + """ + "A field value MUST NOT contain the zero value (ASCII NUL, 0x00), line feed + (ASCII LF, 0x0a), or carriage return (ASCII CR, 0x0d) at any position. A field + value MUST NOT start or end with an ASCII whitespace character (ASCII SP or HTAB, + 0x20 or 0x09)." (https://httpwg.org/specs/rfc9113.html#n-field-validity) + """ + return bool(RE_IS_ILLEGAL_HEADER_VALUE.search(value)) + + +class _LockedObject(typing.Generic[T]): + """ + A wrapper class that hides a specific object behind a lock. + The goal here is to provide a simple way to protect access to an object + that cannot safely be simultaneously accessed from multiple threads. The + intended use of this class is simple: take hold of it with a context + manager, which returns the protected object. + """ + + __slots__ = ( + "lock", + "_obj", + ) + + def __init__(self, obj: T): + self.lock = threading.RLock() + self._obj = obj + + def __enter__(self) -> T: + self.lock.acquire() + return self._obj + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> None: + self.lock.release() + + +class HTTP2Connection(HTTPSConnection): + def __init__( + self, host: str, port: int | None = None, **kwargs: typing.Any + ) -> None: + self._h2_conn = self._new_h2_conn() + self._h2_stream: int | None = None + self._headers: list[tuple[bytes, bytes]] = [] + + if "proxy" in kwargs or "proxy_config" in kwargs: # Defensive: + raise NotImplementedError("Proxies aren't supported with HTTP/2") + + super().__init__(host, port, **kwargs) + + if self._tunnel_host is not None: + raise NotImplementedError("Tunneling isn't supported with HTTP/2") + + def _new_h2_conn(self) -> _LockedObject[h2.connection.H2Connection]: + config = h2.config.H2Configuration(client_side=True) + return _LockedObject(h2.connection.H2Connection(config=config)) + + def connect(self) -> None: + super().connect() + with self._h2_conn as conn: + conn.initiate_connection() + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + + def putrequest( # type: ignore[override] + self, + method: str, + url: str, + **kwargs: typing.Any, + ) -> None: + """putrequest + This deviates from the HTTPConnection method signature since we never need to override + sending accept-encoding headers or the host header. + """ + if "skip_host" in kwargs: + raise NotImplementedError("`skip_host` isn't supported") + if "skip_accept_encoding" in kwargs: + raise NotImplementedError("`skip_accept_encoding` isn't supported") + + self._request_url = url or "/" + self._validate_path(url) # type: ignore[attr-defined] + + if ":" in self.host: + authority = f"[{self.host}]:{self.port or 443}" + else: + authority = f"{self.host}:{self.port or 443}" + + self._headers.append((b":scheme", b"https")) + self._headers.append((b":method", method.encode())) + self._headers.append((b":authority", authority.encode())) + self._headers.append((b":path", url.encode())) + + with self._h2_conn as conn: + self._h2_stream = conn.get_next_available_stream_id() + + def putheader(self, header: str | bytes, *values: str | bytes) -> None: # type: ignore[override] + # TODO SKIPPABLE_HEADERS from urllib3 are ignored. + header = header.encode() if isinstance(header, str) else header + header = header.lower() # A lot of upstream code uses capitalized headers. + if not _is_legal_header_name(header): + raise ValueError(f"Illegal header name {str(header)}") + + for value in values: + value = value.encode() if isinstance(value, str) else value + if _is_illegal_header_value(value): + raise ValueError(f"Illegal header value {str(value)}") + self._headers.append((header, value)) + + def endheaders(self, message_body: typing.Any = None) -> None: # type: ignore[override] + if self._h2_stream is None: + raise ConnectionError("Must call `putrequest` first.") + + with self._h2_conn as conn: + conn.send_headers( + stream_id=self._h2_stream, + headers=self._headers, + end_stream=(message_body is None), + ) + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + self._headers = [] # Reset headers for the next request. + + def send(self, data: typing.Any) -> None: + """Send data to the server. + `data` can be: `str`, `bytes`, an iterable, or file-like objects + that support a .read() method. + """ + if self._h2_stream is None: + raise ConnectionError("Must call `putrequest` first.") + + with self._h2_conn as conn: + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + + if hasattr(data, "read"): # file-like objects + while True: + chunk = data.read(self.blocksize) + if not chunk: + break + if isinstance(chunk, str): + chunk = chunk.encode() + conn.send_data(self._h2_stream, chunk, end_stream=False) + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + conn.end_stream(self._h2_stream) + return + + if isinstance(data, str): # str -> bytes + data = data.encode() + + try: + if isinstance(data, bytes): + conn.send_data(self._h2_stream, data, end_stream=True) + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + else: + for chunk in data: + conn.send_data(self._h2_stream, chunk, end_stream=False) + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + conn.end_stream(self._h2_stream) + except TypeError: + raise TypeError( + "`data` should be str, bytes, iterable, or file. got %r" + % type(data) + ) + + def set_tunnel( + self, + host: str, + port: int | None = None, + headers: typing.Mapping[str, str] | None = None, + scheme: str = "http", + ) -> None: + raise NotImplementedError( + "HTTP/2 does not support setting up a tunnel through a proxy" + ) + + def getresponse( # type: ignore[override] + self, + ) -> HTTP2Response: + status = None + data = bytearray() + with self._h2_conn as conn: + end_stream = False + while not end_stream: + # TODO: Arbitrary read value. + if received_data := self.sock.recv(65535): + events = conn.receive_data(received_data) + for event in events: + if isinstance(event, h2.events.ResponseReceived): + headers = HTTPHeaderDict() + for header, value in event.headers: + if header == b":status": + status = int(value.decode()) + else: + headers.add( + header.decode("ascii"), value.decode("ascii") + ) + + elif isinstance(event, h2.events.DataReceived): + data += event.data + conn.acknowledge_received_data( + event.flow_controlled_length, event.stream_id + ) + + elif isinstance(event, h2.events.StreamEnded): + end_stream = True + + if data_to_send := conn.data_to_send(): + self.sock.sendall(data_to_send) + + assert status is not None + return HTTP2Response( + status=status, + headers=headers, + request_url=self._request_url, + data=bytes(data), + ) + + def request( # type: ignore[override] + self, + method: str, + url: str, + body: _TYPE_BODY | None = None, + headers: typing.Mapping[str, str] | None = None, + *, + preload_content: bool = True, + decode_content: bool = True, + enforce_content_length: bool = True, + **kwargs: typing.Any, + ) -> None: + """Send an HTTP/2 request""" + if "chunked" in kwargs: + # TODO this is often present from upstream. + # raise NotImplementedError("`chunked` isn't supported with HTTP/2") + pass + + if self.sock is not None: + self.sock.settimeout(self.timeout) + + self.putrequest(method, url) + + headers = headers or {} + for k, v in headers.items(): + if k.lower() == "transfer-encoding" and v == "chunked": + continue + else: + self.putheader(k, v) + + if b"user-agent" not in dict(self._headers): + self.putheader(b"user-agent", _get_default_user_agent()) + + if body: + self.endheaders(message_body=body) + self.send(body) + else: + self.endheaders() + + def close(self) -> None: + with self._h2_conn as conn: + try: + conn.close_connection() + if data := conn.data_to_send(): + self.sock.sendall(data) + except Exception: + pass + + # Reset all our HTTP/2 connection state. + self._h2_conn = self._new_h2_conn() + self._h2_stream = None + self._headers = [] + + super().close() + + +class HTTP2Response(BaseHTTPResponse): + # TODO: This is a woefully incomplete response object, but works for non-streaming. + def __init__( + self, + status: int, + headers: HTTPHeaderDict, + request_url: str, + data: bytes, + decode_content: bool = False, # TODO: support decoding + ) -> None: + super().__init__( + status=status, + headers=headers, + # Following CPython, we map HTTP versions to major * 10 + minor integers + version=20, + version_string="HTTP/2", + # No reason phrase in HTTP/2 + reason=None, + decode_content=decode_content, + request_url=request_url, + ) + self._data = data + self.length_remaining = 0 + + @property + def data(self) -> bytes: + return self._data + + def get_redirect_location(self) -> None: + return None + + def close(self) -> None: + pass diff --git a/.venv/lib/python3.13/site-packages/urllib3/http2/probe.py b/.venv/lib/python3.13/site-packages/urllib3/http2/probe.py new file mode 100644 index 0000000000000000000000000000000000000000..9ea900764f0885eafaac9454523417d86e33df2d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/http2/probe.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import threading + + +class _HTTP2ProbeCache: + __slots__ = ( + "_lock", + "_cache_locks", + "_cache_values", + ) + + def __init__(self) -> None: + self._lock = threading.Lock() + self._cache_locks: dict[tuple[str, int], threading.RLock] = {} + self._cache_values: dict[tuple[str, int], bool | None] = {} + + def acquire_and_get(self, host: str, port: int) -> bool | None: + # By the end of this block we know that + # _cache_[values,locks] is available. + value = None + with self._lock: + key = (host, port) + try: + value = self._cache_values[key] + # If it's a known value we return right away. + if value is not None: + return value + except KeyError: + self._cache_locks[key] = threading.RLock() + self._cache_values[key] = None + + # If the value is unknown, we acquire the lock to signal + # to the requesting thread that the probe is in progress + # or that the current thread needs to return their findings. + key_lock = self._cache_locks[key] + key_lock.acquire() + try: + # If the by the time we get the lock the value has been + # updated we want to return the updated value. + value = self._cache_values[key] + + # In case an exception like KeyboardInterrupt is raised here. + except BaseException as e: # Defensive: + assert not isinstance(e, KeyError) # KeyError shouldn't be possible. + key_lock.release() + raise + + return value + + def set_and_release( + self, host: str, port: int, supports_http2: bool | None + ) -> None: + key = (host, port) + key_lock = self._cache_locks[key] + with key_lock: # Uses an RLock, so can be locked again from same thread. + if supports_http2 is None and self._cache_values[key] is not None: + raise ValueError( + "Cannot reset HTTP/2 support for origin after value has been set." + ) # Defensive: not expected in normal usage + + self._cache_values[key] = supports_http2 + key_lock.release() + + def _values(self) -> dict[tuple[str, int], bool | None]: + """This function is for testing purposes only. Gets the current state of the probe cache""" + with self._lock: + return {k: v for k, v in self._cache_values.items()} + + def _reset(self) -> None: + """This function is for testing purposes only. Reset the cache values""" + with self._lock: + self._cache_locks = {} + self._cache_values = {} + + +_HTTP2_PROBE_CACHE = _HTTP2ProbeCache() + +set_and_release = _HTTP2_PROBE_CACHE.set_and_release +acquire_and_get = _HTTP2_PROBE_CACHE.acquire_and_get +_values = _HTTP2_PROBE_CACHE._values +_reset = _HTTP2_PROBE_CACHE._reset + +__all__ = [ + "set_and_release", + "acquire_and_get", +] diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__init__.py b/.venv/lib/python3.13/site-packages/urllib3/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..534126033c083203649022fa9b753a433f005556 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/__init__.py @@ -0,0 +1,42 @@ +# For backwards compatibility, provide imports that used to be here. +from __future__ import annotations + +from .connection import is_connection_dropped +from .request import SKIP_HEADER, SKIPPABLE_HEADERS, make_headers +from .response import is_fp_closed +from .retry import Retry +from .ssl_ import ( + ALPN_PROTOCOLS, + IS_PYOPENSSL, + SSLContext, + assert_fingerprint, + create_urllib3_context, + resolve_cert_reqs, + resolve_ssl_version, + ssl_wrap_socket, +) +from .timeout import Timeout +from .url import Url, parse_url +from .wait import wait_for_read, wait_for_write + +__all__ = ( + "IS_PYOPENSSL", + "SSLContext", + "ALPN_PROTOCOLS", + "Retry", + "Timeout", + "Url", + "assert_fingerprint", + "create_urllib3_context", + "is_connection_dropped", + "is_fp_closed", + "parse_url", + "make_headers", + "resolve_cert_reqs", + "resolve_ssl_version", + "ssl_wrap_socket", + "wait_for_read", + "wait_for_write", + "SKIP_HEADER", + "SKIPPABLE_HEADERS", +) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e64b630dc5f84e441162a9b82a21f3ef73b6d9f0 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/connection.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/connection.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed2f8faa469c108743714b92ba0257bfca7fc09d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/connection.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/proxy.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/proxy.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ba8654f31b01dc2bd97745fda962df6869b055b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/proxy.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/request.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/request.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88bab586066bd645b78c7428247c238fcb487e91 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/request.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/response.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/response.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cccf9f101e7d5a5167b39e1836341f681ad91c8 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/response.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/retry.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/retry.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a97016bcfa97668c2578be41516044421b7b2e88 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/retry.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssl_.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssl_.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbdd1f7663a500a5126f8ee281789c7d84ff4c8a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssl_.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssl_match_hostname.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssl_match_hostname.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76915dcec3818cd94596608e86a94a940e222bae Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssl_match_hostname.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssltransport.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssltransport.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b8723344905b728ca1211488a0fdef7ca92c9c4 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/ssltransport.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/timeout.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/timeout.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..177461255b77210c10c31da4d3572f7649e622f6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/timeout.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/url.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/url.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6355dae34fded009c1effecf95971f9aace20a2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/url.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/wait.cpython-313.pyc b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/wait.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d01cdb1ea58e83667d9b1c2177777de5c439361a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/urllib3/util/__pycache__/wait.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/connection.py b/.venv/lib/python3.13/site-packages/urllib3/util/connection.py new file mode 100644 index 0000000000000000000000000000000000000000..f92519ee9124e91e5da7d60ccc3f274312ed3514 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/connection.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import socket +import typing + +from ..exceptions import LocationParseError +from .timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT + +_TYPE_SOCKET_OPTIONS = list[tuple[int, int, typing.Union[int, bytes]]] + +if typing.TYPE_CHECKING: + from .._base_connection import BaseHTTPConnection + + +def is_connection_dropped(conn: BaseHTTPConnection) -> bool: # Platform-specific + """ + Returns True if the connection is dropped and should be closed. + :param conn: :class:`urllib3.connection.HTTPConnection` object. + """ + return not conn.is_connected + + +# This function is copied from socket.py in the Python 2.7 standard +# library test suite. Added to its signature is only `socket_options`. +# One additional modification is that we avoid binding to IPv6 servers +# discovered in DNS if the system doesn't have IPv6 functionality. +def create_connection( + address: tuple[str, int], + timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + source_address: tuple[str, int] | None = None, + socket_options: _TYPE_SOCKET_OPTIONS | None = None, +) -> socket.socket: + """Connect to *address* and return the socket object. + + Convenience function. Connect to *address* (a 2-tuple ``(host, + port)``) and return the socket object. Passing the optional + *timeout* parameter will set the timeout on the socket instance + before attempting to connect. If no *timeout* is supplied, the + global default timeout setting returned by :func:`socket.getdefaulttimeout` + is used. If *source_address* is set it must be a tuple of (host, port) + for the socket to bind as a source address before making the connection. + An host of '' or port 0 tells the OS to use the default. + """ + + host, port = address + if host.startswith("["): + host = host.strip("[]") + err = None + + # Using the value from allowed_gai_family() in the context of getaddrinfo lets + # us select whether to work with IPv4 DNS records, IPv6 records, or both. + # The original create_connection function always returns all records. + family = allowed_gai_family() + + try: + host.encode("idna") + except UnicodeError: + raise LocationParseError(f"'{host}', label empty or too long") from None + + for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = None + try: + sock = socket.socket(af, socktype, proto) + + # If provided, set socket level options before connecting. + _set_socket_options(sock, socket_options) + + if timeout is not _DEFAULT_TIMEOUT: + sock.settimeout(timeout) + if source_address: + sock.bind(source_address) + sock.connect(sa) + # Break explicitly a reference cycle + err = None + return sock + + except OSError as _: + err = _ + if sock is not None: + sock.close() + + if err is not None: + try: + raise err + finally: + # Break explicitly a reference cycle + err = None + else: + raise OSError("getaddrinfo returns an empty list") + + +def _set_socket_options( + sock: socket.socket, options: _TYPE_SOCKET_OPTIONS | None +) -> None: + if options is None: + return + + for opt in options: + sock.setsockopt(*opt) + + +def allowed_gai_family() -> socket.AddressFamily: + """This function is designed to work in the context of + getaddrinfo, where family=socket.AF_UNSPEC is the default and + will perform a DNS search for both IPv6 and IPv4 records.""" + + family = socket.AF_INET + if HAS_IPV6: + family = socket.AF_UNSPEC + return family + + +def _has_ipv6(host: str) -> bool: + """Returns True if the system can bind an IPv6 address.""" + sock = None + has_ipv6 = False + + if socket.has_ipv6: + # has_ipv6 returns true if cPython was compiled with IPv6 support. + # It does not tell us if the system has IPv6 support enabled. To + # determine that we must bind to an IPv6 address. + # https://github.com/urllib3/urllib3/pull/611 + # https://bugs.python.org/issue658327 + try: + sock = socket.socket(socket.AF_INET6) + sock.bind((host, 0)) + has_ipv6 = True + except Exception: + pass + + if sock: + sock.close() + return has_ipv6 + + +HAS_IPV6 = _has_ipv6("::1") diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/proxy.py b/.venv/lib/python3.13/site-packages/urllib3/util/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..908fc6621d0afbed16bde2c1957a5cf28d3a84d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/proxy.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import typing + +from .url import Url + +if typing.TYPE_CHECKING: + from ..connection import ProxyConfig + + +def connection_requires_http_tunnel( + proxy_url: Url | None = None, + proxy_config: ProxyConfig | None = None, + destination_scheme: str | None = None, +) -> bool: + """ + Returns True if the connection requires an HTTP CONNECT through the proxy. + + :param URL proxy_url: + URL of the proxy. + :param ProxyConfig proxy_config: + Proxy configuration from poolmanager.py + :param str destination_scheme: + The scheme of the destination. (i.e https, http, etc) + """ + # If we're not using a proxy, no way to use a tunnel. + if proxy_url is None: + return False + + # HTTP destinations never require tunneling, we always forward. + if destination_scheme == "http": + return False + + # Support for forwarding with HTTPS proxies and HTTPS destinations. + if ( + proxy_url.scheme == "https" + and proxy_config + and proxy_config.use_forwarding_for_https + ): + return False + + # Otherwise always use a tunnel. + return True diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/request.py b/.venv/lib/python3.13/site-packages/urllib3/util/request.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2372ba7e777826a4eb124ddfb54f0240b65d67 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/request.py @@ -0,0 +1,263 @@ +from __future__ import annotations + +import io +import sys +import typing +from base64 import b64encode +from enum import Enum + +from ..exceptions import UnrewindableBodyError +from .util import to_bytes + +if typing.TYPE_CHECKING: + from typing import Final + +# Pass as a value within ``headers`` to skip +# emitting some HTTP headers that are added automatically. +# The only headers that are supported are ``Accept-Encoding``, +# ``Host``, and ``User-Agent``. +SKIP_HEADER = "@@@SKIP_HEADER@@@" +SKIPPABLE_HEADERS = frozenset(["accept-encoding", "host", "user-agent"]) + +ACCEPT_ENCODING = "gzip,deflate" +try: + try: + import brotlicffi as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401 + except ImportError: + import brotli as _unused_module_brotli # type: ignore[import-not-found] # noqa: F401 +except ImportError: + pass +else: + ACCEPT_ENCODING += ",br" + +try: + if sys.version_info >= (3, 14): + from compression import zstd as _unused_module_zstd # noqa: F401 + else: + from backports import zstd as _unused_module_zstd # noqa: F401 +except ImportError: + pass +else: + ACCEPT_ENCODING += ",zstd" + + +class _TYPE_FAILEDTELL(Enum): + token = 0 + + +_FAILEDTELL: Final[_TYPE_FAILEDTELL] = _TYPE_FAILEDTELL.token + +_TYPE_BODY_POSITION = typing.Union[int, _TYPE_FAILEDTELL] + +# When sending a request with these methods we aren't expecting +# a body so don't need to set an explicit 'Content-Length: 0' +# The reason we do this in the negative instead of tracking methods +# which 'should' have a body is because unknown methods should be +# treated as if they were 'POST' which *does* expect a body. +_METHODS_NOT_EXPECTING_BODY = {"GET", "HEAD", "DELETE", "TRACE", "OPTIONS", "CONNECT"} + + +def make_headers( + keep_alive: bool | None = None, + accept_encoding: bool | list[str] | str | None = None, + user_agent: str | None = None, + basic_auth: str | None = None, + proxy_basic_auth: str | None = None, + disable_cache: bool | None = None, +) -> dict[str, str]: + """ + Shortcuts for generating request headers. + + :param keep_alive: + If ``True``, adds 'connection: keep-alive' header. + + :param accept_encoding: + Can be a boolean, list, or string. + ``True`` translates to 'gzip,deflate'. If the dependencies for + Brotli (either the ``brotli`` or ``brotlicffi`` package) and/or + Zstandard (the ``backports.zstd`` package for Python before 3.14) + algorithms are installed, then their encodings are + included in the string ('br' and 'zstd', respectively). + List will get joined by comma. + String will be used as provided. + + :param user_agent: + String representing the user-agent you want, such as + "python-urllib3/0.6" + + :param basic_auth: + Colon-separated username:password string for 'authorization: basic ...' + auth header. + + :param proxy_basic_auth: + Colon-separated username:password string for 'proxy-authorization: basic ...' + auth header. + + :param disable_cache: + If ``True``, adds 'cache-control: no-cache' header. + + Example: + + .. code-block:: python + + import urllib3 + + print(urllib3.util.make_headers(keep_alive=True, user_agent="Batman/1.0")) + # {'connection': 'keep-alive', 'user-agent': 'Batman/1.0'} + print(urllib3.util.make_headers(accept_encoding=True)) + # {'accept-encoding': 'gzip,deflate'} + """ + headers: dict[str, str] = {} + if accept_encoding: + if isinstance(accept_encoding, str): + pass + elif isinstance(accept_encoding, list): + accept_encoding = ",".join(accept_encoding) + else: + accept_encoding = ACCEPT_ENCODING + headers["accept-encoding"] = accept_encoding + + if user_agent: + headers["user-agent"] = user_agent + + if keep_alive: + headers["connection"] = "keep-alive" + + if basic_auth: + headers["authorization"] = ( + f"Basic {b64encode(basic_auth.encode('latin-1')).decode()}" + ) + + if proxy_basic_auth: + headers["proxy-authorization"] = ( + f"Basic {b64encode(proxy_basic_auth.encode('latin-1')).decode()}" + ) + + if disable_cache: + headers["cache-control"] = "no-cache" + + return headers + + +def set_file_position( + body: typing.Any, pos: _TYPE_BODY_POSITION | None +) -> _TYPE_BODY_POSITION | None: + """ + If a position is provided, move file to that point. + Otherwise, we'll attempt to record a position for future use. + """ + if pos is not None: + rewind_body(body, pos) + elif getattr(body, "tell", None) is not None: + try: + pos = body.tell() + except OSError: + # This differentiates from None, allowing us to catch + # a failed `tell()` later when trying to rewind the body. + pos = _FAILEDTELL + + return pos + + +def rewind_body(body: typing.IO[typing.AnyStr], body_pos: _TYPE_BODY_POSITION) -> None: + """ + Attempt to rewind body to a certain position. + Primarily used for request redirects and retries. + + :param body: + File-like object that supports seek. + + :param int pos: + Position to seek to in file. + """ + body_seek = getattr(body, "seek", None) + if body_seek is not None and isinstance(body_pos, int): + try: + body_seek(body_pos) + except OSError as e: + raise UnrewindableBodyError( + "An error occurred when rewinding request body for redirect/retry." + ) from e + elif body_pos is _FAILEDTELL: + raise UnrewindableBodyError( + "Unable to record file position for rewinding " + "request body during a redirect/retry." + ) + else: + raise ValueError( + f"body_pos must be of type integer, instead it was {type(body_pos)}." + ) + + +class ChunksAndContentLength(typing.NamedTuple): + chunks: typing.Iterable[bytes] | None + content_length: int | None + + +def body_to_chunks( + body: typing.Any | None, method: str, blocksize: int +) -> ChunksAndContentLength: + """Takes the HTTP request method, body, and blocksize and + transforms them into an iterable of chunks to pass to + socket.sendall() and an optional 'Content-Length' header. + + A 'Content-Length' of 'None' indicates the length of the body + can't be determined so should use 'Transfer-Encoding: chunked' + for framing instead. + """ + + chunks: typing.Iterable[bytes] | None + content_length: int | None + + # No body, we need to make a recommendation on 'Content-Length' + # based on whether that request method is expected to have + # a body or not. + if body is None: + chunks = None + if method.upper() not in _METHODS_NOT_EXPECTING_BODY: + content_length = 0 + else: + content_length = None + + # Bytes or strings become bytes + elif isinstance(body, (str, bytes)): + chunks = (to_bytes(body),) + content_length = len(chunks[0]) + + # File-like object, TODO: use seek() and tell() for length? + elif hasattr(body, "read"): + + def chunk_readable() -> typing.Iterable[bytes]: + encode = isinstance(body, io.TextIOBase) + while True: + datablock = body.read(blocksize) + if not datablock: + break + if encode: + datablock = datablock.encode("utf-8") + yield datablock + + chunks = chunk_readable() + content_length = None + + # Otherwise we need to start checking via duck-typing. + else: + try: + # Check if the body implements the buffer API. + mv = memoryview(body) + except TypeError: + try: + # Check if the body is an iterable + chunks = iter(body) + content_length = None + except TypeError: + raise TypeError( + f"'body' must be a bytes-like object, file-like " + f"object, or iterable. Instead was {body!r}" + ) from None + else: + # Since it implements the buffer API can be passed directly to socket.sendall() + chunks = (body,) + content_length = mv.nbytes + + return ChunksAndContentLength(chunks=chunks, content_length=content_length) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/response.py b/.venv/lib/python3.13/site-packages/urllib3/util/response.py new file mode 100644 index 0000000000000000000000000000000000000000..0f4578696fa2e17a900c6890ec26d65e860b0b72 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/response.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import http.client as httplib +from email.errors import MultipartInvariantViolationDefect, StartBoundaryNotFoundDefect + +from ..exceptions import HeaderParsingError + + +def is_fp_closed(obj: object) -> bool: + """ + Checks whether a given file-like object is closed. + + :param obj: + The file-like object to check. + """ + + try: + # Check `isclosed()` first, in case Python3 doesn't set `closed`. + # GH Issue #928 + return obj.isclosed() # type: ignore[no-any-return, attr-defined] + except AttributeError: + pass + + try: + # Check via the official file-like-object way. + return obj.closed # type: ignore[no-any-return, attr-defined] + except AttributeError: + pass + + try: + # Check if the object is a container for another file-like object that + # gets released on exhaustion (e.g. HTTPResponse). + return obj.fp is None # type: ignore[attr-defined] + except AttributeError: + pass + + raise ValueError("Unable to determine whether fp is closed.") + + +def assert_header_parsing(headers: httplib.HTTPMessage) -> None: + """ + Asserts whether all headers have been successfully parsed. + Extracts encountered errors from the result of parsing headers. + + Only works on Python 3. + + :param http.client.HTTPMessage headers: Headers to verify. + + :raises urllib3.exceptions.HeaderParsingError: + If parsing errors are found. + """ + + # This will fail silently if we pass in the wrong kind of parameter. + # To make debugging easier add an explicit check. + if not isinstance(headers, httplib.HTTPMessage): + raise TypeError(f"expected httplib.Message, got {type(headers)}.") + + unparsed_data = None + + # get_payload is actually email.message.Message.get_payload; + # we're only interested in the result if it's not a multipart message + if not headers.is_multipart(): + payload = headers.get_payload() + + if isinstance(payload, (bytes, str)): + unparsed_data = payload + + # httplib is assuming a response body is available + # when parsing headers even when httplib only sends + # header data to parse_headers() This results in + # defects on multipart responses in particular. + # See: https://github.com/urllib3/urllib3/issues/800 + + # So we ignore the following defects: + # - StartBoundaryNotFoundDefect: + # The claimed start boundary was never found. + # - MultipartInvariantViolationDefect: + # A message claimed to be a multipart but no subparts were found. + defects = [ + defect + for defect in headers.defects + if not isinstance( + defect, (StartBoundaryNotFoundDefect, MultipartInvariantViolationDefect) + ) + ] + + if defects or unparsed_data: + raise HeaderParsingError(defects=defects, unparsed_data=unparsed_data) + + +def is_response_to_head(response: httplib.HTTPResponse) -> bool: + """ + Checks whether the request of a response has been a HEAD-request. + + :param http.client.HTTPResponse response: + Response to check if the originating request + used 'HEAD' as a method. + """ + # FIXME: Can we do this somehow without accessing private httplib _method? + method_str = response._method # type: str # type: ignore[attr-defined] + return method_str.upper() == "HEAD" diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/retry.py b/.venv/lib/python3.13/site-packages/urllib3/util/retry.py new file mode 100644 index 0000000000000000000000000000000000000000..b21b4b64ebbd4748eb6fa4301f947b0d4965da8b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/retry.py @@ -0,0 +1,549 @@ +from __future__ import annotations + +import email +import logging +import random +import re +import time +import typing +from itertools import takewhile +from types import TracebackType + +from ..exceptions import ( + ConnectTimeoutError, + InvalidHeader, + MaxRetryError, + ProtocolError, + ProxyError, + ReadTimeoutError, + ResponseError, +) +from .util import reraise + +if typing.TYPE_CHECKING: + from typing_extensions import Self + + from ..connectionpool import ConnectionPool + from ..response import BaseHTTPResponse + +log = logging.getLogger(__name__) + + +# Data structure for representing the metadata of requests that result in a retry. +class RequestHistory(typing.NamedTuple): + method: str | None + url: str | None + error: Exception | None + status: int | None + redirect_location: str | None + + +class Retry: + """Retry configuration. + + Each retry attempt will create a new Retry object with updated values, so + they can be safely reused. + + Retries can be defined as a default for a pool: + + .. code-block:: python + + retries = Retry(connect=5, read=2, redirect=5) + http = PoolManager(retries=retries) + response = http.request("GET", "https://example.com/") + + Or per-request (which overrides the default for the pool): + + .. code-block:: python + + response = http.request("GET", "https://example.com/", retries=Retry(10)) + + Retries can be disabled by passing ``False``: + + .. code-block:: python + + response = http.request("GET", "https://example.com/", retries=False) + + Errors will be wrapped in :class:`~urllib3.exceptions.MaxRetryError` unless + retries are disabled, in which case the causing exception will be raised. + + :param int total: + Total number of retries to allow. Takes precedence over other counts. + + Set to ``None`` to remove this constraint and fall back on other + counts. + + Set to ``0`` to fail on the first retry. + + Set to ``False`` to disable and imply ``raise_on_redirect=False``. + + :param int connect: + How many connection-related errors to retry on. + + These are errors raised before the request is sent to the remote server, + which we assume has not triggered the server to process the request. + + Set to ``0`` to fail on the first retry of this type. + + :param int read: + How many times to retry on read errors. + + These errors are raised after the request was sent to the server, so the + request may have side-effects. + + Set to ``0`` to fail on the first retry of this type. + + :param int redirect: + How many redirects to perform. Limit this to avoid infinite redirect + loops. + + A redirect is a HTTP response with a status code 301, 302, 303, 307 or + 308. + + Set to ``0`` to fail on the first retry of this type. + + Set to ``False`` to disable and imply ``raise_on_redirect=False``. + + :param int status: + How many times to retry on bad status codes. + + These are retries made on responses, where status code matches + ``status_forcelist``. + + Set to ``0`` to fail on the first retry of this type. + + :param int other: + How many times to retry on other errors. + + Other errors are errors that are not connect, read, redirect or status errors. + These errors might be raised after the request was sent to the server, so the + request might have side-effects. + + Set to ``0`` to fail on the first retry of this type. + + If ``total`` is not set, it's a good idea to set this to 0 to account + for unexpected edge cases and avoid infinite retry loops. + + :param Collection allowed_methods: + Set of uppercased HTTP method verbs that we should retry on. + + By default, we only retry on methods which are considered to be + idempotent (multiple requests with the same parameters end with the + same state). See :attr:`Retry.DEFAULT_ALLOWED_METHODS`. + + Set to a ``None`` value to retry on any verb. + + :param Collection status_forcelist: + A set of integer HTTP status codes that we should force a retry on. + A retry is initiated if the request method is in ``allowed_methods`` + and the response status code is in ``status_forcelist``. + + By default, this is disabled with ``None``. + + :param float backoff_factor: + A backoff factor to apply between attempts after the second try + (most errors are resolved immediately by a second try without a + delay). urllib3 will sleep for:: + + {backoff factor} * (2 ** ({number of previous retries})) + + seconds. If `backoff_jitter` is non-zero, this sleep is extended by:: + + random.uniform(0, {backoff jitter}) + + seconds. For example, if the backoff_factor is 0.1, then :func:`Retry.sleep` will + sleep for [0.0s, 0.2s, 0.4s, 0.8s, ...] between retries. No backoff will ever + be longer than `backoff_max`. + + By default, backoff is disabled (factor set to 0). + + :param bool raise_on_redirect: Whether, if the number of redirects is + exhausted, to raise a MaxRetryError, or to return a response with a + response code in the 3xx range. + + :param bool raise_on_status: Similar meaning to ``raise_on_redirect``: + whether we should raise an exception, or return a response, + if status falls in ``status_forcelist`` range and retries have + been exhausted. + + :param tuple history: The history of the request encountered during + each call to :meth:`~Retry.increment`. The list is in the order + the requests occurred. Each list item is of class :class:`RequestHistory`. + + :param bool respect_retry_after_header: + Whether to respect Retry-After header on status codes defined as + :attr:`Retry.RETRY_AFTER_STATUS_CODES` or not. + + :param Collection remove_headers_on_redirect: + Sequence of headers to remove from the request when a response + indicating a redirect is returned before firing off the redirected + request. + + :param int retry_after_max: Number of seconds to allow as the maximum for + Retry-After headers. Defaults to :attr:`Retry.DEFAULT_RETRY_AFTER_MAX`. + Any Retry-After headers larger than this value will be limited to this + value. + """ + + #: Default methods to be used for ``allowed_methods`` + DEFAULT_ALLOWED_METHODS = frozenset( + ["HEAD", "GET", "PUT", "DELETE", "OPTIONS", "TRACE"] + ) + + #: Default status codes to be used for ``status_forcelist`` + RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503]) + + #: Default headers to be used for ``remove_headers_on_redirect`` + DEFAULT_REMOVE_HEADERS_ON_REDIRECT = frozenset( + ["Cookie", "Authorization", "Proxy-Authorization"] + ) + + #: Default maximum backoff time. + DEFAULT_BACKOFF_MAX = 120 + + # This is undocumented in the RFC. Setting to 6 hours matches other popular libraries. + #: Default maximum allowed value for Retry-After headers in seconds + DEFAULT_RETRY_AFTER_MAX: typing.Final[int] = 21600 + + # Backward compatibility; assigned outside of the class. + DEFAULT: typing.ClassVar[Retry] + + def __init__( + self, + total: bool | int | None = 10, + connect: int | None = None, + read: int | None = None, + redirect: bool | int | None = None, + status: int | None = None, + other: int | None = None, + allowed_methods: typing.Collection[str] | None = DEFAULT_ALLOWED_METHODS, + status_forcelist: typing.Collection[int] | None = None, + backoff_factor: float = 0, + backoff_max: float = DEFAULT_BACKOFF_MAX, + raise_on_redirect: bool = True, + raise_on_status: bool = True, + history: tuple[RequestHistory, ...] | None = None, + respect_retry_after_header: bool = True, + remove_headers_on_redirect: typing.Collection[ + str + ] = DEFAULT_REMOVE_HEADERS_ON_REDIRECT, + backoff_jitter: float = 0.0, + retry_after_max: int = DEFAULT_RETRY_AFTER_MAX, + ) -> None: + self.total = total + self.connect = connect + self.read = read + self.status = status + self.other = other + + if redirect is False or total is False: + redirect = 0 + raise_on_redirect = False + + self.redirect = redirect + self.status_forcelist = status_forcelist or set() + self.allowed_methods = allowed_methods + self.backoff_factor = backoff_factor + self.backoff_max = backoff_max + self.retry_after_max = retry_after_max + self.raise_on_redirect = raise_on_redirect + self.raise_on_status = raise_on_status + self.history = history or () + self.respect_retry_after_header = respect_retry_after_header + self.remove_headers_on_redirect = frozenset( + h.lower() for h in remove_headers_on_redirect + ) + self.backoff_jitter = backoff_jitter + + def new(self, **kw: typing.Any) -> Self: + params = dict( + total=self.total, + connect=self.connect, + read=self.read, + redirect=self.redirect, + status=self.status, + other=self.other, + allowed_methods=self.allowed_methods, + status_forcelist=self.status_forcelist, + backoff_factor=self.backoff_factor, + backoff_max=self.backoff_max, + retry_after_max=self.retry_after_max, + raise_on_redirect=self.raise_on_redirect, + raise_on_status=self.raise_on_status, + history=self.history, + remove_headers_on_redirect=self.remove_headers_on_redirect, + respect_retry_after_header=self.respect_retry_after_header, + backoff_jitter=self.backoff_jitter, + ) + + params.update(kw) + return type(self)(**params) # type: ignore[arg-type] + + @classmethod + def from_int( + cls, + retries: Retry | bool | int | None, + redirect: bool | int | None = True, + default: Retry | bool | int | None = None, + ) -> Retry: + """Backwards-compatibility for the old retries format.""" + if retries is None: + retries = default if default is not None else cls.DEFAULT + + if isinstance(retries, Retry): + return retries + + redirect = bool(redirect) and None + new_retries = cls(retries, redirect=redirect) + log.debug("Converted retries value: %r -> %r", retries, new_retries) + return new_retries + + def get_backoff_time(self) -> float: + """Formula for computing the current backoff + + :rtype: float + """ + # We want to consider only the last consecutive errors sequence (Ignore redirects). + consecutive_errors_len = len( + list( + takewhile(lambda x: x.redirect_location is None, reversed(self.history)) + ) + ) + if consecutive_errors_len <= 1: + return 0 + + backoff_value = self.backoff_factor * (2 ** (consecutive_errors_len - 1)) + if self.backoff_jitter != 0.0: + backoff_value += random.random() * self.backoff_jitter + return float(max(0, min(self.backoff_max, backoff_value))) + + def parse_retry_after(self, retry_after: str) -> float: + seconds: float + # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 + if re.match(r"^\s*[0-9]+\s*$", retry_after): + seconds = int(retry_after) + else: + retry_date_tuple = email.utils.parsedate_tz(retry_after) + if retry_date_tuple is None: + raise InvalidHeader(f"Invalid Retry-After header: {retry_after}") + + retry_date = email.utils.mktime_tz(retry_date_tuple) + seconds = retry_date - time.time() + + seconds = max(seconds, 0) + + # Check the seconds do not exceed the specified maximum + if seconds > self.retry_after_max: + seconds = self.retry_after_max + + return seconds + + def get_retry_after(self, response: BaseHTTPResponse) -> float | None: + """Get the value of Retry-After in seconds.""" + + retry_after = response.headers.get("Retry-After") + + if retry_after is None: + return None + + return self.parse_retry_after(retry_after) + + def sleep_for_retry(self, response: BaseHTTPResponse) -> bool: + retry_after = self.get_retry_after(response) + if retry_after: + time.sleep(retry_after) + return True + + return False + + def _sleep_backoff(self) -> None: + backoff = self.get_backoff_time() + if backoff <= 0: + return + time.sleep(backoff) + + def sleep(self, response: BaseHTTPResponse | None = None) -> None: + """Sleep between retry attempts. + + This method will respect a server's ``Retry-After`` response header + and sleep the duration of the time requested. If that is not present, it + will use an exponential backoff. By default, the backoff factor is 0 and + this method will return immediately. + """ + + if self.respect_retry_after_header and response: + slept = self.sleep_for_retry(response) + if slept: + return + + self._sleep_backoff() + + def _is_connection_error(self, err: Exception) -> bool: + """Errors when we're fairly sure that the server did not receive the + request, so it should be safe to retry. + """ + if isinstance(err, ProxyError): + err = err.original_error + return isinstance(err, ConnectTimeoutError) + + def _is_read_error(self, err: Exception) -> bool: + """Errors that occur after the request has been started, so we should + assume that the server began processing it. + """ + return isinstance(err, (ReadTimeoutError, ProtocolError)) + + def _is_method_retryable(self, method: str) -> bool: + """Checks if a given HTTP method should be retried upon, depending if + it is included in the allowed_methods + """ + if self.allowed_methods and method.upper() not in self.allowed_methods: + return False + return True + + def is_retry( + self, method: str, status_code: int, has_retry_after: bool = False + ) -> bool: + """Is this method/status code retryable? (Based on allowlists and control + variables such as the number of total retries to allow, whether to + respect the Retry-After header, whether this header is present, and + whether the returned status code is on the list of status codes to + be retried upon on the presence of the aforementioned header) + """ + if not self._is_method_retryable(method): + return False + + if self.status_forcelist and status_code in self.status_forcelist: + return True + + return bool( + self.total + and self.respect_retry_after_header + and has_retry_after + and (status_code in self.RETRY_AFTER_STATUS_CODES) + ) + + def is_exhausted(self) -> bool: + """Are we out of retries?""" + retry_counts = [ + x + for x in ( + self.total, + self.connect, + self.read, + self.redirect, + self.status, + self.other, + ) + if x + ] + if not retry_counts: + return False + + return min(retry_counts) < 0 + + def increment( + self, + method: str | None = None, + url: str | None = None, + response: BaseHTTPResponse | None = None, + error: Exception | None = None, + _pool: ConnectionPool | None = None, + _stacktrace: TracebackType | None = None, + ) -> Self: + """Return a new Retry object with incremented retry counters. + + :param response: A response object, or None, if the server did not + return a response. + :type response: :class:`~urllib3.response.BaseHTTPResponse` + :param Exception error: An error encountered during the request, or + None if the response was received successfully. + + :return: A new ``Retry`` object. + """ + if self.total is False and error: + # Disabled, indicate to re-raise the error. + raise reraise(type(error), error, _stacktrace) + + total = self.total + if total is not None: + total -= 1 + + connect = self.connect + read = self.read + redirect = self.redirect + status_count = self.status + other = self.other + cause = "unknown" + status = None + redirect_location = None + + if error and self._is_connection_error(error): + # Connect retry? + if connect is False: + raise reraise(type(error), error, _stacktrace) + elif connect is not None: + connect -= 1 + + elif error and self._is_read_error(error): + # Read retry? + if read is False or method is None or not self._is_method_retryable(method): + raise reraise(type(error), error, _stacktrace) + elif read is not None: + read -= 1 + + elif error: + # Other retry? + if other is not None: + other -= 1 + + elif response and response.get_redirect_location(): + # Redirect retry? + if redirect is not None: + redirect -= 1 + cause = "too many redirects" + response_redirect_location = response.get_redirect_location() + if response_redirect_location: + redirect_location = response_redirect_location + status = response.status + + else: + # Incrementing because of a server error like a 500 in + # status_forcelist and the given method is in the allowed_methods + cause = ResponseError.GENERIC_ERROR + if response and response.status: + if status_count is not None: + status_count -= 1 + cause = ResponseError.SPECIFIC_ERROR.format(status_code=response.status) + status = response.status + + history = self.history + ( + RequestHistory(method, url, error, status, redirect_location), + ) + + new_retry = self.new( + total=total, + connect=connect, + read=read, + redirect=redirect, + status=status_count, + other=other, + history=history, + ) + + if new_retry.is_exhausted(): + reason = error or ResponseError(cause) + raise MaxRetryError(_pool, url, reason) from reason # type: ignore[arg-type] + + log.debug("Incremented Retry for (url='%s'): %r", url, new_retry) + + return new_retry + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(total={self.total}, connect={self.connect}, " + f"read={self.read}, redirect={self.redirect}, status={self.status})" + ) + + +# For backwards compatibility (equivalent to pre-v1.9): +Retry.DEFAULT = Retry(3) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/ssl_.py b/.venv/lib/python3.13/site-packages/urllib3/util/ssl_.py new file mode 100644 index 0000000000000000000000000000000000000000..56fe9093adaa86b30085aef2435e49f84841df12 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/ssl_.py @@ -0,0 +1,527 @@ +from __future__ import annotations + +import hashlib +import hmac +import os +import socket +import sys +import typing +import warnings +from binascii import unhexlify + +from ..exceptions import ProxySchemeUnsupported, SSLError +from .url import _BRACELESS_IPV6_ADDRZ_RE, _IPV4_RE + +SSLContext = None +SSLTransport = None +HAS_NEVER_CHECK_COMMON_NAME = False +IS_PYOPENSSL = False +ALPN_PROTOCOLS = ["http/1.1"] + +_TYPE_VERSION_INFO = tuple[int, int, int, str, int] + +# Maps the length of a digest to a possible hash function producing this digest +HASHFUNC_MAP = { + length: getattr(hashlib, algorithm, None) + for length, algorithm in ((32, "md5"), (40, "sha1"), (64, "sha256")) +} + + +def _is_bpo_43522_fixed( + implementation_name: str, + version_info: _TYPE_VERSION_INFO, + pypy_version_info: _TYPE_VERSION_INFO | None, +) -> bool: + """Return True for CPython 3.9.3+ or 3.10+ and PyPy 7.3.8+ where + setting SSLContext.hostname_checks_common_name to False works. + + Outside of CPython and PyPy we don't know which implementations work + or not so we conservatively use our hostname matching as we know that works + on all implementations. + + https://github.com/urllib3/urllib3/issues/2192#issuecomment-821832963 + https://foss.heptapod.net/pypy/pypy/-/issues/3539 + """ + if implementation_name == "pypy": + # https://foss.heptapod.net/pypy/pypy/-/issues/3129 + return pypy_version_info >= (7, 3, 8) # type: ignore[operator] + elif implementation_name == "cpython": + major_minor = version_info[:2] + micro = version_info[2] + return (major_minor == (3, 9) and micro >= 3) or major_minor >= (3, 10) + else: # Defensive: + return False + + +def _is_has_never_check_common_name_reliable( + openssl_version: str, + openssl_version_number: int, + implementation_name: str, + version_info: _TYPE_VERSION_INFO, + pypy_version_info: _TYPE_VERSION_INFO | None, +) -> bool: + # As of May 2023, all released versions of LibreSSL fail to reject certificates with + # only common names, see https://github.com/urllib3/urllib3/pull/3024 + is_openssl = openssl_version.startswith("OpenSSL ") + # Before fixing OpenSSL issue #14579, the SSL_new() API was not copying hostflags + # like X509_CHECK_FLAG_NEVER_CHECK_SUBJECT, which tripped up CPython. + # https://github.com/openssl/openssl/issues/14579 + # This was released in OpenSSL 1.1.1l+ (>=0x101010cf) + is_openssl_issue_14579_fixed = openssl_version_number >= 0x101010CF + + return is_openssl and ( + is_openssl_issue_14579_fixed + or _is_bpo_43522_fixed(implementation_name, version_info, pypy_version_info) + ) + + +if typing.TYPE_CHECKING: + from ssl import VerifyMode + from typing import TypedDict + + from .ssltransport import SSLTransport as SSLTransportType + + class _TYPE_PEER_CERT_RET_DICT(TypedDict, total=False): + subjectAltName: tuple[tuple[str, str], ...] + subject: tuple[tuple[tuple[str, str], ...], ...] + serialNumber: str + + +# Mapping from 'ssl.PROTOCOL_TLSX' to 'TLSVersion.X' +_SSL_VERSION_TO_TLS_VERSION: dict[int, int] = {} + +try: # Do we have ssl at all? + import ssl + from ssl import ( # type: ignore[assignment] + CERT_REQUIRED, + HAS_NEVER_CHECK_COMMON_NAME, + OP_NO_COMPRESSION, + OP_NO_TICKET, + OPENSSL_VERSION, + OPENSSL_VERSION_NUMBER, + PROTOCOL_TLS, + PROTOCOL_TLS_CLIENT, + VERIFY_X509_STRICT, + OP_NO_SSLv2, + OP_NO_SSLv3, + SSLContext, + TLSVersion, + ) + + PROTOCOL_SSLv23 = PROTOCOL_TLS + + # Needed for Python 3.9 which does not define this + VERIFY_X509_PARTIAL_CHAIN = getattr(ssl, "VERIFY_X509_PARTIAL_CHAIN", 0x80000) + + # Setting SSLContext.hostname_checks_common_name = False didn't work before CPython + # 3.9.3, and 3.10 (but OK on PyPy) or OpenSSL 1.1.1l+ + if HAS_NEVER_CHECK_COMMON_NAME and not _is_has_never_check_common_name_reliable( + OPENSSL_VERSION, + OPENSSL_VERSION_NUMBER, + sys.implementation.name, + sys.version_info, + sys.pypy_version_info if sys.implementation.name == "pypy" else None, # type: ignore[attr-defined] + ): # Defensive: for Python < 3.9.3 + HAS_NEVER_CHECK_COMMON_NAME = False + + # Need to be careful here in case old TLS versions get + # removed in future 'ssl' module implementations. + for attr in ("TLSv1", "TLSv1_1", "TLSv1_2"): + try: + _SSL_VERSION_TO_TLS_VERSION[getattr(ssl, f"PROTOCOL_{attr}")] = getattr( + TLSVersion, attr + ) + except AttributeError: # Defensive: + continue + + from .ssltransport import SSLTransport # type: ignore[assignment] +except ImportError: + OP_NO_COMPRESSION = 0x20000 # type: ignore[assignment, misc] + OP_NO_TICKET = 0x4000 # type: ignore[assignment, misc] + OP_NO_SSLv2 = 0x1000000 # type: ignore[assignment, misc] + OP_NO_SSLv3 = 0x2000000 # type: ignore[assignment, misc] + PROTOCOL_SSLv23 = PROTOCOL_TLS = 2 # type: ignore[assignment, misc] + PROTOCOL_TLS_CLIENT = 16 # type: ignore[assignment, misc] + VERIFY_X509_PARTIAL_CHAIN = 0x80000 + VERIFY_X509_STRICT = 0x20 # type: ignore[assignment, misc] + + +_TYPE_PEER_CERT_RET = typing.Union["_TYPE_PEER_CERT_RET_DICT", bytes, None] + + +def assert_fingerprint(cert: bytes | None, fingerprint: str) -> None: + """ + Checks if given fingerprint matches the supplied certificate. + + :param cert: + Certificate as bytes object. + :param fingerprint: + Fingerprint as string of hexdigits, can be interspersed by colons. + """ + + if cert is None: + raise SSLError("No certificate for the peer.") + + fingerprint = fingerprint.replace(":", "").lower() + digest_length = len(fingerprint) + if digest_length not in HASHFUNC_MAP: + raise SSLError(f"Fingerprint of invalid length: {fingerprint}") + hashfunc = HASHFUNC_MAP.get(digest_length) + if hashfunc is None: + raise SSLError( + f"Hash function implementation unavailable for fingerprint length: {digest_length}" + ) + + # We need encode() here for py32; works on py2 and p33. + fingerprint_bytes = unhexlify(fingerprint.encode()) + + cert_digest = hashfunc(cert).digest() + + if not hmac.compare_digest(cert_digest, fingerprint_bytes): + raise SSLError( + f'Fingerprints did not match. Expected "{fingerprint}", got "{cert_digest.hex()}"' + ) + + +def resolve_cert_reqs(candidate: None | int | str) -> VerifyMode: + """ + Resolves the argument to a numeric constant, which can be passed to + the wrap_socket function/method from the ssl module. + Defaults to :data:`ssl.CERT_REQUIRED`. + If given a string it is assumed to be the name of the constant in the + :mod:`ssl` module or its abbreviation. + (So you can specify `REQUIRED` instead of `CERT_REQUIRED`. + If it's neither `None` nor a string we assume it is already the numeric + constant which can directly be passed to wrap_socket. + """ + if candidate is None: + return CERT_REQUIRED + + if isinstance(candidate, str): + res = getattr(ssl, candidate, None) + if res is None: + res = getattr(ssl, "CERT_" + candidate) + return res # type: ignore[no-any-return] + + return candidate # type: ignore[return-value] + + +def resolve_ssl_version(candidate: None | int | str) -> int: + """ + like resolve_cert_reqs + """ + if candidate is None: + return PROTOCOL_TLS + + if isinstance(candidate, str): + res = getattr(ssl, candidate, None) + if res is None: + res = getattr(ssl, "PROTOCOL_" + candidate) + return typing.cast(int, res) + + return candidate + + +def create_urllib3_context( + ssl_version: int | None = None, + cert_reqs: int | None = None, + options: int | None = None, + ciphers: str | None = None, + ssl_minimum_version: int | None = None, + ssl_maximum_version: int | None = None, + verify_flags: int | None = None, +) -> ssl.SSLContext: + """Creates and configures an :class:`ssl.SSLContext` instance for use with urllib3. + + :param ssl_version: + The desired protocol version to use. This will default to + PROTOCOL_SSLv23 which will negotiate the highest protocol that both + the server and your installation of OpenSSL support. + + This parameter is deprecated instead use 'ssl_minimum_version'. + :param ssl_minimum_version: + The minimum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value. + :param ssl_maximum_version: + The maximum version of TLS to be used. Use the 'ssl.TLSVersion' enum for specifying the value. + Not recommended to set to anything other than 'ssl.TLSVersion.MAXIMUM_SUPPORTED' which is the + default value. + :param cert_reqs: + Whether to require the certificate verification. This defaults to + ``ssl.CERT_REQUIRED``. + :param options: + Specific OpenSSL options. These default to ``ssl.OP_NO_SSLv2``, + ``ssl.OP_NO_SSLv3``, ``ssl.OP_NO_COMPRESSION``, and ``ssl.OP_NO_TICKET``. + :param ciphers: + Which cipher suites to allow the server to select. Defaults to either system configured + ciphers if OpenSSL 1.1.1+, otherwise uses a secure default set of ciphers. + :param verify_flags: + The flags for certificate verification operations. These default to + ``ssl.VERIFY_X509_PARTIAL_CHAIN`` and ``ssl.VERIFY_X509_STRICT`` for Python 3.13+. + :returns: + Constructed SSLContext object with specified options + :rtype: SSLContext + """ + if SSLContext is None: + raise TypeError("Can't create an SSLContext object without an ssl module") + + # This means 'ssl_version' was specified as an exact value. + if ssl_version not in (None, PROTOCOL_TLS, PROTOCOL_TLS_CLIENT): + # Disallow setting 'ssl_version' and 'ssl_minimum|maximum_version' + # to avoid conflicts. + if ssl_minimum_version is not None or ssl_maximum_version is not None: + raise ValueError( + "Can't specify both 'ssl_version' and either " + "'ssl_minimum_version' or 'ssl_maximum_version'" + ) + + # 'ssl_version' is deprecated and will be removed in the future. + else: + # Use 'ssl_minimum_version' and 'ssl_maximum_version' instead. + ssl_minimum_version = _SSL_VERSION_TO_TLS_VERSION.get( + ssl_version, TLSVersion.MINIMUM_SUPPORTED + ) + ssl_maximum_version = _SSL_VERSION_TO_TLS_VERSION.get( + ssl_version, TLSVersion.MAXIMUM_SUPPORTED + ) + + # This warning message is pushing users to use 'ssl_minimum_version' + # instead of both min/max. Best practice is to only set the minimum version and + # keep the maximum version to be it's default value: 'TLSVersion.MAXIMUM_SUPPORTED' + warnings.warn( + "'ssl_version' option is deprecated and will be " + "removed in urllib3 v2.6.0. Instead use 'ssl_minimum_version'", + category=DeprecationWarning, + stacklevel=2, + ) + + # PROTOCOL_TLS is deprecated in Python 3.10 so we always use PROTOCOL_TLS_CLIENT + context = SSLContext(PROTOCOL_TLS_CLIENT) + + if ssl_minimum_version is not None: + context.minimum_version = ssl_minimum_version + else: # Python <3.10 defaults to 'MINIMUM_SUPPORTED' so explicitly set TLSv1.2 here + context.minimum_version = TLSVersion.TLSv1_2 + + if ssl_maximum_version is not None: + context.maximum_version = ssl_maximum_version + + # Unless we're given ciphers defer to either system ciphers in + # the case of OpenSSL 1.1.1+ or use our own secure default ciphers. + if ciphers: + context.set_ciphers(ciphers) + + # Setting the default here, as we may have no ssl module on import + cert_reqs = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs + + if options is None: + options = 0 + # SSLv2 is easily broken and is considered harmful and dangerous + options |= OP_NO_SSLv2 + # SSLv3 has several problems and is now dangerous + options |= OP_NO_SSLv3 + # Disable compression to prevent CRIME attacks for OpenSSL 1.0+ + # (issue #309) + options |= OP_NO_COMPRESSION + # TLSv1.2 only. Unless set explicitly, do not request tickets. + # This may save some bandwidth on wire, and although the ticket is encrypted, + # there is a risk associated with it being on wire, + # if the server is not rotating its ticketing keys properly. + options |= OP_NO_TICKET + + context.options |= options + + if verify_flags is None: + verify_flags = 0 + # In Python 3.13+ ssl.create_default_context() sets VERIFY_X509_PARTIAL_CHAIN + # and VERIFY_X509_STRICT so we do the same + if sys.version_info >= (3, 13): + verify_flags |= VERIFY_X509_PARTIAL_CHAIN + verify_flags |= VERIFY_X509_STRICT + + context.verify_flags |= verify_flags + + # Enable post-handshake authentication for TLS 1.3, see GH #1634. PHA is + # necessary for conditional client cert authentication with TLS 1.3. + # The attribute is None for OpenSSL <= 1.1.0 or does not exist when using + # an SSLContext created by pyOpenSSL. + if getattr(context, "post_handshake_auth", None) is not None: + context.post_handshake_auth = True + + # The order of the below lines setting verify_mode and check_hostname + # matter due to safe-guards SSLContext has to prevent an SSLContext with + # check_hostname=True, verify_mode=NONE/OPTIONAL. + # We always set 'check_hostname=False' for pyOpenSSL so we rely on our own + # 'ssl.match_hostname()' implementation. + if cert_reqs == ssl.CERT_REQUIRED and not IS_PYOPENSSL: + context.verify_mode = cert_reqs + context.check_hostname = True + else: + context.check_hostname = False + context.verify_mode = cert_reqs + + try: + context.hostname_checks_common_name = False + except AttributeError: # Defensive: for CPython < 3.9.3; for PyPy < 7.3.8 + pass + + if "SSLKEYLOGFILE" in os.environ: + sslkeylogfile = os.path.expandvars(os.environ.get("SSLKEYLOGFILE")) + else: + sslkeylogfile = None + if sslkeylogfile: + context.keylog_filename = sslkeylogfile + + return context + + +@typing.overload +def ssl_wrap_socket( + sock: socket.socket, + keyfile: str | None = ..., + certfile: str | None = ..., + cert_reqs: int | None = ..., + ca_certs: str | None = ..., + server_hostname: str | None = ..., + ssl_version: int | None = ..., + ciphers: str | None = ..., + ssl_context: ssl.SSLContext | None = ..., + ca_cert_dir: str | None = ..., + key_password: str | None = ..., + ca_cert_data: None | str | bytes = ..., + tls_in_tls: typing.Literal[False] = ..., +) -> ssl.SSLSocket: ... + + +@typing.overload +def ssl_wrap_socket( + sock: socket.socket, + keyfile: str | None = ..., + certfile: str | None = ..., + cert_reqs: int | None = ..., + ca_certs: str | None = ..., + server_hostname: str | None = ..., + ssl_version: int | None = ..., + ciphers: str | None = ..., + ssl_context: ssl.SSLContext | None = ..., + ca_cert_dir: str | None = ..., + key_password: str | None = ..., + ca_cert_data: None | str | bytes = ..., + tls_in_tls: bool = ..., +) -> ssl.SSLSocket | SSLTransportType: ... + + +def ssl_wrap_socket( + sock: socket.socket, + keyfile: str | None = None, + certfile: str | None = None, + cert_reqs: int | None = None, + ca_certs: str | None = None, + server_hostname: str | None = None, + ssl_version: int | None = None, + ciphers: str | None = None, + ssl_context: ssl.SSLContext | None = None, + ca_cert_dir: str | None = None, + key_password: str | None = None, + ca_cert_data: None | str | bytes = None, + tls_in_tls: bool = False, +) -> ssl.SSLSocket | SSLTransportType: + """ + All arguments except for server_hostname, ssl_context, tls_in_tls, ca_cert_data and + ca_cert_dir have the same meaning as they do when using + :func:`ssl.create_default_context`, :meth:`ssl.SSLContext.load_cert_chain`, + :meth:`ssl.SSLContext.set_ciphers` and :meth:`ssl.SSLContext.wrap_socket`. + + :param server_hostname: + When SNI is supported, the expected hostname of the certificate + :param ssl_context: + A pre-made :class:`SSLContext` object. If none is provided, one will + be created using :func:`create_urllib3_context`. + :param ciphers: + A string of ciphers we wish the client to support. + :param ca_cert_dir: + A directory containing CA certificates in multiple separate files, as + supported by OpenSSL's -CApath flag or the capath argument to + SSLContext.load_verify_locations(). + :param key_password: + Optional password if the keyfile is encrypted. + :param ca_cert_data: + Optional string containing CA certificates in PEM format suitable for + passing as the cadata parameter to SSLContext.load_verify_locations() + :param tls_in_tls: + Use SSLTransport to wrap the existing socket. + """ + context = ssl_context + if context is None: + # Note: This branch of code and all the variables in it are only used in tests. + # We should consider deprecating and removing this code. + context = create_urllib3_context(ssl_version, cert_reqs, ciphers=ciphers) + + if ca_certs or ca_cert_dir or ca_cert_data: + try: + context.load_verify_locations(ca_certs, ca_cert_dir, ca_cert_data) + except OSError as e: + raise SSLError(e) from e + + elif ssl_context is None and hasattr(context, "load_default_certs"): + # try to load OS default certs; works well on Windows. + context.load_default_certs() + + # Attempt to detect if we get the goofy behavior of the + # keyfile being encrypted and OpenSSL asking for the + # passphrase via the terminal and instead error out. + if keyfile and key_password is None and _is_key_file_encrypted(keyfile): + raise SSLError("Client private key is encrypted, password is required") + + if certfile: + if key_password is None: + context.load_cert_chain(certfile, keyfile) + else: + context.load_cert_chain(certfile, keyfile, key_password) + + context.set_alpn_protocols(ALPN_PROTOCOLS) + + ssl_sock = _ssl_wrap_socket_impl(sock, context, tls_in_tls, server_hostname) + return ssl_sock + + +def is_ipaddress(hostname: str | bytes) -> bool: + """Detects whether the hostname given is an IPv4 or IPv6 address. + Also detects IPv6 addresses with Zone IDs. + + :param str hostname: Hostname to examine. + :return: True if the hostname is an IP address, False otherwise. + """ + if isinstance(hostname, bytes): + # IDN A-label bytes are ASCII compatible. + hostname = hostname.decode("ascii") + return bool(_IPV4_RE.match(hostname) or _BRACELESS_IPV6_ADDRZ_RE.match(hostname)) + + +def _is_key_file_encrypted(key_file: str) -> bool: + """Detects if a key file is encrypted or not.""" + with open(key_file) as f: + for line in f: + # Look for Proc-Type: 4,ENCRYPTED + if "ENCRYPTED" in line: + return True + + return False + + +def _ssl_wrap_socket_impl( + sock: socket.socket, + ssl_context: ssl.SSLContext, + tls_in_tls: bool, + server_hostname: str | None = None, +) -> ssl.SSLSocket | SSLTransportType: + if tls_in_tls: + if not SSLTransport: + # Import error, ssl is not available. + raise ProxySchemeUnsupported( + "TLS in TLS requires support for the 'ssl' module" + ) + + SSLTransport._validate_ssl_context_for_tls_in_tls(ssl_context) + return SSLTransport(sock, ssl_context, server_hostname) + + return ssl_context.wrap_socket(sock, server_hostname=server_hostname) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/ssl_match_hostname.py b/.venv/lib/python3.13/site-packages/urllib3/util/ssl_match_hostname.py new file mode 100644 index 0000000000000000000000000000000000000000..25d91000419ea4a860f511ebe669fe171b79254c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/ssl_match_hostname.py @@ -0,0 +1,159 @@ +"""The match_hostname() function from Python 3.5, essential when using SSL.""" + +# Note: This file is under the PSF license as the code comes from the python +# stdlib. http://docs.python.org/3/license.html +# It is modified to remove commonName support. + +from __future__ import annotations + +import ipaddress +import re +import typing +from ipaddress import IPv4Address, IPv6Address + +if typing.TYPE_CHECKING: + from .ssl_ import _TYPE_PEER_CERT_RET_DICT + +__version__ = "3.5.0.1" + + +class CertificateError(ValueError): + pass + + +def _dnsname_match( + dn: typing.Any, hostname: str, max_wildcards: int = 1 +) -> typing.Match[str] | None | bool: + """Matching according to RFC 6125, section 6.4.3 + + http://tools.ietf.org/html/rfc6125#section-6.4.3 + """ + pats = [] + if not dn: + return False + + # Ported from python3-syntax: + # leftmost, *remainder = dn.split(r'.') + parts = dn.split(r".") + leftmost = parts[0] + remainder = parts[1:] + + wildcards = leftmost.count("*") + if wildcards > max_wildcards: + # Issue #17980: avoid denials of service by refusing more + # than one wildcard per fragment. A survey of established + # policy among SSL implementations showed it to be a + # reasonable choice. + raise CertificateError( + "too many wildcards in certificate DNS name: " + repr(dn) + ) + + # speed up common case w/o wildcards + if not wildcards: + return bool(dn.lower() == hostname.lower()) + + # RFC 6125, section 6.4.3, subitem 1. + # The client SHOULD NOT attempt to match a presented identifier in which + # the wildcard character comprises a label other than the left-most label. + if leftmost == "*": + # When '*' is a fragment by itself, it matches a non-empty dotless + # fragment. + pats.append("[^.]+") + elif leftmost.startswith("xn--") or hostname.startswith("xn--"): + # RFC 6125, section 6.4.3, subitem 3. + # The client SHOULD NOT attempt to match a presented identifier + # where the wildcard character is embedded within an A-label or + # U-label of an internationalized domain name. + pats.append(re.escape(leftmost)) + else: + # Otherwise, '*' matches any dotless string, e.g. www* + pats.append(re.escape(leftmost).replace(r"\*", "[^.]*")) + + # add the remaining fragments, ignore any wildcards + for frag in remainder: + pats.append(re.escape(frag)) + + pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE) + return pat.match(hostname) + + +def _ipaddress_match(ipname: str, host_ip: IPv4Address | IPv6Address) -> bool: + """Exact matching of IP addresses. + + RFC 9110 section 4.3.5: "A reference identity of IP-ID contains the decoded + bytes of the IP address. An IP version 4 address is 4 octets, and an IP + version 6 address is 16 octets. [...] A reference identity of type IP-ID + matches if the address is identical to an iPAddress value of the + subjectAltName extension of the certificate." + """ + # OpenSSL may add a trailing newline to a subjectAltName's IP address + # Divergence from upstream: ipaddress can't handle byte str + ip = ipaddress.ip_address(ipname.rstrip()) + return bool(ip.packed == host_ip.packed) + + +def match_hostname( + cert: _TYPE_PEER_CERT_RET_DICT | None, + hostname: str, + hostname_checks_common_name: bool = False, +) -> None: + """Verify that *cert* (in decoded format as returned by + SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 + rules are followed, but IP addresses are not accepted for *hostname*. + + CertificateError is raised on failure. On success, the function + returns nothing. + """ + if not cert: + raise ValueError( + "empty or no certificate, match_hostname needs a " + "SSL socket or SSL context with either " + "CERT_OPTIONAL or CERT_REQUIRED" + ) + try: + # Divergence from upstream: ipaddress can't handle byte str + # + # The ipaddress module shipped with Python < 3.9 does not support + # scoped IPv6 addresses so we unconditionally strip the Zone IDs for + # now. Once we drop support for Python 3.9 we can remove this branch. + if "%" in hostname: + host_ip = ipaddress.ip_address(hostname[: hostname.rfind("%")]) + else: + host_ip = ipaddress.ip_address(hostname) + + except ValueError: + # Not an IP address (common case) + host_ip = None + dnsnames = [] + san: tuple[tuple[str, str], ...] = cert.get("subjectAltName", ()) + key: str + value: str + for key, value in san: + if key == "DNS": + if host_ip is None and _dnsname_match(value, hostname): + return + dnsnames.append(value) + elif key == "IP Address": + if host_ip is not None and _ipaddress_match(value, host_ip): + return + dnsnames.append(value) + + # We only check 'commonName' if it's enabled and we're not verifying + # an IP address. IP addresses aren't valid within 'commonName'. + if hostname_checks_common_name and host_ip is None and not dnsnames: + for sub in cert.get("subject", ()): + for key, value in sub: + if key == "commonName": + if _dnsname_match(value, hostname): + return + dnsnames.append(value) # Defensive: for Python < 3.9.3 + + if len(dnsnames) > 1: + raise CertificateError( + "hostname %r " + "doesn't match either of %s" % (hostname, ", ".join(map(repr, dnsnames))) + ) + elif len(dnsnames) == 1: + raise CertificateError(f"hostname {hostname!r} doesn't match {dnsnames[0]!r}") + else: + raise CertificateError("no appropriate subjectAltName fields were found") diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/ssltransport.py b/.venv/lib/python3.13/site-packages/urllib3/util/ssltransport.py new file mode 100644 index 0000000000000000000000000000000000000000..6d59bc3bce2489c3a0aa5bcb83b737dcf33c033b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/ssltransport.py @@ -0,0 +1,271 @@ +from __future__ import annotations + +import io +import socket +import ssl +import typing + +from ..exceptions import ProxySchemeUnsupported + +if typing.TYPE_CHECKING: + from typing_extensions import Self + + from .ssl_ import _TYPE_PEER_CERT_RET, _TYPE_PEER_CERT_RET_DICT + + +_WriteBuffer = typing.Union[bytearray, memoryview] +_ReturnValue = typing.TypeVar("_ReturnValue") + +SSL_BLOCKSIZE = 16384 + + +class SSLTransport: + """ + The SSLTransport wraps an existing socket and establishes an SSL connection. + + Contrary to Python's implementation of SSLSocket, it allows you to chain + multiple TLS connections together. It's particularly useful if you need to + implement TLS within TLS. + + The class supports most of the socket API operations. + """ + + @staticmethod + def _validate_ssl_context_for_tls_in_tls(ssl_context: ssl.SSLContext) -> None: + """ + Raises a ProxySchemeUnsupported if the provided ssl_context can't be used + for TLS in TLS. + + The only requirement is that the ssl_context provides the 'wrap_bio' + methods. + """ + + if not hasattr(ssl_context, "wrap_bio"): + raise ProxySchemeUnsupported( + "TLS in TLS requires SSLContext.wrap_bio() which isn't " + "available on non-native SSLContext" + ) + + def __init__( + self, + socket: socket.socket, + ssl_context: ssl.SSLContext, + server_hostname: str | None = None, + suppress_ragged_eofs: bool = True, + ) -> None: + """ + Create an SSLTransport around socket using the provided ssl_context. + """ + self.incoming = ssl.MemoryBIO() + self.outgoing = ssl.MemoryBIO() + + self.suppress_ragged_eofs = suppress_ragged_eofs + self.socket = socket + + self.sslobj = ssl_context.wrap_bio( + self.incoming, self.outgoing, server_hostname=server_hostname + ) + + # Perform initial handshake. + self._ssl_io_loop(self.sslobj.do_handshake) + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_: typing.Any) -> None: + self.close() + + def fileno(self) -> int: + return self.socket.fileno() + + def read(self, len: int = 1024, buffer: typing.Any | None = None) -> int | bytes: + return self._wrap_ssl_read(len, buffer) + + def recv(self, buflen: int = 1024, flags: int = 0) -> int | bytes: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to recv") + return self._wrap_ssl_read(buflen) + + def recv_into( + self, + buffer: _WriteBuffer, + nbytes: int | None = None, + flags: int = 0, + ) -> None | int | bytes: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to recv_into") + if nbytes is None: + nbytes = len(buffer) + return self.read(nbytes, buffer) + + def sendall(self, data: bytes, flags: int = 0) -> None: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to sendall") + count = 0 + with memoryview(data) as view, view.cast("B") as byte_view: + amount = len(byte_view) + while count < amount: + v = self.send(byte_view[count:]) + count += v + + def send(self, data: bytes, flags: int = 0) -> int: + if flags != 0: + raise ValueError("non-zero flags not allowed in calls to send") + return self._ssl_io_loop(self.sslobj.write, data) + + def makefile( + self, + mode: str, + buffering: int | None = None, + *, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> typing.BinaryIO | typing.TextIO | socket.SocketIO: + """ + Python's httpclient uses makefile and buffered io when reading HTTP + messages and we need to support it. + + This is unfortunately a copy and paste of socket.py makefile with small + changes to point to the socket directly. + """ + if not set(mode) <= {"r", "w", "b"}: + raise ValueError(f"invalid mode {mode!r} (only r, w, b allowed)") + + writing = "w" in mode + reading = "r" in mode or not writing + assert reading or writing + binary = "b" in mode + rawmode = "" + if reading: + rawmode += "r" + if writing: + rawmode += "w" + raw = socket.SocketIO(self, rawmode) # type: ignore[arg-type] + self.socket._io_refs += 1 # type: ignore[attr-defined] + if buffering is None: + buffering = -1 + if buffering < 0: + buffering = io.DEFAULT_BUFFER_SIZE + if buffering == 0: + if not binary: + raise ValueError("unbuffered streams must be binary") + return raw + buffer: typing.BinaryIO + if reading and writing: + buffer = io.BufferedRWPair(raw, raw, buffering) # type: ignore[assignment] + elif reading: + buffer = io.BufferedReader(raw, buffering) + else: + assert writing + buffer = io.BufferedWriter(raw, buffering) + if binary: + return buffer + text = io.TextIOWrapper(buffer, encoding, errors, newline) + text.mode = mode # type: ignore[misc] + return text + + def unwrap(self) -> None: + self._ssl_io_loop(self.sslobj.unwrap) + + def close(self) -> None: + self.socket.close() + + @typing.overload + def getpeercert( + self, binary_form: typing.Literal[False] = ... + ) -> _TYPE_PEER_CERT_RET_DICT | None: ... + + @typing.overload + def getpeercert(self, binary_form: typing.Literal[True]) -> bytes | None: ... + + def getpeercert(self, binary_form: bool = False) -> _TYPE_PEER_CERT_RET: + return self.sslobj.getpeercert(binary_form) # type: ignore[return-value] + + def version(self) -> str | None: + return self.sslobj.version() + + def cipher(self) -> tuple[str, str, int] | None: + return self.sslobj.cipher() + + def selected_alpn_protocol(self) -> str | None: + return self.sslobj.selected_alpn_protocol() + + def shared_ciphers(self) -> list[tuple[str, str, int]] | None: + return self.sslobj.shared_ciphers() + + def compression(self) -> str | None: + return self.sslobj.compression() + + def settimeout(self, value: float | None) -> None: + self.socket.settimeout(value) + + def gettimeout(self) -> float | None: + return self.socket.gettimeout() + + def _decref_socketios(self) -> None: + self.socket._decref_socketios() # type: ignore[attr-defined] + + def _wrap_ssl_read(self, len: int, buffer: bytearray | None = None) -> int | bytes: + try: + return self._ssl_io_loop(self.sslobj.read, len, buffer) + except ssl.SSLError as e: + if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs: + return 0 # eof, return 0. + else: + raise + + # func is sslobj.do_handshake or sslobj.unwrap + @typing.overload + def _ssl_io_loop(self, func: typing.Callable[[], None]) -> None: ... + + # func is sslobj.write, arg1 is data + @typing.overload + def _ssl_io_loop(self, func: typing.Callable[[bytes], int], arg1: bytes) -> int: ... + + # func is sslobj.read, arg1 is len, arg2 is buffer + @typing.overload + def _ssl_io_loop( + self, + func: typing.Callable[[int, bytearray | None], bytes], + arg1: int, + arg2: bytearray | None, + ) -> bytes: ... + + def _ssl_io_loop( + self, + func: typing.Callable[..., _ReturnValue], + arg1: None | bytes | int = None, + arg2: bytearray | None = None, + ) -> _ReturnValue: + """Performs an I/O loop between incoming/outgoing and the socket.""" + should_loop = True + ret = None + + while should_loop: + errno = None + try: + if arg1 is None and arg2 is None: + ret = func() + elif arg2 is None: + ret = func(arg1) + else: + ret = func(arg1, arg2) + except ssl.SSLError as e: + if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): + # WANT_READ, and WANT_WRITE are expected, others are not. + raise e + errno = e.errno + + buf = self.outgoing.read() + self.socket.sendall(buf) + + if errno is None: + should_loop = False + elif errno == ssl.SSL_ERROR_WANT_READ: + buf = self.socket.recv(SSL_BLOCKSIZE) + if buf: + self.incoming.write(buf) + else: + self.incoming.write_eof() + return typing.cast(_ReturnValue, ret) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/timeout.py b/.venv/lib/python3.13/site-packages/urllib3/util/timeout.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb1be11d9cb06900dd82ecebd06aa6a7c5de916 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/timeout.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import time +import typing +from enum import Enum +from socket import getdefaulttimeout + +from ..exceptions import TimeoutStateError + +if typing.TYPE_CHECKING: + from typing import Final + + +class _TYPE_DEFAULT(Enum): + # This value should never be passed to socket.settimeout() so for safety we use a -1. + # socket.settimout() raises a ValueError for negative values. + token = -1 + + +_DEFAULT_TIMEOUT: Final[_TYPE_DEFAULT] = _TYPE_DEFAULT.token + +_TYPE_TIMEOUT = typing.Optional[typing.Union[float, _TYPE_DEFAULT]] + + +class Timeout: + """Timeout configuration. + + Timeouts can be defined as a default for a pool: + + .. code-block:: python + + import urllib3 + + timeout = urllib3.util.Timeout(connect=2.0, read=7.0) + + http = urllib3.PoolManager(timeout=timeout) + + resp = http.request("GET", "https://example.com/") + + print(resp.status) + + Or per-request (which overrides the default for the pool): + + .. code-block:: python + + response = http.request("GET", "https://example.com/", timeout=Timeout(10)) + + Timeouts can be disabled by setting all the parameters to ``None``: + + .. code-block:: python + + no_timeout = Timeout(connect=None, read=None) + response = http.request("GET", "https://example.com/", timeout=no_timeout) + + + :param total: + This combines the connect and read timeouts into one; the read timeout + will be set to the time leftover from the connect attempt. In the + event that both a connect timeout and a total are specified, or a read + timeout and a total are specified, the shorter timeout will be applied. + + Defaults to None. + + :type total: int, float, or None + + :param connect: + The maximum amount of time (in seconds) to wait for a connection + attempt to a server to succeed. Omitting the parameter will default the + connect timeout to the system default, probably `the global default + timeout in socket.py + `_. + None will set an infinite timeout for connection attempts. + + :type connect: int, float, or None + + :param read: + The maximum amount of time (in seconds) to wait between consecutive + read operations for a response from the server. Omitting the parameter + will default the read timeout to the system default, probably `the + global default timeout in socket.py + `_. + None will set an infinite timeout. + + :type read: int, float, or None + + .. note:: + + Many factors can affect the total amount of time for urllib3 to return + an HTTP response. + + For example, Python's DNS resolver does not obey the timeout specified + on the socket. Other factors that can affect total request time include + high CPU load, high swap, the program running at a low priority level, + or other behaviors. + + In addition, the read and total timeouts only measure the time between + read operations on the socket connecting the client and the server, + not the total amount of time for the request to return a complete + response. For most requests, the timeout is raised because the server + has not sent the first byte in the specified time. This is not always + the case; if a server streams one byte every fifteen seconds, a timeout + of 20 seconds will not trigger, even though the request will take + several minutes to complete. + """ + + #: A sentinel object representing the default timeout value + DEFAULT_TIMEOUT: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT + + def __init__( + self, + total: _TYPE_TIMEOUT = None, + connect: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + read: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT, + ) -> None: + self._connect = self._validate_timeout(connect, "connect") + self._read = self._validate_timeout(read, "read") + self.total = self._validate_timeout(total, "total") + self._start_connect: float | None = None + + def __repr__(self) -> str: + return f"{type(self).__name__}(connect={self._connect!r}, read={self._read!r}, total={self.total!r})" + + # __str__ provided for backwards compatibility + __str__ = __repr__ + + @staticmethod + def resolve_default_timeout(timeout: _TYPE_TIMEOUT) -> float | None: + return getdefaulttimeout() if timeout is _DEFAULT_TIMEOUT else timeout + + @classmethod + def _validate_timeout(cls, value: _TYPE_TIMEOUT, name: str) -> _TYPE_TIMEOUT: + """Check that a timeout attribute is valid. + + :param value: The timeout value to validate + :param name: The name of the timeout attribute to validate. This is + used to specify in error messages. + :return: The validated and casted version of the given value. + :raises ValueError: If it is a numeric value less than or equal to + zero, or the type is not an integer, float, or None. + """ + if value is None or value is _DEFAULT_TIMEOUT: + return value + + if isinstance(value, bool): + raise ValueError( + "Timeout cannot be a boolean value. It must " + "be an int, float or None." + ) + try: + float(value) + except (TypeError, ValueError): + raise ValueError( + "Timeout value %s was %s, but it must be an " + "int, float or None." % (name, value) + ) from None + + try: + if value <= 0: + raise ValueError( + "Attempted to set %s timeout to %s, but the " + "timeout cannot be set to a value less " + "than or equal to 0." % (name, value) + ) + except TypeError: + raise ValueError( + "Timeout value %s was %s, but it must be an " + "int, float or None." % (name, value) + ) from None + + return value + + @classmethod + def from_float(cls, timeout: _TYPE_TIMEOUT) -> Timeout: + """Create a new Timeout from a legacy timeout value. + + The timeout value used by httplib.py sets the same timeout on the + connect(), and recv() socket requests. This creates a :class:`Timeout` + object that sets the individual timeouts to the ``timeout`` value + passed to this function. + + :param timeout: The legacy timeout value. + :type timeout: integer, float, :attr:`urllib3.util.Timeout.DEFAULT_TIMEOUT`, or None + :return: Timeout object + :rtype: :class:`Timeout` + """ + return Timeout(read=timeout, connect=timeout) + + def clone(self) -> Timeout: + """Create a copy of the timeout object + + Timeout properties are stored per-pool but each request needs a fresh + Timeout object to ensure each one has its own start/stop configured. + + :return: a copy of the timeout object + :rtype: :class:`Timeout` + """ + # We can't use copy.deepcopy because that will also create a new object + # for _GLOBAL_DEFAULT_TIMEOUT, which socket.py uses as a sentinel to + # detect the user default. + return Timeout(connect=self._connect, read=self._read, total=self.total) + + def start_connect(self) -> float: + """Start the timeout clock, used during a connect() attempt + + :raises urllib3.exceptions.TimeoutStateError: if you attempt + to start a timer that has been started already. + """ + if self._start_connect is not None: + raise TimeoutStateError("Timeout timer has already been started.") + self._start_connect = time.monotonic() + return self._start_connect + + def get_connect_duration(self) -> float: + """Gets the time elapsed since the call to :meth:`start_connect`. + + :return: Elapsed time in seconds. + :rtype: float + :raises urllib3.exceptions.TimeoutStateError: if you attempt + to get duration for a timer that hasn't been started. + """ + if self._start_connect is None: + raise TimeoutStateError( + "Can't get connect duration for timer that has not started." + ) + return time.monotonic() - self._start_connect + + @property + def connect_timeout(self) -> _TYPE_TIMEOUT: + """Get the value to use when setting a connection timeout. + + This will be a positive float or integer, the value None + (never timeout), or the default system timeout. + + :return: Connect timeout. + :rtype: int, float, :attr:`Timeout.DEFAULT_TIMEOUT` or None + """ + if self.total is None: + return self._connect + + if self._connect is None or self._connect is _DEFAULT_TIMEOUT: + return self.total + + return min(self._connect, self.total) # type: ignore[type-var] + + @property + def read_timeout(self) -> float | None: + """Get the value for the read timeout. + + This assumes some time has elapsed in the connection timeout and + computes the read timeout appropriately. + + If self.total is set, the read timeout is dependent on the amount of + time taken by the connect timeout. If the connection time has not been + established, a :exc:`~urllib3.exceptions.TimeoutStateError` will be + raised. + + :return: Value to use for the read timeout. + :rtype: int, float or None + :raises urllib3.exceptions.TimeoutStateError: If :meth:`start_connect` + has not yet been called on this object. + """ + if ( + self.total is not None + and self.total is not _DEFAULT_TIMEOUT + and self._read is not None + and self._read is not _DEFAULT_TIMEOUT + ): + # In case the connect timeout has not yet been established. + if self._start_connect is None: + return self._read + return max(0, min(self.total - self.get_connect_duration(), self._read)) + elif self.total is not None and self.total is not _DEFAULT_TIMEOUT: + return max(0, self.total - self.get_connect_duration()) + else: + return self.resolve_default_timeout(self._read) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/url.py b/.venv/lib/python3.13/site-packages/urllib3/util/url.py new file mode 100644 index 0000000000000000000000000000000000000000..db057f17be610174f30928748b5004dcbf6c501c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/url.py @@ -0,0 +1,469 @@ +from __future__ import annotations + +import re +import typing + +from ..exceptions import LocationParseError +from .util import to_str + +# We only want to normalize urls with an HTTP(S) scheme. +# urllib3 infers URLs without a scheme (None) to be http. +_NORMALIZABLE_SCHEMES = ("http", "https", None) + +# Almost all of these patterns were derived from the +# 'rfc3986' module: https://github.com/python-hyper/rfc3986 +_PERCENT_RE = re.compile(r"%[a-fA-F0-9]{2}") +_SCHEME_RE = re.compile(r"^(?:[a-zA-Z][a-zA-Z0-9+-]*:|/)") +_URI_RE = re.compile( + r"^(?:([a-zA-Z][a-zA-Z0-9+.-]*):)?" + r"(?://([^\\/?#]*))?" + r"([^?#]*)" + r"(?:\?([^#]*))?" + r"(?:#(.*))?$", + re.UNICODE | re.DOTALL, +) + +_IPV4_PAT = r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}" +_HEX_PAT = "[0-9A-Fa-f]{1,4}" +_LS32_PAT = "(?:{hex}:{hex}|{ipv4})".format(hex=_HEX_PAT, ipv4=_IPV4_PAT) +_subs = {"hex": _HEX_PAT, "ls32": _LS32_PAT} +_variations = [ + # 6( h16 ":" ) ls32 + "(?:%(hex)s:){6}%(ls32)s", + # "::" 5( h16 ":" ) ls32 + "::(?:%(hex)s:){5}%(ls32)s", + # [ h16 ] "::" 4( h16 ":" ) ls32 + "(?:%(hex)s)?::(?:%(hex)s:){4}%(ls32)s", + # [ *1( h16 ":" ) h16 ] "::" 3( h16 ":" ) ls32 + "(?:(?:%(hex)s:)?%(hex)s)?::(?:%(hex)s:){3}%(ls32)s", + # [ *2( h16 ":" ) h16 ] "::" 2( h16 ":" ) ls32 + "(?:(?:%(hex)s:){0,2}%(hex)s)?::(?:%(hex)s:){2}%(ls32)s", + # [ *3( h16 ":" ) h16 ] "::" h16 ":" ls32 + "(?:(?:%(hex)s:){0,3}%(hex)s)?::%(hex)s:%(ls32)s", + # [ *4( h16 ":" ) h16 ] "::" ls32 + "(?:(?:%(hex)s:){0,4}%(hex)s)?::%(ls32)s", + # [ *5( h16 ":" ) h16 ] "::" h16 + "(?:(?:%(hex)s:){0,5}%(hex)s)?::%(hex)s", + # [ *6( h16 ":" ) h16 ] "::" + "(?:(?:%(hex)s:){0,6}%(hex)s)?::", +] + +_UNRESERVED_PAT = r"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._\-~" +_IPV6_PAT = "(?:" + "|".join([x % _subs for x in _variations]) + ")" +_ZONE_ID_PAT = "(?:%25|%)(?:[" + _UNRESERVED_PAT + "]|%[a-fA-F0-9]{2})+" +_IPV6_ADDRZ_PAT = r"\[" + _IPV6_PAT + r"(?:" + _ZONE_ID_PAT + r")?\]" +_REG_NAME_PAT = r"(?:[^\[\]%:/?#]|%[a-fA-F0-9]{2})*" +_TARGET_RE = re.compile(r"^(/[^?#]*)(?:\?([^#]*))?(?:#.*)?$") + +_IPV4_RE = re.compile("^" + _IPV4_PAT + "$") +_IPV6_RE = re.compile("^" + _IPV6_PAT + "$") +_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT + "$") +_BRACELESS_IPV6_ADDRZ_RE = re.compile("^" + _IPV6_ADDRZ_PAT[2:-2] + "$") +_ZONE_ID_RE = re.compile("(" + _ZONE_ID_PAT + r")\]$") + +_HOST_PORT_PAT = ("^(%s|%s|%s)(?::0*?(|0|[1-9][0-9]{0,4}))?$") % ( + _REG_NAME_PAT, + _IPV4_PAT, + _IPV6_ADDRZ_PAT, +) +_HOST_PORT_RE = re.compile(_HOST_PORT_PAT, re.UNICODE | re.DOTALL) + +_UNRESERVED_CHARS = set( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._-~" +) +_SUB_DELIM_CHARS = set("!$&'()*+,;=") +_USERINFO_CHARS = _UNRESERVED_CHARS | _SUB_DELIM_CHARS | {":"} +_PATH_CHARS = _USERINFO_CHARS | {"@", "/"} +_QUERY_CHARS = _FRAGMENT_CHARS = _PATH_CHARS | {"?"} + + +class Url( + typing.NamedTuple( + "Url", + [ + ("scheme", typing.Optional[str]), + ("auth", typing.Optional[str]), + ("host", typing.Optional[str]), + ("port", typing.Optional[int]), + ("path", typing.Optional[str]), + ("query", typing.Optional[str]), + ("fragment", typing.Optional[str]), + ], + ) +): + """ + Data structure for representing an HTTP URL. Used as a return value for + :func:`parse_url`. Both the scheme and host are normalized as they are + both case-insensitive according to RFC 3986. + """ + + def __new__( # type: ignore[no-untyped-def] + cls, + scheme: str | None = None, + auth: str | None = None, + host: str | None = None, + port: int | None = None, + path: str | None = None, + query: str | None = None, + fragment: str | None = None, + ): + if path and not path.startswith("/"): + path = "/" + path + if scheme is not None: + scheme = scheme.lower() + return super().__new__(cls, scheme, auth, host, port, path, query, fragment) + + @property + def hostname(self) -> str | None: + """For backwards-compatibility with urlparse. We're nice like that.""" + return self.host + + @property + def request_uri(self) -> str: + """Absolute path including the query string.""" + uri = self.path or "/" + + if self.query is not None: + uri += "?" + self.query + + return uri + + @property + def authority(self) -> str | None: + """ + Authority component as defined in RFC 3986 3.2. + This includes userinfo (auth), host and port. + + i.e. + userinfo@host:port + """ + userinfo = self.auth + netloc = self.netloc + if netloc is None or userinfo is None: + return netloc + else: + return f"{userinfo}@{netloc}" + + @property + def netloc(self) -> str | None: + """ + Network location including host and port. + + If you need the equivalent of urllib.parse's ``netloc``, + use the ``authority`` property instead. + """ + if self.host is None: + return None + if self.port: + return f"{self.host}:{self.port}" + return self.host + + @property + def url(self) -> str: + """ + Convert self into a url + + This function should more or less round-trip with :func:`.parse_url`. The + returned url may not be exactly the same as the url inputted to + :func:`.parse_url`, but it should be equivalent by the RFC (e.g., urls + with a blank port will have : removed). + + Example: + + .. code-block:: python + + import urllib3 + + U = urllib3.util.parse_url("https://google.com/mail/") + + print(U.url) + # "https://google.com/mail/" + + print( urllib3.util.Url("https", "username:password", + "host.com", 80, "/path", "query", "fragment" + ).url + ) + # "https://username:password@host.com:80/path?query#fragment" + """ + scheme, auth, host, port, path, query, fragment = self + url = "" + + # We use "is not None" we want things to happen with empty strings (or 0 port) + if scheme is not None: + url += scheme + "://" + if auth is not None: + url += auth + "@" + if host is not None: + url += host + if port is not None: + url += ":" + str(port) + if path is not None: + url += path + if query is not None: + url += "?" + query + if fragment is not None: + url += "#" + fragment + + return url + + def __str__(self) -> str: + return self.url + + +@typing.overload +def _encode_invalid_chars( + component: str, allowed_chars: typing.Container[str] +) -> str: # Abstract + ... + + +@typing.overload +def _encode_invalid_chars( + component: None, allowed_chars: typing.Container[str] +) -> None: # Abstract + ... + + +def _encode_invalid_chars( + component: str | None, allowed_chars: typing.Container[str] +) -> str | None: + """Percent-encodes a URI component without reapplying + onto an already percent-encoded component. + """ + if component is None: + return component + + component = to_str(component) + + # Normalize existing percent-encoded bytes. + # Try to see if the component we're encoding is already percent-encoded + # so we can skip all '%' characters but still encode all others. + component, percent_encodings = _PERCENT_RE.subn( + lambda match: match.group(0).upper(), component + ) + + uri_bytes = component.encode("utf-8", "surrogatepass") + is_percent_encoded = percent_encodings == uri_bytes.count(b"%") + encoded_component = bytearray() + + for i in range(0, len(uri_bytes)): + # Will return a single character bytestring + byte = uri_bytes[i : i + 1] + byte_ord = ord(byte) + if (is_percent_encoded and byte == b"%") or ( + byte_ord < 128 and byte.decode() in allowed_chars + ): + encoded_component += byte + continue + encoded_component.extend(b"%" + (hex(byte_ord)[2:].encode().zfill(2).upper())) + + return encoded_component.decode() + + +def _remove_path_dot_segments(path: str) -> str: + # See http://tools.ietf.org/html/rfc3986#section-5.2.4 for pseudo-code + segments = path.split("/") # Turn the path into a list of segments + output = [] # Initialize the variable to use to store output + + for segment in segments: + # '.' is the current directory, so ignore it, it is superfluous + if segment == ".": + continue + # Anything other than '..', should be appended to the output + if segment != "..": + output.append(segment) + # In this case segment == '..', if we can, we should pop the last + # element + elif output: + output.pop() + + # If the path starts with '/' and the output is empty or the first string + # is non-empty + if path.startswith("/") and (not output or output[0]): + output.insert(0, "") + + # If the path starts with '/.' or '/..' ensure we add one more empty + # string to add a trailing '/' + if path.endswith(("/.", "/..")): + output.append("") + + return "/".join(output) + + +@typing.overload +def _normalize_host(host: None, scheme: str | None) -> None: ... + + +@typing.overload +def _normalize_host(host: str, scheme: str | None) -> str: ... + + +def _normalize_host(host: str | None, scheme: str | None) -> str | None: + if host: + if scheme in _NORMALIZABLE_SCHEMES: + is_ipv6 = _IPV6_ADDRZ_RE.match(host) + if is_ipv6: + # IPv6 hosts of the form 'a::b%zone' are encoded in a URL as + # such per RFC 6874: 'a::b%25zone'. Unquote the ZoneID + # separator as necessary to return a valid RFC 4007 scoped IP. + match = _ZONE_ID_RE.search(host) + if match: + start, end = match.span(1) + zone_id = host[start:end] + + if zone_id.startswith("%25") and zone_id != "%25": + zone_id = zone_id[3:] + else: + zone_id = zone_id[1:] + zone_id = _encode_invalid_chars(zone_id, _UNRESERVED_CHARS) + return f"{host[:start].lower()}%{zone_id}{host[end:]}" + else: + return host.lower() + elif not _IPV4_RE.match(host): + return to_str( + b".".join([_idna_encode(label) for label in host.split(".")]), + "ascii", + ) + return host + + +def _idna_encode(name: str) -> bytes: + if not name.isascii(): + try: + import idna + except ImportError: + raise LocationParseError( + "Unable to parse URL without the 'idna' module" + ) from None + + try: + return idna.encode(name.lower(), strict=True, std3_rules=True) + except idna.IDNAError: + raise LocationParseError( + f"Name '{name}' is not a valid IDNA label" + ) from None + + return name.lower().encode("ascii") + + +def _encode_target(target: str) -> str: + """Percent-encodes a request target so that there are no invalid characters + + Pre-condition for this function is that 'target' must start with '/'. + If that is the case then _TARGET_RE will always produce a match. + """ + match = _TARGET_RE.match(target) + if not match: # Defensive: + raise LocationParseError(f"{target!r} is not a valid request URI") + + path, query = match.groups() + encoded_target = _encode_invalid_chars(path, _PATH_CHARS) + if query is not None: + query = _encode_invalid_chars(query, _QUERY_CHARS) + encoded_target += "?" + query + return encoded_target + + +def parse_url(url: str) -> Url: + """ + Given a url, return a parsed :class:`.Url` namedtuple. Best-effort is + performed to parse incomplete urls. Fields not provided will be None. + This parser is RFC 3986 and RFC 6874 compliant. + + The parser logic and helper functions are based heavily on + work done in the ``rfc3986`` module. + + :param str url: URL to parse into a :class:`.Url` namedtuple. + + Partly backwards-compatible with :mod:`urllib.parse`. + + Example: + + .. code-block:: python + + import urllib3 + + print( urllib3.util.parse_url('http://google.com/mail/')) + # Url(scheme='http', host='google.com', port=None, path='/mail/', ...) + + print( urllib3.util.parse_url('google.com:80')) + # Url(scheme=None, host='google.com', port=80, path=None, ...) + + print( urllib3.util.parse_url('/foo?bar')) + # Url(scheme=None, host=None, port=None, path='/foo', query='bar', ...) + """ + if not url: + # Empty + return Url() + + source_url = url + if not _SCHEME_RE.search(url): + url = "//" + url + + scheme: str | None + authority: str | None + auth: str | None + host: str | None + port: str | None + port_int: int | None + path: str | None + query: str | None + fragment: str | None + + try: + scheme, authority, path, query, fragment = _URI_RE.match(url).groups() # type: ignore[union-attr] + normalize_uri = scheme is None or scheme.lower() in _NORMALIZABLE_SCHEMES + + if scheme: + scheme = scheme.lower() + + if authority: + auth, _, host_port = authority.rpartition("@") + auth = auth or None + host, port = _HOST_PORT_RE.match(host_port).groups() # type: ignore[union-attr] + if auth and normalize_uri: + auth = _encode_invalid_chars(auth, _USERINFO_CHARS) + if port == "": + port = None + else: + auth, host, port = None, None, None + + if port is not None: + port_int = int(port) + if not (0 <= port_int <= 65535): + raise LocationParseError(url) + else: + port_int = None + + host = _normalize_host(host, scheme) + + if normalize_uri and path: + path = _remove_path_dot_segments(path) + path = _encode_invalid_chars(path, _PATH_CHARS) + if normalize_uri and query: + query = _encode_invalid_chars(query, _QUERY_CHARS) + if normalize_uri and fragment: + fragment = _encode_invalid_chars(fragment, _FRAGMENT_CHARS) + + except (ValueError, AttributeError) as e: + raise LocationParseError(source_url) from e + + # For the sake of backwards compatibility we put empty + # string values for path if there are any defined values + # beyond the path in the URL. + # TODO: Remove this when we break backwards compatibility. + if not path: + if query is not None or fragment is not None: + path = "" + else: + path = None + + return Url( + scheme=scheme, + auth=auth, + host=host, + port=port_int, + path=path, + query=query, + fragment=fragment, + ) diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/util.py b/.venv/lib/python3.13/site-packages/urllib3/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..35c77e4025842f548565334a3c04cba90f9283d6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/util.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import typing +from types import TracebackType + + +def to_bytes( + x: str | bytes, encoding: str | None = None, errors: str | None = None +) -> bytes: + if isinstance(x, bytes): + return x + elif not isinstance(x, str): + raise TypeError(f"not expecting type {type(x).__name__}") + if encoding or errors: + return x.encode(encoding or "utf-8", errors=errors or "strict") + return x.encode() + + +def to_str( + x: str | bytes, encoding: str | None = None, errors: str | None = None +) -> str: + if isinstance(x, str): + return x + elif not isinstance(x, bytes): + raise TypeError(f"not expecting type {type(x).__name__}") + if encoding or errors: + return x.decode(encoding or "utf-8", errors=errors or "strict") + return x.decode() + + +def reraise( + tp: type[BaseException] | None, + value: BaseException, + tb: TracebackType | None = None, +) -> typing.NoReturn: + try: + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None # type: ignore[assignment] + tb = None diff --git a/.venv/lib/python3.13/site-packages/urllib3/util/wait.py b/.venv/lib/python3.13/site-packages/urllib3/util/wait.py new file mode 100644 index 0000000000000000000000000000000000000000..aeca0c7ad5b232eeb1ad9c43d315bd1d74eaed9a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/urllib3/util/wait.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import select +import socket +from functools import partial + +__all__ = ["wait_for_read", "wait_for_write"] + + +# How should we wait on sockets? +# +# There are two types of APIs you can use for waiting on sockets: the fancy +# modern stateful APIs like epoll/kqueue, and the older stateless APIs like +# select/poll. The stateful APIs are more efficient when you have a lots of +# sockets to keep track of, because you can set them up once and then use them +# lots of times. But we only ever want to wait on a single socket at a time +# and don't want to keep track of state, so the stateless APIs are actually +# more efficient. So we want to use select() or poll(). +# +# Now, how do we choose between select() and poll()? On traditional Unixes, +# select() has a strange calling convention that makes it slow, or fail +# altogether, for high-numbered file descriptors. The point of poll() is to fix +# that, so on Unixes, we prefer poll(). +# +# On Windows, there is no poll() (or at least Python doesn't provide a wrapper +# for it), but that's OK, because on Windows, select() doesn't have this +# strange calling convention; plain select() works fine. +# +# So: on Windows we use select(), and everywhere else we use poll(). We also +# fall back to select() in case poll() is somehow broken or missing. + + +def select_wait_for_socket( + sock: socket.socket, + read: bool = False, + write: bool = False, + timeout: float | None = None, +) -> bool: + if not read and not write: + raise RuntimeError("must specify at least one of read=True, write=True") + rcheck = [] + wcheck = [] + if read: + rcheck.append(sock) + if write: + wcheck.append(sock) + # When doing a non-blocking connect, most systems signal success by + # marking the socket writable. Windows, though, signals success by marked + # it as "exceptional". We paper over the difference by checking the write + # sockets for both conditions. (The stdlib selectors module does the same + # thing.) + fn = partial(select.select, rcheck, wcheck, wcheck) + rready, wready, xready = fn(timeout) + return bool(rready or wready or xready) + + +def poll_wait_for_socket( + sock: socket.socket, + read: bool = False, + write: bool = False, + timeout: float | None = None, +) -> bool: + if not read and not write: + raise RuntimeError("must specify at least one of read=True, write=True") + mask = 0 + if read: + mask |= select.POLLIN + if write: + mask |= select.POLLOUT + poll_obj = select.poll() + poll_obj.register(sock, mask) + + # For some reason, poll() takes timeout in milliseconds + def do_poll(t: float | None) -> list[tuple[int, int]]: + if t is not None: + t *= 1000 + return poll_obj.poll(t) + + return bool(do_poll(timeout)) + + +def _have_working_poll() -> bool: + # Apparently some systems have a select.poll that fails as soon as you try + # to use it, either due to strange configuration or broken monkeypatching + # from libraries like eventlet/greenlet. + try: + poll_obj = select.poll() + poll_obj.poll(0) + except (AttributeError, OSError): + return False + else: + return True + + +def wait_for_socket( + sock: socket.socket, + read: bool = False, + write: bool = False, + timeout: float | None = None, +) -> bool: + # We delay choosing which implementation to use until the first time we're + # called. We could do it at import time, but then we might make the wrong + # decision if someone goes wild with monkeypatching select.poll after + # we're imported. + global wait_for_socket + if _have_working_poll(): + wait_for_socket = poll_wait_for_socket + elif hasattr(select, "select"): + wait_for_socket = select_wait_for_socket + return wait_for_socket(sock, read, write, timeout) + + +def wait_for_read(sock: socket.socket, timeout: float | None = None) -> bool: + """Waits for reading to be available on a given socket. + Returns True if the socket is readable, or False if the timeout expired. + """ + return wait_for_socket(sock, read=True, timeout=timeout) + + +def wait_for_write(sock: socket.socket, timeout: float | None = None) -> bool: + """Waits for writing to be available on a given socket. + Returns True if the socket is readable, or False if the timeout expired. + """ + return wait_for_socket(sock, write=True, timeout=timeout) diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1de19635c14b25480b693351592349580fe3070 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/_analytics.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/_analytics.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16cf71de44a7e2b2ef73712b1ed01acd8cf0ed97 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/_analytics.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/_iterutils.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/_iterutils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb7d5e5fc28a6a8e19dfb936fa2c09add873a539 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/_iterutils.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/_strutils.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/_strutils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5474cd8322ee94f57f29b9978fbfe54210fd20d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/_strutils.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/data_types.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/data_types.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c0b8aaafea67c78d8ef6c2b6230fd9b8a04082b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/data_types.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/env.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/env.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0fe88a1e845c570f908e29efec87e531cca9093 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/env.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/trigger.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/trigger.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e0e5446319909a806111f85b2d533cf452c1813 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/trigger.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/util.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/util.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e45b35e852f67b45b39f448e6d85f034227a9bbd Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/util.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/__pycache__/wandb_agent.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/__pycache__/wandb_agent.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f1e78c25be36f8982fe21c16ffe1c04d6413257 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/__pycache__/wandb_agent.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/_pydantic/__init__.py b/.venv/lib/python3.13/site-packages/wandb/_pydantic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1709cf5e268c76398fc808c55195d9c26bdf59ed --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/_pydantic/__init__.py @@ -0,0 +1,41 @@ +"""Internal utilities for working with pydantic.""" + +__all__ = [ + "IS_PYDANTIC_V2", + "CompatBaseModel", + "JsonableModel", + "GQLBase", + "GQLInput", + "GQLResult", + "Connection", + "ConnectionWithTotal", + "Edge", + "PageInfo", + "Typename", + "GQLId", + "AliasChoices", + "computed_field", + "field_validator", + "model_validator", + "pydantic_isinstance", + "to_camel", + "to_json", + "from_json", + "gql_typename", + "ValidationError", +] + +# Available in all supported Pydantic versions. +from pydantic import ValidationError + +from .base import CompatBaseModel, GQLBase, GQLInput, GQLResult, JsonableModel +from .field_types import GQLId, Typename +from .pagination import Connection, ConnectionWithTotal, Edge, PageInfo +from .utils import IS_PYDANTIC_V2, from_json, gql_typename, pydantic_isinstance, to_json +from .v1_compat import ( + AliasChoices, + computed_field, + field_validator, + model_validator, + to_camel, +) diff --git a/.venv/lib/python3.13/site-packages/wandb/_pydantic/base.py b/.venv/lib/python3.13/site-packages/wandb/_pydantic/base.py new file mode 100644 index 0000000000000000000000000000000000000000..373a008a856165f62571286420bae4e6b3d309f0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/_pydantic/base.py @@ -0,0 +1,151 @@ +"""Base classes and other customizations for generated pydantic types.""" + +# Older-style type annotations required for Pydantic v1 / python 3.8 compatibility. +# ruff: noqa: UP006 + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Literal, overload + +from pydantic import BaseModel, ConfigDict +from typing_extensions import TypedDict, Unpack, override + +from .v1_compat import PydanticCompatMixin, to_camel + +if TYPE_CHECKING: + from pydantic.main import IncEx + + +class ModelDumpKwargs(TypedDict, total=False): + """Shared keyword arguments for `BaseModel.model_{dump,dump_json}`. + + Newer pydantic versions may accept more arguments than are listed here. + Last updated for pydantic v2.12.0. + """ + + include: IncEx | None + exclude: IncEx | None + context: Any | None + by_alias: bool | None + exclude_unset: bool + exclude_defaults: bool + exclude_none: bool + exclude_computed_fields: bool + round_trip: bool + warnings: bool | Literal["none", "warn", "error"] + fallback: Callable[[Any], Any] | None + serialize_as_any: bool + + +# --------------------------------------------------------------------------- +# Base models and mixin classes. +# +# Extra info is provided for devs in inline comments, NOT docstrings. This +# prevents it from showing up in generated docs for subclasses. + + +# FOR INTERNAL USE ONLY: v1-compatible drop-in replacement for `pydantic.BaseModel`. +# If pydantic v2 is detected, this is just `pydantic.BaseModel`. +# +# Deliberately inherits ALL default configuration from `pydantic.BaseModel`. +class CompatBaseModel(PydanticCompatMixin, BaseModel): + __doc__ = None # Prevent subclasses from inheriting the BaseModel docstring + + +class JsonableModel(CompatBaseModel, ABC): + # Base class with sensible defaults for converting to and from JSON. + # + # Automatically parse or serialize "raw" API data (e.g. convert to and from + # camelCase keys): + # - `.model_{dump,dump_json}()` should return JSON-ready dicts or JSON + # strings. + # - `.model_{validate,validate_json}()` should accept JSON-ready dicts or + # JSON strings. + # + # Ensure round-trip serialization <-> deserialization between: + # - `model_dump()` <-> `model_validate()` + # - `model_dump_json()` <-> `model_validate_json()` + # + # These behaviors help models predictably handle GraphQL request or response + # data. + + model_config = ConfigDict( + # --------------------------------------------------------------------------- + # Discouraged in v2.11+, deprecated in v3. Kept here for compatibility. + populate_by_name=True, + # --------------------------------------------------------------------------- + # Introduced in v2.11, ignored in earlier versions + validate_by_name=True, + validate_by_alias=True, + serialize_by_alias=True, + # --------------------------------------------------------------------------- + validate_assignment=True, + use_attribute_docstrings=True, + from_attributes=True, + ) + + # Custom default kwargs for `JsonableModel.model_{dump,dump_json}`: + # - by_alias: Convert keys to JSON-ready names and objects to JSON-ready + # dicts. + # - round_trip: Ensure the result can round-trip. + __DUMP_DEFAULTS: ClassVar[Dict[str, Any]] = dict(by_alias=True, round_trip=True) + + @overload # Actual signature + def model_dump( + self, *, mode: str, **kwargs: Unpack[ModelDumpKwargs] + ) -> dict[str, Any]: ... + @overload # In case pydantic adds more kwargs in future releases + def model_dump(self, **kwargs: Any) -> dict[str, Any]: ... + + @override + def model_dump(self, *, mode: str = "json", **kwargs: Any) -> dict[str, Any]: + kwargs = {**self.__DUMP_DEFAULTS, **kwargs} # allows overrides, if needed + return super().model_dump(mode=mode, **kwargs) + + @overload # Actual signature + def model_dump_json( + self, *, indent: int | None, **kwargs: Unpack[ModelDumpKwargs] + ) -> str: ... + @overload # In case pydantic adds more kwargs in future releases + def model_dump_json(self, **kwargs: Any) -> str: ... + + @override + def model_dump_json(self, *, indent: int | None = None, **kwargs: Any) -> str: + kwargs = {**self.__DUMP_DEFAULTS, **kwargs} # allows overrides, if needed + return super().model_dump_json(indent=indent, **kwargs) + + +# Base class for all GraphQL-derived types. +class GQLBase(JsonableModel, ABC): + model_config = ConfigDict( + validate_default=True, + revalidate_instances="always", + protected_namespaces=(), # Some GraphQL fields may begin with "model_" + ) + + +# Base class for GraphQL result types, i.e. parsed GraphQL response data. +class GQLResult(GQLBase, ABC): + model_config = ConfigDict( + alias_generator=to_camel, # Assume JSON names are camelCase, by default + frozen=True, # Keep the actual response data immutable + ) + + +# Base class for GraphQL input types, i.e. prepared variables or input objects +# for queries and mutations. +class GQLInput(GQLBase, ABC): + # For GraphQL inputs, exclude null values when preparing JSON-able request + # data. + __DUMP_DEFAULTS: ClassVar[Dict[str, Any]] = dict(exclude_none=True) + + @override + def model_dump(self, *, mode: str = "json", **kwargs: Any) -> dict[str, Any]: + kwargs = {**self.__DUMP_DEFAULTS, **kwargs} + return super().model_dump(mode=mode, **kwargs) + + @override + def model_dump_json(self, *, indent: int | None = None, **kwargs: Any) -> str: + kwargs = {**self.__DUMP_DEFAULTS, **kwargs} + return super().model_dump_json(indent=indent, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/wandb/agents/__init__.py b/.venv/lib/python3.13/site-packages/wandb/agents/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/analytics/__init__.py b/.venv/lib/python3.13/site-packages/wandb/analytics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d18e7f7cb8d200e9fb7295de0a1be8c417ea1c99 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/analytics/__init__.py @@ -0,0 +1,3 @@ +__all__ = ("get_sentry",) + +from .sentry import get_sentry diff --git a/.venv/lib/python3.13/site-packages/wandb/analytics/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/analytics/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3f10c4bb27cda2a45c9accdc8e2b839229745d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/analytics/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/analytics/__pycache__/sentry.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/analytics/__pycache__/sentry.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..120378a6d869e1eb9395bb7c953873f47bc7c385 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/analytics/__pycache__/sentry.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/analytics/sentry.py b/.venv/lib/python3.13/site-packages/wandb/analytics/sentry.py new file mode 100644 index 0000000000000000000000000000000000000000..51dd327e6f34cc002809166b2ee9576e46090251 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/analytics/sentry.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import atexit +import contextlib +import functools +import os +import pathlib +import sys +import threading +from types import TracebackType +from typing import Any, Callable, Literal, TypeVar +from urllib.parse import quote + +from typing_extensions import Concatenate, Never, ParamSpec + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +SENTRY_DEFAULT_DSN = ( + "https://2592b1968ea94cca9b5ef5e348e094a7@o151352.ingest.sentry.io/4504800232407040" +) + +SessionStatus = Literal["ok", "exited", "crashed", "abnormal"] + + +def _guard( + method: Callable[Concatenate[Sentry, _P], _T], +) -> Callable[Concatenate[Sentry, _P], _T | None]: + """Make a Sentry method safe, lazy, and non-raising. + + The wrapped method becomes a no-op if Sentry is disabled, + this instance belongs to a different PID, or lazy boot fails + """ + + @functools.wraps(method) + def wrapper( + self: Sentry, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _T | None: + if not self._enabled: + return None + + # If this instance belongs to a different process (fork happened), + # do nothing; get_sentry() will create a fresh instance for the child. + if self._pid != os.getpid(): + return None + + if not self._booted and not self._boot(): + return None + + try: + return method(self, *args, **kwargs) + except Exception as e: + if method.__name__ != "exception": + # Best-effort logging of wrapper-level failures. + with contextlib.suppress(Exception): + self.exception(f"Error in {method.__name__}: {e}") + return None + + return wrapper + + +class Sentry: + def __init__(self, *, pid: int) -> None: + from wandb import env as _env + + self._pid: int = pid + self._enabled: bool = bool(_env.error_reporting_enabled()) + self._booted: bool = False + self._boot_lock = threading.Lock() + self._atexit_registered: bool = False + + self._sent_messages: set[str] = set() + self._sdk: Any | None = None # will hold the sentry_sdk module after boot + self.scope: Any | None = None + + self.dsn: str | None = os.environ.get(_env.SENTRY_DSN, SENTRY_DEFAULT_DSN) + + @property + def environment(self) -> str: + is_git = pathlib.Path(__file__).parent.parent.parent.joinpath(".git").exists() + return "development" if is_git else "production" + + def _boot(self) -> bool: + """Import sentry_sdk and set up client/scope.""" + from wandb import __version__ + + with self._boot_lock: + if not self._enabled: + return False + + if self._booted: + return True + + try: + import sentry_sdk # type: ignore + import sentry_sdk.scope # type: ignore + import sentry_sdk.utils # type: ignore + + self._sdk = sentry_sdk + + client = self._sdk.Client( + dsn=self.dsn, + default_integrations=False, + environment=self.environment, + release=__version__, + ) + scope = self._sdk.get_global_scope().fork() + scope.clear() + scope.set_client(client) + + self.scope = scope + self._booted = True + + if not self._atexit_registered: + atexit.register(self.end_session) + self._atexit_registered = True + + except Exception: + # Disable on any failure. + self._enabled = False + self._booted = False + self._sdk = None + self.scope = None + + return False + + return True + + @_guard + def message( + self, + message: str, + repeat: bool = True, + level: str = "info", + ) -> str | None: + if not repeat and message in self._sent_messages: + return None + self._sent_messages.add(message) + with self._sdk.scope.use_isolation_scope(self.scope): # type: ignore + return self._sdk.capture_message(message, level=level) # type: ignore + + @_guard + def exception( + self, + exc: str + | BaseException + | tuple[ + type[BaseException] | None, + BaseException | None, + TracebackType | None, + ] + | None, + handled: bool = False, + status: SessionStatus | None = None, + ) -> str | None: + if isinstance(exc, str): + exc_info = self._sdk.utils.exc_info_from_error(Exception(exc)) # type: ignore + elif isinstance(exc, BaseException): + exc_info = self._sdk.utils.exc_info_from_error(exc) # type: ignore + else: + exc_info = sys.exc_info() + + event, _ = self._sdk.utils.event_from_exception( # type: ignore + exc_info, + client_options=self.scope.get_client().options, # type: ignore + mechanism={"type": "generic", "handled": handled}, + ) + event_id = None + with contextlib.suppress(Exception): + with self._sdk.scope.use_isolation_scope(self.scope): # type: ignore + event_id = self._sdk.capture_event(event) # type: ignore + + status = status or ("crashed" if not handled else "errored") # type: ignore + self.mark_session(status=status) + + client = self.scope.get_client() # type: ignore + if client is not None: + client.flush() + return event_id + + def reraise(self, exc: Any) -> Never: + """Re-raise after logging, preserving traceback. Safe if disabled.""" + try: + self.exception(exc) # @_guard applies here + finally: + _, _, tb = sys.exc_info() + if tb is not None and hasattr(exc, "with_traceback"): + raise exc.with_traceback(tb) + raise exc + + @_guard + def start_session(self) -> None: + if self.scope is None: + return + if self.scope._session is None: + self.scope.start_session() + + @_guard + def end_session(self) -> None: + if self.scope is None: + return + client = self.scope.get_client() + session = self.scope._session + if session is not None and client is not None: + self.scope.end_session() + client.flush() + + @_guard + def mark_session(self, status: SessionStatus | None = None) -> None: + if self.scope is None: + return + session = self.scope._session + if session is not None: + session.update(status=status) + + @_guard + def configure_scope( + self, + tags: dict[str, Any] | None = None, + process_context: str | None = None, + ) -> None: + import wandb.util + + if self.scope is None: + return + + settings_tags = ( + "entity", + "project", + "run_id", + "run_url", + "sweep_url", + "sweep_id", + "deployment", + "launch", + "_platform", + ) + + if process_context: + self.scope.set_tag("process_context", process_context) + + if tags is None: + return None + + for tag in settings_tags: + val = tags.get(tag, None) + if val not in (None, ""): + self.scope.set_tag(tag, val) + + if tags.get("_colab", None): + python_runtime = "colab" + elif tags.get("_jupyter", None): + python_runtime = "jupyter" + elif tags.get("_ipython", None): + python_runtime = "ipython" + else: + python_runtime = "python" + self.scope.set_tag("python_runtime", python_runtime) + + # Construct run_url and sweep_url given run_id and sweep_id. + for obj in ("run", "sweep"): + obj_id, obj_url = f"{obj}_id", f"{obj}_url" + if tags.get(obj_url, None): + continue + try: + app_url = wandb.util.app_url(tags["base_url"]) # type: ignore[index] + entity, project = (quote(tags[k]) for k in ("entity", "project")) # type: ignore[index] + self.scope.set_tag( + obj_url, + f"{app_url}/{entity}/{project}/{obj}s/{tags[obj_id]}", + ) + except Exception: + pass + + email = tags.get("email") + if email: + self.scope.user = {"email": email} + + self.start_session() + + +_singleton: Sentry | None = None +_singleton_lock = threading.Lock() + + +def get_sentry() -> Sentry: + """Return the Sentry singleton for the current process (fork-aware). + + Creates a new instance in child processes after fork. + Thread-safe within each process. + """ + global _singleton + + pid = os.getpid() + + with _singleton_lock: + if _singleton is not None and _singleton._pid == pid: + return _singleton + + if _singleton is None or _singleton._pid != pid: + _singleton = Sentry(pid=pid) + + return _singleton diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..74e5509d29d30102f2d37586b707e7d0722fb34e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/__init__.py @@ -0,0 +1,50 @@ +"""api.""" + +from __future__ import annotations + +from typing import Callable + +import wandb +from wandb import env, util + + +def _disable_ssl() -> Callable[[], None]: + import requests + from urllib3.exceptions import InsecureRequestWarning + + # Because third party libraries may also use requests, we monkey patch it globally + # and turn off urllib3 warnings instead printing a global warning to the user. + wandb.termwarn( + "Disabling SSL verification. Connections to this server are not verified and may be insecure!" + ) + + requests.packages.urllib3.disable_warnings(category=InsecureRequestWarning) + old_merge_environment_settings = requests.Session.merge_environment_settings + + def merge_environment_settings(self, url, proxies, stream, verify, cert): + settings = old_merge_environment_settings( + self, url, proxies, stream, verify, cert + ) + settings["verify"] = False + return settings + + requests.Session.merge_environment_settings = merge_environment_settings + + def reset(): + requests.Session.merge_environment_settings = old_merge_environment_settings + + return reset + + +if env.ssl_disabled(): + _disable_ssl() + + +reset_path = util.vendor_setup() + +from .internal import Api as InternalApi # noqa +from .public import Api as PublicApi # noqa + +reset_path() + +__all__ = ["InternalApi", "PublicApi"] diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..437a2a0428f4f232099a60c5541bc2400371ee36 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/attrs.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/attrs.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a927aa0e9f75bd6a959df4f806d413fbd8ae3682 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/attrs.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/internal.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/internal.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7c4850a85ce4eddaeafd2e742edba16b5a6c41d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/internal.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/normalize.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/normalize.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dc336262a7339148f0754b34e870a9e62433e5a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/normalize.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/paginator.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/paginator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea5434ab6853df9968c35fe971dbfd3459b4d6fd Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/__pycache__/paginator.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3079052bd4dbf24a8e40231895084394d21052c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/__init__.py @@ -0,0 +1,96 @@ +# Generated by ariadne-codegen + +__all__ = [ + "CREATE_INVITE_GQL", + "CREATE_PROJECT_GQL", + "CREATE_SERVICE_ACCOUNT_GQL", + "CREATE_TEAM_GQL", + "CREATE_USER_FROM_ADMIN_GQL", + "DELETE_API_KEY_GQL", + "DELETE_INVITE_GQL", + "GENERATE_API_KEY_GQL", + "GET_DEFAULT_ENTITY_GQL", + "GET_PROJECTS_GQL", + "GET_PROJECT_GQL", + "GET_SWEEPS_GQL", + "GET_SWEEP_GQL", + "GET_SWEEP_LEGACY_GQL", + "GET_TEAM_ENTITY_GQL", + "GET_VIEWER_GQL", + "SEARCH_USERS_GQL", + "GetProjects", + "GetProject", + "CreateProject", + "GetSweeps", + "GetSweep", + "GetSweepLegacy", + "GetTeamEntity", + "CreateTeam", + "CreateInvite", + "DeleteInvite", + "CreateServiceAccount", + "SearchUsers", + "GetViewer", + "GetDefaultEntity", + "CreateUserFromAdmin", + "DeleteApiKey", + "GenerateApiKey", + "ArtifactTypeInput", + "RateLimitsInput", + "UpsertModelInput", + "ApiKeyFragment", + "CreatedProjectFragment", + "LegacySweepFragment", + "PageInfoFragment", + "ProjectFragment", + "SweepFragment", + "UserFragment", + "UserInfoFragment", +] +from .create_invite import CreateInvite +from .create_project import CreateProject +from .create_service_account import CreateServiceAccount +from .create_team import CreateTeam +from .create_user_from_admin import CreateUserFromAdmin +from .delete_api_key import DeleteApiKey +from .delete_invite import DeleteInvite +from .fragments import ( + ApiKeyFragment, + CreatedProjectFragment, + LegacySweepFragment, + PageInfoFragment, + ProjectFragment, + SweepFragment, + UserFragment, + UserInfoFragment, +) +from .generate_api_key import GenerateApiKey +from .get_default_entity import GetDefaultEntity +from .get_project import GetProject +from .get_projects import GetProjects +from .get_sweep import GetSweep +from .get_sweep_legacy import GetSweepLegacy +from .get_sweeps import GetSweeps +from .get_team_entity import GetTeamEntity +from .get_viewer import GetViewer +from .input_types import ArtifactTypeInput, RateLimitsInput, UpsertModelInput +from .operations import ( + CREATE_INVITE_GQL, + CREATE_PROJECT_GQL, + CREATE_SERVICE_ACCOUNT_GQL, + CREATE_TEAM_GQL, + CREATE_USER_FROM_ADMIN_GQL, + DELETE_API_KEY_GQL, + DELETE_INVITE_GQL, + GENERATE_API_KEY_GQL, + GET_DEFAULT_ENTITY_GQL, + GET_PROJECT_GQL, + GET_PROJECTS_GQL, + GET_SWEEP_GQL, + GET_SWEEP_LEGACY_GQL, + GET_SWEEPS_GQL, + GET_TEAM_ENTITY_GQL, + GET_VIEWER_GQL, + SEARCH_USERS_GQL, +) +from .search_users import SearchUsers diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_invite.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_invite.py new file mode 100644 index 0000000000000000000000000000000000000000..901a4f050746df4c2a8437e312380f925475d864 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_invite.py @@ -0,0 +1,35 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLId, GQLResult + + +class CreateInvite(GQLResult): + result: Optional[CreateInviteResult] + + +class CreateInviteResult(GQLResult): + invite: Optional[CreateInviteResultInvite] + + +class CreateInviteResultInvite(GQLResult): + id: GQLId + name: str + email: Optional[str] + created_at: Optional[str] = Field(alias="createdAt") + to_user: Optional[CreateInviteResultInviteToUser] = Field(alias="toUser") + + +class CreateInviteResultInviteToUser(GQLResult): + name: str + + +CreateInvite.model_rebuild() +CreateInviteResult.model_rebuild() +CreateInviteResultInvite.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_project.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_project.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ad33fba56a9b922ed6f03f65acc8c29ab985a2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_project.py @@ -0,0 +1,24 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import CreatedProjectFragment + + +class CreateProject(GQLResult): + result: Optional[CreateProjectResult] + + +class CreateProjectResult(GQLResult): + project: Optional[CreatedProjectFragment] + model: Optional[CreatedProjectFragment] + inserted: Optional[bool] + + +CreateProject.model_rebuild() +CreateProjectResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_service_account.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_service_account.py new file mode 100644 index 0000000000000000000000000000000000000000..078f84d4a978691c31c5724acc5d7480289c8bc5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_service_account.py @@ -0,0 +1,24 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLId, GQLResult + + +class CreateServiceAccount(GQLResult): + result: Optional[CreateServiceAccountResult] + + +class CreateServiceAccountResult(GQLResult): + user: Optional[CreateServiceAccountResultUser] + + +class CreateServiceAccountResultUser(GQLResult): + id: GQLId + + +CreateServiceAccount.model_rebuild() +CreateServiceAccountResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_team.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_team.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7556e980e3abbba5bc8f0981899f32e982022e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_team.py @@ -0,0 +1,30 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLId, GQLResult + + +class CreateTeam(GQLResult): + result: Optional[CreateTeamResult] + + +class CreateTeamResult(GQLResult): + entity: Optional[CreateTeamResultEntity] + + +class CreateTeamResultEntity(GQLResult): + id: GQLId + name: str + available: Optional[bool] + photo_url: Optional[str] = Field(alias="photoUrl") + limits: Optional[str] + + +CreateTeam.model_rebuild() +CreateTeamResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_user_from_admin.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_user_from_admin.py new file mode 100644 index 0000000000000000000000000000000000000000..3fd2c95b68e586b37cb13d367911310beb4ba725 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/create_user_from_admin.py @@ -0,0 +1,22 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import UserInfoFragment + + +class CreateUserFromAdmin(GQLResult): + result: Optional[CreateUserFromAdminResult] + + +class CreateUserFromAdminResult(GQLResult): + user: Optional[UserInfoFragment] + + +CreateUserFromAdmin.model_rebuild() +CreateUserFromAdminResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/delete_api_key.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/delete_api_key.py new file mode 100644 index 0000000000000000000000000000000000000000..758a80e821c5ba987d9b7f73122d345de3a8d876 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/delete_api_key.py @@ -0,0 +1,19 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + + +class DeleteApiKey(GQLResult): + result: Optional[DeleteApiKeyResult] + + +class DeleteApiKeyResult(GQLResult): + success: Optional[bool] + + +DeleteApiKey.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/delete_invite.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/delete_invite.py new file mode 100644 index 0000000000000000000000000000000000000000..2488598f6f11c837ca26ee2899dad219b40d8575 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/delete_invite.py @@ -0,0 +1,19 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + + +class DeleteInvite(GQLResult): + result: Optional[DeleteInviteResult] + + +class DeleteInviteResult(GQLResult): + success: Optional[bool] + + +DeleteInvite.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/enums.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7d61d95669bab416e88c3acc378913704e64e8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/enums.py @@ -0,0 +1,4 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/fragments.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/fragments.py new file mode 100644 index 0000000000000000000000000000000000000000..e8444b48dcd52e9b72feb5d8e8ad9e635d393032 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/fragments.py @@ -0,0 +1,123 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field +from typing_extensions import Literal + +from wandb._pydantic import GQLId, GQLResult, Typename + + +class ApiKeyFragment(GQLResult): + id: GQLId + name: str + description: Optional[str] + + +class CreatedProjectFragment(GQLResult): + id: GQLId + name: str + entity_name: str = Field(alias="entityName") + description: Optional[str] + access: Optional[str] + views: Optional[str] + + +class LegacySweepFragment(GQLResult): + typename__: Typename[Literal["Sweep"]] = "Sweep" + id: GQLId + name: str + state: str + best_loss: Optional[float] = Field(alias="bestLoss") + config: str + + +class PageInfoFragment(GQLResult): + typename__: Typename[Literal["PageInfo"]] = "PageInfo" + end_cursor: Optional[str] = Field(alias="endCursor") + has_next_page: bool = Field(alias="hasNextPage") + + +class ProjectFragment(GQLResult): + typename__: Typename[Literal["Project"]] = "Project" + id: GQLId + name: str + entity_name: str = Field(alias="entityName") + created_at: str = Field(alias="createdAt") + is_benchmark: bool = Field(alias="isBenchmark") + + +class SweepFragment(GQLResult): + typename__: Typename[Literal["Sweep"]] = "Sweep" + id: GQLId + name: str + display_name: Optional[str] = Field(alias="displayName") + method: str + state: str + description: Optional[str] + best_loss: Optional[float] = Field(alias="bestLoss") + config: str + created_at: str = Field(alias="createdAt") + updated_at: Optional[str] = Field(alias="updatedAt") + run_count: int = Field(alias="runCount") + run_count_expected: Optional[int] = Field(alias="runCountExpected") + + +class UserFragment(GQLResult): + id: GQLId + name: str + username: Optional[str] + email: Optional[str] + admin: Optional[bool] + flags: Optional[str] + entity: Optional[str] + deleted_at: Optional[str] = Field(alias="deletedAt") + api_keys: Optional[UserFragmentApiKeys] = Field(alias="apiKeys") + teams: Optional[UserFragmentTeams] + + +class UserFragmentApiKeys(GQLResult): + edges: List[UserFragmentApiKeysEdges] + + +class UserFragmentApiKeysEdges(GQLResult): + node: Optional[ApiKeyFragment] + + +class UserFragmentTeams(GQLResult): + edges: List[UserFragmentTeamsEdges] + + +class UserFragmentTeamsEdges(GQLResult): + node: Optional[UserFragmentTeamsEdgesNode] + + +class UserFragmentTeamsEdgesNode(GQLResult): + name: str + + +class UserInfoFragment(GQLResult): + id: GQLId + name: str + username: Optional[str] + email: Optional[str] + admin: Optional[bool] + + +ApiKeyFragment.model_rebuild() +CreatedProjectFragment.model_rebuild() +LegacySweepFragment.model_rebuild() +PageInfoFragment.model_rebuild() +ProjectFragment.model_rebuild() +SweepFragment.model_rebuild() +UserFragment.model_rebuild() +UserFragmentApiKeys.model_rebuild() +UserFragmentApiKeysEdges.model_rebuild() +ApiKeyFragment.model_rebuild() +UserFragmentTeams.model_rebuild() +UserFragmentTeamsEdges.model_rebuild() +UserFragmentTeamsEdgesNode.model_rebuild() +UserInfoFragment.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/generate_api_key.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/generate_api_key.py new file mode 100644 index 0000000000000000000000000000000000000000..790ed09a87d5eed0d84992c2fe2777f23e7feca7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/generate_api_key.py @@ -0,0 +1,24 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ApiKeyFragment + + +class GenerateApiKey(GQLResult): + result: Optional[GenerateApiKeyResult] + + +class GenerateApiKeyResult(GQLResult): + api_key: Optional[ApiKeyFragment] = Field(alias="apiKey") + + +GenerateApiKey.model_rebuild() +GenerateApiKeyResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_default_entity.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_default_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..88b8fc29317c668a489879403268bd8953788ac9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_default_entity.py @@ -0,0 +1,20 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLId, GQLResult + + +class GetDefaultEntity(GQLResult): + viewer: Optional[GetDefaultEntityViewer] + + +class GetDefaultEntityViewer(GQLResult): + id: GQLId + entity: Optional[str] + + +GetDefaultEntity.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_project.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_project.py new file mode 100644 index 0000000000000000000000000000000000000000..e1232e15da7dd6826053752a4c4a6868ed484ece --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_project.py @@ -0,0 +1,17 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import ProjectFragment + + +class GetProject(GQLResult): + project: Optional[ProjectFragment] + + +GetProject.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_projects.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_projects.py new file mode 100644 index 0000000000000000000000000000000000000000..6c83f3bcd220eee360c86aaf806e1d67b3402547 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_projects.py @@ -0,0 +1,30 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import PageInfoFragment, ProjectFragment + + +class GetProjects(GQLResult): + models: Optional[GetProjectsModels] + + +class GetProjectsModels(GQLResult): + page_info: PageInfoFragment = Field(alias="pageInfo") + edges: List[GetProjectsModelsEdges] + + +class GetProjectsModelsEdges(GQLResult): + node: Optional[ProjectFragment] + + +GetProjects.model_rebuild() +GetProjectsModels.model_rebuild() +GetProjectsModelsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweep.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweep.py new file mode 100644 index 0000000000000000000000000000000000000000..92360640b7d99f20e9f3fba36294c0ea3774098c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweep.py @@ -0,0 +1,22 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import SweepFragment + + +class GetSweep(GQLResult): + project: Optional[GetSweepProject] + + +class GetSweepProject(GQLResult): + sweep: Optional[SweepFragment] + + +GetSweep.model_rebuild() +GetSweepProject.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweep_legacy.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweep_legacy.py new file mode 100644 index 0000000000000000000000000000000000000000..18e4da464bbc7dd1f18f2902d2e6c6837a9dbcd5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweep_legacy.py @@ -0,0 +1,22 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import LegacySweepFragment + + +class GetSweepLegacy(GQLResult): + project: Optional[GetSweepLegacyProject] + + +class GetSweepLegacyProject(GQLResult): + sweep: Optional[LegacySweepFragment] + + +GetSweepLegacy.model_rebuild() +GetSweepLegacyProject.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweeps.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweeps.py new file mode 100644 index 0000000000000000000000000000000000000000..ff74e654bf31620aa7245cfa274050d13a27e87b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_sweeps.py @@ -0,0 +1,36 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import PageInfoFragment, SweepFragment + + +class GetSweeps(GQLResult): + project: Optional[GetSweepsProject] + + +class GetSweepsProject(GQLResult): + total_sweeps: int = Field(alias="totalSweeps") + sweeps: Optional[GetSweepsProjectSweeps] + + +class GetSweepsProjectSweeps(GQLResult): + page_info: PageInfoFragment = Field(alias="pageInfo") + edges: List[GetSweepsProjectSweepsEdges] + + +class GetSweepsProjectSweepsEdges(GQLResult): + node: SweepFragment + + +GetSweeps.model_rebuild() +GetSweepsProject.model_rebuild() +GetSweepsProjectSweeps.model_rebuild() +GetSweepsProjectSweepsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_team_entity.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_team_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..ac6647132735708fdbdfa6ee232f82161525db02 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_team_entity.py @@ -0,0 +1,46 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLId, GQLResult + + +class GetTeamEntity(GQLResult): + entity: Optional[GetTeamEntityEntity] + + +class GetTeamEntityEntity(GQLResult): + id: GQLId + name: str + available: Optional[bool] + photo_url: Optional[str] = Field(alias="photoUrl") + read_only: Optional[bool] = Field(alias="readOnly") + read_only_admin: bool = Field(alias="readOnlyAdmin") + is_team: bool = Field(alias="isTeam") + private_only: bool = Field(alias="privateOnly") + storage_bytes: int = Field(alias="storageBytes") + code_saving_enabled: bool = Field(alias="codeSavingEnabled") + default_access: str = Field(alias="defaultAccess") + is_paid: Optional[bool] = Field(alias="isPaid") + members: List[GetTeamEntityEntityMembers] + + +class GetTeamEntityEntityMembers(GQLResult): + id: Optional[str] + admin: Optional[bool] + pending: Optional[bool] + email: Optional[str] + username: Optional[str] + name: str + photo_url: Optional[str] = Field(alias="photoUrl") + account_type: Optional[str] = Field(alias="accountType") + api_key: Optional[str] = Field(alias="apiKey") + + +GetTeamEntity.model_rebuild() +GetTeamEntityEntity.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_viewer.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..41ac3884942b3edb756611a33c6bac85cddeff30 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/get_viewer.py @@ -0,0 +1,17 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import UserFragment + + +class GetViewer(GQLResult): + viewer: Optional[UserFragment] + + +GetViewer.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/input_types.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/input_types.py new file mode 100644 index 0000000000000000000000000000000000000000..13d61aaf2748002467f7cdd5d310f9887d736ca3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/input_types.py @@ -0,0 +1,54 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLId, GQLInput + + +class UpsertModelInput(GQLInput): + name: Optional[str] = Field(default=None, max_length=128) + description: Optional[str] = None + id: Optional[str] = None + framework: Optional[str] = None + entity_name: Optional[str] = Field(alias="entityName", default=None) + docker_image: Optional[str] = Field( + alias="dockerImage", default=None, max_length=512 + ) + repo: Optional[str] = Field(default=None, max_length=256) + access: Optional[str] = None + views: Optional[str] = None + is_benchmark: Optional[bool] = Field(alias="isBenchmark", default=None) + linked_benchmark: Optional[GQLId] = Field(alias="linkedBenchmark", default=None) + is_published: Optional[bool] = Field(alias="isPublished", default=None) + owner: Optional[GQLId] = None + allow_all_artifact_types_in_registry: Optional[bool] = Field( + alias="allowAllArtifactTypesInRegistry", default=None + ) + rate_limits: Optional[RateLimitsInput] = Field(alias="rateLimits", default=None) + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + artifact_types: Optional[List[ArtifactTypeInput]] = Field( + alias="artifactTypes", default=None + ) + + +class RateLimitsInput(GQLInput): + graphql: Optional[int] = None + sdk_graphql: Optional[int] = Field(alias="sdkGraphql", default=None) + filestream_count: Optional[int] = Field(alias="filestreamCount", default=None) + filestream_size: Optional[int] = Field(alias="filestreamSize", default=None) + sdk_graphql_query_seconds: Optional[float] = Field( + alias="sdkGraphqlQuerySeconds", default=None + ) + + +class ArtifactTypeInput(GQLInput): + name: str = Field(max_length=128, pattern="^[-\\w]+([ ]*[-.\\w]+)*$") + description: Optional[str] = None + + +UpsertModelInput.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/operations.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..058d065b31d8aee8cc10003ec3254b0946cd4884 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/operations.py @@ -0,0 +1,394 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +__all__ = [ + "CREATE_INVITE_GQL", + "CREATE_PROJECT_GQL", + "CREATE_SERVICE_ACCOUNT_GQL", + "CREATE_TEAM_GQL", + "CREATE_USER_FROM_ADMIN_GQL", + "DELETE_API_KEY_GQL", + "DELETE_INVITE_GQL", + "GENERATE_API_KEY_GQL", + "GET_DEFAULT_ENTITY_GQL", + "GET_PROJECTS_GQL", + "GET_PROJECT_GQL", + "GET_SWEEPS_GQL", + "GET_SWEEP_GQL", + "GET_SWEEP_LEGACY_GQL", + "GET_TEAM_ENTITY_GQL", + "GET_VIEWER_GQL", + "SEARCH_USERS_GQL", +] + +GET_PROJECTS_GQL = """ +query GetProjects($entity: String, $cursor: String, $perPage: Int = 50) { + models(entityName: $entity, after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...ProjectFragment + } + } + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectFragment on Project { + __typename + id + name + entityName + createdAt + isBenchmark +} +""" + +GET_PROJECT_GQL = """ +query GetProject($name: String!, $entity: String!) { + project(name: $name, entityName: $entity) { + ...ProjectFragment + } +} + +fragment ProjectFragment on Project { + __typename + id + name + entityName + createdAt + isBenchmark +} +""" + +CREATE_PROJECT_GQL = """ +mutation CreateProject($input: UpsertModelInput!) { + result: upsertModel(input: $input) { + project { + ...CreatedProjectFragment + } + model { + ...CreatedProjectFragment + } + inserted + } +} + +fragment CreatedProjectFragment on Project { + id + name + entityName + description + access + views +} +""" + +GET_SWEEPS_GQL = """ +query GetSweeps($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50) { + project(name: $project, entityName: $entity) { + totalSweeps + sweeps(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...SweepFragment + } + } + } + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment SweepFragment on Sweep { + __typename + id + name + displayName + method + state + description + bestLoss + config + createdAt + updatedAt + runCount + runCountExpected +} +""" + +GET_SWEEP_GQL = """ +query GetSweep($name: String!, $project: String, $entity: String) { + project(name: $project, entityName: $entity) { + sweep(sweepName: $name) { + ...SweepFragment + } + } +} + +fragment SweepFragment on Sweep { + __typename + id + name + displayName + method + state + description + bestLoss + config + createdAt + updatedAt + runCount + runCountExpected +} +""" + +GET_SWEEP_LEGACY_GQL = """ +query GetSweepLegacy($name: String!, $project: String, $entity: String) { + project(name: $project, entityName: $entity) { + sweep(sweepName: $name) { + ...LegacySweepFragment + } + } +} + +fragment LegacySweepFragment on Sweep { + __typename + id + name + state + bestLoss + config +} +""" + +GET_TEAM_ENTITY_GQL = """ +query GetTeamEntity($name: String!) { + entity(name: $name) { + id + name + available + photoUrl + readOnly + readOnlyAdmin + isTeam + privateOnly + storageBytes + codeSavingEnabled + defaultAccess + isPaid + members { + id + admin + pending + email + username + name + photoUrl + accountType + apiKey + } + } +} +""" + +CREATE_TEAM_GQL = """ +mutation CreateTeam($teamName: String!, $teamAdminUserName: String) { + result: createTeam( + input: {teamName: $teamName, teamAdminUserName: $teamAdminUserName} + ) { + entity { + id + name + available + photoUrl + limits + } + } +} +""" + +CREATE_INVITE_GQL = """ +mutation CreateInvite($entity: String!, $email: String, $username: String, $admin: Boolean) { + result: createInvite( + input: {entityName: $entity, email: $email, username: $username, admin: $admin} + ) { + invite { + id + name + email + createdAt + toUser { + name + } + } + } +} +""" + +DELETE_INVITE_GQL = """ +mutation DeleteInvite($id: String, $entity: String) { + result: deleteInvite(input: {id: $id, entityName: $entity}) { + success + } +} +""" + +CREATE_SERVICE_ACCOUNT_GQL = """ +mutation CreateServiceAccount($entity: String!, $description: String!) { + result: createServiceAccount( + input: {description: $description, entityName: $entity} + ) { + user { + id + } + } +} +""" + +SEARCH_USERS_GQL = """ +query SearchUsers($query: String) { + users(query: $query) { + edges { + node { + ...UserFragment + } + } + } +} + +fragment ApiKeyFragment on ApiKey { + id + name + description +} + +fragment UserFragment on User { + id + name + username + email + admin + flags + entity + deletedAt + apiKeys { + edges { + node { + ...ApiKeyFragment + } + } + } + teams { + edges { + node { + name + } + } + } +} +""" + +GET_VIEWER_GQL = """ +query GetViewer { + viewer { + ...UserFragment + } +} + +fragment ApiKeyFragment on ApiKey { + id + name + description +} + +fragment UserFragment on User { + id + name + username + email + admin + flags + entity + deletedAt + apiKeys { + edges { + node { + ...ApiKeyFragment + } + } + } + teams { + edges { + node { + name + } + } + } +} +""" + +GET_DEFAULT_ENTITY_GQL = """ +query GetDefaultEntity { + viewer { + id + entity + } +} +""" + +CREATE_USER_FROM_ADMIN_GQL = """ +mutation CreateUserFromAdmin($email: String!, $admin: Boolean) { + result: createUser(input: {email: $email, admin: $admin}) { + user { + ...UserInfoFragment + } + } +} + +fragment UserInfoFragment on User { + id + name + username + email + admin +} +""" + +DELETE_API_KEY_GQL = """ +mutation DeleteApiKey($id: String!) { + result: deleteApiKey(input: {id: $id}) { + success + } +} +""" + +GENERATE_API_KEY_GQL = """ +mutation GenerateApiKey($description: String) { + result: generateApiKey(input: {description: $description}) { + apiKey { + ...ApiKeyFragment + } + } +} + +fragment ApiKeyFragment on ApiKey { + id + name + description +} +""" diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/_generated/search_users.py b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/search_users.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcce4995139fc80cf9477b9274ff8ef919bc079 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/_generated/search_users.py @@ -0,0 +1,27 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/api/ + +from __future__ import annotations + +from typing import List, Optional + +from wandb._pydantic import GQLResult + +from .fragments import UserFragment + + +class SearchUsers(GQLResult): + users: Optional[SearchUsersUsers] + + +class SearchUsersUsers(GQLResult): + edges: List[SearchUsersUsersEdges] + + +class SearchUsersUsersEdges(GQLResult): + node: Optional[UserFragment] + + +SearchUsers.model_rebuild() +SearchUsersUsers.model_rebuild() +SearchUsersUsersEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/attrs.py b/.venv/lib/python3.13/site-packages/wandb/apis/attrs.py new file mode 100644 index 0000000000000000000000000000000000000000..36f2cb627a41355ea6a4d9d79d3f2b477f99f883 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/attrs.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Any, Mapping + +import wandb + +from ..sdk.lib import ipython + + +class Attrs: + def __init__(self, attrs: Mapping[str, Any]): + self._attrs = dict(attrs) + + def snake_to_camel(self, string): + camel = "".join([i.title() for i in string.split("_")]) + return camel[0].lower() + camel[1:] + + def display(self, height=420, hidden=False) -> bool: + """Display this object in jupyter.""" + if wandb.run and wandb.run._settings.silent: + return False + + if not ipython.in_jupyter(): + return False + + html = self.to_html(height, hidden) + if html is None: + wandb.termwarn("This object does not support `.display()`") + return False + + try: + from IPython import display + except ImportError: + wandb.termwarn(".display() only works in jupyter environments") + return False + + display.display(display.HTML(html)) + return True + + def to_html(self, *args, **kwargs): + return None + + def __getattr__(self, name): + key = self.snake_to_camel(name) + if key == "user": + raise AttributeError + if key in self._attrs.keys(): + return self._attrs[key] + elif name in self._attrs.keys(): + return self._attrs[name] + else: + raise AttributeError(f"{repr(self)!r} object has no attribute {name!r}") diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0266b9f204d4a5487c5d66a6d2487c3670055082 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/__init__.py @@ -0,0 +1 @@ +from .internals.util import Namespace diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/internal.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/internal.py new file mode 100644 index 0000000000000000000000000000000000000000..10bf67ba20255deb63980e60500105b4b327b41b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/internal.py @@ -0,0 +1,375 @@ +import json +import logging +import math +import os +import queue +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Optional + +import numpy as np +from tenacity import retry, stop_after_attempt, wait_random_exponential + +from wandb import Artifact +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto import wandb_telemetry_pb2 as telem_pb +from wandb.sdk.interface.interface import file_policy_to_enum +from wandb.sdk.interface.interface_queue import InterfaceQueue +from wandb.sdk.internal import context +from wandb.sdk.internal.sender import SendManager +from wandb.sdk.internal.settings_static import SettingsStatic +from wandb.util import coalesce, recursive_cast_dictlike_to_dict + +from .protocols import ImporterRun + +ROOT_DIR = "./wandb-importer" + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +if os.getenv("WANDB_IMPORTER_ENABLE_RICH_LOGGING"): + from rich.logging import RichHandler + + logger.addHandler(RichHandler(rich_tracebacks=True, tracebacks_show_locals=True)) +else: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + + +exp_retry = retry( + wait=wait_random_exponential(multiplier=1, max=10), stop=stop_after_attempt(3) +) + + +class AlternateSendManager(SendManager): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._send_artifact = exp_retry(self._send_artifact) + + +@dataclass(frozen=True) +class SendManagerConfig: + """Configure which parts of SendManager tooling to use.""" + + use_artifacts: bool = False + log_artifacts: bool = False + metadata: bool = False + files: bool = False + media: bool = False + code: bool = False + history: bool = False + summary: bool = False + terminal_output: bool = False + + +@dataclass +class RecordMaker: + run: ImporterRun + interface: InterfaceQueue = InterfaceQueue() + + @property + def run_dir(self) -> str: + p = Path(f"{ROOT_DIR}/{self.run.run_id()}/wandb") + p.mkdir(parents=True, exist_ok=True) + return f"{ROOT_DIR}/{self.run.run_id()}" + + def make_artifacts_only_records( + self, + artifacts: Optional[Iterable[Artifact]] = None, + used_artifacts: Optional[Iterable[Artifact]] = None, + ) -> Iterable[pb.Record]: + """Only make records required to upload artifacts. + + Escape hatch for adding extra artifacts to a run. + """ + yield self._make_run_record() + + if used_artifacts: + for art in used_artifacts: + yield self._make_artifact_record(art, use_artifact=True) + + if artifacts: + for art in artifacts: + yield self._make_artifact_record(art) + + def make_records( + self, + config: SendManagerConfig, + ) -> Iterable[pb.Record]: + """Make all the records that constitute a run.""" + yield self._make_run_record() + yield self._make_telem_record() + + include_artifacts = config.log_artifacts or config.use_artifacts + yield self._make_files_record( + include_artifacts, config.files, config.media, config.code + ) + + if config.use_artifacts: + if (used_artifacts := self.run.used_artifacts()) is not None: + for artifact in used_artifacts: + yield self._make_artifact_record(artifact, use_artifact=True) + + if config.log_artifacts: + if (artifacts := self.run.artifacts()) is not None: + for artifact in artifacts: + yield self._make_artifact_record(artifact) + + if config.history: + yield from self._make_history_records() + + if config.summary: + yield self._make_summary_record() + + if config.terminal_output: + if (lines := self.run.logs()) is not None: + for line in lines: + yield self._make_output_record(line) + + def _make_run_record(self) -> pb.Record: + run = pb.RunRecord() + run.run_id = self.run.run_id() + run.entity = self.run.entity() + run.project = self.run.project() + run.display_name = coalesce(self.run.display_name()) + run.notes = coalesce(self.run.notes(), "") + run.tags.extend(coalesce(self.run.tags(), [])) + run.start_time.FromMilliseconds(self.run.start_time()) + + host = self.run.host() + if host is not None: + run.host = host + + runtime = self.run.runtime() + if runtime is not None: + run.runtime = runtime + + run_group = self.run.run_group() + if run_group is not None: + run.run_group = run_group + + config = self.run.config() + if "_wandb" not in config: + config["_wandb"] = {} + + # how do I get this automatically? + config["_wandb"]["code_path"] = self.run.code_path() + config["_wandb"]["python_version"] = self.run.python_version() + config["_wandb"]["cli_version"] = self.run.cli_version() + + self.interface._make_config( + data=config, + obj=run.config, + ) # is there a better way? + return self.interface._make_record(run=run) + + def _make_output_record(self, line) -> pb.Record: + output_raw = pb.OutputRawRecord() + output_raw.output_type = pb.OutputRawRecord.OutputType.STDOUT + output_raw.line = line + return self.interface._make_record(output_raw=output_raw) + + def _make_summary_record(self) -> pb.Record: + d: dict = { + **self.run.summary(), + "_runtime": self.run.runtime(), # quirk of runtime -- it has to be here! + # '_timestamp': self.run.start_time()/1000, + } + d = recursive_cast_dictlike_to_dict(d) + summary = self.interface._make_summary_from_dict(d) + return self.interface._make_record(summary=summary) + + def _make_history_records(self) -> Iterable[pb.Record]: + for metrics in self.run.metrics(): + history = pb.HistoryRecord() + for k, v in metrics.items(): + item = history.item.add() + item.key = k + # There seems to be some conversion issue to breaks when we try to re-upload. + # np.NaN gets converted to float("nan"), which is not expected by our system. + # If this cast to string (!) is not done, the row will be dropped. + if (isinstance(v, float) and math.isnan(v)) or v == "NaN": + v = np.NaN + + if isinstance(v, bytes): + # it's a json string encoded as bytes + v = v.decode("utf-8") + else: + v = json.dumps(v) + + item.value_json = v + rec = self.interface._make_record(history=history) + yield rec + + def _make_files_record( + self, artifacts: bool, files: bool, media: bool, code: bool + ) -> pb.Record: + run_files = self.run.files() + metadata_fname = f"{self.run_dir}/files/wandb-metadata.json" + if not files or run_files is None: + # We'll always need a metadata file even if there are no other files to upload + metadata_fname = self._make_metadata_file() + run_files = [(metadata_fname, "end")] + files_record = pb.FilesRecord() + for path, policy in run_files: + if not artifacts and path.startswith("artifact/"): + continue + if not media and path.startswith("media/"): + continue + if not code and path.startswith("code/"): + continue + + # DirWatcher requires the path to start with media/ instead of the full path + if "media" in path: + p = Path(path) + path = str(p.relative_to(f"{self.run_dir}/files")) + f = files_record.files.add() + f.path = path + f.policy = file_policy_to_enum(policy) + + return self.interface._make_record(files=files_record) + + def _make_artifact_record( + self, artifact: Artifact, use_artifact=False + ) -> pb.Record: + proto = self.interface._make_artifact(artifact) + proto.run_id = str(self.run.run_id()) + proto.project = str(self.run.project()) + proto.entity = str(self.run.entity()) + proto.user_created = use_artifact + proto.use_after_commit = use_artifact + proto.finalize = True + + aliases = artifact._aliases + aliases += ["latest", "imported"] + + for alias in aliases: + proto.aliases.append(alias) + return self.interface._make_record(artifact=proto) + + def _make_telem_record(self) -> pb.Record: + telem = telem_pb.TelemetryRecord() + + feature = telem_pb.Feature() + feature.importer_mlflow = True + telem.feature.CopyFrom(feature) + + cli_version = self.run.cli_version() + if cli_version: + telem.cli_version = cli_version + + python_version = self.run.python_version() + if python_version: + telem.python_version = python_version + + return self.interface._make_record(telemetry=telem) + + def _make_metadata_file(self) -> str: + missing_text = "This data was not captured" + files_dir = f"{self.run_dir}/files" + os.makedirs(files_dir, exist_ok=True) + + d = {} + d["os"] = coalesce(self.run.os_version(), missing_text) + d["python"] = coalesce(self.run.python_version(), missing_text) + d["program"] = coalesce(self.run.program(), missing_text) + d["cuda"] = coalesce(self.run.cuda_version(), missing_text) + d["host"] = coalesce(self.run.host(), missing_text) + d["username"] = coalesce(self.run.username(), missing_text) + d["executable"] = coalesce(self.run.executable(), missing_text) + + gpus_used = self.run.gpus_used() + if gpus_used is not None: + d["gpu_devices"] = json.dumps(gpus_used) + d["gpu_count"] = json.dumps(len(gpus_used)) + + cpus_used = self.run.cpus_used() + if cpus_used is not None: + d["cpu_count"] = json.dumps(self.run.cpus_used()) + + mem_used = self.run.memory_used() + if mem_used is not None: + d["memory"] = json.dumps({"total": self.run.memory_used()}) + + fname = f"{files_dir}/wandb-metadata.json" + with open(fname, "w") as f: + f.write(json.dumps(d)) + return fname + + +def _make_settings( + root_dir: str, settings_override: Optional[Dict[str, Any]] = None +) -> SettingsStatic: + _settings_override = coalesce(settings_override, {}) + + return SettingsStatic( + { + "x_files_dir": os.path.join(root_dir, "files"), + "root_dir": root_dir, + "resume": "never", + "program": None, + "ignore_globs": [], + "disable_job_creation": True, + "x_start_time": 0, + "x_sync": True, + "x_live_policy_rate_limit": 15, # matches dir_watcher + "x_live_policy_wait_time": 600, # matches dir_watcher + "x_file_stream_timeout_seconds": 60, + **_settings_override, + } + ) + + +def send_run( + run: ImporterRun, + *, + extra_arts: Optional[Iterable[Artifact]] = None, + extra_used_arts: Optional[Iterable[Artifact]] = None, + config: Optional[SendManagerConfig] = None, + overrides: Optional[Dict[str, Any]] = None, + settings_override: Optional[Dict[str, Any]] = None, +) -> None: + if config is None: + config = SendManagerConfig() + + # does this need to be here for pmap? + if overrides: + for k, v in overrides.items(): + # `lambda: v` won't work! + # https://stackoverflow.com/questions/10802002/why-deepcopy-doesnt-create-new-references-to-lambda-function + setattr(run, k, lambda v=v: v) + + rm = RecordMaker(run) + root_dir = rm.run_dir + + settings = _make_settings(root_dir, settings_override) + sm_record_q = queue.Queue() + # wm_record_q = queue.Queue() + result_q = queue.Queue() + interface = InterfaceQueue(record_q=sm_record_q) + context_keeper = context.ContextKeeper() + sm = AlternateSendManager( + settings, sm_record_q, result_q, interface, context_keeper + ) + + if extra_arts or extra_used_arts: + records = rm.make_artifacts_only_records(extra_arts, extra_used_arts) + else: + records = rm.make_records(config) + + for r in records: + logger.debug(f"Sending {r=}") + # In a future update, it might be good to write to a transaction log and have + # incremental uploads only send the missing records. + # wm.write(r) + + sm.send(r) + + sm.finish() + # wm.finish() diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/protocols.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/protocols.py new file mode 100644 index 0000000000000000000000000000000000000000..87d42e99175ce5d23db6e375a8cfcaa66568722a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/protocols.py @@ -0,0 +1,103 @@ +import logging +from typing import ( + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Protocol, + Tuple, + runtime_checkable, +) + +from wandb.sdk.artifacts.artifact import Artifact + +logger = logging.getLogger("import_logger") + +PathStr = str +Policy = Literal["now", "end", "live"] + + +@runtime_checkable +class ImporterRun(Protocol): + def run_id(self) -> str: ... # pragma: no cover + + def entity(self) -> str: ... # pragma: no cover + + def project(self) -> str: ... # pragma: no cover + + def config(self) -> Dict[str, Any]: ... # pragma: no cover + + def summary(self) -> Dict[str, float]: ... # pragma: no cover + + def metrics(self) -> Iterable[Dict[str, float]]: + """Metrics for the run. + + We expect metrics in this shape: + + [ + {'metric1': 1, 'metric2': 1, '_step': 0}, + {'metric1': 2, 'metric2': 4, '_step': 1}, + {'metric1': 3, 'metric2': 9, '_step': 2}, + ... + ] + + You can also submit metrics in this shape: + [ + {'metric1': 1, '_step': 0}, + {'metric2': 1, '_step': 0}, + {'metric1': 2, '_step': 1}, + {'metric2': 4, '_step': 1}, + ... + ] + """ + ... # pragma: no cover + + def run_group(self) -> Optional[str]: ... # pragma: no cover + + def job_type(self) -> Optional[str]: ... # pragma: no cover + + def display_name(self) -> str: ... # pragma: no cover + + def notes(self) -> Optional[str]: ... # pragma: no cover + + def tags(self) -> Optional[List[str]]: ... # pragma: no cover + + def artifacts(self) -> Optional[Iterable[Artifact]]: ... # pragma: no cover + + def used_artifacts(self) -> Optional[Iterable[Artifact]]: ... # pragma: no cover + + def os_version(self) -> Optional[str]: ... # pragma: no cover + + def python_version(self) -> Optional[str]: ... # pragma: no cover + + def cuda_version(self) -> Optional[str]: ... # pragma: no cover + + def program(self) -> Optional[str]: ... # pragma: no cover + + def host(self) -> Optional[str]: ... # pragma: no cover + + def username(self) -> Optional[str]: ... # pragma: no cover + + def executable(self) -> Optional[str]: ... # pragma: no cover + + def gpus_used(self) -> Optional[str]: ... # pragma: no cover + + def cpus_used(self) -> Optional[int]: ... # pragma: no cover + + def memory_used(self) -> Optional[int]: ... # pragma: no cover + + def runtime(self) -> Optional[int]: ... # pragma: no cover + + def start_time(self) -> Optional[int]: ... # pragma: no cover + + def code_path(self) -> Optional[str]: ... # pragma: no cover + + def cli_version(self) -> Optional[str]: ... # pragma: no cover + + def files( + self, + ) -> Optional[Iterable[Tuple[PathStr, Policy]]]: ... # pragma: no cover + + def logs(self) -> Optional[Iterable[str]]: ... # pragma: no cover diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/util.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/util.py new file mode 100644 index 0000000000000000000000000000000000000000..d77585f42e35c26ddc2607620e30888e8b007286 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/internals/util.py @@ -0,0 +1,78 @@ +import logging +import sys +import traceback +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Iterable, Optional + + +@dataclass(frozen=True) +class Namespace: + """Configure an alternate entity/project at the dst server your data will end up in.""" + + entity: str + project: str + + @classmethod + def from_path(cls, path: str): + entity, project = path.split("/") + return cls(entity, project) + + @property + def path(self): + return f"{self.entity}/{self.project}" + + @property + def send_manager_overrides(self): + overrides = {} + if self.entity: + overrides["entity"] = self.entity + if self.project: + overrides["project"] = self.project + return overrides + + +logger = logging.getLogger("import_logger") + + +def parallelize( + func, + iterable: Iterable, + *args, + max_workers: Optional[int] = None, + raise_on_error: bool = False, + **kwargs, +): + def safe_func(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + _, _, exc_traceback = sys.exc_info() + traceback_details = traceback.extract_tb(exc_traceback) + filename = traceback_details[-1].filename + lineno = traceback_details[-1].lineno + logger.debug( + f"Exception: {func=} {args=} {kwargs=} {e=} {filename=} {lineno=}. {traceback_details=}" + ) + if raise_on_error: + raise + + results = [] + with ThreadPoolExecutor(max_workers) as exc: + futures = {exc.submit(safe_func, x, *args, **kwargs): x for x in iterable} + for future in as_completed(futures): + results.append(future.result()) + return results + + +def for_each( + func, iterable: Iterable, parallel: bool = True, max_workers: Optional[int] = None +): + if parallel: + return parallelize( + func, + iterable, + max_workers=max_workers, + ) + + return [func(x) for x in iterable] diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/mlflow.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/mlflow.py new file mode 100644 index 0000000000000000000000000000000000000000..1aa5e018bcd9c0f1d1d41eb40413f28d7035a498 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/mlflow.py @@ -0,0 +1,254 @@ +import itertools +import logging +import re +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import mlflow +from packaging.version import Version # type: ignore + +import wandb +from wandb import Artifact + +from .internals import internal +from .internals.util import Namespace, for_each + +mlflow_version = Version(mlflow.__version__) + +logger = logging.getLogger("import_logger") + + +class MlflowRun: + def __init__(self, run, mlflow_client): + self.run = run + self.mlflow_client: mlflow.MlflowClient = mlflow_client + + def run_id(self) -> str: + return self.run.info.run_id + + def entity(self) -> str: + return self.run.info.user_id + + def project(self) -> str: + return "imported-from-mlflow" + + def config(self) -> Dict[str, Any]: + conf = self.run.data.params + + # Add tags here since mlflow supports very long tag names but we only support up to 64 chars + tags = { + k: v for k, v in self.run.data.tags.items() if not k.startswith("mlflow.") + } + return {**conf, "imported_mlflow_tags": tags} + + def summary(self) -> Dict[str, float]: + return self.run.data.metrics + + def metrics(self) -> Iterable[Dict[str, float]]: + d: Dict[int, Dict[str, float]] = defaultdict(dict) + for k in self.run.data.metrics.keys(): + metric = self.mlflow_client.get_metric_history(self.run.info.run_id, k) + for item in metric: + d[item.step][item.key] = item.value + + for k, v in d.items(): + yield {"_step": k, **v} + + def run_group(self) -> Optional[str]: + # this is nesting? Parent at `run.info.tags.get("mlflow.parentRunId")` + return f"Experiment {self.run.info.experiment_id}" + + def job_type(self) -> Optional[str]: + # Is this the right approach? + return f"User {self.run.info.user_id}" + + def display_name(self) -> str: + if mlflow_version < Version("1.30.0"): + return self.run.data.tags["mlflow.runName"] + return self.run.info.run_name + + def notes(self) -> Optional[str]: + return self.run.data.tags.get("mlflow.note.content") + + def tags(self) -> Optional[List[str]]: + ... + + # W&B tags are different than mlflow tags. + # The full mlflow tags are added to config under key `imported_mlflow_tags` instead + + def artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore + if mlflow_version < Version("2.0.0"): + dir_path = self.mlflow_client.download_artifacts( + run_id=self.run.info.run_id, + path="", + ) + else: + dir_path = mlflow.artifacts.download_artifacts(run_id=self.run.info.run_id) + + # Since mlflow doesn't have extra metadata about the artifacts, + # we just lump them all together into a single wandb.Artifact + artifact_name = self._handle_incompatible_strings(self.display_name()) + art = wandb.Artifact(artifact_name, "imported-artifacts") + art.add_dir(dir_path) + + return [art] + + def used_artifacts(self) -> Optional[Iterable[Artifact]]: # type: ignore + ... # pragma: no cover + + def os_version(self) -> Optional[str]: ... # pragma: no cover + + def python_version(self) -> Optional[str]: ... # pragma: no cover + + def cuda_version(self) -> Optional[str]: ... # pragma: no cover + + def program(self) -> Optional[str]: ... # pragma: no cover + + def host(self) -> Optional[str]: ... # pragma: no cover + + def username(self) -> Optional[str]: ... # pragma: no cover + + def executable(self) -> Optional[str]: ... # pragma: no cover + + def gpus_used(self) -> Optional[str]: ... # pragma: no cover + + def cpus_used(self) -> Optional[int]: # can we get the model? + ... # pragma: no cover + + def memory_used(self) -> Optional[int]: ... # pragma: no cover + + def runtime(self) -> Optional[int]: + end_time = ( + self.run.info.end_time // 1000 + if self.run.info.end_time is not None + else self.start_time() + ) + return end_time - self.start_time() + + def start_time(self) -> Optional[int]: + return self.run.info.start_time // 1000 + + def code_path(self) -> Optional[str]: ... # pragma: no cover + + def cli_version(self) -> Optional[str]: ... # pragma: no cover + + def files(self) -> Optional[Iterable[Tuple[str, str]]]: ... # pragma: no cover + + def logs(self) -> Optional[Iterable[str]]: ... # pragma: no cover + + @staticmethod + def _handle_incompatible_strings(s: str) -> str: + valid_chars = r"[^a-zA-Z0-9_\-\.]" + replacement = "__" + + return re.sub(valid_chars, replacement, s) + + +class MlflowImporter: + def __init__( + self, + dst_base_url: str, + dst_api_key: str, + mlflow_tracking_uri: str, + mlflow_registry_uri: Optional[str] = None, + *, + custom_api_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.dst_base_url = dst_base_url + self.dst_api_key = dst_api_key + + if custom_api_kwargs is None: + custom_api_kwargs = {"timeout": 600} + + self.dst_api = wandb.Api( + api_key=dst_api_key, + overrides={"base_url": dst_base_url}, + **custom_api_kwargs, + ) + self.mlflow_tracking_uri = mlflow_tracking_uri + mlflow.set_tracking_uri(self.mlflow_tracking_uri) + + if mlflow_registry_uri: + mlflow.set_registry_uri(mlflow_registry_uri) + + self.mlflow_client = mlflow.tracking.MlflowClient(mlflow_tracking_uri) + + def __repr__(self): + return f"" + + def collect_runs(self, *, limit: Optional[int] = None) -> Iterable[MlflowRun]: + if mlflow_version < Version("1.28.0"): + experiments = self.mlflow_client.list_experiments() + else: + experiments = self.mlflow_client.search_experiments() + + def _runs(): + for exp in experiments: + for run in self.mlflow_client.search_runs(exp.experiment_id): + yield MlflowRun(run, self.mlflow_client) + + runs = itertools.islice(_runs(), limit) + yield from runs + + def _import_run( + self, + run: MlflowRun, + *, + artifacts: bool = True, + namespace: Optional[Namespace] = None, + config: Optional[internal.SendManagerConfig] = None, + ) -> None: + if namespace is None: + namespace = Namespace(run.entity(), run.project()) + + if config is None: + config = internal.SendManagerConfig( + metadata=True, + files=True, + media=True, + code=True, + history=True, + summary=True, + terminal_output=True, + ) + + settings_override = { + "api_key": self.dst_api_key, + "base_url": self.dst_base_url, + "resume": "allow", + "resumed": True, + } + + mlflow.set_tracking_uri(self.mlflow_tracking_uri) + internal.send_run( + run, + overrides=namespace.send_manager_overrides, + settings_override=settings_override, + config=config, + ) + + # in mlflow, the artifacts come with the runs, so import them together + if artifacts: + arts = list(run.artifacts()) + logger.debug(f"Importing history artifacts, {run=}") + internal.send_run( + run, + extra_arts=arts, + overrides=namespace.send_manager_overrides, + settings_override=settings_override, + config=internal.SendManagerConfig(log_artifacts=True), + ) + + def import_runs( + self, + runs: Iterable[MlflowRun], + *, + artifacts: bool = True, + namespace: Optional[Namespace] = None, + parallel: bool = True, + max_workers: Optional[int] = None, + ) -> None: + def _import_run_wrapped(run): + self._import_run(run, namespace=namespace, artifacts=artifacts) + + for_each(_import_run_wrapped, runs, parallel=parallel, max_workers=max_workers) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/validation.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..219c40285935b6681c819c2c6de006252a806aa7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/validation.py @@ -0,0 +1,108 @@ +import filecmp +import logging +import os + +import requests + +import wandb + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _compare_artifact_manifests( + src_art: wandb.Artifact, dst_art: wandb.Artifact +) -> list: + problems = [] + if isinstance(dst_art, wandb.CommError): + return ["commError"] + + if src_art.digest != dst_art.digest: + problems.append(f"digest mismatch {src_art.digest=}, {dst_art.digest=}") + + for name, src_entry in src_art.manifest.entries.items(): + dst_entry = dst_art.manifest.entries.get(name) + if dst_entry is None: + problems.append(f"missing manifest entry {name=}, {src_entry=}") + continue + + for attr in ["path", "digest", "size"]: + if getattr(src_entry, attr) != getattr(dst_entry, attr): + problems.append( + f"manifest entry mismatch {attr=}, {getattr(src_entry, attr)=}, {getattr(dst_entry, attr)=}" + ) + + return problems + + +def _compare_artifact_dirs(src_dir, dst_dir) -> list: + def compare(src_dir, dst_dir): + comparison = filecmp.dircmp(src_dir, dst_dir) + differences = { + "left_only": comparison.left_only, + "right_only": comparison.right_only, + "diff_files": comparison.diff_files, + "subdir_differences": {}, + } + + # Recursively find differences in subdirectories + for subdir in comparison.subdirs: + subdir_src = os.path.join(src_dir, subdir) + subdir_dst = os.path.join(dst_dir, subdir) + subdir_differences = compare(subdir_src, subdir_dst) + # If there are differences, add them to the result + if subdir_differences and any(subdir_differences.values()): + differences["subdir_differences"][subdir] = subdir_differences + + if all(not diff for diff in differences.values()): + return None + + return differences + + return compare(src_dir, dst_dir) + + +def _check_entries_are_downloadable(art): + entries = _collect_entries(art) + for entry in entries: + if not _check_entry_is_downloable(entry): + return False + return True + + +def _collect_entries(art): + has_next_page = True + cursor = None + entries = [] + while has_next_page: + attrs = art._fetch_file_urls(cursor) + has_next_page = attrs["pageInfo"]["hasNextPage"] + cursor = attrs["pageInfo"]["endCursor"] + for edge in attrs["edges"]: + name = edge["node"]["name"] + entry = art.get_entry(name) + entry._download_url = edge["node"]["directUrl"] + entries.append(entry) + return entries + + +def _check_entry_is_downloable(entry): + url = entry._download_url + expected_size = entry.size + + try: + resp = requests.head(url, allow_redirects=True) + except Exception: + logger.exception(f"Problem validating {entry=}") + return False + + if resp.status_code != 200: + return False + + actual_size = resp.headers.get("content-length", -1) + actual_size = int(actual_size) + + if expected_size == actual_size: + return True + + return False diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/importers/wandb.py b/.venv/lib/python3.13/site-packages/wandb/apis/importers/wandb.py new file mode 100644 index 0000000000000000000000000000000000000000..99a350c5a6b2deab56cac4eb3b42e701f2f0447b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/importers/wandb.py @@ -0,0 +1,1609 @@ +"""Tooling for the W&B Importer.""" + +import itertools +import json +import logging +import numbers +import os +import re +import shutil +from dataclasses import dataclass, field +from datetime import datetime as dt +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple +from unittest.mock import patch + +import filelock +import polars as pl +import requests +import urllib3 +import wandb_workspaces.reports.v1 as wr +import yaml +from wandb_gql import gql +from wandb_workspaces.reports.v1 import Report + +import wandb +from wandb.apis.public import ArtifactCollection, Run +from wandb.apis.public.files import File +from wandb.sdk.lib import json_util +from wandb.util import coalesce, remove_keys_with_none_values + +from . import validation +from .internals import internal +from .internals.protocols import PathStr, Policy +from .internals.util import Namespace, for_each + +Artifact = wandb.Artifact +Api = wandb.Api +Project = wandb.apis.public.Project + +ARTIFACT_ERRORS_FNAME = "artifact_errors.jsonl" +ARTIFACT_SUCCESSES_FNAME = "artifact_successes.jsonl" +RUN_ERRORS_FNAME = "run_errors.jsonl" +RUN_SUCCESSES_FNAME = "run_successes.jsonl" + +ART_SEQUENCE_DUMMY_PLACEHOLDER = "__ART_SEQUENCE_DUMMY_PLACEHOLDER__" +RUN_DUMMY_PLACEHOLDER = "__RUN_DUMMY_PLACEHOLDER__" +ART_DUMMY_PLACEHOLDER_PATH = "__importer_temp__" +ART_DUMMY_PLACEHOLDER_TYPE = "__temp__" + +SRC_ART_PATH = "./artifacts/src" +DST_ART_PATH = "./artifacts/dst" + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +if os.getenv("WANDB_IMPORTER_ENABLE_RICH_LOGGING"): + from rich.logging import RichHandler + + logger.addHandler(RichHandler(rich_tracebacks=True, tracebacks_show_locals=True)) +else: + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + console_handler.setFormatter(formatter) + + logger.addHandler(console_handler) + + +@dataclass +class ArtifactSequence: + artifacts: Iterable[wandb.Artifact] + entity: str + project: str + type_: str + name: str + + def __iter__(self) -> Iterator: + return iter(self.artifacts) + + def __repr__(self) -> str: + return f"ArtifactSequence({self.identifier})" + + @property + def identifier(self) -> str: + return "/".join([self.entity, self.project, self.type_, self.name]) + + @classmethod + def from_collection(cls, collection: ArtifactCollection): + arts = collection.artifacts() + arts = sorted(arts, key=lambda a: int(a.version.lstrip("v"))) + return ArtifactSequence( + arts, + collection.entity, + collection.project, + collection.type, + collection.name, + ) + + +class WandbRun: + def __init__( + self, + run: Run, + *, + src_base_url: str, + src_api_key: str, + dst_base_url: str, + dst_api_key: str, + ) -> None: + self.run = run + self.api = wandb.Api( + api_key=src_api_key, + overrides={"base_url": src_base_url}, + ) + self.dst_api = wandb.Api( + api_key=dst_api_key, + overrides={"base_url": dst_base_url}, + ) + + # For caching + self._files: Optional[Iterable[Tuple[str, str]]] = None + self._artifacts: Optional[Iterable[Artifact]] = None + self._used_artifacts: Optional[Iterable[Artifact]] = None + self._parquet_history_paths: Optional[Iterable[str]] = None + + def __repr__(self) -> str: + s = os.path.join(self.entity(), self.project(), self.run_id()) + return f"WandbRun({s})" + + def run_id(self) -> str: + return self.run.id + + def entity(self) -> str: + return self.run.entity + + def project(self) -> str: + return self.run.project + + def config(self) -> Dict[str, Any]: + return self.run.config + + def summary(self) -> Dict[str, float]: + s = self.run.summary + return s + + def metrics(self) -> Iterable[Dict[str, float]]: + if self._parquet_history_paths is None: + self._parquet_history_paths = list(self._get_parquet_history_paths()) + + if self._parquet_history_paths: + rows = self._get_rows_from_parquet_history_paths() + else: + logger.warning( + "No parquet files detected; using scan history (this may not be reliable)" + ) + rows = self.run.scan_history() + + for row in rows: + row = remove_keys_with_none_values(row) + yield row + + def run_group(self) -> Optional[str]: + return self.run.group + + def job_type(self) -> Optional[str]: + return self.run.job_type + + def display_name(self) -> str: + return self.run.display_name + + def notes(self) -> Optional[str]: + # Notes includes the previous notes and serves as a catch-all for things we missed or can't add back + previous_link = f"Imported from: {self.run.url}" + previous_author = f"Author: {self.run.user.username}" + + header = [previous_link, previous_author] + previous_notes = self.run.notes or "" + + return "\n".join(header) + "\n---\n" + previous_notes + + def tags(self) -> Optional[List[str]]: + return self.run.tags + + def artifacts(self) -> Optional[Iterable[Artifact]]: + if self._artifacts is None: + _artifacts = [] + for art in self.run.logged_artifacts(): + a = _clone_art(art) + _artifacts.append(a) + self._artifacts = _artifacts + + yield from self._artifacts + + def used_artifacts(self) -> Optional[Iterable[Artifact]]: + if self._used_artifacts is None: + _used_artifacts = [] + for art in self.run.used_artifacts(): + a = _clone_art(art) + _used_artifacts.append(a) + self._used_artifacts = _used_artifacts + + yield from self._used_artifacts + + def os_version(self) -> Optional[str]: ... # pragma: no cover + + def python_version(self) -> Optional[str]: + return self._metadata_file().get("python") + + def cuda_version(self) -> Optional[str]: ... # pragma: no cover + + def program(self) -> Optional[str]: ... # pragma: no cover + + def host(self) -> Optional[str]: + return self._metadata_file().get("host") + + def username(self) -> Optional[str]: ... # pragma: no cover + + def executable(self) -> Optional[str]: ... # pragma: no cover + + def gpus_used(self) -> Optional[str]: ... # pragma: no cover + + def cpus_used(self) -> Optional[int]: # can we get the model? + ... # pragma: no cover + + def memory_used(self) -> Optional[int]: ... # pragma: no cover + + def runtime(self) -> Optional[int]: + wandb_runtime = self.run.summary.get("_wandb", {}).get("runtime") + base_runtime = self.run.summary.get("_runtime") + + if (t := coalesce(wandb_runtime, base_runtime)) is None: + return t + return int(t) + + def start_time(self) -> Optional[int]: + t = dt.fromisoformat(self.run.created_at).timestamp() * 1000 + return int(t) + + def code_path(self) -> Optional[str]: + path = self._metadata_file().get("codePath", "") + return f"code/{path}" + + def cli_version(self) -> Optional[str]: + return self._config_file().get("_wandb", {}).get("value", {}).get("cli_version") + + def files(self) -> Optional[Iterable[Tuple[PathStr, Policy]]]: + if self._files is None: + files_dir = f"{internal.ROOT_DIR}/{self.run_id()}/files" + _files = [] + for f in self.run.files(): + f: File + # These optimizations are intended to avoid rate limiting when importing many runs in parallel + # Don't carry over empty files + if f.size == 0: + continue + # Skip deadlist to avoid overloading S3 + if "wandb_manifest.json.deadlist" in f.name: + continue + + result = f.download(files_dir, exist_ok=True, api=self.api) + file_and_policy = (result.name, "end") + _files.append(file_and_policy) + self._files = _files + + yield from self._files + + def logs(self) -> Optional[Iterable[str]]: + log_files = self._find_all_in_files_regex(r"^.*output\.log$") + for path in log_files: + with open(path) as f: + yield from f.readlines() + + def _metadata_file(self) -> Dict[str, Any]: + if (fname := self._find_in_files("wandb-metadata.json")) is None: + return {} + + with open(fname) as f: + return json.loads(f.read()) + + def _config_file(self) -> Dict[str, Any]: + if (fname := self._find_in_files("config.yaml")) is None: + return {} + + with open(fname) as f: + return yaml.safe_load(f) or {} + + def _get_rows_from_parquet_history_paths(self) -> Iterable[Dict[str, Any]]: + # Unfortunately, it's not feasible to validate non-parquet history + if not (paths := self._get_parquet_history_paths()): + yield {} + return + + # Collect and merge parquet history + dfs = [ + pl.read_parquet(p) for path in paths for p in Path(path).glob("*.parquet") + ] + if "_step" in (df := _merge_dfs(dfs)): + df = df.with_columns(pl.col("_step").cast(pl.Int64)) + yield from df.iter_rows(named=True) + + def _get_parquet_history_paths(self) -> Iterable[str]: + if self._parquet_history_paths is None: + paths = [] + # self.artifacts() returns a copy of the artifacts; use this to get raw + for art in self.run.logged_artifacts(): + if art.type != "wandb-history": + continue + if ( + path := _download_art(art, root=f"{SRC_ART_PATH}/{art.name}") + ) is None: + continue + paths.append(path) + self._parquet_history_paths = paths + + yield from self._parquet_history_paths + + def _find_in_files(self, name: str) -> Optional[str]: + if files := self.files(): + for path, _ in files: + if name in path: + return path + return None + + def _find_all_in_files_regex(self, regex: str) -> Iterable[str]: + if files := self.files(): + for path, _ in files: + if re.match(regex, path): + yield path + + +class WandbImporter: + """Transfers runs, reports, and artifact sequences between W&B instances.""" + + def __init__( + self, + src_base_url: str, + src_api_key: str, + dst_base_url: str, + dst_api_key: str, + *, + custom_api_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.src_base_url = src_base_url + self.src_api_key = src_api_key + self.dst_base_url = dst_base_url + self.dst_api_key = dst_api_key + + if custom_api_kwargs is None: + custom_api_kwargs = {"timeout": 600} + + self.src_api = wandb.Api( + api_key=src_api_key, + overrides={"base_url": src_base_url}, + **custom_api_kwargs, + ) + self.dst_api = wandb.Api( + api_key=dst_api_key, + overrides={"base_url": dst_base_url}, + **custom_api_kwargs, + ) + + self.run_api_kwargs = { + "src_base_url": src_base_url, + "src_api_key": src_api_key, + "dst_base_url": dst_base_url, + "dst_api_key": dst_api_key, + } + + def __repr__(self): + return f"" # pragma: no cover + + def _import_run( + self, + run: WandbRun, + *, + namespace: Optional[Namespace] = None, + config: Optional[internal.SendManagerConfig] = None, + ) -> None: + """Import one WandbRun. + + Use `namespace` to specify alternate settings like where the run should be uploaded + """ + if namespace is None: + namespace = Namespace(run.entity(), run.project()) + + if config is None: + config = internal.SendManagerConfig( + metadata=True, + files=True, + media=True, + code=True, + history=True, + summary=True, + terminal_output=True, + ) + + settings_override = { + "api_key": self.dst_api_key, + "base_url": self.dst_base_url, + "resume": "true", + "resumed": True, + } + + # Send run with base config + logger.debug(f"Importing run, {run=}") + internal.send_run( + run, + overrides=namespace.send_manager_overrides, + settings_override=settings_override, + config=config, + ) + + if config.history: + # Send run again with history artifacts in case config history=True, artifacts=False + # The history artifact must come with the actual history data + + logger.debug(f"Collecting history artifacts, {run=}") + history_arts = [] + for art in run.run.logged_artifacts(): + if art.type != "wandb-history": + continue + logger.debug(f"Collecting history artifact {art.name=}") + new_art = _clone_art(art) + history_arts.append(new_art) + + logger.debug(f"Importing history artifacts, {run=}") + internal.send_run( + run, + extra_arts=history_arts, + overrides=namespace.send_manager_overrides, + settings_override=settings_override, + config=config, + ) + + def _delete_collection_in_dst( + self, + seq: ArtifactSequence, + namespace: Optional[Namespace] = None, + ): + """Deletes the equivalent artifact collection in destination. + + Intended to clear the destination when an uploaded artifact does not pass validation. + """ + entity = coalesce(namespace.entity, seq.entity) + project = coalesce(namespace.project, seq.project) + art_type = f"{entity}/{project}/{seq.type_}" + art_name = seq.name + + logger.info( + f"Deleting collection {entity=}, {project=}, {art_type=}, {art_name=}" + ) + try: + dst_collection = self.dst_api.artifact_collection(art_type, art_name) + except (wandb.CommError, ValueError): + logger.warning(f"Collection doesn't exist {art_type=}, {art_name=}") + return + + try: + dst_collection.delete() + except (wandb.CommError, ValueError) as e: + logger.warning( + f"Collection can't be deleted, {art_type=}, {art_name=}, {e=}" + ) + return + + def _import_artifact_sequence( + self, + seq: ArtifactSequence, + *, + namespace: Optional[Namespace] = None, + ) -> None: + """Import one artifact sequence. + + Use `namespace` to specify alternate settings like where the artifact sequence should be uploaded + """ + if not seq.artifacts: + # The artifact sequence has no versions. This usually means all artifacts versions were deleted intentionally, + # but it can also happen if the sequence represents run history and that run was deleted. + logger.warning(f"Artifact {seq=} has no artifacts, skipping.") + return + + if namespace is None: + namespace = Namespace(seq.entity, seq.project) + + settings_override = { + "api_key": self.dst_api_key, + "base_url": self.dst_base_url, + "resume": "true", + "resumed": True, + } + + send_manager_config = internal.SendManagerConfig(log_artifacts=True) + + # Delete any existing artifact sequence, otherwise versions will be out of order + # Unfortunately, you can't delete only part of the sequence because versions are "remembered" even after deletion + self._delete_collection_in_dst(seq, namespace) + + # Get a placeholder run for dummy artifacts we'll upload later + art = seq.artifacts[0] + run_or_dummy: Optional[Run] = _get_run_or_dummy_from_art(art, self.src_api) + + # Each `group_of_artifacts` is either: + # 1. A single "real" artifact in a list; or + # 2. A list of dummy artifacts that are uploaded together. + # This guarantees the real artifacts have the correct version numbers while allowing for parallel upload of dummies. + groups_of_artifacts = list(_make_groups_of_artifacts(seq)) + for i, group in enumerate(groups_of_artifacts, 1): + art = group[0] + if art.description == ART_SEQUENCE_DUMMY_PLACEHOLDER: + run = WandbRun(run_or_dummy, **self.run_api_kwargs) + else: + try: + wandb_run = art.logged_by() + except ValueError: + # The run used to exist but has since been deleted + # wandb_run = None + pass + + # Could be logged by None (rare) or ValueError + if wandb_run is None: + logger.warning( + f"Run for {art.name=} does not exist (deleted?), using {run_or_dummy=}" + ) + wandb_run = run_or_dummy + + new_art = _clone_art(art) + group = [new_art] + run = WandbRun(wandb_run, **self.run_api_kwargs) + + logger.info( + f"Uploading partial artifact {seq=}, {i}/{len(groups_of_artifacts)}" + ) + internal.send_run( + run, + extra_arts=group, + overrides=namespace.send_manager_overrides, + settings_override=settings_override, + config=send_manager_config, + ) + logger.info(f"Finished uploading {seq=}") + + # query it back and remove placeholders + self._remove_placeholders(seq) + + def _remove_placeholders(self, seq: ArtifactSequence) -> None: + try: + retry_arts_func = internal.exp_retry(self._dst_api.artifacts) + dst_arts = list(retry_arts_func(seq.type_, seq.name)) + except wandb.CommError: + logger.warning( + f"{seq=} does not exist in dst. Has it already been deleted?" + ) + return + except TypeError: + logger.exception("Problem getting dst versions (try again later).") + return + + for art in dst_arts: + if art.description != ART_SEQUENCE_DUMMY_PLACEHOLDER: + continue + if art.type in ("wandb-history", "job"): + continue + + try: + art.delete(delete_aliases=True) + except wandb.CommError as e: + if "cannot delete system managed artifact" in str(e): + logger.warning("Cannot delete system managed artifact") + else: + raise + + def _get_dst_art( + self, src_art: Run, entity: Optional[str] = None, project: Optional[str] = None + ) -> Artifact: + entity = coalesce(entity, src_art.entity) + project = coalesce(project, src_art.project) + name = src_art.name + + return self.dst_api._artifact(f"{entity}/{project}/{name}") + + def _get_run_problems( + self, src_run: Run, dst_run: Run, force_retry: bool = False + ) -> List[dict]: + problems = [] + + if force_retry: + problems.append("__force_retry__") + + if non_matching_metadata := self._compare_run_metadata(src_run, dst_run): + problems.append("metadata:" + str(non_matching_metadata)) + + if non_matching_summary := self._compare_run_summary(src_run, dst_run): + problems.append("summary:" + str(non_matching_summary)) + + # TODO: Compare files? + + return problems + + def _compare_run_metadata(self, src_run: Run, dst_run: Run) -> dict: + fname = "wandb-metadata.json" + # problems = {} + + src_f = src_run.file(fname) + if src_f.size == 0: + # the src was corrupted so no comparisons here will ever work + return {} + + dst_f = dst_run.file(fname) + try: + contents = wandb.util.download_file_into_memory( + dst_f.url, self.dst_api.api_key + ) + except urllib3.exceptions.ReadTimeoutError: + return {"Error checking": "Timeout"} + except requests.HTTPError as e: + if e.response.status_code == 404: + return {"Bad upload": f"File not found: {fname}"} + return {"http problem": f"{fname}: ({e})"} + + dst_meta = json_util.loads(contents) + + non_matching = {} + if src_run.metadata: + for k, src_v in src_run.metadata.items(): + if k not in dst_meta: + non_matching[k] = {"src": src_v, "dst": "KEY NOT FOUND"} + continue + dst_v = dst_meta[k] + if src_v != dst_v: + non_matching[k] = {"src": src_v, "dst": dst_v} + + return non_matching + + def _compare_run_summary(self, src_run: Run, dst_run: Run) -> dict: + non_matching = {} + for k, src_v in src_run.summary.items(): + # These won't match between systems and that's ok + if isinstance(src_v, str) and src_v.startswith("wandb-client-artifact://"): + continue + if k in ("_wandb", "_runtime"): + continue + + src_v = _recursive_cast_to_dict(src_v) + + dst_v = dst_run.summary.get(k) + dst_v = _recursive_cast_to_dict(dst_v) + + if isinstance(src_v, dict) and isinstance(dst_v, dict): + for kk, sv in src_v.items(): + # These won't match between systems and that's ok + if isinstance(sv, str) and sv.startswith( + "wandb-client-artifact://" + ): + continue + dv = dst_v.get(kk) + if not _almost_equal(sv, dv): + non_matching[f"{k}-{kk}"] = {"src": sv, "dst": dv} + else: + if not _almost_equal(src_v, dst_v): + non_matching[k] = {"src": src_v, "dst": dst_v} + + return non_matching + + def _collect_failed_artifact_sequences(self) -> Iterable[ArtifactSequence]: + if (df := _read_ndjson(ARTIFACT_ERRORS_FNAME)) is None: + logger.debug(f"{ARTIFACT_ERRORS_FNAME=} is empty, returning nothing") + return + + unique_failed_sequences = df[ + ["src_entity", "src_project", "name", "type"] + ].unique() + + for row in unique_failed_sequences.iter_rows(named=True): + entity = row["src_entity"] + project = row["src_project"] + name = row["name"] + _type = row["type"] + + art_name = f"{entity}/{project}/{name}" + arts = self.src_api.artifacts(_type, art_name) + arts = sorted(arts, key=lambda a: int(a.version.lstrip("v"))) + arts = sorted(arts, key=lambda a: a.type) + + yield ArtifactSequence(arts, entity, project, _type, name) + + def _cleanup_dummy_runs( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + api: Optional[Api] = None, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ) -> None: + api = coalesce(api, self.dst_api) + namespaces = coalesce(namespaces, self._all_namespaces()) + + for ns in namespaces: + if remapping and ns in remapping: + ns = remapping[ns] + + logger.debug(f"Cleaning up, {ns=}") + try: + runs = list( + api.runs(ns.path, filters={"displayName": RUN_DUMMY_PLACEHOLDER}) + ) + except ValueError as e: + if "Could not find project" in str(e): + logger.exception("Could not find project, does it exist?") + continue + + for run in runs: + logger.debug(f"Deleting dummy {run=}") + run.delete(delete_artifacts=False) + + def _import_report( + self, report: Report, *, namespace: Optional[Namespace] = None + ) -> None: + """Import one wandb.Report. + + Use `namespace` to specify alternate settings like where the report should be uploaded + """ + if namespace is None: + namespace = Namespace(report.entity, report.project) + + entity = coalesce(namespace.entity, report.entity) + project = coalesce(namespace.project, report.project) + name = report.name + title = report.title + description = report.description + + api = self.dst_api + + # We shouldn't need to upsert the project for every report + logger.debug(f"Upserting {entity=}/{project=}") + try: + api.create_project(project, entity) + except requests.exceptions.HTTPError as e: + if e.response.status_code != 409: + logger.warning(f"Issue upserting {entity=}/{project=}, {e=}") + + logger.debug(f"Upserting report {entity=}, {project=}, {name=}, {title=}") + api.client.execute( + wr.report.UPSERT_VIEW, + variable_values={ + "id": None, # Is there any benefit for this to be the same as default report? + "name": name, + "entityName": entity, + "projectName": project, + "description": description, + "displayName": title, + "type": "runs", + "spec": json.dumps(report.spec), + }, + ) + + def _use_artifact_sequence( + self, + sequence: ArtifactSequence, + *, + namespace: Optional[Namespace] = None, + ): + if namespace is None: + namespace = Namespace(sequence.entity, sequence.project) + + settings_override = { + "api_key": self.dst_api_key, + "base_url": self.dst_base_url, + "resume": "true", + "resumed": True, + } + logger.debug(f"Using artifact sequence with {settings_override=}, {namespace=}") + + send_manager_config = internal.SendManagerConfig(use_artifacts=True) + + for art in sequence: + if (used_by := art.used_by()) is None: + continue + + for wandb_run in used_by: + run = WandbRun(wandb_run, **self.run_api_kwargs) + + internal.send_run( + run, + overrides=namespace.send_manager_overrides, + settings_override=settings_override, + config=send_manager_config, + ) + + def import_runs( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + remapping: Optional[Dict[Namespace, Namespace]] = None, + parallel: bool = True, + incremental: bool = True, + max_workers: Optional[int] = None, + limit: Optional[int] = None, + metadata: bool = True, + files: bool = True, + media: bool = True, + code: bool = True, + history: bool = True, + summary: bool = True, + terminal_output: bool = True, + ): + logger.info("START: Import runs") + + logger.info("Setting up for import") + _create_files_if_not_exists() + _clear_fname(RUN_ERRORS_FNAME) + + logger.info("Collecting runs") + runs = list(self._collect_runs(namespaces=namespaces, limit=limit)) + + logger.info(f"Validating runs, {len(runs)=}") + self._validate_runs( + runs, + skip_previously_validated=incremental, + remapping=remapping, + ) + + logger.info("Collecting failed runs") + runs = list(self._collect_failed_runs()) + + logger.info(f"Importing runs, {len(runs)=}") + + def _import_run_wrapped(run): + namespace = Namespace(run.entity(), run.project()) + if remapping is not None and namespace in remapping: + namespace = remapping[namespace] + + config = internal.SendManagerConfig( + metadata=metadata, + files=files, + media=media, + code=code, + history=history, + summary=summary, + terminal_output=terminal_output, + ) + + logger.debug(f"Importing {run=}, {namespace=}, {config=}") + self._import_run(run, namespace=namespace, config=config) + logger.debug(f"Finished importing {run=}, {namespace=}, {config=}") + + for_each(_import_run_wrapped, runs, max_workers=max_workers, parallel=parallel) + logger.info("END: Importing runs") + + def import_reports( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + limit: Optional[int] = None, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ): + logger.info("START: Importing reports") + + logger.info("Collecting reports") + reports = self._collect_reports(namespaces=namespaces, limit=limit) + + logger.info("Importing reports") + + def _import_report_wrapped(report): + namespace = Namespace(report.entity, report.project) + if remapping is not None and namespace in remapping: + namespace = remapping[namespace] + + logger.debug(f"Importing {report=}, {namespace=}") + self._import_report(report, namespace=namespace) + logger.debug(f"Finished importing {report=}, {namespace=}") + + for_each(_import_report_wrapped, reports) + + logger.info("END: Importing reports") + + def import_artifact_sequences( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + incremental: bool = True, + max_workers: Optional[int] = None, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ): + """Import all artifact sequences from `namespaces`. + + Note: There is a known bug with the AWS backend where artifacts > 2048MB will fail to upload. This seems to be related to multipart uploads, but we don't have a fix yet. + """ + logger.info("START: Importing artifact sequences") + _clear_fname(ARTIFACT_ERRORS_FNAME) + + logger.info("Collecting artifact sequences") + seqs = list(self._collect_artifact_sequences(namespaces=namespaces)) + + logger.info("Validating artifact sequences") + self._validate_artifact_sequences( + seqs, + incremental=incremental, + remapping=remapping, + ) + + logger.info("Collecting failed artifact sequences") + seqs = list(self._collect_failed_artifact_sequences()) + + logger.info(f"Importing artifact sequences, {len(seqs)=}") + + def _import_artifact_sequence_wrapped(seq): + namespace = Namespace(seq.entity, seq.project) + if remapping is not None and namespace in remapping: + namespace = remapping[namespace] + + logger.debug(f"Importing artifact sequence {seq=}, {namespace=}") + self._import_artifact_sequence(seq, namespace=namespace) + logger.debug(f"Finished importing artifact sequence {seq=}, {namespace=}") + + for_each(_import_artifact_sequence_wrapped, seqs, max_workers=max_workers) + + # it's safer to just use artifact on all seqs to make sure we don't miss anything + # For seqs that have already been used, this is a no-op. + logger.debug(f"Using artifact sequences, {len(seqs)=}") + + def _use_artifact_sequence_wrapped(seq): + namespace = Namespace(seq.entity, seq.project) + if remapping is not None and namespace in remapping: + namespace = remapping[namespace] + + logger.debug(f"Using artifact sequence {seq=}, {namespace=}") + self._use_artifact_sequence(seq, namespace=namespace) + logger.debug(f"Finished using artifact sequence {seq=}, {namespace=}") + + for_each(_use_artifact_sequence_wrapped, seqs, max_workers=max_workers) + + # Artifacts whose parent runs have been deleted should have that run deleted in the + # destination as well + + logger.info("Cleaning up dummy runs") + self._cleanup_dummy_runs( + namespaces=namespaces, + remapping=remapping, + ) + + logger.info("END: Importing artifact sequences") + + def import_all( + self, + *, + runs: bool = True, + artifacts: bool = True, + reports: bool = True, + namespaces: Optional[Iterable[Namespace]] = None, + incremental: bool = True, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ): + logger.info(f"START: Importing all, {runs=}, {artifacts=}, {reports=}") + if runs: + self.import_runs( + namespaces=namespaces, + incremental=incremental, + remapping=remapping, + ) + + if reports: + self.import_reports( + namespaces=namespaces, + remapping=remapping, + ) + + if artifacts: + self.import_artifact_sequences( + namespaces=namespaces, + incremental=incremental, + remapping=remapping, + ) + + logger.info("END: Importing all") + + def _validate_run( + self, + src_run: Run, + *, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ) -> None: + namespace = Namespace(src_run.entity, src_run.project) + if remapping is not None and namespace in remapping: + namespace = remapping[namespace] + + dst_entity = namespace.entity + dst_project = namespace.project + run_id = src_run.id + + try: + dst_run = self.dst_api.run(f"{dst_entity}/{dst_project}/{run_id}") + except wandb.CommError: + problems = [f"run does not exist in dst at {dst_entity=}/{dst_project=}"] + else: + problems = self._get_run_problems(src_run, dst_run) + + d = { + "src_entity": src_run.entity, + "src_project": src_run.project, + "dst_entity": dst_entity, + "dst_project": dst_project, + "run_id": run_id, + } + if problems: + d["problems"] = problems + fname = RUN_ERRORS_FNAME + else: + fname = RUN_SUCCESSES_FNAME + + with filelock.FileLock("runs.lock"): + with open(fname, "a") as f: + f.write(json.dumps(d) + "\n") + + def _filter_previously_checked_runs( + self, + runs: Iterable[Run], + *, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ) -> Iterable[Run]: + if (df := _read_ndjson(RUN_SUCCESSES_FNAME)) is None: + logger.debug(f"{RUN_SUCCESSES_FNAME=} is empty, yielding all runs") + yield from runs + return + + data = [] + for r in runs: + namespace = Namespace(r.entity, r.project) + if remapping is not None and namespace in remapping: + namespace = remapping[namespace] + + data.append( + { + "src_entity": r.entity, + "src_project": r.project, + "dst_entity": namespace.entity, + "dst_project": namespace.project, + "run_id": r.id, + "data": r, + } + ) + df2 = pl.DataFrame(data) + logger.debug(f"Starting with {len(runs)=} in namespaces") + + results = df2.join( + df, + how="anti", + on=["src_entity", "src_project", "dst_entity", "dst_project", "run_id"], + ) + logger.debug(f"After filtering out already successful runs, {len(results)=}") + + if not results.is_empty(): + results = results.filter(~results["run_id"].is_null()) + results = results.unique( + ["src_entity", "src_project", "dst_entity", "dst_project", "run_id"] + ) + + for r in results.iter_rows(named=True): + yield r["data"] + + def _validate_artifact( + self, + src_art: Artifact, + dst_entity: str, + dst_project: str, + download_files_and_compare: bool = False, + check_entries_are_downloadable: bool = True, + ): + problems = [] + + # These patterns of artifacts are special and should not be validated + ignore_patterns = [ + r"^job-(.*?)\.py(:v\d+)?$", + # r"^run-.*-history(?:\:v\d+)?$$", + ] + for pattern in ignore_patterns: + if re.search(pattern, src_art.name): + return (src_art, dst_entity, dst_project, problems) + + try: + dst_art = self._get_dst_art(src_art, dst_entity, dst_project) + except Exception: + problems.append("destination artifact not found") + return (src_art, dst_entity, dst_project, problems) + + try: + logger.debug("Comparing artifact manifests") + except Exception as e: + problems.append( + f"Problem getting problems! problem with {src_art.entity=}, {src_art.project=}, {src_art.name=} {e=}" + ) + else: + problems += validation._compare_artifact_manifests(src_art, dst_art) + + if check_entries_are_downloadable: + # validation._check_entries_are_downloadable(src_art) + validation._check_entries_are_downloadable(dst_art) + + if download_files_and_compare: + logger.debug(f"Downloading {src_art=}") + try: + src_dir = _download_art(src_art, root=f"{SRC_ART_PATH}/{src_art.name}") + except requests.HTTPError as e: + problems.append( + f"Invalid download link for src {src_art.entity=}, {src_art.project=}, {src_art.name=}, {e}" + ) + + logger.debug(f"Downloading {dst_art=}") + try: + dst_dir = _download_art(dst_art, root=f"{DST_ART_PATH}/{dst_art.name}") + except requests.HTTPError as e: + problems.append( + f"Invalid download link for dst {dst_art.entity=}, {dst_art.project=}, {dst_art.name=}, {e}" + ) + else: + logger.debug(f"Comparing artifact dirs {src_dir=}, {dst_dir=}") + if problem := validation._compare_artifact_dirs(src_dir, dst_dir): + problems.append(problem) + + return (src_art, dst_entity, dst_project, problems) + + def _validate_runs( + self, + runs: Iterable[WandbRun], + *, + skip_previously_validated: bool = True, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ): + base_runs = [r.run for r in runs] + if skip_previously_validated: + base_runs = list( + self._filter_previously_checked_runs( + base_runs, + remapping=remapping, + ) + ) + + def _validate_run(run): + logger.debug(f"Validating {run=}") + self._validate_run(run, remapping=remapping) + logger.debug(f"Finished validating {run=}") + + for_each(_validate_run, base_runs) + + def _collect_failed_runs(self): + if (df := _read_ndjson(RUN_ERRORS_FNAME)) is None: + logger.debug(f"{RUN_ERRORS_FNAME=} is empty, returning nothing") + return + + unique_failed_runs = df[ + ["src_entity", "src_project", "dst_entity", "dst_project", "run_id"] + ].unique() + + for row in unique_failed_runs.iter_rows(named=True): + src_entity = row["src_entity"] + src_project = row["src_project"] + # dst_entity = row["dst_entity"] + # dst_project = row["dst_project"] + run_id = row["run_id"] + + run = self.src_api.run(f"{src_entity}/{src_project}/{run_id}") + yield WandbRun(run, **self.run_api_kwargs) + + def _filter_previously_checked_artifacts(self, seqs: Iterable[ArtifactSequence]): + if (df := _read_ndjson(ARTIFACT_SUCCESSES_FNAME)) is None: + logger.info( + f"{ARTIFACT_SUCCESSES_FNAME=} is empty, yielding all artifact sequences" + ) + for seq in seqs: + yield from seq.artifacts + return + + for seq in seqs: + for art in seq: + try: + logged_by = _get_run_or_dummy_from_art(art, self.src_api) + except requests.HTTPError: + logger.exception(f"Failed to get run, skipping: {art=}") + continue + + if art.type == "wandb-history" and isinstance(logged_by, _DummyRun): + logger.debug(f"Skipping history artifact {art=}") + # We can never upload valid history for a deleted run, so skip it + continue + + entity = art.entity + project = art.project + _type = art.type + name, ver = _get_art_name_ver(art) + + filtered_df = df.filter( + (df["src_entity"] == entity) + & (df["src_project"] == project) + & (df["name"] == name) + & (df["version"] == ver) + & (df["type"] == _type) + ) + + # not in file, so not verified yet, don't filter out + if len(filtered_df) == 0: + yield art + + def _validate_artifact_sequences( + self, + seqs: Iterable[ArtifactSequence], + *, + incremental: bool = True, + download_files_and_compare: bool = False, + check_entries_are_downloadable: bool = True, + remapping: Optional[Dict[Namespace, Namespace]] = None, + ): + if incremental: + logger.info("Validating in incremental mode") + + def filtered_sequences(): + for seq in seqs: + if not seq.artifacts: + continue + + art = seq.artifacts[0] + try: + logged_by = _get_run_or_dummy_from_art(art, self.src_api) + except requests.HTTPError: + logger.exception( + f"Validate Artifact http error: {art.entity=}," + f" {art.project=}, {art.name=}" + ) + continue + + if art.type == "wandb-history" and isinstance(logged_by, _DummyRun): + # We can never upload valid history for a deleted run, so skip it + continue + + yield seq + + artifacts = self._filter_previously_checked_artifacts(filtered_sequences()) + else: + logger.info("Validating in non-incremental mode") + artifacts = [art for seq in seqs for art in seq.artifacts] + + def _validate_artifact_wrapped(args): + art, entity, project = args + if ( + remapping is not None + and (namespace := Namespace(entity, project)) in remapping + ): + remapped_ns = remapping[namespace] + entity = remapped_ns.entity + project = remapped_ns.project + + logger.debug(f"Validating {art=}, {entity=}, {project=}") + result = self._validate_artifact( + art, + entity, + project, + download_files_and_compare=download_files_and_compare, + check_entries_are_downloadable=check_entries_are_downloadable, + ) + logger.debug(f"Finished validating {art=}, {entity=}, {project=}") + return result + + args = ((art, art.entity, art.project) for art in artifacts) + art_problems = for_each(_validate_artifact_wrapped, args) + for art, dst_entity, dst_project, problems in art_problems: + name, ver = _get_art_name_ver(art) + d = { + "src_entity": art.entity, + "src_project": art.project, + "dst_entity": dst_entity, + "dst_project": dst_project, + "name": name, + "version": ver, + "type": art.type, + } + + if problems: + d["problems"] = problems + fname = ARTIFACT_ERRORS_FNAME + else: + fname = ARTIFACT_SUCCESSES_FNAME + + with open(fname, "a") as f: + f.write(json.dumps(d) + "\n") + + def _collect_runs( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + limit: Optional[int] = None, + skip_ids: Optional[List[str]] = None, + start_date: Optional[str] = None, + api: Optional[Api] = None, + ) -> Iterable[WandbRun]: + api = coalesce(api, self.src_api) + namespaces = coalesce(namespaces, self._all_namespaces()) + + filters: Dict[str, Any] = {} + if skip_ids is not None: + filters["name"] = {"$nin": skip_ids} + if start_date is not None: + filters["createdAt"] = {"$gte": start_date} + + def _runs(): + for ns in namespaces: + logger.debug(f"Collecting runs from {ns=}") + for run in api.runs(ns.path, filters=filters): + yield WandbRun(run, **self.run_api_kwargs) + + runs = itertools.islice(_runs(), limit) + yield from runs + + def _all_namespaces( + self, *, entity: Optional[str] = None, api: Optional[Api] = None + ): + api = coalesce(api, self.src_api) + entity = coalesce(entity, api.default_entity) + projects = api.projects(entity) + for p in projects: + yield Namespace(p.entity, p.name) + + def _collect_reports( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + limit: Optional[int] = None, + api: Optional[Api] = None, + ): + api = coalesce(api, self.src_api) + namespaces = coalesce(namespaces, self._all_namespaces()) + + wandb.login(key=self.src_api_key, host=self.src_base_url) + + def reports(): + for ns in namespaces: + for r in api.reports(ns.path): + yield wr.Report.from_url(r.url, api=api) + + yield from itertools.islice(reports(), limit) + + def _collect_artifact_sequences( + self, + *, + namespaces: Optional[Iterable[Namespace]] = None, + limit: Optional[int] = None, + api: Optional[Api] = None, + ): + api = coalesce(api, self.src_api) + namespaces = coalesce(namespaces, self._all_namespaces()) + + def artifact_sequences(): + for ns in namespaces: + logger.debug(f"Collecting artifact sequences from {ns=}") + types = [] + try: + types = [t for t in api.artifact_types(ns.path)] + except Exception: + logger.exception("Failed to get artifact types.") + + for t in types: + collections = [] + + # Skip history because it's really for run history + if t.name == "wandb-history": + continue + + try: + collections = t.collections() + except Exception: + logger.exception("Failed to get artifact collections.") + + for c in collections: + if c.is_sequence(): + yield ArtifactSequence.from_collection(c) + + seqs = itertools.islice(artifact_sequences(), limit) + unique_sequences = {seq.identifier: seq for seq in seqs} + yield from unique_sequences.values() + + +def _get_art_name_ver(art: Artifact) -> Tuple[str, int]: + name, ver = art.name.split(":v") + return name, int(ver) + + +def _make_dummy_art(name: str, _type: str, ver: int): + art = Artifact(name, ART_DUMMY_PLACEHOLDER_TYPE) + art._type = _type + art._description = ART_SEQUENCE_DUMMY_PLACEHOLDER + + p = Path(ART_DUMMY_PLACEHOLDER_PATH) + p.mkdir(parents=True, exist_ok=True) + + # dummy file with different name to prevent dedupe + fname = p / str(ver) + with open(fname, "w"): + pass + art.add_file(fname) + + return art + + +def _make_groups_of_artifacts(seq: ArtifactSequence, start: int = 0): + prev_ver = start - 1 + for art in seq: + name, ver = _get_art_name_ver(art) + + # If there's a gap between versions, fill with dummy artifacts + if ver - prev_ver > 1: + yield [_make_dummy_art(name, art.type, v) for v in range(prev_ver + 1, ver)] + + # Then yield the actual artifact + # Must always be a list of one artifact to guarantee ordering + yield [art] + prev_ver = ver + + +def _recursive_cast_to_dict(obj): + if isinstance(obj, list): + return [_recursive_cast_to_dict(item) for item in obj] + elif isinstance(obj, dict) or hasattr(obj, "items"): + new_dict = {} + for key, value in obj.items(): + new_dict[key] = _recursive_cast_to_dict(value) + return new_dict + else: + return obj + + +def _almost_equal(x, y, eps=1e-6): + if isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(_almost_equal(x[k], y[k], eps) for k in x) + + if isinstance(x, numbers.Number) and isinstance(y, numbers.Number): + return abs(x - y) < eps + + if type(x) is not type(y): + return False + + return x == y + + +@dataclass +class _DummyUser: + username: str = "" + + +@dataclass +class _DummyRun: + entity: str = "" + project: str = "" + run_id: str = RUN_DUMMY_PLACEHOLDER + id: str = RUN_DUMMY_PLACEHOLDER + display_name: str = RUN_DUMMY_PLACEHOLDER + notes: str = "" + url: str = "" + group: str = "" + created_at: str = "2000-01-01" + user: _DummyUser = field(default_factory=_DummyUser) + tags: list = field(default_factory=list) + summary: dict = field(default_factory=dict) + config: dict = field(default_factory=dict) + + def files(self): + return [] + + +def _read_ndjson(fname: str) -> Optional[pl.DataFrame]: + try: + df = pl.read_ndjson(fname) + except FileNotFoundError: + return None + except RuntimeError as e: + # No runs previously checked + if "empty string is not a valid JSON value" in str(e): + return None + if "error parsing ndjson" in str(e): + return None + raise + + return df + + +def _get_run_or_dummy_from_art(art: Artifact, api=None): + run = None + + try: + run = art.logged_by() + except ValueError as e: + logger.warning( + f"Can't log artifact because run doesn't exist, {art=}, {run=}, {e=}" + ) + + if run is not None: + return run + + query = gql( + """ + query ArtifactCreatedBy( + $id: ID! + ) { + artifact(id: $id) { + createdBy { + ... on Run { + name + project { + name + entityName + } + } + } + } + } + """ + ) + response = api.client.execute(query, variable_values={"id": art.id}) + creator = response.get("artifact", {}).get("createdBy", {}) + run = _DummyRun( + entity=art.entity, + project=art.project, + run_id=creator.get("name", RUN_DUMMY_PLACEHOLDER), + id=creator.get("name", RUN_DUMMY_PLACEHOLDER), + ) + return run + + +def _clear_fname(fname: str) -> None: + old_fname = f"{internal.ROOT_DIR}/{fname}" + new_fname = f"{internal.ROOT_DIR}/prev_{fname}" + + logger.debug(f"Moving {old_fname=} to {new_fname=}") + try: + shutil.copy2(old_fname, new_fname) + except FileNotFoundError: + # this is just to make a copy of the last iteration, so its ok if the src doesn't exist + pass + + with open(fname, "w"): + pass + + +def _download_art(art: Artifact, root: str) -> Optional[str]: + try: + with patch("click.echo"): + return art.download(root=root, skip_cache=True) + except Exception: + logger.exception(f"Error downloading artifact {art=}") + + +def _clone_art(art: Artifact, root: Optional[str] = None): + if root is None: + # Currently, we would only ever clone a src artifact to move it to dst. + root = f"{SRC_ART_PATH}/{art.name}" + + if (path := _download_art(art, root=root)) is None: + raise ValueError(f"Problem downloading {art=}") + + name, _ = art.name.split(":v") + + # Hack: skip naming validation check for wandb-* types + new_art = Artifact(name, ART_DUMMY_PLACEHOLDER_TYPE) + new_art._type = art.type + new_art._created_at = art.created_at + + new_art._aliases = art.aliases + new_art._description = art.description + + with patch("click.echo"): + new_art.add_dir(path) + + return new_art + + +def _create_files_if_not_exists() -> None: + fnames = [ + ARTIFACT_ERRORS_FNAME, + ARTIFACT_SUCCESSES_FNAME, + RUN_ERRORS_FNAME, + RUN_SUCCESSES_FNAME, + ] + + for fname in fnames: + logger.debug(f"Creating {fname=} if not exists") + with open(fname, "a"): + pass + + +def _merge_dfs(dfs: List[pl.DataFrame]) -> pl.DataFrame: + # Ensure there are DataFrames in the list + if len(dfs) == 0: + return pl.DataFrame() + + if len(dfs) == 1: + return dfs[0] + + merged_df = dfs[0] + for df in dfs[1:]: + merged_df = merged_df.join(df, how="outer", on=["_step"]) + col_pairs = [ + (c, f"{c}_right") + for c in merged_df.columns + if f"{c}_right" in merged_df.columns + ] + for col, right in col_pairs: + new_col = merged_df[col].fill_null(merged_df[right]) + merged_df = merged_df.with_columns(new_col).drop(right) + + return merged_df diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/internal.py b/.venv/lib/python3.13/site-packages/wandb/apis/internal.py new file mode 100644 index 0000000000000000000000000000000000000000..6094c358f8bf26758beb0377a8ddf805534a7f7c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/internal.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +from typing import Any + +from wandb.sdk.internal.internal_api import Api as InternalApi + + +class Api: + """Internal proxy to the official internal API.""" + + # TODO: Move these methods to PublicApi. + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._api_args = args + self._api_kwargs = kwargs + self._api = None + + def __getstate__(self): + """Use for serializing. + + self._api is not serializable, so it's dropped + """ + state = self.__dict__.copy() + del state["_api"] + return state + + def __setstate__(self, state): + """Used for deserializing. + + Don't need to set self._api because it's constructed when needed. + """ + self.__dict__.update(state) + self._api = None + + @property + def api(self) -> InternalApi: + # This is a property in order to delay construction of Internal API + # for as long as possible. If constructed in constructor, then the + # whole InternalAPI is started when simply importing wandb. + if self._api is None: + self._api = InternalApi(*self._api_args, **self._api_kwargs) + return self._api + + @property + def api_key(self): + return self.api.api_key + + @property + def is_authenticated(self): + return self.api.access_token is not None or self.api.api_key is not None + + @property + def api_url(self): + return self.api.api_url + + @property + def app_url(self): + return self.api.app_url + + @property + def default_entity(self): + return self.api.default_entity + + def validate_api_key(self) -> bool: + """Returns whether the API key stored on initialization is valid.""" + return self.api.validate_api_key() + + def file_current(self, *args): + return self.api.file_current(*args) + + def download_file(self, *args, **kwargs): + return self.api.download_file(*args, **kwargs) + + def download_write_file(self, *args, **kwargs): + return self.api.download_write_file(*args, **kwargs) + + def set_current_run_id(self, run_id): + return self.api.set_current_run_id(run_id) + + def viewer(self): + return self.api.viewer() + + def max_cli_version(self): + return self.api.max_cli_version() + + def viewer_server_info(self): + return self.api.viewer_server_info() + + def list_projects(self, entity=None): + return self.api.list_projects(entity=entity) + + def format_project(self, project): + return self.api.format_project(project) + + def upsert_project(self, project, id=None, description=None, entity=None): + return self.api.upsert_project( + project, id=id, description=description, entity=entity + ) + + def upsert_run(self, *args, **kwargs): + return self.api.upsert_run(*args, **kwargs) + + def settings(self, *args, **kwargs): + return self.api.settings(*args, **kwargs) + + def clear_setting(self, key: str) -> None: + return self.api.clear_setting(key) + + def set_setting(self, key: str, value: Any) -> None: + return self.api.set_setting(key, value) + + def parse_slug(self, *args, **kwargs): + return self.api.parse_slug(*args, **kwargs) + + def download_url(self, *args, **kwargs): + return self.api.download_url(*args, **kwargs) + + def download_urls(self, *args, **kwargs): + return self.api.download_urls(*args, **kwargs) + + def push(self, *args, **kwargs): + return self.api.push(*args, **kwargs) + + def sweep(self, *args, **kwargs): + return self.api.sweep(*args, **kwargs) + + def upsert_sweep(self, *args, **kwargs): + return self.api.upsert_sweep(*args, **kwargs) + + def set_sweep_state(self, *args, **kwargs): + return self.api.set_sweep_state(*args, **kwargs) + + def get_sweep_state(self, *args, **kwargs): + return self.api.get_sweep_state(*args, **kwargs) + + def stop_sweep(self, *args, **kwargs): + return self.api.stop_sweep(*args, **kwargs) + + def cancel_sweep(self, *args, **kwargs): + return self.api.cancel_sweep(*args, **kwargs) + + def pause_sweep(self, *args, **kwargs): + return self.api.pause_sweep(*args, **kwargs) + + def resume_sweep(self, *args, **kwargs): + return self.api.resume_sweep(*args, **kwargs) + + def register_agent(self, *args, **kwargs): + return self.api.register_agent(*args, **kwargs) + + def agent_heartbeat(self, *args, **kwargs): + return self.api.agent_heartbeat(*args, **kwargs) + + def use_artifact(self, *args, **kwargs): + return self.api.use_artifact(*args, **kwargs) + + def create_artifact(self, *args, **kwargs): + return self.api.create_artifact(*args, **kwargs) + + def complete_multipart_upload_artifact(self, *args, **kwargs): + return self.api.complete_multipart_upload_artifact(*args, **kwargs) + + def run_config(self, *args, **kwargs): + return self.api.run_config(*args, **kwargs) + + def upload_file_retry(self, *args, **kwargs): + return self.api.upload_file_retry(*args, **kwargs) + + def upload_multipart_file_chunk_retry(self, *args, **kwargs): + return self.api.upload_multipart_file_chunk_retry(*args, **kwargs) + + def get_run_info(self, *args, **kwargs): + return self.api.get_run_info(*args, **kwargs) + + def get_run_state(self, *args, **kwargs): + return self.api.get_run_state(*args, **kwargs) + + def entity_is_team(self, *args, **kwargs): + return self.api.entity_is_team(*args, **kwargs) + + def get_project_run_queues(self, *args, **kwargs): + return self.api.get_project_run_queues(*args, **kwargs) + + def push_to_run_queue(self, *args, **kwargs): + return self.api.push_to_run_queue(*args, **kwargs) + + def pop_from_run_queue(self, *args, **kwargs): + return self.api.pop_from_run_queue(*args, **kwargs) + + def ack_run_queue_item(self, *args, **kwargs): + return self.api.ack_run_queue_item(*args, **kwargs) + + def create_launch_agent(self, *args, **kwargs): + return self.api.create_launch_agent(*args, **kwargs) + + def create_default_resource_config(self, *args, **kwargs): + return self.api.create_default_resource_config(*args, **kwargs) + + def create_run_queue(self, *args, **kwargs): + return self.api.create_run_queue(*args, **kwargs) + + def upsert_run_queue(self, *args, **kwargs): + return self.api.upsert_run_queue(*args, **kwargs) + + def create_custom_chart(self, *args, **kwargs): + return self.api.create_custom_chart(*args, **kwargs) + + def update_launch_agent_status(self, *args, **kwargs): + return self.api.update_launch_agent_status(*args, **kwargs) + + def launch_agent_introspection(self, *args, **kwargs): + return self.api.launch_agent_introspection(*args, **kwargs) + + def fail_run_queue_item_introspection(self, *args, **kwargs): + return self.api.fail_run_queue_item_introspection(*args, **kwargs) + + def fail_run_queue_item(self, *args, **kwargs): + return self.api.fail_run_queue_item(*args, **kwargs) + + def update_run_queue_item_warning(self, *args, **kwargs): + return self.api.update_run_queue_item_warning(*args, **kwargs) + + def get_launch_agent(self, *args, **kwargs): + return self.api.get_launch_agent(*args, **kwargs) + + def stop_run(self, *args, **kwargs): + return self.api.stop_run(*args, **kwargs) + + +__all__ = ["Api"] diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/normalize.py b/.venv/lib/python3.13/site-packages/wandb/apis/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..e225faf77db49ff0768d065809d082c2ec0f74d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/normalize.py @@ -0,0 +1,84 @@ +"""normalize.""" + +from __future__ import annotations + +import ast +import sys +from functools import wraps +from typing import Callable, TypeVar + +from wandb_gql.client import RetryError + +from wandb import env +from wandb.errors import CommError, Error +from wandb.util import parse_backend_error_messages + +_F = TypeVar("_F", bound=Callable) + + +def normalize_exceptions(func: _F) -> _F: + """Function decorator for catching common errors and re-raising as wandb.Error.""" + + @wraps(func) + def wrapper(*args, **kwargs): + import requests + + message = "Whoa, you found a bug." + try: + return func(*args, **kwargs) + + except requests.HTTPError as error: + errors = parse_backend_error_messages(error.response) + status = error.response.status_code + + if errors: + message = f"HTTP {status}: {'; '.join(errors)}" + elif error.response.text: + message = f"HTTP {status}: {error.response.text}" + elif error.response.reason: + # Visually different to distinguish backend errors from + # standard HTTP status descriptions. + message = f"HTTP {status} ({error.response.reason})" + else: + message = f"HTTP {status}" + + raise CommError(message, error) + + except RetryError as err: + if ( + "response" in dir(err.last_exception) + and err.last_exception.response is not None + ): + try: + message = err.last_exception.response.json().get( + "errors", [{"message": message}] + )[0]["message"] + except ValueError: + message = err.last_exception.response.text + else: + message = err.last_exception + + if env.is_debug(): + raise err.last_exception.with_traceback(sys.exc_info()[2]) + else: + raise CommError(message, err.last_exception).with_traceback( + sys.exc_info()[2] + ) + except Error: + raise + except Exception as err: + # gql raises server errors with dict's as strings... + if len(err.args) > 0: + payload = err.args[0] + else: + payload = err + if str(payload).startswith("{"): + message = ast.literal_eval(str(payload))["message"] + else: + message = str(err) + if env.is_debug(): + raise + else: + raise CommError(message, err).with_traceback(sys.exc_info()[2]) + + return wrapper diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/paginator.py b/.venv/lib/python3.13/site-packages/wandb/apis/paginator.py new file mode 100644 index 0000000000000000000000000000000000000000..c116ac594b2aed1e8f98ce4d1af4c551a356a771 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/paginator.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Iterable, + Iterator, + Mapping, + Sized, + TypeVar, + overload, +) + +import wandb +from wandb._strutils import nameof + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb._pydantic import Connection + from wandb.apis.public.api import RetryingClient + +_WandbT = TypeVar("_WandbT") +"""Generic type variable for a W&B object.""" + +_NodeT = TypeVar("_NodeT") +"""Generic type variable for a parsed GraphQL relay node.""" + + +class Paginator(Iterator[_WandbT], ABC): + """An iterator for paginated objects from GraphQL requests.""" + + QUERY: Document | ClassVar[Document | None] + + def __init__( + self, + client: RetryingClient, + variables: Mapping[str, Any], + per_page: int = 50, # We don't allow unbounded paging + ): + self.client = client + + # shallow copy partly guards against mutating the original input + self.variables: dict[str, Any] = dict(variables) + + self.per_page: int = per_page + self.objects: list[_WandbT] = [] + self.index: int = -1 + self.last_response: object | None = None + + def __iter__(self) -> Iterator[_WandbT]: + self.index = -1 + return self + + @property + @abstractmethod + def more(self) -> bool: + """Whether there are more pages to be fetched.""" + raise NotImplementedError + + @property + @abstractmethod + def cursor(self) -> str | None: + """The start cursor to use for the next fetched page.""" + raise NotImplementedError + + @abstractmethod + def convert_objects(self) -> Iterable[_WandbT]: + """Convert the last fetched response data into the iterated objects.""" + raise NotImplementedError + + def update_variables(self) -> None: + """Update the query variables for the next page fetch.""" + self.variables.update({"perPage": self.per_page, "cursor": self.cursor}) + + def _update_response(self) -> None: + """Fetch and store the response data for the next page.""" + self.last_response = self.client.execute( + self.QUERY, variable_values=self.variables + ) + + def _load_page(self) -> bool: + """Fetch the next page, if any, returning True and storing the response if there was one.""" + if not self.more: + return False + self.update_variables() + self._update_response() + self.objects.extend(self.convert_objects()) + return True + + @overload + def __getitem__(self, index: int) -> _WandbT: ... + @overload + def __getitem__(self, index: slice) -> list[_WandbT]: ... + + def __getitem__(self, index: int | slice) -> _WandbT | list[_WandbT]: + loaded = True + stop = index.stop if isinstance(index, slice) else index + while loaded and stop > len(self.objects) - 1: + loaded = self._load_page() + return self.objects[index] + + def __next__(self) -> _WandbT: + self.index += 1 + if len(self.objects) <= self.index: + if not self._load_page(): + raise StopIteration + if len(self.objects) <= self.index: + raise StopIteration + return self.objects[self.index] + + next = __next__ + + +class SizedPaginator(Paginator[_WandbT], Sized, ABC): + """A Paginator for objects with a known total count.""" + + @property + def length(self) -> int | None: + wandb.termwarn( + ( + "`.length` is deprecated and will be removed in a future version. " + "Use `len(...)` instead." + ), + repeat=False, + ) + return len(self) + + def __len__(self) -> int: + if self._length is None: + self._load_page() + if self._length is None: + raise ValueError("Object doesn't provide length") + return self._length + + @property + @abstractmethod + def _length(self) -> int | None: + raise NotImplementedError + + +class RelayPaginator(Paginator[_WandbT], Generic[_NodeT, _WandbT], ABC): + """A Paginator for GQL relay-style nodes parsed via Pydantic. + + + """ + + last_response: Connection[_NodeT] | None + + @property + def more(self) -> bool: + return (conn := self.last_response) is None or conn.has_next + + @property + def cursor(self) -> str | None: + return conn.next_cursor if (conn := self.last_response) else None + + @abstractmethod + def _convert(self, node: _NodeT) -> _WandbT | Any: + """Convert a parsed GraphQL node into the iterated object. + + If a falsey value is returned, it will be skipped during iteration. + """ + raise NotImplementedError + + def convert_objects(self) -> Iterable[_WandbT]: + # Default implementation. Subclasses can override this if if more complex + # logic is needed, but ideally most shouldn't need to. + if conn := self.last_response: + yield from filter(None, map(self._convert, conn.nodes())) + + +class SizedRelayPaginator(RelayPaginator[_NodeT, _WandbT], Sized, ABC): + """A Paginator for GQL nodes parsed via Pydantic, with a known total count. + + + """ + + last_response: Connection[_NodeT] | None + + def __len__(self) -> int: + """Returns the total number of objects to expect.""" + # If the first page hasn't been fetched yet, do that first + if self.last_response is None: + self._load_page() + if (conn := self.last_response) and (total := conn.total_count) is not None: + return total + raise NotImplementedError(f"{nameof(type(self))!r} doesn't provide length") diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f38ba5f4b3065ed04a2017d461677b92c8cd192c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/__init__.py @@ -0,0 +1,86 @@ +__all__ = ( + "Api", + "RetryingClient", # doc:exclude + "requests", # doc:exclude + "ArtifactCollection", + "ArtifactCollections", + "ArtifactFiles", + "Artifacts", + "ArtifactType", + "ArtifactTypes", + "DownloadHistoryResult", + "RunArtifacts", + "Automations", + "File", + "Files", + "HistoryScan", # doc:exclude + "IncompleteRunHistoryError", + "SampledHistoryScan", # doc:exclude + "SlackIntegrations", # doc:exclude + "WebhookIntegrations", # doc:exclude + "Job", # doc:exclude + "QueuedRun", # doc:exclude + "RunQueue", # doc:exclude + "RunQueueAccessType", # doc:exclude + "RunQueuePrioritizationMode", # doc:exclude + "RunQueueResourceType", # doc:exclude + "Project", + "Projects", + "Sweeps", + "QueryGenerator", # doc:exclude + "Registry", + "Registries", # doc:exclude + "BetaReport", + "PanelMetricsHelper", # doc:exclude + "PythonMongoishQueryGenerator", # doc:exclude + "Reports", + "Run", + "Runs", + "Sweep", + "Member", + "Team", + "User", +) + + +from wandb.apis.public.api import Api, RetryingClient +from wandb.apis.public.artifacts import ( + ArtifactCollection, + ArtifactCollections, + ArtifactFiles, + Artifacts, + ArtifactType, + ArtifactTypes, + RunArtifacts, +) +from wandb.apis.public.automations import Automations +from wandb.apis.public.files import FILE_FRAGMENT, File, Files +from wandb.apis.public.history import BetaHistoryScan, HistoryScan, SampledHistoryScan +from wandb.apis.public.integrations import SlackIntegrations, WebhookIntegrations +from wandb.apis.public.jobs import ( + Job, + QueuedRun, + RunQueue, + RunQueueAccessType, + RunQueuePrioritizationMode, + RunQueueResourceType, +) +from wandb.apis.public.projects import Project, Projects, Sweeps +from wandb.apis.public.query_generator import QueryGenerator +from wandb.apis.public.registries import Registries, Registry +from wandb.apis.public.reports import ( + BetaReport, + PanelMetricsHelper, + PythonMongoishQueryGenerator, + Reports, +) +from wandb.apis.public.runs import ( + RUN_FRAGMENT, + DownloadHistoryResult, + IncompleteRunHistoryError, + Run, + Runs, +) +from wandb.apis.public.sweeps import Sweep +from wandb.apis.public.teams import Member, Team +from wandb.apis.public.users import User diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6183e4c492d7bef376f472848ab702b19f10c63e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/api.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/api.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8424560a4662f043b6167a9fcc1123c925e4ffc Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/api.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/artifacts.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/artifacts.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9229e7a919f09f39eaff8b86dd14743cd7fb15 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/artifacts.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/automations.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/automations.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d05e358d9aaa2d82f15372122a590a93bbe10126 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/automations.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/const.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/const.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79a2f6353a05ec066ff8bab09fc1e8d36fee3b44 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/const.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/files.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/files.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d3ad0d3c120347b3de34e61aa23d5e1095a33e5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/files.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/history.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/history.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78d73cb5ef66d51c2791c150b67a8db82bd861ef Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/history.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/integrations.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/integrations.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bcdb2fd19bb1b5844facc850f6e9bd80c9e028d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/integrations.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/jobs.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/jobs.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb5592333277ec7921062d13b0e4554846fc0f0 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/jobs.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/projects.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/projects.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adb43587140a85b02a9b52a7b6c5590918c58067 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/projects.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/query_generator.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/query_generator.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92b2e7fccf942b36baa23c2505bc2b4f66d669f8 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/query_generator.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/reports.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/reports.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4637979514663c94d8629bb229ab31f62445c629 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/reports.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/runs.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/runs.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8283d10b9c647a14a0f4357b5e143014c5f980c3 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/runs.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/sweeps.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/sweeps.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4730db98e80f75dce3b9b3c99e20898bddb7c494 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/sweeps.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/teams.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/teams.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1fc4c22ededc7c175cfe291609c55da72738e72 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/teams.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/users.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/users.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2f88671c24934d8f5293af1c49a2d290612e6c2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/users.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/utils.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..350c17cd9f41d181571a8b5a507bd8d41a5a9ef9 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/__pycache__/utils.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/api.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a3964e1dc334b05632c1ad5cfd65405f1ff79529 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/api.py @@ -0,0 +1,2459 @@ +"""Use the Public API to export or update data that you have saved to W&B. + +Before using this API, you'll want to log data from your script — check the +[Quickstart](https://docs.wandb.ai/quickstart) for more details. + +You might use the Public API to + - update metadata or metrics for an experiment after it has been completed, + - pull down your results as a dataframe for post-hoc analysis in a Jupyter notebook, or + - check your saved model artifacts for those tagged as `ready-to-deploy`. + +For more on using the Public API, check out [our guide](https://docs.wandb.com/guides/track/public-api-guide). +""" + +from __future__ import annotations + +import json +import logging +import os +import urllib +from http import HTTPStatus +from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal + +from pydantic import ValidationError +from typing_extensions import Unpack, overload +from wandb_gql import Client, gql +from wandb_gql.client import RetryError + +import wandb +from wandb import env, util +from wandb._analytics import tracked +from wandb._iterutils import one +from wandb._strutils import nameof +from wandb.apis import public +from wandb.apis.normalize import normalize_exceptions +from wandb.apis.public.const import RETRY_TIMEDELTA +from wandb.apis.public.registries import Registries, Registry +from wandb.apis.public.registries._utils import fetch_org_entity_from_organization +from wandb.apis.public.utils import ( + PathType, + fetch_org_from_settings_or_entity, + gql_compat, + parse_org_from_registry_path, +) +from wandb.errors import UsageError +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto.wandb_api_pb2 import ApiRequest, ApiResponse +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk import wandb_login, wandb_setup +from wandb.sdk.artifacts._gqlutils import resolve_org_entity_name, server_supports +from wandb.sdk.internal.internal_api import Api as InternalApi +from wandb.sdk.launch.utils import LAUNCH_DEFAULT_PROJECT +from wandb.sdk.lib import retry, runid, wbauth +from wandb.sdk.lib.deprecation import warn_and_record_deprecation +from wandb.sdk.lib.gql_request import GraphQLSession + +if TYPE_CHECKING: + from wandb.automations import ( + ActionType, + Automation, + EventType, + Integration, + NewAutomation, + SlackIntegration, + WebhookIntegration, + ) + from wandb.automations._utils import WriteAutomationsKwargs + from wandb.sdk.artifacts.artifact import Artifact + + from .artifacts import ( + ArtifactCollection, + ArtifactCollections, + Artifacts, + ArtifactType, + ArtifactTypes, + ) + from .teams import Team + from .users import User + +logger = logging.getLogger(__name__) + + +class RetryingClient: + """A GraphQL client that retries requests on failure. + + + """ + + INFO_QUERY = gql( + """ + query ServerInfo{ + serverInfo { + cliVersionInfo + latestLocalVersionInfo { + outOfDate + latestVersionString + versionOnThisInstanceString + } + } + } + """ + ) + + def __init__(self, client: Client): + self._server_info = None + self._client = client + self._execute_decorated: Callable[..., Any] | None = None + + def execute(self, *args, **kwargs): + if self._execute_decorated is None: + self._execute_decorated = self._build_execute_wrapper() + return self._execute_decorated(*args, **kwargs) + + def _build_execute_wrapper(self) -> Callable[..., Any]: + import requests + + @retry.retriable( + retry_timedelta=RETRY_TIMEDELTA, + check_retry_fn=util.no_retry_auth, + retryable_exceptions=(RetryError, requests.RequestException), + ) + def _wrapped(*args, **kwargs): + try: + return self._client.execute(*args, **kwargs) + except requests.exceptions.ReadTimeout: + if "timeout" not in kwargs: + timeout = self._client.transport.default_timeout + wandb.termwarn( + f"A graphql request initiated by the public wandb API timed out (timeout={timeout} sec). " + f"Create a new API with an integer timeout larger than {timeout}, e.g., " + f"`api = wandb.Api(timeout={timeout + 10})` to increase the graphql timeout." + ) + raise + + return _wrapped + + @property + def app_url(self): + return util.app_url(self._client.transport.url.replace("/graphql", "")) + "/" + + @property + def server_info(self): + if self._server_info is None: + self._server_info = self.execute(self.INFO_QUERY).get("serverInfo") + return self._server_info + + def version_supported( + self, min_version: str + ) -> bool: # User not encouraged to use this class directly + from packaging.version import parse + + return parse(min_version) <= parse( + self.server_info["cliVersionInfo"]["max_cli_version"] + ) + + +class Api: + """Used for querying the W&B server. + + Examples: + ```python + import wandb + + wandb.Api() + ``` + """ + + _HTTP_TIMEOUT = env.get_http_timeout(19) + + def __init__( + self, + overrides: dict[str, Any] | None = None, + timeout: int | None = None, + api_key: str | None = None, + ) -> None: + """Initialize the API. + + Args: + overrides: You can set `base_url` if you are + using a W&B server other than `https://api.wandb.ai`. You can also + set defaults for `entity`, `project`, and `run`. + timeout: HTTP timeout in seconds for API requests. If not + specified, the default timeout will be used. + api_key: API key to use for authentication. If not provided, + the API key from the current environment or configuration will be used. + Prompts for an API key if none is provided + or configured in the environment. + """ + self.settings = InternalApi().settings() + self.settings.update(overrides or {}) + self.settings["base_url"] = self.settings["base_url"].rstrip("/") + + if api_key: + self.api_key = api_key + else: + self.api_key = self._load_api_key( + base_url=self.settings["base_url"], + ) + + wandb_login._verify_login( + key=self.api_key, + base_url=self.settings["base_url"], + ) + + self._viewer = None + self._projects = {} + self._runs = {} + self._sweeps = {} + self._reports = {} + self._default_entity = None + self._timeout = timeout if timeout is not None else self._HTTP_TIMEOUT + proxies = self.settings.get("_proxies") or json.loads( + os.environ.get("WANDB__PROXIES", "{}") + ) + self._base_client = Client( + transport=GraphQLSession( + headers={ + "User-Agent": self.user_agent, + "Use-Admin-Privileges": "true", + }, + use_json=True, + # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s + # https://bugs.python.org/issue22889 + timeout=self._timeout, + auth=("api", self.api_key), + url="{}/graphql".format(self.settings["base_url"]), + proxies=proxies, + ) + ) + self._client = RetryingClient(self._base_client) + self._sentry = wandb.analytics.sentry.Sentry(pid=os.getpid()) + self._configure_sentry() + + self._backend: wandb.sdk.backend.backend.Backend | None = None + self._service = None + + def _start_backend_service(self): + """Starts the backend service and initializes resources to enable handling API requests.""" + from wandb.sdk import wandb_setup + + self._stream_id = str(runid.generate_id()) + singleton = wandb_setup.singleton() + self._settings = singleton.settings.model_copy() + self._settings.base_url = self.settings["base_url"] + self._settings.silent = True + + self._service = singleton.ensure_service() + self._service.api_init_request(self._settings.to_proto()) + + def _load_api_key(self, base_url: str) -> str: + """Load or prompt for an API key.""" + auth = wbauth.authenticate_session( + host=base_url, + source="wandb.Api()", + no_offline=True, + input_timeout=wandb_setup.singleton().settings.login_timeout, + ) + + if not auth: + raise UsageError("No API key configured. Use `wandb login` to log in.") + if not isinstance(auth, wbauth.AuthApiKey): + message = ( + "wandb.Api() can only use API key authentication, but you have" + " another form of credentials configured." + " Check if you have set WANDB_IDENTITY_TOKEN_FILE." + f" Current credentials: {auth}" + ) + raise UsageError(message) + + return auth.api_key + + def _configure_sentry(self) -> None: + if not env.error_reporting_enabled(): + return + + import requests + + try: + viewer = self.viewer + except (ValueError, requests.RequestException): + # we need the viewer to configure the entity, and user email + return + + email = viewer.email if viewer else None + entity = self.default_entity + + self._sentry.configure_scope( + tags={ + "entity": entity, + "email": email, + }, + ) + + def _send_api_request( + self, + request: ApiRequest, + timeout: float | None = None, + ) -> ApiResponse: + """Sends an API request to the backend service. + + Creates the backend service attribute if it has not been created yet. + + TODO: remove this helper function once all requests are routed through wandb-core. + The backend service should be created and initalized + during the instantiation of the Api object. + """ + if self._service is None: + self._start_backend_service() + + assert self._service is not None + return self._service.api_request(request, timeout=timeout) + + def create_project(self, name: str, entity: str) -> None: + """Create a new project. + + Args: + name: The name of the new project. + entity: The entity of the new project. + """ + from wandb.apis._generated import CREATE_PROJECT_GQL, UpsertModelInput + + gql_input = UpsertModelInput(name=name, entity_name=entity) + self.client.execute(gql(CREATE_PROJECT_GQL), {"input": gql_input.model_dump()}) + + def create_run( + self, + *, + run_id: str | None = None, + project: str | None = None, + entity: str | None = None, + ) -> public.Run: + """Create a new run. + + Args: + run_id: The ID to assign to the run. If not specified, W&B + creates a random ID. + project: The project where to log the run to. If no project is specified, + log the run to a project called "Uncategorized". + entity: The entity that owns the project. If no entity is + specified, log the run to the default entity. + + Returns: + The newly created `Run`. + """ + if entity is None: + entity = self.default_entity + return public.Run.create(self, run_id=run_id, project=project, entity=entity) + + def create_run_queue( + self, + name: str, + type: public.RunQueueResourceType, + entity: str | None = None, + prioritization_mode: public.RunQueuePrioritizationMode | None = None, + config: dict | None = None, + template_variables: dict | None = None, + ) -> public.RunQueue: + """Create a new run queue in W&B Launch. + + Args: + name: Name of the queue to create + type: Type of resource to be used for the queue. One of + "local-container", "local-process", "kubernetes","sagemaker", + or "gcp-vertex". + entity: Name of the entity to create the queue. If `None`, use + the configured or default entity. + prioritization_mode: Version of prioritization to use. + Either "V0" or `None`. + config: Default resource configuration to be used for the queue. + Use handlebars (eg. `{{var}}`) to specify template variables. + template_variables: A dictionary of template variable schemas to + use with the config. + + Returns: + The newly created `RunQueue`. + + Raises: + `ValueError` if any of the parameters are invalid + `wandb.Error` on wandb API errors + """ + # TODO(np): Need to check server capabilities for this feature + # 0. assert params are valid/normalized + if entity is None: + entity = self.settings["entity"] or self.default_entity + if entity is None: + raise ValueError( + "entity must be passed as a parameter, or set in settings" + ) + + if len(name) == 0: + raise ValueError("name must be non-empty") + if len(name) > 64: + raise ValueError("name must be less than 64 characters") + + if type not in [ + "local-container", + "local-process", + "kubernetes", + "sagemaker", + "gcp-vertex", + ]: + raise ValueError( + "resource_type must be one of 'local-container', 'local-process', 'kubernetes', 'sagemaker', or 'gcp-vertex'" + ) + + if prioritization_mode: + prioritization_mode = prioritization_mode.upper() + if prioritization_mode not in ["V0"]: + raise ValueError("prioritization_mode must be 'V0' if present") + + if config is None: + config = {} + + # 1. create required default launch project in the entity + self.create_project(LAUNCH_DEFAULT_PROJECT, entity) + + api = InternalApi( + default_settings={ + "entity": entity, + "project": self.project(LAUNCH_DEFAULT_PROJECT), + }, + retry_timedelta=RETRY_TIMEDELTA, + ) + + # 2. create default resource config, receive config id + config_json = json.dumps({"resource_args": {type: config}}) + create_config_result = api.create_default_resource_config( + entity, type, config_json, template_variables + ) + if not create_config_result["success"]: + raise wandb.Error("failed to create default resource config") + config_id = create_config_result["defaultResourceConfigID"] + + # 3. create run queue + create_queue_result = api.create_run_queue( + entity, + LAUNCH_DEFAULT_PROJECT, + name, + "PROJECT", + prioritization_mode, + config_id, + ) + if not create_queue_result["success"]: + raise wandb.Error("failed to create run queue") + + return public.RunQueue( + client=self.client, + name=name, + entity=entity, + prioritization_mode=prioritization_mode, + _access="PROJECT", + _default_resource_config_id=config_id, + _default_resource_config=config, + ) + + def create_custom_chart( + self, + entity: str, + name: str, + display_name: str, + spec_type: Literal["vega2"], + access: Literal["private", "public"], + spec: str | dict, + ) -> str: + """Create a custom chart preset and return its id. + + Args: + entity: The entity (user or team) that owns the chart + name: Unique identifier for the chart preset + display_name: Human-readable name shown in the UI + spec_type: Type of specification. Must be "vega2" for Vega-Lite v2 specifications. + access: Access level for the chart: + - "private": Chart is only accessible to the entity that created it + - "public": Chart is publicly accessible + spec: The Vega/Vega-Lite specification as a dictionary or JSON string + + Returns: + The ID of the created chart preset in the format "entity/name" + + Raises: + wandb.Error: If chart creation fails + UnsupportedError: If the server doesn't support custom charts + + Example: + ```python + import wandb + + api = wandb.Api() + + # Define a simple bar chart specification + vega_spec = { + "$schema": "https://vega.github.io/schema/vega-lite/v6.json", + "mark": "bar", + "data": {"name": "wandb"}, + "encoding": { + "x": {"field": "${field:x}", "type": "ordinal"}, + "y": {"field": "${field:y}", "type": "quantitative"}, + }, + } + + # Create the custom chart + chart_id = api.create_custom_chart( + entity="my-team", + name="my-bar-chart", + display_name="My Custom Bar Chart", + spec_type="vega2", + access="private", + spec=vega_spec, + ) + + # Use with wandb.plot_table() + chart = wandb.plot_table( + vega_spec_name=chart_id, + data_table=my_table, + fields={"x": "category", "y": "value"}, + ) + ``` + """ + # Convert user-facing lowercase access to backend uppercase + backend_access = access.upper() + + api = InternalApi(retry_timedelta=RETRY_TIMEDELTA) + result = api.create_custom_chart( + entity=entity, + name=name, + display_name=display_name, + spec_type=spec_type, + access=backend_access, + spec=spec, + ) + if result is None or result.get("chart") is None: + raise wandb.Error("failed to create custom chart") + return result["chart"]["id"] + + def upsert_run_queue( + self, + name: str, + resource_config: dict, + resource_type: public.RunQueueResourceType, + entity: str | None = None, + template_variables: dict | None = None, + external_links: dict | None = None, + prioritization_mode: public.RunQueuePrioritizationMode | None = None, + ): + """Upsert a run queue in W&B Launch. + + Args: + name: Name of the queue to create + entity: Optional name of the entity to create the queue. If `None`, + use the configured or default entity. + resource_config: Optional default resource configuration to be used + for the queue. Use handlebars (eg. `{{var}}`) to specify + template variables. + resource_type: Type of resource to be used for the queue. One of + "local-container", "local-process", "kubernetes", "sagemaker", + or "gcp-vertex". + template_variables: A dictionary of template variable schemas to + be used with the config. + external_links: Optional dictionary of external links to be used + with the queue. + prioritization_mode: Optional version of prioritization to use. + Either "V0" or None + + Returns: + The upserted `RunQueue`. + + Raises: + ValueError if any of the parameters are invalid + wandb.Error on wandb API errors + """ + if entity is None: + entity = self.settings["entity"] or self.default_entity + if entity is None: + raise ValueError( + "entity must be passed as a parameter, or set in settings" + ) + + if len(name) == 0: + raise ValueError("name must be non-empty") + if len(name) > 64: + raise ValueError("name must be less than 64 characters") + + prioritization_mode = prioritization_mode or "DISABLED" + prioritization_mode = prioritization_mode.upper() + if prioritization_mode not in ["V0", "DISABLED"]: + raise ValueError( + "prioritization_mode must be 'V0' or 'DISABLED' if present" + ) + + if resource_type not in [ + "local-container", + "local-process", + "kubernetes", + "sagemaker", + "gcp-vertex", + ]: + raise ValueError( + "resource_type must be one of 'local-container', 'local-process', 'kubernetes', 'sagemaker', or 'gcp-vertex'" + ) + + self.create_project(LAUNCH_DEFAULT_PROJECT, entity) + api = InternalApi( + default_settings={ + "entity": entity, + "project": self.project(LAUNCH_DEFAULT_PROJECT), + }, + retry_timedelta=RETRY_TIMEDELTA, + ) + # User provides external_links as a dict with name: url format + # but backend stores it as a list of dicts with url and label keys. + external_links = external_links or {} + external_links = { + "links": [ + { + "label": key, + "url": value, + } + for key, value in external_links.items() + ] + } + upsert_run_queue_result = api.upsert_run_queue( + name, + entity, + resource_type, + {"resource_args": {resource_type: resource_config}}, + template_variables=template_variables, + external_links=external_links, + prioritization_mode=prioritization_mode, + ) + if not upsert_run_queue_result["success"]: + raise wandb.Error("failed to create run queue") + schema_errors = ( + upsert_run_queue_result.get("configSchemaValidationErrors") or [] + ) + for error in schema_errors: + wandb.termwarn(f"resource config validation: {error}") + + return public.RunQueue( + client=self.client, + name=name, + entity=entity, + ) + + def create_user(self, email: str, admin: bool | None = False) -> User: + """Create a new user. + + Args: + email: The email address of the user. + admin: Set user as a global instance administrator. + + Returns: + A `User` object. + """ + from .users import User + + return User.create(self, email, admin) + + def sync_tensorboard(self, root_dir, run_id=None, project=None, entity=None): + """Sync a local directory containing tfevent files to wandb.""" + from wandb.sync import SyncManager # TODO: circular import madness + + run_id = run_id or runid.generate_id() + project = project or self.settings.get("project") or "uncategorized" + entity = entity or self.default_entity + # TODO: pipe through log_path to inform the user how to debug + sm = SyncManager( + project=project, + entity=entity, + run_id=run_id, + mark_synced=False, + app_url=self.client.app_url, + view=False, + verbose=False, + sync_tensorboard=True, + ) + sm.add(root_dir) + sm.start() + while not sm.is_done(): + _ = sm.poll() + return self.run("/".join([entity, project, run_id])) + + @property + def client(self) -> RetryingClient: + """Returns the client object.""" + return self._client + + @property + def user_agent(self) -> str: + """Returns W&B public user agent.""" + return "W&B Public Client {}".format(wandb.__version__) + + @property + def default_entity(self) -> str | None: + """Returns the default W&B entity.""" + from wandb.apis._generated import GET_DEFAULT_ENTITY_GQL, GetDefaultEntity + + if self._default_entity is None: + data = self._client.execute(gql(GET_DEFAULT_ENTITY_GQL)) + result = GetDefaultEntity.model_validate(data) + if (viewer := result.viewer) and (entity := viewer.entity): + self._default_entity = entity + return self._default_entity + + @property + def viewer(self) -> User: + """Returns the viewer object. + + Raises: + ValueError: If viewer data is not able to be fetched from W&B. + requests.RequestException: If an error occurs while making the graphql request. + """ + from wandb.apis._generated import GET_VIEWER_GQL, GetViewer + + from .users import User + + if self._viewer is None: + data = self._client.execute(gql(GET_VIEWER_GQL)) + result = GetViewer.model_validate(data) + if (viewer := result.viewer) is None: + msg = "Unable to fetch user data from W&B, please verify your API key is valid." + raise ValueError(msg) + self._viewer = User(self._client, viewer.model_dump()) + self._default_entity = self._viewer.entity + return self._viewer + + def flush(self): + """Flush the local cache. + + The api object keeps a local cache of runs, so if the state of the run + may change while executing your script you must clear the local cache + with `api.flush()` to get the latest values associated with the run. + """ + self._runs = {} + + def from_path(self, path: str): + """Return a run, sweep, project or report from a path. + + Args: + path: The path to the project, run, sweep or report + + Returns: + A `Project`, `Run`, `Sweep`, or `BetaReport` instance. + + Raises: + `wandb.Error` if path is invalid or the object doesn't exist. + + Examples: + In the proceeding code snippets "project", "team", "run_id", "sweep_id", + and "report_name" are placeholders for the project, team, run ID, + sweep ID, and the name of a specific report, respectively. + + ```python + import wandb + + api = wandb.Api() + + project = api.from_path("project") + team_project = api.from_path("team/project") + run = api.from_path("team/project/runs/run_id") + sweep = api.from_path("team/project/sweeps/sweep_id") + report = api.from_path("team/project/reports/report_name") + ``` + """ + parts = path.strip("/ ").split("/") + if len(parts) == 1: + return self.project(path) + elif len(parts) == 2: + return self.project(parts[1], parts[0]) + elif len(parts) == 3: + return self.run(path) + elif len(parts) == 4: + if parts[2].startswith("run"): + return self.run(path) + elif parts[2].startswith("sweep"): + return self.sweep(path) + elif parts[2].startswith("report"): + if "--" not in parts[-1]: + if "-" in parts[-1]: + raise wandb.Error( + "Invalid report path, should be team/project/reports/Name--XXXX" + ) + else: + parts[-1] = "--" + parts[-1] + name, id = parts[-1].split("--") + return public.BetaReport( + self.client, + { + "displayName": urllib.parse.unquote(name.replace("-", " ")), + "id": id, + "spec": "{}", + }, + parts[0], + parts[1], + ) + raise wandb.Error( + "Invalid path, should be TEAM/PROJECT/TYPE/ID where TYPE is runs, sweeps, or reports" + ) + + def _parse_project_path(self, path): + """Return project and entity for project specified by path.""" + project = self.settings["project"] or "uncategorized" + entity = self.settings["entity"] or self.default_entity + if path is None: + return entity, project + parts = path.split("/", 1) + if len(parts) == 1: + return entity, path + return parts + + def _parse_path(self, path): + """Parse url, filepath, or docker paths. + + Allows paths in the following formats: + - url: entity/project/runs/id + - path: entity/project/id + - docker: entity/project:id + + Entity is optional and will fall back to the current logged-in user. + """ + project = self.settings["project"] or "uncategorized" + entity = self.settings["entity"] or self.default_entity + parts = ( + path.replace("/runs/", "/").replace("/sweeps/", "/").strip("/ ").split("/") + ) + if ":" in parts[-1]: + id = parts[-1].split(":")[-1] + parts[-1] = parts[-1].split(":")[0] + elif parts[-1]: + id = parts[-1] + if len(parts) == 1 and project != "uncategorized": + pass + elif len(parts) > 1: + project = parts[1] + if entity and id == project: + project = parts[0] + else: + entity = parts[0] + if len(parts) == 3: + entity = parts[0] + else: + project = parts[0] + return entity, project, id + + @overload + def _parse_artifact_path(self, path: None) -> tuple[str | None, str]: ... + @overload + def _parse_artifact_path(self, path: str) -> tuple[str | None, str, str]: ... + + def _parse_artifact_path(self, path: str | None) -> tuple[str | None, ...]: + """Return project, entity and artifact name for project specified by path.""" + from wandb.sdk.artifacts._validators import ArtifactPath + + project = self.settings["project"] or "uncategorized" + entity = self.settings["entity"] or self.default_entity + if path is None: + return entity, project + + parsed = ArtifactPath.from_str(path) + parsed = parsed.with_defaults(prefix=entity, project=project) + return parsed.prefix, parsed.project, parsed.name + + def projects( + self, entity: str | None = None, per_page: int = 200 + ) -> public.Projects: + """Get projects for a given entity. + + Args: + entity: Name of the entity requested. If None, will fall back to + the default entity passed to `Api`. If no default entity, + will raise a `ValueError`. + per_page: Sets the page size for query pagination. + Usually there is no reason to change this. + + Returns: + A `Projects` object which is an iterable collection of `Project`objects. + """ + if entity is None: + entity = self.settings["entity"] or self.default_entity + if entity is None: + raise ValueError( + "entity must be passed as a parameter, or set in settings" + ) + if entity not in self._projects: + self._projects[entity] = public.Projects( + self.client, entity, per_page=per_page + ) + return self._projects[entity] + + def project(self, name: str, entity: str | None = None) -> public.Project: + """Return the `Project` with the given name (and entity, if given). + + Args: + name: The project name. + entity: Name of the entity requested. If None, will fall back to the + default entity passed to `Api`. If no default entity, will + raise a `ValueError`. + + Returns: + A `Project` object. + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + # For registry artifacts, capture potential org user inputted before resolving entity + org = entity if is_artifact_registry_project(name) else "" + + if entity is None: + entity = self.settings["entity"] or self.default_entity + + # For registry artifacts, resolve org-based entity + if is_artifact_registry_project(name): + settings_entity = self.settings["entity"] or self.default_entity + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=org + ) + return public.Project(self.client, entity, name, {}) + + def reports( + self, path: str = "", name: str | None = None, per_page: int = 50 + ) -> public.Reports: + """Get reports for a given project path. + + Note: `wandb.Api.reports()` API is in beta and will likely change in + future releases. + + Args: + path: The path to the project the report resides in. Specify the + entity that created the project as a prefix followed by a + forward slash. + name: Name of the report requested. + per_page: Sets the page size for query pagination. + Usually there is no reason to change this. + + Returns: + A `Reports` object which is an iterable collection of + `BetaReport` objects. + + Examples: + ```python + import wandb + + wandb.Api.reports("entity/project") + ``` + """ + entity, project, _ = self._parse_path(path + "/fake_run") + + if name: + name = urllib.parse.unquote(name) + key = "/".join([entity, project, str(name)]) + else: + key = "/".join([entity, project]) + + if key not in self._reports: + self._reports[key] = public.Reports( + self.client, + public.Project(self.client, entity, project, {}), + name=name, + per_page=per_page, + ) + return self._reports[key] + + def create_team(self, team: str, admin_username: str | None = None) -> Team: + """Create a new team. + + Args: + team: The name of the team + admin_username: Username of the admin user of the team. + Defaults to the current user. + + Returns: + A `Team` object. + """ + from .teams import Team + + return Team.create(self, team, admin_username) + + def team(self, team: str) -> Team: + """Return the matching `Team` with the given name. + + Args: + team: The name of the team. + + Returns: + A `Team` object. + """ + from .teams import Team + + return Team(self.client, team) + + def user(self, username_or_email: str) -> User | None: + """Return a user from a username or email address. + + This function only works for local administrators. Use `api.viewer` + to get your own user object. + + Args: + username_or_email: The username or email address of the user. + + Returns: + A `User` object or None if a user is not found. + """ + from wandb.apis._generated import SEARCH_USERS_GQL, SearchUsers + + from .users import User + + data = self._client.execute(gql(SEARCH_USERS_GQL), {"query": username_or_email}) + result = SearchUsers.model_validate(data) + if not (conn := result.users) or not (edges := conn.edges): + return None + if len(edges) > 1: + msg = f"Found multiple users, returning the first user matching {username_or_email!r}" + wandb.termwarn(msg) + return User(self._client, edges[0].node.model_dump()) + + def users(self, username_or_email: str) -> list[User]: + """Return all users from a partial username or email address query. + + This function only works for local administrators. Use `api.viewer` + to get your own user object. + + Args: + username_or_email: The prefix or suffix of the user you want to find. + + Returns: + An array of `User` objects. + """ + from wandb.apis._generated import SEARCH_USERS_GQL, SearchUsers + + from .users import User + + data = self._client.execute(gql(SEARCH_USERS_GQL), {"query": username_or_email}) + result = SearchUsers.model_validate(data) + if not ((conn := result.users) and (edges := conn.edges)): + return [] + return [User(self._client, edge.node.model_dump()) for edge in edges] + + def runs( + self, + path: str | None = None, + filters: dict[str, Any] | None = None, + order: str = "+created_at", + per_page: int = 50, + include_sweeps: bool = True, + lazy: bool = True, + ): + """Returns a `Runs` object, which lazily iterates over `Run` objects. + + Fields you can filter by include: + - `createdAt`: The timestamp when the run was created. (in ISO 8601 format, e.g. "2023-01-01T12:00:00Z") + - `displayName`: The human-readable display name of the run. (e.g. "eager-fox-1") + - `duration`: The total runtime of the run in seconds. + - `group`: The group name used to organize related runs together. + - `host`: The hostname where the run was executed. + - `jobType`: The type of job or purpose of the run. + - `name`: The unique identifier of the run. (e.g. "a1b2cdef") + - `state`: The current state of the run. + - `tags`: The tags associated with the run. + - `username`: The username of the user who initiated the run + + Additionally, you can filter by items in the run config or summary metrics. + Such as `config.experiment_name`, `summary_metrics.loss`, etc. + + For more complex filtering, you can use MongoDB query operators. + For details, see: https://docs.mongodb.com/manual/reference/operator/query + The following operations are supported: + - `$and` + - `$or` + - `$nor` + - `$eq` + - `$ne` + - `$gt` + - `$gte` + - `$lt` + - `$lte` + - `$in` + - `$nin` + - `$exists` + - `$regex` + + + + Args: + path: (str) path to project, should be in the form: "entity/project" + filters: (dict) queries for specific runs using the MongoDB query language. + You can filter by run properties such as config.key, summary_metrics.key, state, entity, createdAt, etc. + For example: `{"config.experiment_name": "foo"}` would find runs with a config entry + of experiment name set to "foo" + order: (str) Order can be `created_at`, `heartbeat_at`, `config.*.value`, or `summary_metrics.*`. + If you prepend order with a + order is ascending (default). + If you prepend order with a - order is descending. + The default order is run.created_at from oldest to newest. + per_page: (int) Sets the page size for query pagination. + include_sweeps: (bool) Whether to include the sweep runs in the results. + lazy: (bool) Whether to use lazy loading for faster performance. + When True (default), only essential run metadata is loaded initially. + Heavy fields like config, summaryMetrics, and systemMetrics are loaded + on-demand when accessed. Set to False for full data upfront. + + Returns: + A `Runs` object, which is an iterable collection of `Run` objects. + + Examples: + ```python + import wandb + from wandb.apis.public import Api + + # Find runs in project where config.experiment_name has been set to "foo" + Api.runs(path="my_entity/project", filters={"config.experiment_name": "foo"}) + ``` + + ```python + # Find runs in project where config.experiment_name has been set to "foo" or "bar" + Api.runs( + path="my_entity/project", + filters={ + "$or": [ + {"config.experiment_name": "foo"}, + {"config.experiment_name": "bar"}, + ] + }, + ) + ``` + + ```python + # Find runs in project where config.experiment_name matches a regex + # (anchors are not supported) + Api.runs( + path="my_entity/project", + filters={"config.experiment_name": {"$regex": "b.*"}}, + ) + ``` + + ```python + # Find runs in project where the run name matches a regex + # (anchors are not supported) + Api.runs( + path="my_entity/project", filters={"display_name": {"$regex": "^foo.*"}} + ) + ``` + + ```python + # Find runs in project sorted by ascending loss + Api.runs(path="my_entity/project", order="+summary_metrics.loss") + ``` + """ + entity, project = self._parse_project_path(path) + filters = filters or {} + key = (path or "") + str(filters) + str(order) + + # Check if we have cached results + if self._runs.get(key): + cached_runs = self._runs[key] + # If requesting full data but cached data is lazy, upgrade it + if not lazy and cached_runs._lazy: + cached_runs.upgrade_to_full() + return cached_runs + + # Create new Runs object + self._runs[key] = public.Runs( + self.client, + entity, + project, + api=self, + filters=filters, + order=order, + per_page=per_page, + include_sweeps=include_sweeps, + lazy=lazy, + ) + return self._runs[key] + + @normalize_exceptions + def run(self, path=""): + """Return a single run by parsing path in the form `entity/project/run_id`. + + Args: + path: Path to run in the form `entity/project/run_id`. + If `api.entity` is set, this can be in the form `project/run_id` + and if `api.project` is set this can just be the run_id. + + Returns: + A `Run` object. + """ + entity, project, run_id = self._parse_path(path) + if not self._runs.get(path): + # Individual runs should load full data by default + self._runs[path] = public.Run( + self.client, + entity, + project, + run_id, + api=self, + lazy=False, + ) + return self._runs[path] + + def queued_run( + self, + entity: str, + project: str, + queue_name: str, + run_queue_item_id: str, + project_queue=None, + priority=None, + ): + """Return a single queued run based on the path. + + Parses paths of the form `entity/project/queue_id/run_queue_item_id`. + """ + return public.QueuedRun( + self.client, + entity, + project, + queue_name, + run_queue_item_id, + project_queue=project_queue, + priority=priority, + ) + + def run_queue( + self, + entity: str, + name: str, + ): + """Return the named `RunQueue` for entity. + + See `Api.create_run_queue` for more information on how to create a run queue. + """ + return public.RunQueue( + self.client, + name, + entity, + ) + + @normalize_exceptions + def sweep(self, path=""): + """Return a sweep by parsing path in the form `entity/project/sweep_id`. + + Args: + path: Path to sweep in the form entity/project/sweep_id. + If `api.entity` is set, this can be in the form + project/sweep_id and if `api.project` is set + this can just be the sweep_id. + + Returns: + A `Sweep` object. + """ + entity, project, sweep_id = self._parse_path(path) + if not self._sweeps.get(path): + self._sweeps[path] = public.Sweep(self.client, entity, project, sweep_id) + return self._sweeps[path] + + @normalize_exceptions + def artifact_types(self, project: str | None = None) -> ArtifactTypes: + """Returns a collection of matching artifact types. + + Args: + project: The project name or path to filter on. + + Returns: + An iterable `ArtifactTypes` object. + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + from .artifacts import ArtifactTypes + + project_path = project + entity, project = self._parse_project_path(project_path) + # If its a Registry project, the entity is considered to be an org instead + if is_artifact_registry_project(project): + settings_entity = self.settings["entity"] or self.default_entity + org = parse_org_from_registry_path(project_path, PathType.PROJECT) + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=org + ) + return ArtifactTypes(self.client, entity, project) + + @normalize_exceptions + def artifact_type(self, type_name: str, project: str | None = None) -> ArtifactType: + """Returns the matching `ArtifactType`. + + Args: + type_name: The name of the artifact type to retrieve. + project: If given, a project name or path to filter on. + + Returns: + An `ArtifactType` object. + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + from .artifacts import ArtifactType + + project_path = project + entity, project = self._parse_project_path(project_path) + # If its an Registry artifact, the entity is an org instead + if is_artifact_registry_project(project): + org = parse_org_from_registry_path(project_path, PathType.PROJECT) + settings_entity = self.settings["entity"] or self.default_entity + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=org + ) + return ArtifactType(self.client, entity, project, type_name) + + @normalize_exceptions + def artifact_collections( + self, project_name: str, type_name: str, per_page: int = 50 + ) -> ArtifactCollections: + """Returns a collection of matching artifact collections. + + Args: + project_name: The name of the project to filter on. + type_name: The name of the artifact type to filter on. + per_page: Sets the page size for query pagination. + Usually there is no reason to change this. + + Returns: + An iterable `ArtifactCollections` object. + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + from .artifacts import ArtifactCollections + + entity, project = self._parse_project_path(project_name) + # If iterating through Registry project, the entity is considered to be an org instead + if is_artifact_registry_project(project): + org = parse_org_from_registry_path(project_name, PathType.PROJECT) + settings_entity = self.settings["entity"] or self.default_entity + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=org + ) + return ArtifactCollections( + self.client, entity, project, type_name, per_page=per_page + ) + + @normalize_exceptions + def artifact_collection(self, type_name: str, name: str) -> ArtifactCollection: + """Returns a single artifact collection by type. + + You can use the returned `ArtifactCollection` object to retrieve + information about specific artifacts in that collection, and more. + + Args: + type_name: The type of artifact collection to fetch. + name: An artifact collection name. Optionally append the entity + that logged the artifact as a prefix followed by a forward + slash. + + Returns: + An `ArtifactCollection` object. + + Examples: + In the proceeding code snippet "type", "entity", "project", and + "artifact_name" are placeholders for the collection type, your W&B + entity, name of the project the artifact is in, and the name of + the artifact, respectively. + + ```python + import wandb + + collections = wandb.Api().artifact_collection( + type_name="type", name="entity/project/artifact_name" + ) + + # Get the first artifact in the collection + artifact_example = collections.artifacts()[0] + + # Download the contents of the artifact to the specified root directory. + artifact_example.download() + ``` + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + from .artifacts import ArtifactCollection + + entity, project, collection_name = self._parse_artifact_path(name) + # If its an Registry artifact, the entity is considered to be an org instead + if is_artifact_registry_project(project): + org = parse_org_from_registry_path(name, PathType.ARTIFACT) + settings_entity = self.settings["entity"] or self.default_entity + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=org + ) + + if entity is None: + raise ValueError( + "Could not determine entity. Please include the entity as part of the collection name path." + ) + + return ArtifactCollection( + self.client, entity, project, collection_name, type_name + ) + + @normalize_exceptions + def artifact_versions(self, type_name, name, per_page=50): + """Deprecated. Use `Api.artifacts(type_name, name)` method instead.""" + warn_and_record_deprecation( + feature=Deprecated(api__artifact_versions=True), + message=( + "Api.artifact_versions(type_name, name) is deprecated, " + "use Api.artifacts(type_name, name) instead." + ), + ) + return self.artifacts(type_name, name, per_page=per_page) + + @normalize_exceptions + def artifacts( + self, + type_name: str, + name: str, + per_page: int = 50, + tags: list[str] | None = None, + ) -> Artifacts: + """Return an `Artifacts` collection. + + Args: + type_name: The type of artifacts to fetch. + name: The artifact's collection name. Optionally append the + entity that logged the artifact as a prefix followed by + a forward slash. + per_page: Sets the page size for query pagination. Usually + there is no reason to change this. + tags: Only return artifacts with all of these tags. + + Returns: + An iterable `Artifacts` object. + + Examples: + In the proceeding code snippet, "type", "entity", "project", and + "artifact_name" are placeholders for the artifact type, W&B entity, + name of the project the artifact was logged to, + and the name of the artifact, respectively. + + ```python + import wandb + + wandb.Api().artifacts(type_name="type", name="entity/project/artifact_name") + ``` + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + from .artifacts import Artifacts + + entity, project, collection_name = self._parse_artifact_path(name) + # If its an Registry project, the entity is considered to be an org instead + if is_artifact_registry_project(project): + org = parse_org_from_registry_path(name, PathType.ARTIFACT) + settings_entity = self.settings["entity"] or self.default_entity + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=org + ) + return Artifacts( + self.client, + entity, + project, + collection_name, + type_name, + per_page=per_page, + tags=tags, + ) + + @normalize_exceptions + def _artifact( + self, name: str, type: str | None = None, enable_tracking: bool = False + ) -> Artifact: + from wandb.sdk.artifacts._validators import ( + FullArtifactPath, + is_artifact_registry_project, + ) + from wandb.sdk.artifacts.artifact import Artifact + + if name is None: + raise ValueError("You must specify name= to fetch an artifact.") + entity, project, artifact_name = self._parse_artifact_path(name) + + # If its an Registry artifact, the entity is an org instead + if is_artifact_registry_project(project): + organization = ( + name.split("/")[0] + if name.count("/") == 2 + else self.settings["organization"] + ) + # set entity to match the settings since in above code it was potentially set to an org + settings_entity = self.settings["entity"] or self.default_entity + # Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path, + # we need to fetch the org entity to for the user behind the scenes. + entity = resolve_org_entity_name( + self.client, non_org_entity=settings_entity, org_or_entity=organization + ) + + if entity is None: + raise ValueError( + "Could not determine entity. Please include the entity as part of the artifact name path." + ) + + path = FullArtifactPath(prefix=entity, project=project, name=artifact_name) + artifact = Artifact._from_name( + path=path, + client=self.client, + enable_tracking=enable_tracking, + ) + if type is not None and artifact.type != type: + raise ValueError( + f"type {type} specified but this artifact is of type {artifact.type}" + ) + return artifact + + @normalize_exceptions + def artifact(self, name: str, type: str | None = None): + """Returns a single artifact. + + Args: + name: The artifact's name. The name of an artifact resembles a + filepath that consists, at a minimum, the name of the project + the artifact was logged to, the name of the artifact, and the + artifact's version or alias. Optionally append the entity that + logged the artifact as a prefix followed by a forward slash. + If no entity is specified in the name, the Run or API + setting's entity is used. + type: The type of artifact to fetch. + + Returns: + An `Artifact` object. + + Raises: + ValueError: If the artifact name is not specified. + ValueError: If the artifact type is specified but does not + match the type of the fetched artifact. + + Examples: + In the proceeding code snippets "entity", "project", "artifact", + "version", and "alias" are placeholders for your W&B entity, name + of the project the artifact is in, the name of the artifact, + and artifact's version, respectively. + + ```python + import wandb + + # Specify the project, artifact's name, and the artifact's alias + wandb.Api().artifact(name="project/artifact:alias") + + # Specify the project, artifact's name, and a specific artifact version + wandb.Api().artifact(name="project/artifact:version") + + # Specify the entity, project, artifact's name, and the artifact's alias + wandb.Api().artifact(name="entity/project/artifact:alias") + + # Specify the entity, project, artifact's name, and a specific artifact version + wandb.Api().artifact(name="entity/project/artifact:version") + ``` + + Note: + This method is intended for external use only. Do not call `api.artifact()` within the wandb repository code. + """ + return self._artifact(name=name, type=type, enable_tracking=True) + + @normalize_exceptions + def job(self, name: str | None, path: str | None = None) -> public.Job: + """Return a `Job` object. + + Args: + name: The name of the job. + path: The root path to download the job artifact. + + Returns: + A `Job` object. + """ + if name is None: + raise ValueError("You must specify name= to fetch a job.") + elif name.count("/") != 2 or ":" not in name: + raise ValueError( + "Invalid job specification. A job must be of the form: //:" + ) + return public.Job(self, name, path) + + @normalize_exceptions + def list_jobs(self, entity: str, project: str) -> list[dict[str, Any]]: + """Return a list of jobs, if any, for the given entity and project. + + Args: + entity: The entity for the listed jobs. + project: The project for the listed jobs. + + Returns: + A list of matching jobs. + """ + import requests + + if entity is None: + raise ValueError("Specify an entity when listing jobs") + if project is None: + raise ValueError("Specify a project when listing jobs") + + query = gql( + """ + query ArtifactOfType( + $entityName: String!, + $projectName: String!, + $artifactTypeName: String!, + ) { + project(name: $projectName, entityName: $entityName) { + artifactType(name: $artifactTypeName) { + artifactCollections { + edges { + node { + artifacts { + edges { + node { + id + state + aliases { + alias + } + artifactSequence { + name + } + } + } + } + } + } + } + } + } + } + """ + ) + + try: + artifact_query = self._client.execute( + query, + { + "projectName": project, + "entityName": entity, + "artifactTypeName": "job", + }, + ) + + if not artifact_query or not artifact_query["project"]: + wandb.termerror( + f"Project: '{project}' not found in entity: '{entity}' or access denied." + ) + return [] + + if artifact_query["project"]["artifactType"] is None: + return [] + + artifacts = artifact_query["project"]["artifactType"][ + "artifactCollections" + ]["edges"] + + return [x["node"]["artifacts"] for x in artifacts] + except requests.exceptions.HTTPError: + return False + + @normalize_exceptions + def artifact_exists(self, name: str, type: str | None = None) -> bool: + """Whether an artifact version exists within the specified project and entity. + + Args: + name: The name of artifact. Add the artifact's entity and project + as a prefix. Append the version or the alias of the artifact + with a colon. If the entity or project is not specified, + W&B uses override parameters if populated. Otherwise, the + entity is pulled from the user settings and the project is + set to "Uncategorized". + type: The type of artifact. + + Returns: + True if the artifact version exists, False otherwise. + + Examples: + In the proceeding code snippets "entity", "project", "artifact", + "version", and "alias" are placeholders for your W&B entity, name of + the project the artifact is in, the name of the artifact, and + artifact's version, respectively. + + ```python + import wandb + + wandb.Api().artifact_exists("entity/project/artifact:version") + wandb.Api().artifact_exists("entity/project/artifact:alias") + ``` + + """ + import requests + + try: + self._artifact(name, type) + except wandb.errors.CommError as e: + if isinstance(e.exc, requests.Timeout): + raise + return False + return True + + @normalize_exceptions + def artifact_collection_exists(self, name: str, type: str) -> bool: + """Whether an artifact collection exists within a specified project and entity. + + Args: + name: An artifact collection name. Optionally append the + entity that logged the artifact as a prefix followed by + a forward slash. If entity or project is not specified, + infer the collection from the override params if they exist. + Otherwise, entity is pulled from the user settings and project + will default to "uncategorized". + type: The type of artifact collection. + + Returns: + True if the artifact collection exists, False otherwise. + + Examples: + In the proceeding code snippet "type", and "collection_name" refer to the type + of the artifact collection and the name of the collection, respectively. + + ```python + import wandb + + wandb.Api.artifact_collection_exists(type="type", name="collection_name") + ``` + """ + import requests + + try: + self.artifact_collection(type, name) + except wandb.errors.CommError as e: + if isinstance(e.exc, requests.Timeout): + raise + return False + return True + + @tracked + def registries( + self, + organization: str | None = None, + filter: dict[str, Any] | None = None, + per_page: int = 100, + ) -> Registries: + """Returns a lazy iterator of `Registry` objects. + + Use the iterator to search and filter registries, collections, + or artifact versions across your organization's registry. + + Args: + organization: (str, optional) The organization of the registry to fetch. + If not specified, use the organization specified in the user's settings. + filter: (dict, optional) MongoDB-style filter to apply to each object in the lazy registry iterator. + Fields available to filter for registries are + `name`, `description`, `created_at`, `updated_at`. + Fields available to filter for collections are + `name`, `tag`, `description`, `created_at`, `updated_at` + Fields available to filter for versions are + `tag`, `alias`, `created_at`, `updated_at`, `metadata` + per_page: Sets the page size for query pagination. + + Returns: + A lazy iterator of `Registry` objects. + + Examples: + Find all registries with the names that contain "model" + + ```python + import wandb + + api = wandb.Api() # specify an org if your entity belongs to multiple orgs + api.registries(filter={"name": {"$regex": "model"}}) + ``` + + Find all collections in the registries with the name "my_collection" and the tag "my_tag" + + ```python + api.registries().collections(filter={"name": "my_collection", "tag": "my_tag"}) + ``` + + Find all artifact versions in the registries with a collection name that contains "my_collection" and a version that has the alias "best" + + ```python + api.registries().collections( + filter={"name": {"$regex": "my_collection"}} + ).versions(filter={"alias": "best"}) + ``` + + Find all artifact versions in the registries that contain "model" and have the tag "prod" or alias "best" + + ```python + api.registries(filter={"name": {"$regex": "model"}}).versions( + filter={"$or": [{"tag": "prod"}, {"alias": "best"}]} + ) + ``` + """ + if not server_supports(self.client, pb.ARTIFACT_REGISTRY_SEARCH): + raise RuntimeError( + "Registry search API is not enabled on this wandb server version. " + "Please upgrade your server version or contact support at support@wandb.com." + ) + + organization = organization or fetch_org_from_settings_or_entity( + self.settings, self.default_entity + ) + return Registries( + self.client, organization=organization, filter=filter, per_page=per_page + ) + + @tracked + def registry(self, name: str, organization: str | None = None) -> Registry: + """Return a registry given a registry name. + + Args: + name: The name of the registry. This is without the `wandb-registry-` + prefix. + organization: The organization of the registry. + If no organization is set in the settings, the organization will be + fetched from the entity if the entity only belongs to one + organization. + + Returns: + A registry object. + + Examples: + Fetch and update a registry + + ```python + import wandb + + api = wandb.Api() + registry = api.registry(name="my-registry", organization="my-org") + registry.description = "This is an updated description" + registry.save() + ``` + """ + if not server_supports(self.client, pb.ARTIFACT_REGISTRY_SEARCH): + raise RuntimeError( + "api.registry() is not enabled on this wandb server version. " + "Please upgrade your server version or contact support at support@wandb.com." + ) + organization = organization or fetch_org_from_settings_or_entity( + self.settings, self.default_entity + ) + org_entity = fetch_org_entity_from_organization(self.client, organization) + registry = Registry(self.client, organization, org_entity, name) + registry.load() + return registry + + @tracked + def create_registry( + self, + name: str, + visibility: Literal["organization", "restricted"], + organization: str | None = None, + description: str | None = None, + artifact_types: list[str] | None = None, + ) -> Registry: + """Create a new registry. + + Args: + name: The name of the registry. Name must be unique within the organization. + visibility: The visibility of the registry. + organization: Anyone in the organization can view this registry. You can + edit their roles later from the settings in the UI. + restricted: Only invited members via the UI can access this registry. + Public sharing is disabled. + organization: The organization of the registry. + If no organization is set in the settings, the organization will be + fetched from the entity if the entity only belongs to one organization. + description: The description of the registry. + artifact_types: The accepted artifact types of the registry. A type is no + more than 128 characters and do not include characters `/` or `:`. If + not specified, all types are accepted. + Allowed types added to the registry cannot be removed later. + + Returns: + A registry object. + + Examples: + ```python + import wandb + + api = wandb.Api() + registry = api.create_registry( + name="my-registry", + visibility="restricted", + organization="my-org", + description="This is a test registry", + artifact_types=["model"], + ) + ``` + """ + if not server_supports( + self.client, pb.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION + ): + raise RuntimeError( + "create_registry api is not enabled on this wandb server version. " + "Please upgrade your server version or contact support at support@wandb.com." + ) + + organization = organization or fetch_org_from_settings_or_entity( + self.settings, self.default_entity + ) + + try: + existing_registry = self.registry(name=name, organization=organization) + except ValueError: + existing_registry = None + if existing_registry: + raise ValueError( + f"Registry {name!r} already exists in organization {organization!r}," + " please use a different name." + ) + + return Registry.create( + self.client, + organization, + name, + visibility, + description, + artifact_types, + ) + + @tracked + def integrations( + self, + entity: str | None = None, + *, + per_page: int = 50, + ) -> Iterator[Integration]: + """Return an iterator of all integrations for an entity. + + Args: + entity: The entity (e.g. team name) for which to + fetch integrations. If not provided, the user's default entity + will be used. + per_page: Number of integrations to fetch per page. + Defaults to 50. Usually there is no reason to change this. + + Yields: + Iterator[SlackIntegration | WebhookIntegration]: An iterator of any supported integrations. + """ + from wandb.apis.public.integrations import Integrations + + variables = {"entity": entity or self.default_entity} + return Integrations(self.client, variables=variables, per_page=per_page) + + @tracked + def webhook_integrations( + self, entity: str | None = None, *, per_page: int = 50 + ) -> Iterator[WebhookIntegration]: + """Returns an iterator of webhook integrations for an entity. + + Args: + entity: The entity (e.g. team name) for which to + fetch integrations. If not provided, the user's default entity + will be used. + per_page: Number of integrations to fetch per page. + Defaults to 50. Usually there is no reason to change this. + + Yields: + Iterator[WebhookIntegration]: An iterator of webhook integrations. + + Examples: + Get all registered webhook integrations for the team "my-team": + + ```python + import wandb + + api = wandb.Api() + webhook_integrations = api.webhook_integrations(entity="my-team") + ``` + + Find only webhook integrations that post requests to "https://my-fake-url.com": + + ```python + webhook_integrations = api.webhook_integrations(entity="my-team") + my_webhooks = [ + ig + for ig in webhook_integrations + if ig.url_endpoint.startswith("https://my-fake-url.com") + ] + ``` + """ + from wandb.apis.public.integrations import WebhookIntegrations + + variables = {"entity": entity or self.default_entity} + return WebhookIntegrations(self.client, variables=variables, per_page=per_page) + + @tracked + def slack_integrations( + self, *, entity: str | None = None, per_page: int = 50 + ) -> Iterator[SlackIntegration]: + """Returns an iterator of Slack integrations for an entity. + + Args: + entity: The entity (e.g. team name) for which to + fetch integrations. If not provided, the user's default entity + will be used. + per_page: Number of integrations to fetch per page. + Defaults to 50. Usually there is no reason to change this. + + Yields: + Iterator[SlackIntegration]: An iterator of Slack integrations. + + Examples: + Get all registered Slack integrations for the team "my-team": + + ```python + import wandb + + api = wandb.Api() + slack_integrations = api.slack_integrations(entity="my-team") + ``` + + Find only Slack integrations that post to channel names starting with "team-alerts-": + + ```python + slack_integrations = api.slack_integrations(entity="my-team") + team_alert_integrations = [ + ig + for ig in slack_integrations + if ig.channel_name.startswith("team-alerts-") + ] + ``` + """ + from wandb.apis.public.integrations import SlackIntegrations + + variables = {"entity": entity or self.default_entity} + return SlackIntegrations(self.client, variables=variables, per_page=per_page) + + def _supports_automation( + self, + *, + event: EventType | None = None, + action: ActionType | None = None, + ) -> bool: + """Returns whether the server recognizes the automation event and/or action.""" + from wandb.automations._utils import ( + ALWAYS_SUPPORTED_ACTIONS, + ALWAYS_SUPPORTED_EVENTS, + ) + + supports_event = ( + (event is None) + or (event in ALWAYS_SUPPORTED_EVENTS) + or server_supports(self.client, f"AUTOMATION_EVENT_{event.value}") + ) + supports_action = ( + (action is None) + or (action in ALWAYS_SUPPORTED_ACTIONS) + or server_supports(self.client, f"AUTOMATION_ACTION_{action.value}") + ) + return supports_event and supports_action + + def _omitted_automation_fragments(self) -> set[str]: + """Returns the names of unsupported automation-related fragments. + + Older servers won't recognize newer GraphQL types, so a valid request may + unnecessarily error out because it won't recognize fragments defined on those types. + + So e.g. if a server does not support `NO_OP` action types, then the following need to be + removed from the body of the GraphQL request: + + - Fragment definition: + ``` + fragment NoOpActionFields on NoOpTriggeredAction { + noOp + } + ``` + + - Fragment spread in selection set: + ``` + { + ...NoOpActionFields + # ... other fields ... + } + ``` + """ + from wandb.automations import ActionType + from wandb.automations._generated import ( + GenericWebhookActionFields, + NoOpActionFields, + NotificationActionFields, + QueueJobActionFields, + ) + + # Note: we can't currently define this as a constant outside the method + # and still keep it nearby in this module, because it relies on pydantic v2-only imports + fragment_names: dict[ActionType, str] = { + ActionType.NO_OP: nameof(NoOpActionFields), + ActionType.QUEUE_JOB: nameof(QueueJobActionFields), + ActionType.NOTIFICATION: nameof(NotificationActionFields), + ActionType.GENERIC_WEBHOOK: nameof(GenericWebhookActionFields), + } + + return set( + name + for action in ActionType + if (not self._supports_automation(action=action)) + and (name := fragment_names.get(action)) + ) + + @tracked + def automation( + self, + name: str, + *, + entity: str | None = None, + ) -> Automation: + """Returns the only Automation matching the parameters. + + Args: + name: The name of the automation to fetch. + entity: The entity to fetch the automation for. + + Raises: + ValueError: If zero or multiple Automations match the search criteria. + + Examples: + Get an existing automation named "my-automation": + + ```python + import wandb + + api = wandb.Api() + automation = api.automation(name="my-automation") + ``` + + Get an existing automation named "other-automation", from the entity "my-team": + + ```python + automation = api.automation(name="other-automation", entity="my-team") + ``` + """ + return one( + self.automations(entity=entity, name=name), + too_short=ValueError("No automations found"), + too_long=ValueError("Multiple automations found"), + ) + + @tracked + def automations( + self, + entity: str | None = None, + *, + name: str | None = None, + per_page: int = 50, + ) -> Iterator[Automation]: + """Returns an iterator over all Automations that match the given parameters. + + If no parameters are provided, the returned iterator will contain all + Automations that the user has access to. + + Args: + entity: The entity to fetch the automations for. + name: The name of the automation to fetch. + per_page: The number of automations to fetch per page. + Defaults to 50. Usually there is no reason to change this. + + Returns: + A list of automations. + + Examples: + Fetch all existing automations for the entity "my-team": + + ```python + import wandb + + api = wandb.Api() + automations = api.automations(entity="my-team") + ``` + """ + from wandb.apis.public.automations import Automations + from wandb.automations._generated import ( + GET_AUTOMATIONS_BY_ENTITY_GQL, + GET_AUTOMATIONS_GQL, + ) + + # For now, we need to use different queries depending on whether entity is given + variables = {"entity": entity} + if entity is None: + gql_str = GET_AUTOMATIONS_GQL # Automations for viewer + else: + gql_str = GET_AUTOMATIONS_BY_ENTITY_GQL # Automations for entity + + # If needed, rewrite the GraphQL field selection set to omit unsupported fields/fragments/types + omit_fragments = self._omitted_automation_fragments() + query = gql_compat(gql_str, omit_fragments=omit_fragments) + iterator = Automations( + self.client, variables=variables, per_page=per_page, _query=query + ) + + # FIXME: this is crude, move this client-side filtering logic into backend + if name is not None: + iterator = filter(lambda x: x.name == name, iterator) + yield from iterator + + @normalize_exceptions + @tracked + def create_automation( + self, + obj: NewAutomation, + *, + fetch_existing: bool = False, + **kwargs: Unpack[WriteAutomationsKwargs], + ) -> Automation: + """Create a new Automation. + + Args: + obj: + The automation to create. + fetch_existing: + If True, and a conflicting automation already exists, attempt + to fetch the existing automation instead of raising an error. + **kwargs: + Any additional values to assign to the automation before + creating it. If given, these will override any values that may + already be set on the automation: + - `name`: The name of the automation. + - `description`: The description of the automation. + - `enabled`: Whether the automation is enabled. + - `scope`: The scope of the automation. + - `event`: The event that triggers the automation. + - `action`: The action that is triggered by the automation. + + Returns: + The saved Automation. + + Examples: + Create a new automation named "my-automation" that sends a Slack notification + when a run within a specific project logs a metric exceeding a custom threshold: + + ```python + import wandb + from wandb.automations import OnRunMetric, RunEvent, SendNotification + + api = wandb.Api() + + project = api.project("my-project", entity="my-team") + + # Use the first Slack integration for the team + slack_hook = next(api.slack_integrations(entity="my-team")) + + event = OnRunMetric( + scope=project, + filter=RunEvent.metric("custom-metric") > 10, + ) + action = SendNotification.from_integration(slack_hook) + + automation = api.create_automation( + event >> action, + name="my-automation", + description="Send a Slack message whenever 'custom-metric' exceeds 10.", + ) + ``` + """ + import requests + + from wandb.automations import Automation + from wandb.automations._generated import CREATE_AUTOMATION_GQL, CreateAutomation + from wandb.automations._utils import prepare_to_create + + gql_input = prepare_to_create(obj, **kwargs) + + if not self._supports_automation( + event=(event := gql_input.triggering_event_type), + action=(action := gql_input.triggered_action_type), + ): + raise ValueError( + f"Automation event or action ({event!r} -> {action!r}) " + "is not supported on this wandb server version. " + "Please upgrade your server version, or contact support at " + "support@wandb.com." + ) + + # If needed, rewrite the GraphQL field selection set to omit unsupported fields/fragments/types + omit_fragments = self._omitted_automation_fragments() + mutation = gql_compat(CREATE_AUTOMATION_GQL, omit_fragments=omit_fragments) + variables = {"input": gql_input.model_dump()} + + name = gql_input.name + try: + data = self.client.execute(mutation, variable_values=variables) + except requests.HTTPError as e: + status = HTTPStatus(e.response.status_code) + if status is HTTPStatus.CONFLICT: # 409 + if fetch_existing: + wandb.termlog(f"Automation {name!r} exists. Fetching it instead.") + return self.automation(name=name) + + raise ValueError( + f"Automation {name!r} exists. Unable to create another with the same name." + ) from None + raise + + try: + result = CreateAutomation.model_validate(data).result + except ValidationError as e: + msg = f"Invalid response while creating automation {name!r}" + raise RuntimeError(msg) from e + + if (result is None) or (result.trigger is None): + msg = f"Empty response while creating automation {name!r}" + raise RuntimeError(msg) + + return Automation.model_validate(result.trigger) + + @normalize_exceptions + @tracked + def update_automation( + self, + obj: Automation, + *, + create_missing: bool = False, + **kwargs: Unpack[WriteAutomationsKwargs], + ) -> Automation: + """Update an existing automation. + + Args: + obj: The automation to update. Must be an existing automation. + create_missing (bool): + If True, and the automation does not exist, create it. + **kwargs: + Any additional values to assign to the automation before + updating it. If given, these will override any values that may + already be set on the automation: + - `name`: The name of the automation. + - `description`: The description of the automation. + - `enabled`: Whether the automation is enabled. + - `scope`: The scope of the automation. + - `event`: The event that triggers the automation. + - `action`: The action that is triggered by the automation. + + Returns: + The updated automation. + + Examples: + Disable and edit the description of an existing automation ("my-automation"): + + ```python + import wandb + + api = wandb.Api() + + automation = api.automation(name="my-automation") + automation.enabled = False + automation.description = "Kept for reference, but no longer used." + + updated_automation = api.update_automation(automation) + ``` + + OR + + ```python + import wandb + + api = wandb.Api() + + automation = api.automation(name="my-automation") + + updated_automation = api.update_automation( + automation, + enabled=False, + description="Kept for reference, but no longer used.", + ) + ``` + """ + import requests + + from wandb.automations import ActionType, Automation + from wandb.automations._generated import UPDATE_AUTOMATION_GQL, UpdateAutomation + from wandb.automations._utils import prepare_to_update + + # Check if the server even supports updating automations. + # + # NOTE: Unfortunately, there is no current server feature flag for this. As a workaround, + # we check whether the server supports the NO_OP action, which is a reasonably safe proxy + # for whether it supports updating automations. + if not self._supports_automation(action=ActionType.NO_OP): + raise RuntimeError( + "Updating existing automations is not enabled on this wandb server version. " + "Please upgrade your server version, or contact support at support@wandb.com." + ) + + gql_input = prepare_to_update(obj, **kwargs) + + if not self._supports_automation( + event=(event := gql_input.triggering_event_type), + action=(action := gql_input.triggered_action_type), + ): + raise ValueError( + f"Automation event or action ({event.value} -> {action.value}) " + "is not supported on this wandb server version. " + "Please upgrade your server version, or contact support at " + "support@wandb.com." + ) + + # If needed, rewrite the GraphQL field selection set to omit unsupported fields/fragments/types + omit_fragments = self._omitted_automation_fragments() + mutation = gql_compat(UPDATE_AUTOMATION_GQL, omit_fragments=omit_fragments) + variables = {"input": gql_input.model_dump()} + + name = gql_input.name + try: + data = self.client.execute(mutation, variable_values=variables) + except requests.HTTPError as e: + status = HTTPStatus(e.response.status_code) + if status is HTTPStatus.NOT_FOUND: # 404 + if create_missing: + wandb.termlog(f"Automation {name!r} not found. Creating it.") + return self.create_automation(obj) + + raise ValueError( + f"Automation {name!r} not found. Unable to edit it." + ) from e + + # Not a (known) recoverable HTTP error + wandb.termerror(f"Got response status {status!r}: {e.response.text!r}") + raise + + try: + result = UpdateAutomation.model_validate(data).result + except ValidationError as e: + msg = f"Invalid response while updating automation {name!r}" + raise RuntimeError(msg) from e + + if (result is None) or (result.trigger is None): + msg = f"Empty response while updating automation {name!r}" + raise RuntimeError(msg) + + return Automation.model_validate(result.trigger) + + @normalize_exceptions + @tracked + def delete_automation(self, obj: Automation | str) -> Literal[True]: + """Delete an automation. + + Args: + obj: The automation to delete, or its ID. + + Returns: + True if the automation was deleted successfully. + """ + from wandb.automations._generated import DELETE_AUTOMATION_GQL, DeleteAutomation + from wandb.automations._utils import extract_id + + id_ = extract_id(obj) + mutation = gql(DELETE_AUTOMATION_GQL) + variables = {"id": id_} + + data = self.client.execute(mutation, variable_values=variables) + + try: + result = DeleteAutomation.model_validate(data).result + except ValidationError as e: + msg = f"Invalid response while deleting automation {id_!r}" + raise RuntimeError(msg) from e + + if result is None: + msg = f"Empty response while deleting automation {id_!r}" + raise RuntimeError(msg) + + if not result.success: + raise RuntimeError(f"Failed to delete automation: {id_!r}") + + return result.success diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/artifacts.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..d4e121124253f31aacf259b26de4b1bc85e0a604 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/artifacts.py @@ -0,0 +1,940 @@ +"""W&B Public API for Artifact objects. + +This module provides classes for interacting with W&B artifacts and their +collections. +""" + +from __future__ import annotations + +import json +from copy import copy +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Collection, + Iterable, + List, + Literal, + Mapping, + Sequence, + TypeVar, +) + +from typing_extensions import override +from wandb_gql import gql + +from wandb._iterutils import always_list +from wandb._pydantic import Connection, ConnectionWithTotal, Edge +from wandb._strutils import nameof +from wandb.apis.normalize import normalize_exceptions +from wandb.apis.paginator import RelayPaginator, SizedRelayPaginator +from wandb.errors.term import termlog +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk.artifacts._models import ArtifactCollectionData +from wandb.sdk.lib.deprecation import warn_and_record_deprecation + +from .files import File +from .utils import gql_compat + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb.apis.public.api import RetryingClient + from wandb.sdk.artifacts._generated import ( + ArtifactAliasFragment, + ArtifactCollectionFragment, + ArtifactFragment, + ArtifactTypeFragment, + FileFragment, + ) + from wandb.sdk.artifacts._models.pagination import ( + ArtifactCollectionConnection, + ArtifactFileConnection, + ArtifactTypeConnection, + ) + from wandb.sdk.artifacts.artifact import Artifact + + from . import Run + + +TNode = TypeVar("TNode") + + +@lru_cache(maxsize=1) +def _run_artifacts_mode_to_gql() -> dict[Literal["logged", "used"], str]: + """Lazily import and cache the run artifact GQL query strings. + + This keeps import-time light and only loads the generated GQL + when RunArtifacts is actually used. + """ + from wandb.sdk.artifacts._generated import ( + RUN_INPUT_ARTIFACTS_GQL, + RUN_OUTPUT_ARTIFACTS_GQL, + ) + + return {"logged": RUN_OUTPUT_ARTIFACTS_GQL, "used": RUN_INPUT_ARTIFACTS_GQL} + + +class _ArtifactCollectionAliases(RelayPaginator["ArtifactAliasFragment", str]): + """An internal iterator of collection alias names. + + + """ + + QUERY: ClassVar[Document | None] = None + last_response: Connection[ArtifactAliasFragment] | None + + def __init__( + self, + client: RetryingClient, + collection_id: str, + per_page: int = 1_000, + ): + if self.QUERY is None: + from wandb.sdk.artifacts._generated import ARTIFACT_COLLECTION_ALIASES_GQL + + type(self).QUERY = gql(ARTIFACT_COLLECTION_ALIASES_GQL) + + variables = {"id": collection_id} + super().__init__(client, variables=variables, per_page=per_page) + + def _update_response(self) -> None: + from wandb.sdk.artifacts._generated import ( + ArtifactAliasFragment, + ArtifactCollectionAliases, + ) + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = ArtifactCollectionAliases.model_validate(data) + + # Extract the inner `*Connection` result for faster/easier access. + if not ((coll := result.artifact_collection) and (conn := coll.aliases)): + raise ValueError(f"Unable to parse {nameof(type(self))!r} response data") + + self.last_response = Connection[ArtifactAliasFragment].model_validate(conn) + + def _convert(self, node: ArtifactAliasFragment) -> str: + return node.alias + + +class ArtifactTypes(RelayPaginator["ArtifactTypeFragment", "ArtifactType"]): + """An lazy iterator of `ArtifactType` objects for a specific project. + + + """ + + QUERY: ClassVar[Document | None] = None + last_response: ArtifactTypeConnection | None + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + per_page: int = 50, + ): + if self.QUERY is None: + from wandb.sdk.artifacts._generated import PROJECT_ARTIFACT_TYPES_GQL + + type(self).QUERY = gql(PROJECT_ARTIFACT_TYPES_GQL) + + self.entity = entity + self.project = project + variables = {"entity": entity, "project": project} + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + """Fetch and validate the response data for the current page.""" + from wandb.sdk.artifacts._generated import ProjectArtifactTypes + from wandb.sdk.artifacts._models.pagination import ArtifactTypeConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = ProjectArtifactTypes.model_validate(data) + + # Extract the inner `*Connection` result for faster/easier access. + if not ((proj := result.project) and (conn := proj.artifact_types)): + raise ValueError(f"Unable to parse {nameof(type(self))!r} response data") + + self.last_response = ArtifactTypeConnection.model_validate(conn) + + def _convert(self, node: ArtifactTypeFragment) -> ArtifactType: + return ArtifactType( + client=self.client, + entity=self.entity, + project=self.project, + type_name=node.name, + attrs=node, + ) + + +class ArtifactType: + """An artifact object that satisfies query based on the specified type. + + Args: + client: The client instance to use for querying W&B. + entity: The entity (user or team) that owns the project. + project: The name of the project to query for artifact types. + type_name: The name of the artifact type. + attrs: Optional attributes to initialize the ArtifactType. + If omitted, the object will load its attributes from W&B upon + initialization. + + + """ + + _attrs: ArtifactTypeFragment + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + type_name: str, + attrs: ArtifactTypeFragment | None = None, + ): + from wandb.sdk.artifacts._generated import ArtifactTypeFragment + + self.client = client + self.entity = entity + self.project = project + self.type = type_name + + # FIXME: Make this lazy, so we don't (re-)fetch the attributes until they are needed + self._attrs = ArtifactTypeFragment.model_validate(attrs or self.load()) + + def load(self) -> ArtifactTypeFragment: + """Load the artifact type attributes from W&B. + + + """ + from wandb.sdk.artifacts._generated import ( + PROJECT_ARTIFACT_TYPE_GQL, + ArtifactTypeFragment, + ProjectArtifactType, + ) + + gql_op = gql(PROJECT_ARTIFACT_TYPE_GQL) + gql_vars = {"entity": self.entity, "project": self.project, "type": self.type} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = ProjectArtifactType.model_validate(data) + if not ((proj := result.project) and (artifact_type := proj.artifact_type)): + raise ValueError(f"Could not find artifact type {self.type!r}") + return ArtifactTypeFragment.model_validate(artifact_type) + + @property + def id(self) -> str: + """The unique identifier of the artifact type.""" + return self._attrs.id + + @property + def name(self) -> str: + """The name of the artifact type.""" + return self._attrs.name + + @normalize_exceptions + def collections(self, per_page: int = 50) -> ArtifactCollections: + """Get all artifact collections associated with this artifact type. + + Args: + per_page (int): The number of artifact collections to fetch per page. + Default is 50. + """ + return ArtifactCollections( + self.client, + entity=self.entity, + project=self.project, + type_name=self.type, + ) + + def collection(self, name: str) -> ArtifactCollection: + """Get a specific artifact collection by name. + + Args: + name (str): The name of the artifact collection to retrieve. + """ + return ArtifactCollection( + self.client, + entity=self.entity, + project=self.project, + name=name, + type=self.type, + ) + + def __repr__(self) -> str: + return f"" + + +class ArtifactCollections( + SizedRelayPaginator["ArtifactCollectionFragment", "ArtifactCollection"] +): + """Artifact collections of a specific type in a project. + + Args: + client: The client instance to use for querying W&B. + entity: The entity (user or team) that owns the project. + project: The name of the project to query for artifact collections. + type_name: The name of the artifact type for which to fetch collections. + per_page: The number of artifact collections to fetch per page. Default is 50. + + + """ + + QUERY: ClassVar[Document | None] = None + last_response: ArtifactCollectionConnection | None + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + type_name: str, + per_page: int = 50, + ): + if self.QUERY is None: + from wandb.sdk.artifacts._generated import PROJECT_ARTIFACT_COLLECTIONS_GQL + + type(self).QUERY = gql(PROJECT_ARTIFACT_COLLECTIONS_GQL) + + self.entity = entity + self.project = project + self.type_name = type_name + variables = {"entity": entity, "project": project, "type": type_name} + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + """Fetch and validate the response data for the current page.""" + from wandb.sdk.artifacts._generated import ProjectArtifactCollections + from wandb.sdk.artifacts._models.pagination import ArtifactCollectionConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = ProjectArtifactCollections.model_validate(data) + + # Extract the inner `*Connection` result for faster/easier access. + if not ( + (proj := result.project) + and (artifact_type := proj.artifact_type) + and (conn := artifact_type.artifact_collections) + ): + raise ValueError(f"Unable to parse {nameof(type(self))!r} response data") + + self.last_response = ArtifactCollectionConnection.model_validate(conn) + + def _convert(self, node: ArtifactCollectionFragment) -> ArtifactCollection | None: + if not node.project: + return None + return ArtifactCollection( + client=self.client, + entity=node.project.entity.name, + project=node.project.name, + name=node.name, + type=node.type.name, + attrs=node, + ) + + +class ArtifactCollection: + """An artifact collection that represents a group of related artifacts. + + Args: + client: The client instance to use for querying W&B. + entity: The entity (user or team) that owns the project. + project: The name of the project to query for artifact collections. + name: The name of the artifact collection. + type: The type of the artifact collection (e.g., "dataset", "model"). + organization: Optional organization name if applicable. + attrs: Optional mapping of attributes to initialize the artifact collection. + If not provided, the object will load its attributes from W&B upon + initialization. + + + """ + + _saved: ArtifactCollectionData + """The saved artifact collection data as last fetched from the W&B server.""" + + _current: ArtifactCollectionData + """The local, editable artifact collection data.""" + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + name: str, + type: str, + organization: str | None = None, + attrs: ArtifactCollectionFragment | None = None, + ): + self.client = client + + # FIXME: Make this lazy, so we don't (re-)fetch the attributes until they are needed + self._update_data(attrs or self.load(entity, project, type, name)) + + self.organization = organization + + def _update_data(self, fragment: ArtifactCollectionFragment) -> None: + """Update the saved/current state of this collection with the given fragment. + + Can be used after receiving a GraphQL response with ArtifactCollection data. + """ + # Separate "saved" vs "current" copies of the artifact collection data + validated = ArtifactCollectionData.from_fragment(fragment) + self._saved = validated + self._current = validated.model_copy(deep=True) + + @property + def id(self) -> str: + """The unique identifier of the artifact collection.""" + return self._current.id + + @property + def entity(self) -> str: + """The entity (user or team) that owns the project.""" + return self._current.entity + + @property + def project(self) -> str: + """The project that contains the artifact collection.""" + return self._current.project + + @normalize_exceptions + def artifacts(self, per_page: int = 50) -> Artifacts: + """Get all artifacts in the collection.""" + return Artifacts( + client=self.client, + entity=self.entity, + project=self.project, + # Use the saved name and type, as they're mutable attributes + # and may have been edited locally. + collection_name=self._saved.name, + type=self._saved.type, + per_page=per_page, + ) + + @property + def aliases(self) -> list[str]: + """The aliases for all artifact versions contained in this collection.""" + if self._saved.aliases is None: + aliases = list( + _ArtifactCollectionAliases(self.client, collection_id=self.id) + ) + self._saved = self._saved.model_copy(update={"aliases": aliases}) + self._current = self._current.model_copy(update={"aliases": aliases}) + + return list(self._saved.aliases) + + @property + def created_at(self) -> str: + """The creation date of the artifact collection.""" + return self._saved.created_at + + def load( + self, entity: str, project: str, type_: str, name: str + ) -> ArtifactCollectionFragment: + """Fetch and return the validated artifact collection data from W&B. + + + """ + from wandb.sdk.artifacts._generated import ( + PROJECT_ARTIFACT_COLLECTION_GQL, + ProjectArtifactCollection, + ) + + gql_op = gql(PROJECT_ARTIFACT_COLLECTION_GQL) + gql_vars = {"entity": entity, "project": project, "type": type_, "name": name} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = ProjectArtifactCollection.model_validate(data) + if not ( + result.project + and (proj := result.project) + and (artifact_type := proj.artifact_type) + and (collection := artifact_type.artifact_collection) + ): + raise ValueError(f"Could not find artifact type {type_!r}") + return collection + + @normalize_exceptions + def change_type(self, new_type: str) -> None: + """Deprecated, change type directly with `save` instead.""" + from wandb.sdk.artifacts._generated import ( + UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL, + MoveArtifactSequenceInput, + ) + from wandb.sdk.artifacts._validators import validate_artifact_type + + warn_and_record_deprecation( + feature=Deprecated(artifact_collection__change_type=True), + message="ArtifactCollection.change_type(type) is deprecated, use ArtifactCollection.save() instead.", + ) + + if (old_type := self._saved.type) != new_type: + try: + validate_artifact_type(old_type, self.name) + except ValueError as e: + raise ValueError( + f"The current type {old_type!r} is an internal type and cannot be changed." + ) from e + + # Check that the new type is not going to conflict with internal types + new_type = validate_artifact_type(new_type, self.name) + + if not self.is_sequence(): + raise ValueError("Artifact collection needs to be a sequence") + + termlog(f"Changing artifact collection type of {old_type!r} to {new_type!r}") + + gql_op = gql(UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL) + gql_input = MoveArtifactSequenceInput( + artifact_sequence_id=self.id, + destination_artifact_type_name=new_type, + ) + self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()}) + self._saved.type = new_type + self._current.type = new_type + + def is_sequence(self) -> bool: + """Return whether the artifact collection is a sequence.""" + return self._saved.is_sequence + + @normalize_exceptions + def delete(self) -> None: + """Delete the entire artifact collection.""" + from wandb.sdk.artifacts._generated import ( + DELETE_ARTIFACT_PORTFOLIO_GQL, + DELETE_ARTIFACT_SEQUENCE_GQL, + ) + + gql_op = gql( + DELETE_ARTIFACT_SEQUENCE_GQL + if self.is_sequence() + else DELETE_ARTIFACT_PORTFOLIO_GQL + ) + self.client.execute(gql_op, variable_values={"id": self.id}) + + @property + def description(self) -> str | None: + """A description of the artifact collection.""" + return self._current.description + + @description.setter + def description(self, description: str | None) -> None: + """Set the description of the artifact collection.""" + self._current.description = description + + @property + def tags(self) -> list[str]: + """The tags associated with the artifact collection.""" + return self._current.tags + + @tags.setter + def tags(self, tags: Collection[str]) -> None: + """Set the tags associated with the artifact collection.""" + self._current.tags = tags + + @property + def name(self) -> str: + """The name of the artifact collection.""" + return self._current.name + + @name.setter + def name(self, name: str) -> None: + """Set the name of the artifact collection.""" + self._current.name = name + + @property + def type(self): + """Returns the type of the artifact collection.""" + return self._current.type + + @type.setter + def type(self, type: str) -> None: + """Set the type of the artifact collection.""" + if not self.is_sequence(): + raise ValueError( + "Type can only be changed if the artifact collection is a sequence." + ) + self._current.type = type + + def _update_collection(self) -> None: + from wandb.sdk.artifacts._generated import ( + UPDATE_ARTIFACT_PORTFOLIO_GQL, + UPDATE_ARTIFACT_SEQUENCE_GQL, + UpdateArtifactPortfolioInput, + UpdateArtifactSequenceInput, + ) + + if self.is_sequence(): + gql_op = gql(UPDATE_ARTIFACT_SEQUENCE_GQL) + gql_input = UpdateArtifactSequenceInput( + artifact_sequence_id=self.id, + name=self.name, + description=self.description, + ) + else: + gql_op = gql(UPDATE_ARTIFACT_PORTFOLIO_GQL) + gql_input = UpdateArtifactPortfolioInput( + artifact_portfolio_id=self.id, + name=self.name, + description=self.description, + ) + self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()}) + self._saved.name = self._current.name + self._saved.description = self._current.description + + def _update_sequence_type(self) -> None: + from wandb.sdk.artifacts._generated import ( + UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL, + MoveArtifactSequenceInput, + ) + + gql_op = gql(UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL) + gql_input = MoveArtifactSequenceInput( + artifact_sequence_id=self.id, + destination_artifact_type_name=self.type, + ) + self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()}) + self._saved.type = self._current.type + + def _add_tags(self, tag_names: Iterable[str]) -> None: + from wandb.sdk.artifacts._generated import ( + ADD_ARTIFACT_COLLECTION_TAGS_GQL, + CreateArtifactCollectionTagAssignmentsInput, + ) + + gql_op = gql(ADD_ARTIFACT_COLLECTION_TAGS_GQL) + gql_input = CreateArtifactCollectionTagAssignmentsInput( + entity_name=self.entity, + project_name=self.project, + artifact_collection_name=self._saved.name, + tags=[{"tagName": tag} for tag in tag_names], + ) + self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()}) + + def _delete_tags(self, tag_names: Iterable[str]) -> None: + from wandb.sdk.artifacts._generated import ( + DELETE_ARTIFACT_COLLECTION_TAGS_GQL, + DeleteArtifactCollectionTagAssignmentsInput, + ) + + gql_op = gql(DELETE_ARTIFACT_COLLECTION_TAGS_GQL) + gql_input = DeleteArtifactCollectionTagAssignmentsInput( + entity_name=self.entity, + project_name=self.project, + artifact_collection_name=self._saved.name, + tags=[{"tagName": tag} for tag in tag_names], + ) + self.client.execute(gql_op, variable_values={"input": gql_input.model_dump()}) + + @normalize_exceptions + def save(self) -> None: + """Persist any changes made to the artifact collection.""" + from wandb.sdk.artifacts._validators import validate_artifact_type + + if (old_type := self._saved.type) != (new_type := self.type): + try: + validate_artifact_type(new_type, self.name) + except ValueError as e: + reason = str(e) + raise ValueError( + f"Failed to save artifact collection {self.name!r}: {reason}" + ) from e + try: + validate_artifact_type(old_type, self.name) + except ValueError as e: + reason = f"The current type {old_type!r} is an internal type and cannot be changed." + raise ValueError( + f"Failed to save artifact collection {self.name!r}: {reason}" + ) from e + + # FIXME: Consider consolidating the multiple GQL mutations into a single call. + self._update_collection() + + if self.is_sequence() and (old_type != new_type): + self._update_sequence_type() + + if (new_tags := set(self._current.tags)) != (old_tags := set(self._saved.tags)): + if added_tags := (new_tags - old_tags): + self._add_tags(added_tags) + if deleted_tags := (old_tags - new_tags): + self._delete_tags(deleted_tags) + self._saved.tags = copy(new_tags) + + def __repr__(self) -> str: + return f"" + + +class _ArtifactEdgeGeneric(Edge[TNode]): + version: str # Extra field defined only on VersionedArtifactEdge + + +class _ArtifactConnectionGeneric(ConnectionWithTotal[TNode]): + edges: List[_ArtifactEdgeGeneric] # noqa: UP006 + + +class Artifacts(SizedRelayPaginator["ArtifactFragment", "Artifact"]): + """An iterable collection of artifact versions associated with a project. + + Optionally pass in filters to narrow down the results based on specific criteria. + + Args: + client: The client instance to use for querying W&B. + entity: The entity (user or team) that owns the project. + project: The name of the project to query for artifacts. + collection_name: The name of the artifact collection to query. + type: The type of the artifacts to query. Common examples include + "dataset" or "model". + filters: Optional mapping of filters to apply to the query. + order: Optional string to specify the order of the results. + per_page: The number of artifact versions to fetch per page. Default is 50. + tags: Optional string or list of strings to filter artifacts by tags. + + + """ + + QUERY: Document # Must be set per-instance + + # Loosely-annotated to avoid importing heavy types at module import time. + last_response: _ArtifactConnectionGeneric | None + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + collection_name: str, + type: str, + filters: Mapping[str, Any] | None = None, + order: str | None = None, + per_page: int = 50, + tags: str | list[str] | None = None, + ): + from wandb.sdk.artifacts._generated import PROJECT_ARTIFACTS_GQL + + self.QUERY = gql(PROJECT_ARTIFACTS_GQL) + + self.entity = entity + self.collection_name = collection_name + self.type = type + self.project = project + self.filters = {"state": "COMMITTED"} if filters is None else filters + self.tags = always_list(tags or []) + self.order = order + variables = { + "entity": self.entity, + "project": self.project, + "order": self.order, + "type": self.type, + "collection": self.collection_name, + "filters": json.dumps(self.filters), + } + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + from wandb.sdk.artifacts._generated import ArtifactFragment, ProjectArtifacts + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = ProjectArtifacts.model_validate(data) + + # Extract the inner `*Connection` result for faster/easier access. + if not ( + (proj := result.project) + and (type_ := proj.artifact_type) + and (collection := type_.artifact_collection) + and (conn := collection.artifacts) + ): + raise ValueError(f"Unable to parse {nameof(type(self))!r} response data") + + self.last_response = _ArtifactConnectionGeneric[ + ArtifactFragment + ].model_validate(conn) + + # FIXME: For now, we deliberately override the signatures of: + # - `_convert()` + # - `convert_objects()` + # ... since the prior implementation must get `version` from the GQL edge + # (i.e. `edge.version`), which lives outside of the GQL node (`edge.node`). + # + # In the future, we should move to fetching artifacts via (GQL) artifactMemberships, + # not (GQL) artifacts, so we don't have to deal with this hack. + @override + def _convert(self, edge: _ArtifactEdgeGeneric[ArtifactFragment]) -> Artifact: + from wandb.sdk.artifacts._validators import FullArtifactPath + from wandb.sdk.artifacts.artifact import Artifact + + return Artifact._from_attrs( + path=FullArtifactPath( + prefix=self.entity, + project=self.project, + name=f"{self.collection_name}:{edge.version}", + ), + src_art=edge.node, + client=self.client, + ) + + @override + def convert_objects(self) -> list[Artifact]: + """Convert the raw response data into a list of wandb.Artifact objects. + + + """ + if (conn := self.last_response) is None: + return [] + artifacts = (self._convert(edge) for edge in conn.edges if edge.node) + required_tags = set(self.tags or []) + return [art for art in artifacts if required_tags.issubset(art.tags)] + + +class RunArtifacts(SizedRelayPaginator["ArtifactFragment", "Artifact"]): + """An iterable collection of artifacts associated with a specific run. + + + """ + + QUERY: Document # Must be set per-instance + last_response: ConnectionWithTotal[ArtifactFragment] | None + + def __init__( + self, + client: RetryingClient, + run: Run, + mode: Literal["logged", "used"] = "logged", + per_page: int = 50, + ): + try: + query_str = _run_artifacts_mode_to_gql()[mode] + except LookupError: + raise ValueError("mode must be logged or used") + else: + self.QUERY = gql(query_str) + + self.run = run + variables = {"entity": run.entity, "project": run.project, "run": run.id} + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + from wandb.sdk.artifacts._models.pagination import RunArtifactConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + + # Extract the inner `*Connection` result for faster/easier access. + inner_data = data["project"]["run"]["artifacts"] + self.last_response = RunArtifactConnection.model_validate(inner_data) + + def _convert(self, node: ArtifactFragment) -> Artifact | None: + from wandb.sdk.artifacts._validators import FullArtifactPath + from wandb.sdk.artifacts.artifact import Artifact + + if node.artifact_sequence.project is None: + return None + return Artifact._from_attrs( + path=FullArtifactPath( + prefix=node.artifact_sequence.project.entity.name, + project=node.artifact_sequence.project.name, + name=f"{node.artifact_sequence.name}:v{node.version_index}", + ), + src_art=node, + client=self.client, + ) + + +class ArtifactFiles(SizedRelayPaginator["FileFragment", "File"]): + """A paginator for files in an artifact. + + + """ + + QUERY: Document # Must be set per-instance + last_response: ArtifactFileConnection | None + + def __init__( + self, + client: RetryingClient, + artifact: Artifact, + names: Sequence[str] | None = None, + per_page: int = 50, + ): + from wandb.sdk.artifacts._generated import ( + GET_ARTIFACT_FILES_GQL, + GET_ARTIFACT_MEMBERSHIP_FILES_GQL, + ) + from wandb.sdk.artifacts._gqlutils import server_supports + + self.query_via_membership = server_supports( + client, pb.ARTIFACT_COLLECTION_MEMBERSHIP_FILES + ) + self.artifact = artifact + + if self.query_via_membership: + query_str = GET_ARTIFACT_MEMBERSHIP_FILES_GQL + variables = { + "entity": artifact.entity, + "project": artifact.project, + "collection": artifact.name.split(":")[0], + "alias": artifact.version, + "fileNames": names, + } + else: + query_str = GET_ARTIFACT_FILES_GQL + variables = { + "entity": artifact.source_entity, + "project": artifact.source_project, + "name": artifact.source_name, + "type": artifact.type, + "fileNames": names, + } + + omit_fields = ( + None + if server_supports(client, pb.TOTAL_COUNT_IN_FILE_CONNECTION) + else {"totalCount"} + ) + self.QUERY = gql_compat(query_str, omit_fields=omit_fields) + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + from wandb.sdk.artifacts._generated import ( + GetArtifactFiles, + GetArtifactMembershipFiles, + ) + from wandb.sdk.artifacts._models.pagination import ArtifactFileConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + + # Extract the inner `*Connection` result for faster/easier access. + if self.query_via_membership: + result = GetArtifactMembershipFiles.model_validate(data) + conn = result.project.artifact_collection.artifact_membership.files + else: + result = GetArtifactFiles.model_validate(data) + conn = result.project.artifact_type.artifact.files + + if conn is None: + raise ValueError(f"Unable to parse {nameof(type(self))!r} response data") + + self.last_response = ArtifactFileConnection.model_validate(conn) + + @property + def path(self) -> list[str]: + """Returns the path of the artifact.""" + return [self.artifact.entity, self.artifact.project, self.artifact.name] + + def _convert(self, node: FileFragment) -> File: + return File(self.client, attrs=node.model_dump(exclude_unset=True)) + + def __repr__(self) -> str: + path_str = "/".join(self.path) + try: + total = len(self) + except NotImplementedError: + # Older server versions don't correctly support totalCount + return f"" + else: + return f"" diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/automations.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/automations.py new file mode 100644 index 0000000000000000000000000000000000000000..7a310a5aabcd8fb6cccaf7dcc0820e90fa86d813 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/automations.py @@ -0,0 +1,64 @@ +"""W&B Public API for Automation objects.""" + +from __future__ import annotations + +from itertools import chain +from typing import TYPE_CHECKING, Any, Iterator, Mapping + +from pydantic import ValidationError +from typing_extensions import override + +from wandb.apis.paginator import RelayPaginator + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb._pydantic import Connection + from wandb.apis.public.api import RetryingClient + from wandb.automations import Automation + from wandb.automations._generated import ProjectTriggersFields + + +class Automations(RelayPaginator["ProjectTriggersFields", "Automation"]): + """A lazy iterator of `Automation` objects. + + + """ + + QUERY: Document # Must be set per-instance + last_response: Connection[ProjectTriggersFields] | None + + def __init__( + self, + client: RetryingClient, + variables: Mapping[str, Any], + per_page: int = 50, + *, + _query: Document, # internal use only, but required + ): + self.QUERY = _query + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + """Fetch the raw response data for the current page.""" + from wandb._pydantic import Connection + from wandb.automations._generated import ProjectTriggersFields + + data = self.client.execute(self.QUERY, variable_values=self.variables) + try: + conn_data = data["scope"]["projects"] + conn = Connection[ProjectTriggersFields].model_validate(conn_data) + self.last_response = conn + except (LookupError, AttributeError, ValidationError) as e: + raise ValueError("Unexpected response data") from e + + @override + def _convert(self, node: ProjectTriggersFields) -> Iterator[Automation]: + from wandb.automations import Automation + + return (Automation.model_validate(obj) for obj in node.triggers) + + @override + def convert_objects(self) -> Iterator[Automation]: + return chain.from_iterable(super().convert_objects()) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/const.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/const.py new file mode 100644 index 0000000000000000000000000000000000000000..e75e387177d8b2c65d468ded64f4080b5b67a5bd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/const.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +import datetime + +# Only retry requests for 20 seconds in the public api +RETRY_TIMEDELTA = datetime.timedelta(seconds=20) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/files.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/files.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9b6152c65bea25f01a6dacc890ce53ee4a3982 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/files.py @@ -0,0 +1,435 @@ +"""W&B Public API for File objects. + +This module provides classes for interacting with files stored in W&B. + +Example: +```python +from wandb.apis.public import Api + +# Get files from a specific run +run = Api().run("entity/project/run_id") +files = run.files() + +# Work with files +for file in files: + print(f"File: {file.name}") + print(f"Size: {file.size} bytes") + print(f"Type: {file.mimetype}") + + # Download file + if file.size < 1000000: # Less than 1MB + file.download(root="./downloads") + + # Get S3 URI for large files + if file.size >= 1000000: + print(f"S3 URI: {file.path_uri}") +``` + +Note: + This module is part of the W&B Public API and provides methods to access, + download, and manage files stored in W&B. Files are typically associated + with specific runs and can include model weights, datasets, visualizations, + and other artifacts. +""" + +from __future__ import annotations + +import io +import os +from typing import TYPE_CHECKING, Any, Callable + +from wandb_gql import gql +from wandb_gql.client import RetryError + +import wandb +from wandb._strutils import nameof +from wandb.apis.attrs import Attrs +from wandb.apis.normalize import normalize_exceptions +from wandb.apis.paginator import SizedPaginator +from wandb.apis.public import utils +from wandb.apis.public.const import RETRY_TIMEDELTA +from wandb.apis.public.runs import Run, _server_provides_internal_id_for_project +from wandb.sdk.lib import retry +from wandb.util import POW_2_BYTES, download_file_from_url, no_retry_auth, to_human_size + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb.apis.public import Api, RetryingClient + +FILE_FRAGMENT = """fragment RunFilesFragment on Run { + files(names: $fileNames, after: $fileCursor, first: $fileLimit, pattern: $pattern) { + edges { + node { + id + name + url(upload: $upload) + directUrl + sizeBytes + mimetype + updatedAt + md5 + } + cursor + } + pageInfo { + endCursor + hasNextPage + } + } +}""" + + +class Files(SizedPaginator["File"]): + """A lazy iterator over a collection of `File` objects. + + Access and manage files uploaded to W&B during a run. Handles pagination + automatically when iterating through large collections of files. + + Example: + ```python + from wandb.apis.public.files import Files + from wandb.apis.public.api import Api + + # Example run object + run = Api().run("entity/project/run-id") + + # Create a Files object to iterate over files in the run + files = Files(api.client, run) + + # Iterate over files + for file in files: + print(file.name) + print(file.url) + print(file.size) + + # Download the file + file.download(root="download_directory", replace=True) + ``` + """ + + def _get_query(self) -> Document: + """Generate query dynamically based on server capabilities.""" + with_internal_id = _server_provides_internal_id_for_project(self.client) + return gql( + f""" + query RunFiles($project: String!, $entity: String!, $name: String!, $fileCursor: String, + $fileLimit: Int = 50, $fileNames: [String] = [], $upload: Boolean = false, $pattern: String) {{ + project(name: $project, entityName: $entity) {{ + {"internalId" if with_internal_id else ""} + run(name: $name) {{ + fileCount + ...RunFilesFragment + }} + }} + }} + {FILE_FRAGMENT} + """ + ) + + def __init__( + self, + client: RetryingClient, + run: Run, + names: list[str] | None = None, + per_page: int = 50, + upload: bool = False, + pattern: str | None = None, + ): + """Initialize a lazy iterator over a collection of `File` objects. + + Files are retrieved in pages from the W&B server as needed. + + Args: + client: The run object that contains the files + run: The run object that contains the files + names (list, optional): A list of file names to filter the files + per_page (int, optional): The number of files to fetch per page + upload (bool, optional): If `True`, fetch the upload URL for each file + pattern (str, optional): Pattern to match when returning files from W&B + This pattern uses mySQL's LIKE syntax, + so matching all files that end with .json would be "%.json". + If both names and pattern are provided, a ValueError will be raised. + """ + if names and pattern: + raise ValueError( + "Querying for files by both names and pattern is not supported." + " Please provide either a list of names or a pattern to match.", + ) + + self.run = run + variables = { + "project": run.project, + "entity": run.entity, + "name": run.id, + "fileNames": names or [], + "upload": upload, + "pattern": pattern, + } + super().__init__(client, variables, per_page) + + def _update_response(self) -> None: + """Fetch and store the response data for the next page using dynamic query.""" + self.last_response = self.client.execute( + self._get_query(), variable_values=self.variables + ) + + @property + def _length(self) -> int: + """ + Returns total number of files. + + + """ + if not self.last_response: + self._load_page() + + return self.last_response["project"]["run"]["fileCount"] + + @property + def more(self) -> bool: + """Returns whether there are more files to fetch. + + + """ + if self.last_response: + return self.last_response["project"]["run"]["files"]["pageInfo"][ + "hasNextPage" + ] + else: + return True + + @property + def cursor(self) -> str | None: + """Returns the cursor position for pagination of file results. + + + """ + if self.last_response: + return self.last_response["project"]["run"]["files"]["edges"][-1]["cursor"] + else: + return None + + def update_variables(self) -> None: + """Updates the GraphQL query variables for pagination. + + + """ + self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor}) + + def convert_objects(self) -> list[File]: + """Converts GraphQL edges to File objects. + + + """ + return [ + File(self.client, r["node"], self.run) + for r in self.last_response["project"]["run"]["files"]["edges"] + ] + + def __repr__(self) -> str: + return f"<{nameof(type(self))} {'/'.join(self.run.path)} ({len(self)})>" + + +class File(Attrs): + """File saved to W&B. + + Represents a single file stored in W&B. Includes access to file metadata. + Files are associated with a specific run and + can include text files, model weights, datasets, visualizations, and other + artifacts. You can download the file, delete the file, and access file + properties. + + Specify one or more attributes in a dictionary to fine a specific + file logged to a specific run. You can search using the following keys: + + - id (str): The ID of the run that contains the file + - name (str): Name of the file + - url (str): path to file + - direct_url (str): path to file in the bucket + - sizeBytes (int): size of file in bytes + - md5 (str): md5 of file + - mimetype (str): mimetype of file + - updated_at (str): timestamp of last update + - path_uri (str): path to file in the bucket, currently only available for S3 objects and reference files + + Args: + client: The run object that contains the file + attrs (dict): A dictionary of attributes that define the file + run: The run object that contains the file + + + """ + + def __init__( + self, + client: RetryingClient, + attrs: dict[str, Any], + run: Run | None = None, + ): + self.client = client + self._attrs = attrs + self.run = run + self.server_supports_delete_file_with_project_id: bool | None = None + self._download_decorated: Callable[..., Any] | None = None + super().__init__(dict(attrs)) + + @property + def size(self) -> int: + """Returns the size of the file in bytes.""" + size_bytes = self._attrs["sizeBytes"] + if size_bytes is not None: + return int(size_bytes) + return 0 + + @property + def path_uri(self) -> str: + """Returns the URI path to the file in the storage bucket. + + Returns: + str: The S3 URI (e.g., 's3://bucket/path/to/file') if the file is stored in S3, + the direct URL if it's a reference file, or an empty string if unavailable. + """ + if not (direct_url := self._attrs.get("directUrl")): + wandb.termwarn("Unable to find direct_url of file") + return "" + + # For reference files, both the directUrl and the url are just the path to the file in the bucket + if direct_url == self._attrs.get("url"): + return direct_url + + try: + return utils.parse_s3_url_to_s3_uri(direct_url) + except ValueError: + wandb.termwarn("path_uri is only available for files stored in S3") + return "" + + def _build_download_wrapper(self) -> Callable[..., io.TextIOWrapper]: + import requests + + @retry.retriable( + retry_timedelta=RETRY_TIMEDELTA, + check_retry_fn=no_retry_auth, + retryable_exceptions=(RetryError, requests.RequestException), + ) + def _impl( + root: str = ".", + replace: bool = False, + exist_ok: bool = False, + api: Api | None = None, + ) -> io.TextIOWrapper: + if api is None: + api = wandb.Api() + + path = os.path.join(root, self.name) + if os.path.exists(path) and not replace: + if exist_ok: + return open(path) + raise ValueError( + "File already exists, pass replace=True to overwrite " + "or exist_ok=True to leave it as is and don't error." + ) + + download_file_from_url(path, self.url, api.api_key) + return open(path) + + return _impl + + @normalize_exceptions + def download( + self, + root: str = ".", + replace: bool = False, + exist_ok: bool = False, + api: Api | None = None, + ) -> io.TextIOWrapper: + """Downloads a file previously saved by a run from the wandb server. + + Args: + root: Local directory to save the file. Defaults to the + current working directory ("."). + replace: If `True`, download will overwrite a local file + if it exists. Defaults to `False`. + exist_ok: If `True`, will not raise ValueError if file already + exists and will not re-download unless replace=True. + Defaults to `False`. + api: If specified, the `Api` instance used to download the file. + + Raises: + `ValueError` if file already exists, `replace=False` and + `exist_ok=False`. + """ + if self._download_decorated is None: + self._download_decorated = self._build_download_wrapper() + return self._download_decorated(root, replace, exist_ok, api) + + @normalize_exceptions + def delete(self) -> None: + """Delete the file from the W&B server.""" + project_id_mutation_fragment = "" + project_id_variable_fragment = "" + variable_values = { + "files": [self.id], + } + + # Add projectId to mutation and variables if the server supports it. + # Otherwise, do not include projectId in mutation for older server versions which do not support it. + if self._server_accepts_project_id_for_delete_file(): + variable_values["projectId"] = self.run._project_internal_id + project_id_variable_fragment = ", $projectId: Int" + project_id_mutation_fragment = "projectId: $projectId" + + mutation_string = """ + mutation deleteFiles($files: [ID!]!{}) {{ + deleteFiles(input: {{ + files: $files + {} + }}) {{ + success + }} + }} + """.format(project_id_variable_fragment, project_id_mutation_fragment) + mutation = gql(mutation_string) + + self.client.execute( + mutation, + variable_values=variable_values, + ) + + def __repr__(self) -> str: + classname = nameof(type(self)) + size = to_human_size(self.size, units=POW_2_BYTES) + return f"<{classname} {self.name} ({self.mimetype}) {size}>" + + @normalize_exceptions + def _server_accepts_project_id_for_delete_file(self) -> bool: + """Returns True if the server supports deleting files with a projectId. + + This check is done by utilizing GraphQL introspection in the available fields on the DeleteFiles API. + """ + query_string = """ + query ProbeDeleteFilesProjectIdInput { + DeleteFilesProjectIdInputType: __type(name:"DeleteFilesInput") { + inputFields{ + name + } + } + } + """ + + # Only perform the query once to avoid extra network calls + if self.server_supports_delete_file_with_project_id is None: + query = gql(query_string) + res = self.client.execute(query) + + # If projectId is in the inputFields, the server supports deleting files with a projectId + self.server_supports_delete_file_with_project_id = "projectId" in [ + x["name"] + for x in ( + res.get("DeleteFilesProjectIdInputType", {}).get( + "inputFields", [{}] + ) + ) + ] + + return self.server_supports_delete_file_with_project_id diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/history.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/history.py new file mode 100644 index 0000000000000000000000000000000000000000..69208cbc55aea99e9e6c7548756a43a703d9b441 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/history.py @@ -0,0 +1,336 @@ +"""W&B Public API for Run History. + +This module provides classes for efficiently scanning and sampling run +history data. + +Note: + This module is part of the W&B Public API and provides methods + to access run history data. It handles pagination automatically and offers + both complete and sampled access to metrics logged during training runs. +""" + +from __future__ import annotations + +import contextlib +import json +import weakref +from typing import TYPE_CHECKING, Any, Dict, Iterator + +from typing_extensions import Self, TypeAlias +from wandb_gql import gql + +from wandb.apis.normalize import normalize_exceptions +from wandb.proto import wandb_api_pb2 as pb +from wandb.sdk.mailbox.mailbox import MailboxClosedError + +if TYPE_CHECKING: + from . import runs + from .api import Api, RetryingClient + +_RowDict: TypeAlias = Dict[str, Any] +"""Type alias for a single history row as a dict.""" + + +class BetaHistoryScan(Iterator[_RowDict]): + """Iterator for scanning complete run history. + + + """ + + def __init__( + self, + api: Api, + run: runs.Run, + min_step: int, + max_step: int, + keys: list[str] | None = None, + page_size: int = 1_000, + use_cache: bool = True, + ): + self.run = run + self.min_step = min_step + self.max_step = max_step + self.keys = keys + self.page_size = page_size + self._api = api + + # Tell wandb-core to initialize resources to scan the run's history. + scan_run_history_init = pb.ScanRunHistoryInit( + entity=self.run.entity, + project=self.run.project, + run_id=self.run.id, + keys=self.keys, + use_cache=use_cache, + ) + scan_run_history_init_request = pb.ReadRunHistoryRequest( + scan_run_history_init=scan_run_history_init + ) + api_request = pb.ApiRequest( + read_run_history_request=scan_run_history_init_request + ) + response: pb.ApiResponse = self._api._send_api_request(api_request) + + self._scan_request_id = ( + response.read_run_history_response.scan_run_history_init.request_id + ) + + self.scan_offset = 0 + self.rows: list[_RowDict] = [] + self.keys = keys + + # Add cleanup hook to clean up resources in wandb-core + # when this scan object is deleted. + # + # Using weakref.finalize ensures that references to objects needed during cleanup + # are not garbage collected before being used. + # see: https://docs.python.org/3/library/weakref.html#comparing-finalizers-with-del-methods + weakref.finalize( + self, + self.cleanup, + self._api, + self._scan_request_id, + ) + + def __iter__(self) -> Self: + self.scan_offset = 0 + self.page_offset = self.min_step + self.rows = [] + return self + + def __next__(self) -> _RowDict: + while True: + if self.scan_offset < len(self.rows): + row = self.rows[self.scan_offset] + self.scan_offset += 1 + return row + if self.page_offset >= self.max_step: + raise StopIteration() + self._load_next() + + def _load_next(self) -> None: + from wandb.proto import wandb_api_pb2 as pb + + max_step = min(self.page_offset + self.page_size, self.max_step) + + read_run_history_request = pb.ReadRunHistoryRequest( + scan_run_history=pb.ScanRunHistory( + min_step=self.page_offset, + max_step=max_step, + request_id=self._scan_request_id, + ), + ) + api_request = pb.ApiRequest(read_run_history_request=read_run_history_request) + + response: pb.ApiResponse = self._api._send_api_request(api_request) + run_history: pb.RunHistoryResponse = ( + response.read_run_history_response.run_history + ) + self.rows = [ + self._convert_history_row_to_dict(row) for row in run_history.history_rows + ] + self.page_offset += self.page_size + self.scan_offset = 0 + + @staticmethod + def _convert_history_row_to_dict(history_row: pb.HistoryRow) -> _RowDict: + return { + item.key: json.loads(item.value_json) for item in history_row.history_items + } + + @staticmethod + def cleanup(api: Api, request_id: int) -> None: + scan_run_history_cleanup = pb.ScanRunHistoryCleanup( + request_id=request_id, + ) + scan_run_history_cleanup_request = pb.ReadRunHistoryRequest( + scan_run_history_cleanup=scan_run_history_cleanup + ) + + with contextlib.suppress(ConnectionResetError, MailboxClosedError): + api._send_api_request( + pb.ApiRequest(read_run_history_request=scan_run_history_cleanup_request) + ) + + +class HistoryScan(Iterator[_RowDict]): + """Iterator for scanning complete run history. + + + """ + + QUERY = gql( + """ + query HistoryPage($entity: String!, $project: String!, $run: String!, $minStep: Int64!, $maxStep: Int64!, $pageSize: Int!) { + project(name: $project, entityName: $entity) { + run(name: $run) { + history(minStep: $minStep, maxStep: $maxStep, samples: $pageSize) + } + } + } + """ + ) + + def __init__( + self, + client: RetryingClient, + run: runs.Run, + min_step: int, + max_step: int, + page_size: int = 1_000, + ): + """Initialize a HistoryScan instance. + + Args: + client: The client instance to use for making API calls to the W&B backend. + run: The run object whose history is to be scanned. + min_step: The minimum step to start scanning from. + max_step: The maximum step to scan up to. + page_size: Number of history rows to fetch per page. + Default page_size is 1000. + """ + self.client = client + self.run = run + self.page_size = page_size + self.min_step = min_step + self.max_step = max_step + self.page_offset = min_step # minStep for next page + self.scan_offset = 0 # index within current page of rows + self.rows: list[_RowDict] = [] # current page of rows + + def __iter__(self) -> Self: + self.page_offset = self.min_step + self.scan_offset = 0 + self.rows = [] + return self + + def __next__(self) -> _RowDict: + """Return the next row of history data with automatic pagination. + + + """ + while True: + if self.scan_offset < len(self.rows): + row = self.rows[self.scan_offset] + self.scan_offset += 1 + return row + if self.page_offset >= self.max_step: + raise StopIteration() + self._load_next() + + next = __next__ + + @normalize_exceptions + def _load_next(self) -> None: + max_step = self.page_offset + self.page_size + if max_step > self.max_step: + max_step = self.max_step + variables = { + "entity": self.run.entity, + "project": self.run.project, + "run": self.run.id, + "minStep": int(self.page_offset), + "maxStep": int(max_step), + "pageSize": int(self.page_size), + } + + res = self.client.execute(self.QUERY, variable_values=variables) + res = res["project"]["run"]["history"] + self.rows = [json.loads(row) for row in res] + self.page_offset += self.page_size + self.scan_offset = 0 + + +class SampledHistoryScan(Iterator[_RowDict]): + """Iterator for sampling run history data. + + + """ + + QUERY = gql( + """ + query SampledHistoryPage($entity: String!, $project: String!, $run: String!, $spec: JSONString!) { + project(name: $project, entityName: $entity) { + run(name: $run) { + sampledHistory(specs: [$spec]) + } + } + } + """ + ) + + def __init__( + self, + client: RetryingClient, + run: runs.Run, + keys: list[str], + min_step: int, + max_step: int, + page_size: int = 1_000, + ): + """Initialize a SampledHistoryScan instance. + + Args: + client: The client instance to use for making API calls to the W&B backend. + run: The run object whose history is to be sampled. + keys: List of keys to sample from the history. + min_step: The minimum step to start sampling from. + max_step: The maximum step to sample up to. + page_size: Number of sampled history rows to fetch per page. + Default page_size is 1000. + """ + self.client = client + self.run = run + self.keys = keys + self.page_size = page_size + self.min_step = min_step + self.max_step = max_step + self.page_offset = min_step # minStep for next page + self.scan_offset = 0 # index within current page of rows + self.rows: list[_RowDict] = [] # current page of rows + + def __iter__(self) -> Self: + self.page_offset = self.min_step + self.scan_offset = 0 + self.rows = [] + return self + + def __next__(self) -> _RowDict: + """Return the next row of sampled history data with automatic pagination. + + + """ + while True: + if self.scan_offset < len(self.rows): + row = self.rows[self.scan_offset] + self.scan_offset += 1 + return row + if self.page_offset >= self.max_step: + raise StopIteration() + self._load_next() + + next = __next__ + + @normalize_exceptions + def _load_next(self) -> None: + max_step = self.page_offset + self.page_size + if max_step > self.max_step: + max_step = self.max_step + variables = { + "entity": self.run.entity, + "project": self.run.project, + "run": self.run.id, + "spec": json.dumps( + { + "keys": self.keys, + "minStep": int(self.page_offset), + "maxStep": int(max_step), + "samples": int(self.page_size), + } + ), + } + + res = self.client.execute(self.QUERY, variable_values=variables) + res = res["project"]["run"]["sampledHistory"] + self.rows = res[0] + self.page_offset += self.page_size + self.scan_offset = 0 diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/integrations.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/integrations.py new file mode 100644 index 0000000000000000000000000000000000000000..70d54150cb7e9cddf84f7bbf479631136a523cbc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/integrations.py @@ -0,0 +1,90 @@ +"""W&B Public API for integrations. + +This module provides classes for interacting with W&B integrations. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Union + +from typing_extensions import override +from wandb_gql import gql + +from wandb.apis.paginator import RelayPaginator + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb._pydantic import Connection + from wandb.apis.public.api import RetryingClient + from wandb.automations import Integration, SlackIntegration, WebhookIntegration + from wandb.automations._generated import ( + SlackIntegrationFields, + WebhookIntegrationFields, + ) + + IntegrationFields = Union[SlackIntegrationFields, WebhookIntegrationFields] + + +class Integrations(RelayPaginator["IntegrationFields", "Integration"]): + """A lazy iterator of `Integration` objects. + + + """ + + QUERY: ClassVar[Document | None] = None + last_response: Connection[IntegrationFields] | None + + def __init__( + self, + client: RetryingClient, + variables: dict[str, Any], + per_page: int = 50, + ): + if self.QUERY is None: + from wandb.automations._generated import INTEGRATIONS_BY_ENTITY_GQL + + type(self).QUERY = gql(INTEGRATIONS_BY_ENTITY_GQL) + + super().__init__(client, variables=variables, per_page=per_page) + + @override + def _update_response(self) -> None: + """Fetch and parse the response data for the current page.""" + from wandb._pydantic import Connection + from wandb.automations._generated import IntegrationsByEntity + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = IntegrationsByEntity.model_validate(data) + if not ((entity := result.entity) and (conn := entity.integrations)): + raise ValueError("Unexpected response data") + self.last_response = Connection.model_validate(conn) + + def _convert(self, node: IntegrationFields) -> Integration: + from wandb.automations.integrations import IntegrationAdapter + + return IntegrationAdapter.validate_python(node) + + +# The paginators below filter on `typename__` since the GQL response still +# includes all `Integration` types. Applying a `@skip/@include` directive +# does not change this. Restricting results to a single type requires +# a client-side filter. +class WebhookIntegrations(Integrations): + """A lazy iterator of `WebhookIntegration` objects. + + + """ + + def _convert(self, node: IntegrationFields) -> WebhookIntegration: + return node if (node.typename__ == "GenericWebhookIntegration") else None + + +class SlackIntegrations(Integrations): + """A lazy iterator of `SlackIntegration` objects. + + + """ + + def _convert(self, node: IntegrationFields) -> SlackIntegration: + return node if (node.typename__ == "SlackIntegration") else None diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/jobs.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..c2349d3dfafcf2c6919a826a7fc9fc10c1eec1b7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/jobs.py @@ -0,0 +1,748 @@ +"""W&B Public API for management Launch Jobs and Launch Queues. + +This module provides classes for managing W&B jobs, queued runs, and run +queues. +""" + +from __future__ import annotations + +import json +import os +import shutil +import time +from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping + +from wandb_gql import gql + +import wandb +from wandb import util +from wandb.apis import public +from wandb.apis.normalize import normalize_exceptions +from wandb.errors import CommError +from wandb.sdk.data_types._dtypes import InvalidType, Type, TypeRegistry +from wandb.sdk.launch.errors import LaunchError +from wandb.sdk.launch.utils import ( + LAUNCH_DEFAULT_PROJECT, + _fetch_git_repo, + apply_patch, + convert_jupyter_notebook_to_script, +) + +if TYPE_CHECKING: + from wandb.apis.public import Api, RetryingClient + from wandb.sdk.launch._project_spec import LaunchProject + + +class Job: + _name: str + _input_types: Type + _output_types: Type + _entity: str + _project: str + _entrypoint: list[str] + _notebook_job: bool + _partial: bool + + def __init__(self, api: Api, name, path: str | None = None) -> None: + try: + self._job_artifact = api._artifact(name, type="job") + except CommError: + raise CommError(f"Job artifact {name} not found") + if path: + self._fpath = path + self._job_artifact.download(root=path) + else: + self._fpath = self._job_artifact.download() + self._name = name + self._api = api + self._entity = api.default_entity + + with open(os.path.join(self._fpath, "wandb-job.json")) as f: + self._job_info: Mapping[str, Any] = json.load(f) + source_info = self._job_info.get("source", {}) + # only use notebook job if entrypoint not set and notebook is set + self._notebook_job = source_info.get("notebook", False) + self._entrypoint = source_info.get("entrypoint") + self._dockerfile = source_info.get("dockerfile") + self._build_context = source_info.get("build_context") + self._base_image = source_info.get("base_image") + self._args = source_info.get("args") + self._partial = self._job_info.get("_partial", False) + self._requirements_file = os.path.join(self._fpath, "requirements.frozen.txt") + self._input_types = TypeRegistry.type_from_dict( + self._job_info.get("input_types") + ) + self._output_types = TypeRegistry.type_from_dict( + self._job_info.get("output_types") + ) + if self._job_info.get("source_type") == "artifact": + self._set_configure_launch_project(self._configure_launch_project_artifact) + if self._job_info.get("source_type") == "repo": + self._set_configure_launch_project(self._configure_launch_project_repo) + if self._job_info.get("source_type") == "image": + self._set_configure_launch_project(self._configure_launch_project_container) + + @property + def name(self) -> str: + """The name of the job.""" + return self._name + + def _set_configure_launch_project(self, func: Callable[[LaunchProject], None]): + self.configure_launch_project = func + + def _get_code_artifact(self, artifact_string): + from wandb.sdk.artifacts.artifact_state import ArtifactState + + artifact_string, base_url, is_id = util.parse_artifact_string(artifact_string) + if is_id: + code_artifact = wandb.Artifact._from_id(artifact_string, self._api._client) + else: + code_artifact = self._api._artifact(name=artifact_string, type="code") + if code_artifact is None: + raise LaunchError("No code artifact found") + if code_artifact.state == ArtifactState.DELETED: + raise LaunchError( + f"Job {self.name} references deleted code artifact {code_artifact.name}" + ) + return code_artifact + + def _configure_launch_project_notebook(self, launch_project: LaunchProject) -> None: + new_fname = convert_jupyter_notebook_to_script( + self._entrypoint[-1], launch_project.project_dir + ) + new_entrypoint = self._entrypoint + new_entrypoint[-1] = new_fname + launch_project.set_job_entry_point(new_entrypoint) + + def _configure_launch_project_repo(self, launch_project: LaunchProject) -> None: + git_info = self._job_info.get("source", {}).get("git", {}) + _fetch_git_repo( + launch_project.project_dir, + git_info["remote"], + git_info["commit"], + ) + if os.path.exists(os.path.join(self._fpath, "diff.patch")): + with open(os.path.join(self._fpath, "diff.patch")) as f: + apply_patch(f.read(), launch_project.project_dir) + shutil.copy(self._requirements_file, launch_project.project_dir) + launch_project.python_version = self._job_info.get("runtime") + if self._notebook_job: + self._configure_launch_project_notebook(launch_project) + else: + launch_project.set_job_entry_point(self._entrypoint) + + if self._dockerfile: + launch_project.set_job_dockerfile(self._dockerfile) + if self._build_context: + launch_project.set_job_build_context(self._build_context) + if self._base_image: + launch_project.set_job_base_image(self._base_image) + + def _configure_launch_project_artifact(self, launch_project: LaunchProject) -> None: + artifact_string = self._job_info.get("source", {}).get("artifact") + if artifact_string is None: + raise LaunchError(f"Job {self.name} had no source artifact") + + code_artifact = self._get_code_artifact(artifact_string) + launch_project.python_version = self._job_info.get("runtime") + shutil.copy(self._requirements_file, launch_project.project_dir) + + code_artifact.download(launch_project.project_dir) + + if self._notebook_job: + self._configure_launch_project_notebook(launch_project) + else: + launch_project.set_job_entry_point(self._entrypoint) + + if self._dockerfile: + launch_project.set_job_dockerfile(self._dockerfile) + if self._build_context: + launch_project.set_job_build_context(self._build_context) + if self._base_image: + launch_project.set_job_base_image(self._base_image) + + def _configure_launch_project_container( + self, launch_project: LaunchProject + ) -> None: + launch_project.docker_image = self._job_info.get("source", {}).get("image") + if launch_project.docker_image is None: + raise LaunchError( + "Job had malformed source dictionary without an image key" + ) + if self._entrypoint: + launch_project.set_job_entry_point(self._entrypoint) + + def set_entrypoint(self, entrypoint: list[str]) -> None: + """Set the entrypoint for the job.""" + self._entrypoint = entrypoint + + def call( + self, + config, + project=None, + entity=None, + queue=None, + resource="local-container", + resource_args=None, + template_variables=None, + project_queue=None, + priority=None, + ): + """Call the job with the given configuration. + + Args: + config (dict): The configuration to pass to the job. + This should be a dictionary containing key-value pairs that + match the input types defined in the job. + project (str, optional): The project to log the run to. Defaults + to the job's project. + entity (str, optional): The entity to log the run under. Defaults + to the job's entity. + queue (str, optional): The name of the queue to enqueue the job to. + Defaults to None. + resource (str, optional): The resource type to use for execution. + Defaults to "local-container". + resource_args (dict, optional): Additional arguments for the + resource type. Defaults to None. + template_variables (dict, optional): Template variables to use for + the job. Defaults to None. + project_queue (str, optional): The project that manages the queue. + Defaults to None. + priority (int, optional): The priority of the queued run. + Defaults to None. + """ + from wandb.sdk.launch import _launch_add + + run_config = {} + for key, item in config.items(): + if util._is_artifact_object(item): + if isinstance(item, wandb.Artifact) and item.is_draft(): + raise ValueError("Cannot queue jobs with unlogged artifacts") + run_config[key] = util.artifact_to_json(item) + + run_config.update(config) + + assigned_config_type = self._input_types.assign(run_config) + if self._partial: + wandb.termwarn( + "Launching manually created job for the first time, can't verify types" + ) + else: + if isinstance(assigned_config_type, InvalidType): + raise TypeError(self._input_types.explain(run_config)) + + queued_run = _launch_add.launch_add( + job=self._name, + config={"overrides": {"run_config": run_config}}, + template_variables=template_variables, + project=project or self._project, + entity=entity or self._entity, + queue_name=queue, + resource=resource, + project_queue=project_queue, + resource_args=resource_args, + priority=priority, + ) + return queued_run + + +class QueuedRun: + """A single queued run associated with an entity and project. + + Args: + entity: The entity associated with the queued run. + project (str): The project where runs executed by the queue are logged to. + queue_name (str): The name of the queue. + run_queue_item_id (int): The id of the run queue item. + project_queue (str): The project that manages the queue. + priority (str): The priority of the queued run. + + Call `run = queued_run.wait_until_running()` or + `run = queued_run.wait_until_finished()` to access the run. + """ + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + queue_name: str, + run_queue_item_id: str, + project_queue: str = LAUNCH_DEFAULT_PROJECT, + priority: int | None = None, + ): + self.client = client + self._entity = entity + self._project = project + self._queue_name = queue_name + self._run_queue_item_id = run_queue_item_id + self.sweep = None + self._run: public.Run | None = None + self.project_queue = project_queue + self.priority = priority + + @property + def queue_name(self) -> str: + """The name of the queue.""" + return self._queue_name + + @property + def id(self) -> str: + """The id of the queued run.""" + return self._run_queue_item_id + + @property + def project(self) -> str: + """The project associated with the queued run.""" + return self._project + + @property + def entity(self) -> str: + """The entity associated with the queued run.""" + return self._entity + + @property + def state(self) -> str: + """The state of the queued run.""" + item = self._get_item() + if item: + return item["state"].lower() + + raise ValueError( + f"Could not find QueuedRunItem associated with id: {self.id} on queue {self.queue_name} at itemId: {self.id}" + ) + + @normalize_exceptions + def _get_run_queue_item_legacy(self) -> dict[str, Any]: + query = gql( + """ + query GetRunQueueItem($projectName: String!, $entityName: String!, $runQueue: String!) { + project(name: $projectName, entityName: $entityName) { + runQueue(name:$runQueue) { + runQueueItems { + edges { + node { + id + state + associatedRunId + } + } + } + } + } + } + """ + ) + variable_values = { + "projectName": self.project_queue, + "entityName": self._entity, + "runQueue": self.queue_name, + } + res = self.client.execute(query, variable_values) + + for item in res["project"]["runQueue"]["runQueueItems"]["edges"]: + if str(item["node"]["id"]) == str(self.id): + return item["node"] + + @normalize_exceptions + def _get_item(self) -> dict[str, Any]: + query = gql( + """ + query GetRunQueueItem($projectName: String!, $entityName: String!, $runQueue: String!, $itemId: ID!) { + project(name: $projectName, entityName: $entityName) { + runQueue(name: $runQueue) { + runQueueItem(id: $itemId) { + id + state + associatedRunId + } + } + } + } + """ + ) + variable_values = { + "projectName": self.project_queue, + "entityName": self._entity, + "runQueue": self.queue_name, + "itemId": self.id, + } + try: + res = self.client.execute(query, variable_values) # exception w/ old server + if res["project"]["runQueue"].get("runQueueItem") is not None: + return res["project"]["runQueue"]["runQueueItem"] + except Exception as e: + if "Cannot query field" not in str(e): + raise LaunchError(f"Unknown exception: {e}") + + return self._get_run_queue_item_legacy() + + @normalize_exceptions + def wait_until_finished(self) -> public.Run: + """Wait for the queued run to complete and return the finished run.""" + if not self._run: + self.wait_until_running() + + self._run.wait_until_finished() + # refetch run to get updated summary + self._run.load(force=True) + return self._run + + @normalize_exceptions + def delete(self, delete_artifacts: bool = False) -> None: + """Delete the given queued run from the wandb backend.""" + query = gql( + """ + query fetchRunQueuesFromProject($entityName: String!, $projectName: String!, $runQueueName: String!) { + project(name: $projectName, entityName: $entityName) { + runQueue(name: $runQueueName) { + id + } + } + } + """ + ) + + res = self.client.execute( + query, + variable_values={ + "entityName": self.entity, + "projectName": self.project_queue, + "runQueueName": self.queue_name, + }, + ) + + if res["project"].get("runQueue") is not None: + queue_id = res["project"]["runQueue"]["id"] + + mutation = gql( + """ + mutation DeleteFromRunQueue( + $queueID: ID!, + $runQueueItemId: ID! + ) { + deleteFromRunQueue(input: { + queueID: $queueID + runQueueItemId: $runQueueItemId + }) { + success + clientMutationId + } + } + """ + ) + self.client.execute( + mutation, + variable_values={ + "queueID": queue_id, + "runQueueItemId": self._run_queue_item_id, + }, + ) + + @normalize_exceptions + def wait_until_running(self) -> public.Run: + """Wait until the queued run is running and return the run.""" + if self._run is not None: + return self._run + + while True: + # sleep here to hide an ugly warning + time.sleep(2) + item = self._get_item() + if item and item["associatedRunId"] is not None: + try: + self._run = public.Run( + self.client, + self._entity, + self.project, + item["associatedRunId"], + None, + ) + self._run_id = item["associatedRunId"] + except ValueError as e: + wandb.termwarn(str(e)) + else: + return self._run + elif item: + wandb.termlog("Waiting for run to start") + + time.sleep(3) + + def __repr__(self) -> str: + return f" None: + self._name: str = name + self._client = client + self._entity = entity + self._prioritization_mode = prioritization_mode + self._access = _access + self._default_resource_config_id = _default_resource_config_id + self._default_resource_config = _default_resource_config + self._template_variables = None + self._type = None + self._items: list[QueuedRun] | None = None + self._id: str | None = None + + @property + def name(self) -> str: + """The name of the queue.""" + return self._name + + @property + def entity(self) -> str: + """The entity that owns the queue.""" + return self._entity + + @property + def prioritization_mode(self) -> RunQueuePrioritizationMode: + """The prioritization mode of the queue. + + Can be set to "DISABLED" or "V0". + """ + if self._prioritization_mode is None: + self._get_metadata() + return self._prioritization_mode + + @property + def access(self) -> RunQueueAccessType: + """The access level of the queue.""" + if self._access is None: + self._get_metadata() + return self._access + + @property + def external_links(self) -> dict[str, str]: + """External resource links for the queue.""" + if self._external_links is None: + self._get_metadata() + return self._external_links + + @property + def type(self) -> RunQueueResourceType: + """The resource type for execution.""" + if self._type is None: + if self._default_resource_config_id is None: + self._get_metadata() + self._get_default_resource_config() + return self._type + + @property + def default_resource_config(self) -> dict[str, Any]: + """The default configuration for resources.""" + if self._default_resource_config is None: + if self._default_resource_config_id is None: + self._get_metadata() + self._get_default_resource_config() + return self._default_resource_config + + @property + def template_variables(self) -> dict[str, Any]: + """Variables for resource templates.""" + if self._template_variables is None: + if self._default_resource_config_id is None: + self._get_metadata() + self._get_default_resource_config() + return self._template_variables + + @property + def id(self) -> str: + """The id of the queue.""" + if self._id is None: + self._get_metadata() + return self._id + + @property + def items(self) -> list[QueuedRun]: + """Up to the first 100 queued runs. Modifying this list will not modify the queue or any enqueued items!""" + # TODO(np): Add a paginated interface + if self._items is None: + self._get_items() + return self._items + + @normalize_exceptions + def delete(self) -> None: + """Delete the run queue from the wandb backend.""" + query = gql( + """ + mutation DeleteRunQueue($id: ID!) { + deleteRunQueues(input: {queueIDs: [$id]}) { + success + clientMutationId + } + } + """ + ) + variable_values = {"id": self.id} + res = self._client.execute(query, variable_values) + if res["deleteRunQueues"]["success"]: + self._id = None + self._access = None + self._default_resource_config_id = None + self._default_resource_config = None + self._items = None + else: + raise CommError(f"Failed to delete run queue {self.name}") + + def __repr__(self) -> str: + return f"" + + @normalize_exceptions + def _get_metadata(self) -> None: + query = gql( + """ + query GetRunQueueMetadata($projectName: String!, $entityName: String!, $runQueue: String!) { + project(name: $projectName, entityName: $entityName) { + runQueue(name: $runQueue) { + id + access + defaultResourceConfigID + prioritizationMode + externalLinks + } + } + } + """ + ) + variable_values = { + "projectName": LAUNCH_DEFAULT_PROJECT, + "entityName": self._entity, + "runQueue": self._name, + } + res = self._client.execute(query, variable_values) + self._id = res["project"]["runQueue"]["id"] + self._access = res["project"]["runQueue"]["access"] + self._default_resource_config_id = res["project"]["runQueue"][ + "defaultResourceConfigID" + ] + self._external_links = res["project"]["runQueue"]["externalLinks"] + if self._default_resource_config_id is None: + self._default_resource_config = {} + self._prioritization_mode = res["project"]["runQueue"]["prioritizationMode"] + + @normalize_exceptions + def _get_default_resource_config(self) -> None: + query = gql( + """ + query GetDefaultResourceConfig($entityName: String!, $id: ID!) { + entity(name: $entityName) { + defaultResourceConfig(id: $id) { + config + resource + templateVariables { + name + schema + } + } + } + } + """ + ) + variable_values = { + "entityName": self._entity, + "id": self._default_resource_config_id, + } + res = self._client.execute(query, variable_values) + self._type = res["entity"]["defaultResourceConfig"]["resource"] + self._default_resource_config = res["entity"]["defaultResourceConfig"]["config"] + self._template_variables = res["entity"]["defaultResourceConfig"][ + "templateVariables" + ] + + @normalize_exceptions + def _get_items(self) -> None: + query = gql( + """ + query GetRunQueueItems($projectName: String!, $entityName: String!, $runQueue: String!) { + project(name: $projectName, entityName: $entityName) { + runQueue(name: $runQueue) { + runQueueItems(first: 100) { + edges { + node { + id + } + } + } + } + } + } + """ + ) + variable_values = { + "projectName": LAUNCH_DEFAULT_PROJECT, + "entityName": self._entity, + "runQueue": self._name, + } + res = self._client.execute(query, variable_values) + self._items = [] + for item in res["project"]["runQueue"]["runQueueItems"]["edges"]: + self._items.append( + QueuedRun( + self._client, + self._entity, + LAUNCH_DEFAULT_PROJECT, + self._name, + item["node"]["id"], + ) + ) + + @classmethod + def create( + cls, + name: str, + resource: RunQueueResourceType, + entity: str | None = None, + prioritization_mode: RunQueuePrioritizationMode | None = None, + config: dict | None = None, + template_variables: dict | None = None, + ) -> RunQueue: + """Create a RunQueue. + + Args: + name: The name of the run queue to create. + resource: The resource type for execution. + entity: The entity (user or team) that will own the queue. + Defaults to the default entity of the API client. + prioritization_mode: The prioritization mode for the queue. + Can be "DISABLED" or "V0". Defaults to None. + config: Optional dictionary for the default resource + configuration. Defaults to None. + template_variables: Optional dictionary for template variables + used in the resource configuration. + """ + public_api = Api() + return public_api.create_run_queue( + name, resource, entity, prioritization_mode, config, template_variables + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/projects.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/projects.py new file mode 100644 index 0000000000000000000000000000000000000000..98998fec86d0c583d95ec9a8e55d8a658c2a4140 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/projects.py @@ -0,0 +1,243 @@ +"""W&B Public API for Project objects. + +This module provides classes for interacting with W&B projects and their +associated data. + +Example: +```python +from wandb.apis.public import Api + +# Get all projects for an entity +projects = Api().projects("entity") + +# Access project data +for project in projects: + print(f"Project: {project.name}") + print(f"URL: {project.url}") + + # Get artifact types + for artifact_type in project.artifacts_types(): + print(f"Artifact Type: {artifact_type.name}") + + # Get sweeps + for sweep in project.sweeps(): + print(f"Sweep ID: {sweep.id}") + print(f"State: {sweep.state}") +``` + +Note: + This module is part of the W&B Public API and provides methods to access + and manage projects. For creating new projects, use wandb.init() + with a new project name. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Mapping + +from typing_extensions import override +from wandb_gql import gql + +from wandb._strutils import nameof +from wandb.apis import public +from wandb.apis.attrs import Attrs +from wandb.apis.normalize import normalize_exceptions +from wandb.apis.paginator import RelayPaginator +from wandb.apis.public.api import RetryingClient +from wandb.apis.public.sweeps import Sweeps +from wandb.sdk.lib import ipython + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb._pydantic import Connection + from wandb.apis._generated import ProjectFragment + + +class Projects(RelayPaginator["ProjectFragment", "Project"]): + """An lazy iterator of `Project` objects. + + An iterable interface to access projects created and saved by the entity. + + Args: + client (`wandb.apis.internal.Api`): The API client instance to use. + entity (str): The entity name (username or team) to fetch projects for. + per_page (int): Number of projects to fetch per request (default is 50). + + Example: + ```python + from wandb.apis.public.api import Api + + # Find projects that belong to this entity + projects = Api().projects(entity="entity") + + # Iterate over files + for project in projects: + print(f"Project: {project.name}") + print(f"- URL: {project.url}") + print(f"- Created at: {project.created_at}") + print(f"- Is benchmark: {project.is_benchmark}") + ``` + """ + + QUERY: ClassVar[Document | None] = None + last_response: Connection[ProjectFragment] | None + + def __init__( + self, + client: RetryingClient, + entity: str, + per_page: int = 50, + ) -> Projects: + """An iterable collection of `Project` objects. + + Args: + client: The API client used to query W&B. + entity: The entity which owns the projects. + per_page: The number of projects to fetch per request to the API. + """ + if self.QUERY is None: + from wandb.apis._generated import GET_PROJECTS_GQL + + type(self).QUERY = gql(GET_PROJECTS_GQL) + + self.entity = entity + super().__init__(client, variables={"entity": entity}, per_page=per_page) + + @override + def _update_response(self) -> None: + """Fetch and validate the response data for the current page.""" + from wandb._pydantic import Connection + from wandb.apis._generated import GetProjects, ProjectFragment + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = GetProjects.model_validate(data) + if not (conn := result.models): + raise ValueError(f"Unable to parse {nameof(type(self))!r} response data") + self.last_response = Connection[ProjectFragment].model_validate(conn) + + @property + def length(self) -> None: + """Returns the total number of projects. + + Note: This property is not available for projects. + + + """ + # For backwards compatibility, even though this isn't a SizedPaginator + return None + + def _convert(self, node: ProjectFragment) -> Project: + return Project(self.client, self.entity, node.name, node.model_dump()) + + def __repr__(self): + return f"" + + +class Project(Attrs): + """A project is a namespace for runs. + + Args: + client: W&B API client instance. + name (str): The name of the project. + entity (str): The entity name that owns the project. + """ + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + attrs: Mapping[str, Any], + ) -> Project: + """A single project associated with an entity. + + Args: + client: The API client used to query W&B. + entity: The entity which owns the project. + project: The name of the project to query. + attrs: The attributes of the project. + """ + super().__init__(attrs) + self._is_loaded = bool(attrs) + self.client = client + self.name = project + self.entity = entity + + def _load(self) -> None: + from requests import HTTPError + + from wandb.apis._generated import GET_PROJECT_GQL, GetProject + + gql_vars = {"name": self.name, "entity": self.entity} + try: + data = self.client.execute(gql(GET_PROJECT_GQL), gql_vars) + except HTTPError as e: + raise ValueError(f"Unable to fetch project ID: {gql_vars!r}") from e + + project = GetProject.model_validate(data).project + self._attrs = project.model_dump() if project else {} + self._is_loaded = True + + @property + def path(self) -> list[str]: + """Returns the path of the project. The path is a list containing the + entity and project name.""" + return [self.entity, self.name] + + @property + def url(self) -> str: + """Returns the URL of the project.""" + return self.client.app_url + "/".join(self.path + ["workspace"]) + + def to_html(self, height: int = 420, hidden: bool = False) -> str: + """Generate HTML containing an iframe displaying this project. + + + """ + url = self.url + "?jupyter=true" + style = f"border:none;width:100%;height:{height}px;" + prefix = "" + if hidden: + style += "display:none;" + prefix = ipython.toggle_button("project") + return prefix + f"" + + def _repr_html_(self) -> str: + return self.to_html() + + def __repr__(self): + return "".format("/".join(self.path)) + + @normalize_exceptions + def artifacts_types(self, per_page: int = 50) -> public.ArtifactTypes: + """Returns all artifact types associated with this project.""" + return public.ArtifactTypes(self.client, self.entity, self.name) + + @normalize_exceptions + def sweeps(self, per_page: int = 50) -> Sweeps: + """Return a paginated collection of sweeps in this project. + + Args: + per_page: The number of sweeps to fetch per request to the API. + + Returns: + A `Sweeps` object, which is an iterable collection of `Sweep` objects. + """ + return Sweeps(self.client, self.entity, self.name, per_page=per_page) + + @property + def id(self) -> str: + if not self._is_loaded: + self._load() + + if "id" not in self._attrs: + raise ValueError(f"Project {self.name} not found") + + return self._attrs["id"] + + @override + def __getattr__(self, name: str) -> Any: + if not self._is_loaded: + self._load() + return super().__getattr__(name) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/query_generator.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/query_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..def8ce18b74d9b452864bc5f25bbfeec147070ce --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/query_generator.py @@ -0,0 +1,179 @@ +from __future__ import annotations + + +class QueryGenerator: + """QueryGenerator is a helper object to write filters for runs. + + + """ + + INDIVIDUAL_OP_TO_MONGO = { + "!=": "$ne", + ">": "$gt", + ">=": "$gte", + "<": "$lt", + "<=": "$lte", + "IN": "$in", + "NIN": "$nin", + "REGEX": "$regex", + } + MONGO_TO_INDIVIDUAL_OP = {v: k for k, v in INDIVIDUAL_OP_TO_MONGO.items()} + + GROUP_OP_TO_MONGO = {"AND": "$and", "OR": "$or"} + MONGO_TO_GROUP_OP = {v: k for k, v in GROUP_OP_TO_MONGO.items()} + + def __init__(self): + pass + + @classmethod + def format_order_key(cls, key: str): + """Format a key for sorting.""" + if key.startswith(("+", "-")): + direction = key[0] + key = key[1:] + else: + direction = "-" + parts = key.split(".") + if len(parts) == 1: + # Assume the user meant summary_metrics if not a run column + if parts[0] not in ["createdAt", "updatedAt", "name", "sweep"]: + return direction + "summary_metrics." + parts[0] + # Assume summary metrics if prefix isn't known + elif parts[0] not in ["config", "summary_metrics", "tags"]: + return direction + ".".join(["summary_metrics"] + parts) + else: + return direction + ".".join(parts) + + def _is_group(self, op): + return op.get("filters") is not None + + def _is_individual(self, op): + return op.get("key") is not None + + def _to_mongo_op_value(self, op, value): + if op == "=": + return value + else: + return {self.INDIVIDUAL_OP_TO_MONGO[op]: value} + + def key_to_server_path(self, key): + """Convert a key dictionary to the corresponding server path string.""" + if key["section"] == "config": + return "config." + key["name"] + elif key["section"] == "summary": + return "summary_metrics." + key["name"] + elif key["section"] == "keys_info": + return "keys_info.keys." + key["name"] + elif key["section"] == "run": + return key["name"] + elif key["section"] == "tags": + return "tags." + key["name"] + raise ValueError("Invalid key: {}".format(key)) + + def server_path_to_key(self, path): + """Convert a server path string to the corresponding key dictionary.""" + if path.startswith("config."): + return {"section": "config", "name": path.split("config.", 1)[1]} + elif path.startswith("summary_metrics."): + return {"section": "summary", "name": path.split("summary_metrics.", 1)[1]} + elif path.startswith("keys_info.keys."): + return {"section": "keys_info", "name": path.split("keys_info.keys.", 1)[1]} + elif path.startswith("tags."): + return {"section": "tags", "name": path.split("tags.", 1)[1]} + else: + return {"section": "run", "name": path} + + def keys_to_order(self, keys): + """Convert a list of key dictionaries to an order string.""" + orders = [] + for key in keys["keys"]: + order = self.key_to_server_path(key["key"]) + if key.get("ascending"): + order = "+" + order + else: + order = "-" + order + orders.append(order) + # return ",".join(orders) + return orders + + def order_to_keys(self, order): + """Convert an order string to a list of key dictionaries.""" + keys = [] + for k in order: # orderstr.split(","): + name = k[1:] + if k[0] == "+": + ascending = True + elif k[0] == "-": + ascending = False + else: + raise Exception("you must sort by ascending(+) or descending(-)") + + key = {"key": {"section": "run", "name": name}, "ascending": ascending} + keys.append(key) + + return {"keys": keys} + + def _to_mongo_individual(self, filter): + if filter["key"]["name"] == "": + return None + + if filter.get("value") is None and filter["op"] != "=" and filter["op"] != "!=": + return None + + if filter.get("disabled") is not None and filter["disabled"]: + return None + + if filter["key"]["section"] == "tags": + if filter["op"] == "IN": + return {"tags": {"$in": filter["value"]}} + if filter["value"] is False: + return { + "$or": [{"tags": None}, {"tags": {"$ne": filter["key"]["name"]}}] + } + else: + return {"tags": filter["key"]["name"]} + path = self.key_to_server_path(filter["key"]) + if path is None: + return path + return {path: self._to_mongo_op_value(filter["op"], filter["value"])} + + def filter_to_mongo(self, filter): + """Returns dictionary with filter format converted to MongoDB filter.""" + if self._is_individual(filter): + return self._to_mongo_individual(filter) + elif self._is_group(filter): + return { + self.GROUP_OP_TO_MONGO[filter["op"]]: [ + self.filter_to_mongo(f) for f in filter["filters"] + ] + } + + def mongo_to_filter(self, filter): + """Returns dictionary with MongoDB filter converted to filter format.""" + # Returns {"op": "OR", "filters": [{"op": "AND", "filters": []}]} + if filter is None: + return None # this covers the case where self.filter_to_mongo returns None. + + group_op = None + for key in filter.keys(): + # if self.MONGO_TO_GROUP_OP[key]: + if key in self.MONGO_TO_GROUP_OP: + group_op = key + break + if group_op is not None: + return { + "op": self.MONGO_TO_GROUP_OP[group_op], + "filters": [self.mongo_to_filter(f) for f in filter[group_op]], + } + else: + for k, v in filter.items(): + if isinstance(v, dict): + # TODO: do we always have one key in this case? + op = next(iter(v.keys())) + return { + "key": self.server_path_to_key(k), + "op": self.MONGO_TO_INDIVIDUAL_OP[op], + "value": v[op], + } + else: + return {"key": self.server_path_to_key(k), "op": "=", "value": v} diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba451b425be56a42905463a35c0cbe8c16b56fe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "Registry", # doc:exclude + "Registries", # doc:exclude +] + +from .registries_search import Registries +from .registry import Registry diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6949ff10ae62a980d70203b99a7ab008afa22608 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_freezable_list.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_freezable_list.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c48844d332183c40d3002bca28877e1699899de5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_freezable_list.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_members.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_members.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6527950bfa1eb90d64c972c4a017720466216a5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_members.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_utils.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b1d5c199ba16597fda894eb680a21ad25cc8b10 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/_utils.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/registries_search.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/registries_search.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7781c7a66f50b4ad362f615844919a7c06c3ad4 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/registries_search.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/registry.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/registry.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e456700dc649dcd1119c0eabfe77280325a3083 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/__pycache__/registry.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_freezable_list.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_freezable_list.py new file mode 100644 index 0000000000000000000000000000000000000000..bff670158c6133dfcc066b7ff551bed5999924cb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_freezable_list.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +from itertools import chain +from typing import ( + Any, + Iterable, + Iterator, + MutableSequence, + Sequence, + TypeVar, + final, + overload, +) + +from wandb._strutils import nameof + +T = TypeVar("T") + + +@final +class FreezableList(MutableSequence[T]): + """A list-like container type that only allows adding new items. + + It tracks "saved" (immutable) and "draft" (mutable) items. + Items can be added, inserted, and removed while in draft state, but once frozen, + they become immutable. Unlike a set, duplicate items are allowed in the draft + state but duplicates already present in the saved state cannot be added. + Any initial items passed to the constructor are saved. + """ + + def __init__(self, iterable: Iterable[T] | None = None, /) -> None: + self._frozen: tuple[T, ...] = tuple(iterable or ()) + self._draft: list[T] = [] + + def append(self, value: T) -> None: + """Append an item to the draft list. No duplicates are allowed.""" + if (value in self._frozen) or (value in self._draft): + return + self._draft.append(value) + + def remove(self, value: T) -> None: + """Remove the first occurrence of value from the draft list.""" + if value in self._frozen: + raise ValueError(f"Cannot remove item from frozen list: {value!r}") + self._draft.remove(value) + + def freeze(self) -> None: + """Freeze any draft items by adding them to the saved tuple.""" + # Filter out duplicates already in saved before extending + new_items = tuple(item for item in self._draft if item not in self._frozen) + self._frozen = self._frozen + new_items + self._draft.clear() + + def __eq__(self, value: object) -> bool: + if not isinstance(value, Sequence): + return NotImplemented + return list(self) == list(value) + + def __contains__(self, value: Any) -> bool: + return value in self._frozen or value in self._draft + + def __len__(self) -> int: + return len(self._frozen) + len(self._draft) + + def __iter__(self) -> Iterator[T]: + return iter(chain(self._frozen, self._draft)) + + @overload + def __getitem__(self, index: int) -> T: ... + + @overload + def __getitem__(self, index: slice) -> Sequence[T]: ... + + def __getitem__(self, index: int | slice) -> T | Sequence[T]: + return [*self._frozen, *self._draft][index] + + @overload + def __setitem__(self, index: int, value: T) -> None: ... + + @overload + def __setitem__(self, index: slice, value: Iterable[T]) -> None: ... + + def __setitem__(self, index: int | slice, value: T | Iterable[T]) -> None: + if isinstance(index, slice): + # Setting slices might affect saved items, disallow for simplicity + raise TypeError(f"{nameof(type(self))!r} does not support slice assignment") + else: + if value in self._frozen or value in self._draft: + return + + # The frozen items are sequentially first and protected from changes + len_frozen = len(self._frozen) + size = len(self) + + if (index >= size) or (index < -size): + raise IndexError("Index out of range") + + draft_index = (index % size) - len_frozen + if draft_index < 0: + raise ValueError(f"Cannot assign to saved item at index {index!r}") + self._draft[draft_index] = value + + @overload + def __delitem__(self, index: int) -> None: ... + + @overload + def __delitem__(self, index: slice) -> None: ... + + def __delitem__(self, index: int | slice) -> None: + if isinstance(index, slice): + raise TypeError(f"{nameof(type(self))!r} does not support slice deletion") + else: + # The frozen items are sequentially first and protected from changes + len_frozen = len(self._frozen) + size = len(self) + + if (index >= size) or (index < -size): + raise IndexError("Index out of range") + + draft_index = (index % size) - len_frozen + if draft_index < 0: + raise ValueError(f"Cannot delete saved item at index {index!r}") + del self._draft[draft_index] + + def insert(self, index: int, value: T) -> None: + """Insert item before index. + + Insertion is only allowed at indices corresponding to the draft portion + of the list (i.e., index >= len(frozen_items)). Negative indices are + interpreted relative to the combined length of frozen and draft items. + """ + if value in self._frozen or value in self._draft: + # Silently ignore duplicates, similar to append + return + + # The frozen items are sequentially first and protected from changes + len_frozen = len(self._frozen) + size = len(self) + + # Follow `list.insert()`'s behavior when the index is out of bounds. + # - Negative out-of-bounds index: prepend (only works if frozen items + # are empty). + if index < -size and not self._frozen: + return self._draft.insert(0, value) + + # - positive out-of-bounds index: append. + if index >= size: + return self._draft.append(value) + + # - in-bounds index: insert only if into the draft portion. + draft_index = (index % size) - len_frozen + if draft_index < 0: + raise IndexError( + f"Cannot insert into the frozen list (index < {len_frozen})" + ) + return self._draft.insert(draft_index, value) + + def __repr__(self) -> str: + return f"{nameof(type(self))}(frozen={list(self._frozen)!r}, draft={list(self._draft)!r})" + + @property + def draft(self) -> tuple[T, ...]: + """A read-only, tuple copy of the current draft items.""" + return tuple(self._draft) + + +class AddOnlyArtifactTypesList(FreezableList[str]): + def remove(self, value: str) -> None: + try: + super().remove(value) + except ValueError: + raise ValueError( + f"Cannot remove artifact type: {value!r} that has been saved to the registry" + ) + + def __repr__(self) -> str: + return f"{nameof(type(self))}(saved={list(self._frozen)!r}, draft={list(self._draft)!r})" diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_members.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_members.py new file mode 100644 index 0000000000000000000000000000000000000000..3823049c7ae74400a9c52e427e3d3d9c9e7269c9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_members.py @@ -0,0 +1,110 @@ +"""Types and helpers for managing registry members.""" + +from __future__ import annotations + +from collections import defaultdict +from enum import Enum +from functools import singledispatchmethod +from typing import Iterable, Literal, Union + +from pydantic.dataclasses import dataclass as pydantic_dataclass + +from wandb._strutils import b64decode_ascii, b64encode_ascii, nameof +from wandb.sdk.artifacts._models import ArtifactsBase + +from ..teams import Team +from ..users import User + + +class MemberKind(str, Enum): + """Identifies what kind of object a registry member is.""" + + USER = "User" + ENTITY = "Entity" + + TEAM = ENTITY # Convenience alias + + +class MemberRole(str, Enum): + """Identifies the role of a member.""" + + ADMIN = "admin" + MEMBER = "member" + VIEWER = "viewer" + RESTRICTED_VIEWER = "restricted_viewer" + + +class UserMember(ArtifactsBase, arbitrary_types_allowed=True): + kind: Literal[MemberKind.USER] = MemberKind.USER + + user: User + role: Union[MemberRole, str] # noqa: UP007 + + +class TeamMember(ArtifactsBase, arbitrary_types_allowed=True): + kind: Literal[MemberKind.ENTITY] = MemberKind.ENTITY + + team: Team + role: Union[MemberRole, str] # noqa: UP007 + + +MemberOrId = Union[User, Team, UserMember, TeamMember, str] +"""Type hint for a registry member argument that accepts a User, Team, or their ID.""" + + +def parse_member_ids(members: Iterable[MemberOrId]) -> tuple[list[str], list[str]]: + """Returns a tuple of (user_ids, team_ids) from parsing the given objects.""" + ids_by_kind: dict[MemberKind, set[str]] = defaultdict(set) + + for parsed in map(MemberId.from_obj, members): + ids_by_kind[parsed.kind].add(parsed.encode()) + + user_ids = ids_by_kind[MemberKind.USER] + team_ids = ids_by_kind[MemberKind.ENTITY] + + # Ordering shouldn't matter, but sort anyway for reproducibility and testing + return sorted(user_ids), sorted(team_ids) + + +@pydantic_dataclass +class MemberId: + kind: MemberKind + index: int + + def encode(self) -> str: + """Converts this parsed ID to a base64-encoded GraphQL ID.""" + return b64encode_ascii(f"{self.kind.value}:{self.index}") + + @singledispatchmethod + @classmethod + def from_obj(cls, obj: MemberOrId, /) -> MemberId: + """Parses `User` or `Team` ID from the argument.""" + # Fallback for unexpected types + raise TypeError( + f"Member arg must be a {nameof(User)!r}, {nameof(Team)!r}, or a user/team ID. " + f"Got: {nameof(type(obj))!r}" + ) + + @from_obj.register(User) + @from_obj.register(Team) + @classmethod + def _from_obj_with_id(cls, obj: User | Team, /) -> MemberId: + # Use the object's string (base64-encoded) GraphQL ID + return cls._from_id(obj.id) + + @from_obj.register(UserMember) + @classmethod + def _from_user_member(cls, member: UserMember, /) -> MemberId: + return cls._from_id(member.user.id) + + @from_obj.register(TeamMember) + @classmethod + def _from_team_member(cls, member: TeamMember, /) -> MemberId: + return cls._from_id(member.team.id) + + @from_obj.register(str) + @classmethod + def _from_id(cls, id_: str, /) -> MemberId: + # Parse the ID to figure out if it's a team or user ID + kind, index = b64decode_ascii(id_).split(":", maxsplit=1) + return cls(kind=kind, index=index) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_utils.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c6264e008e013f6ad6c91f05ba6b3bbe25496b11 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/_utils.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from enum import Enum +from functools import lru_cache, partial +from typing import TYPE_CHECKING, Any, Collection + +from wandb_gql import gql + +from wandb._strutils import ensureprefix + +if TYPE_CHECKING: + from wandb.apis.public.api import RetryingClient + + +class Visibility(str, Enum): + # names are what users see/pass into Python methods + # values are what's expected by backend API + organization = "PRIVATE" + restricted = "RESTRICTED" + + @classmethod + def _missing_(cls, value: object) -> Any: + # Allow instantiation from enum names too (e.g. "organization" or "restricted") + return cls.__members__.get(value) + + @classmethod + def from_gql(cls, value: str) -> Visibility: + """Convert a GraphQL `visibility` value to a Visibility enum.""" + try: + return cls(value) + except ValueError: + expected = ",".join(repr(e.value) for e in cls) + raise ValueError( + f"Invalid visibility {value!r} from backend. Expected one of: {expected}" + ) from None + + @classmethod + def from_python(cls, name: str) -> Visibility: + """Convert a visibility string to a `Visibility` enum.""" + try: + return cls(name) + except ValueError: + expected = ",".join(repr(e.name) for e in cls) + raise ValueError( + f"Invalid visibility {name!r}. Expected one of: {expected}" + ) from None + + +def prepare_artifact_types_input( + artifact_types: Collection[str] | None, +) -> list[dict[str, str]] | None: + """Format the artifact types for the GQL input. + + Args: + artifact_types: The artifact types to add to the registry. + + Returns: + The artifact types for the GQL input. + """ + from wandb.sdk.artifacts._validators import validate_artifact_types + + if artifact_types: + return [{"name": typ} for typ in validate_artifact_types(artifact_types)] + return None + + +def ensure_registry_prefix_on_names(query: Any, in_name: bool = False) -> Any: + """Recursively the registry prefix to values under "name" keys, excluding regex ops. + + - in_name: True if we are under a "name" key (or propagating from one). + + EX: {"name": "model"} -> {"name": "wandb-registry-model"} + """ + from wandb.sdk.artifacts._validators import REGISTRY_PREFIX + + if isinstance((txt := query), str): + return ensureprefix(txt, REGISTRY_PREFIX) if in_name else txt + if isinstance((dct := query), dict): + new_dict = {} + for key, obj in dct.items(): + if key == "$regex": + # For regex operator, we skip transformation of its value. + new_dict[key] = obj + elif key == "name": + new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=True) + else: + # For any other key, propagate flags as-is. + new_dict[key] = ensure_registry_prefix_on_names(obj, in_name=in_name) + return new_dict + if isinstance((seq := query), (list, tuple)): + return list(map(partial(ensure_registry_prefix_on_names, in_name=in_name), seq)) + return query + + +@lru_cache(maxsize=10) +def fetch_org_entity_from_organization( + client: RetryingClient, organization: str +) -> str: + """Fetch the org entity from the organization. + + Args: + client (Client): Graphql client. + organization (str): The organization to fetch the org entity for. + """ + from wandb.sdk.artifacts._generated import ( + FETCH_ORG_ENTITY_FROM_ORGANIZATION_GQL, + FetchOrgEntityFromOrganization, + ) + + gql_op = gql(FETCH_ORG_ENTITY_FROM_ORGANIZATION_GQL) + try: + data = client.execute(gql_op, variable_values={"organization": organization}) + except Exception as e: + msg = f"Error fetching org entity for organization: {organization!r}" + raise ValueError(msg) from e + + result = FetchOrgEntityFromOrganization.model_validate(data) + if ( + not (org := result.organization) + or not (org_entity := org.org_entity) + or not (org_name := org_entity.name) + ): + raise ValueError(f"Organization entity for {organization!r} not found.") + + return org_name diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/registries_search.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/registries_search.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae99f6725dc5d84466aae035d2634744f73ad5f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/registries_search.py @@ -0,0 +1,311 @@ +"""Public API: registries search.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, ClassVar + +from pydantic import PositiveInt, ValidationError +from typing_extensions import override +from wandb_gql import gql + +from wandb._analytics import tracked +from wandb.apis.paginator import RelayPaginator, SizedRelayPaginator + +from ._utils import ensure_registry_prefix_on_names + +if TYPE_CHECKING: + from wandb_graphql.language.ast import Document + + from wandb.apis.public import ArtifactCollection, RetryingClient + from wandb.apis.public.registries.registry import Registry + from wandb.sdk.artifacts._generated import ( + ArtifactMembershipFragment, + RegistryCollectionFragment, + RegistryFragment, + ) + from wandb.sdk.artifacts._models.pagination import ( + ArtifactMembershipConnection, + RegistryCollectionConnection, + RegistryConnection, + ) + from wandb.sdk.artifacts.artifact import Artifact + + +class Registries(RelayPaginator["RegistryFragment", "Registry"]): + """A lazy iterator of `Registry` objects.""" + + QUERY: ClassVar[Document | None] = None + last_response: RegistryConnection | None + + def __init__( + self, + client: RetryingClient, + organization: str, + filter: dict[str, Any] | None = None, + per_page: PositiveInt = 100, + ): + if self.QUERY is None: + from wandb.sdk.artifacts._generated import FETCH_REGISTRIES_GQL + + type(self).QUERY = gql(FETCH_REGISTRIES_GQL) + + self.client = client + self.organization = organization + self.filter = ensure_registry_prefix_on_names(filter or {}) + + variables = {"organization": organization, "filters": json.dumps(self.filter)} + super().__init__(client, variables=variables, per_page=per_page) + + def __next__(self): + # Implement custom next since its possible to load empty pages because of auth + self.index += 1 + while len(self.objects) <= self.index: + if not self._load_page(): + raise StopIteration + return self.objects[self.index] + + @tracked + def collections( + self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100 + ) -> Collections: + return Collections( + client=self.client, + organization=self.organization, + registry_filter=self.filter, + collection_filter=filter, + per_page=per_page, + ) + + @tracked + def versions( + self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100 + ) -> Versions: + return Versions( + client=self.client, + organization=self.organization, + registry_filter=self.filter, + collection_filter=None, + artifact_filter=filter, + per_page=per_page, + ) + + @property + def length(self): + if self.last_response is None: + return None + return len(self.last_response.edges) + + @override + def _update_response(self) -> None: + from wandb.sdk.artifacts._generated import FetchRegistries + from wandb.sdk.artifacts._models.pagination import RegistryConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = FetchRegistries.model_validate(data) + if not ((org := result.organization) and (org_entity := org.org_entity)): + raise ValueError( + f"Organization {self.organization!r} not found. Please verify the organization name is correct." + ) + + try: + conn = org_entity.projects + self.last_response = RegistryConnection.model_validate(conn) + except (LookupError, AttributeError, ValidationError) as e: + raise ValueError("Unexpected response data") from e + + def _convert(self, node: RegistryFragment) -> Registry: + from wandb.apis.public.registries.registry import Registry + from wandb.sdk.artifacts._validators import remove_registry_prefix + + return Registry( + client=self.client, + organization=self.organization, + entity=node.entity.name, + name=remove_registry_prefix(node.name), + attrs=node, + ) + + +class Collections( + SizedRelayPaginator["RegistryCollectionFragment", "ArtifactCollection"] +): + """An lazy iterator of `ArtifactCollection` objects in a Registry.""" + + QUERY: ClassVar[Document | None] = None + last_response: RegistryCollectionConnection | None + + def __init__( + self, + client: RetryingClient, + organization: str, + registry_filter: dict[str, Any] | None = None, + collection_filter: dict[str, Any] | None = None, + per_page: PositiveInt = 100, + ): + if self.QUERY is None: + from wandb.sdk.artifacts._generated import REGISTRY_COLLECTIONS_GQL + + type(self).QUERY = gql(REGISTRY_COLLECTIONS_GQL) + + self.client = client + self.organization = organization + self.registry_filter = registry_filter + self.collection_filter = collection_filter or {} + + variables = { + "registryFilter": json.dumps(f) if (f := registry_filter) else None, + "collectionFilter": json.dumps(f) if (f := collection_filter) else None, + "organization": organization, + "perPage": per_page, + } + super().__init__(client, variables=variables, per_page=per_page) + + def __next__(self): + # Implement custom next since its possible to load empty pages because of auth + self.index += 1 + while len(self.objects) <= self.index: + if not self._load_page(): + raise StopIteration + return self.objects[self.index] + + @tracked + def versions( + self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100 + ) -> Versions: + return Versions( + client=self.client, + organization=self.organization, + registry_filter=self.registry_filter, + collection_filter=self.collection_filter, + artifact_filter=filter, + per_page=per_page, + ) + + @override + def _update_response(self) -> None: + from wandb.sdk.artifacts._generated import RegistryCollections + from wandb.sdk.artifacts._models.pagination import RegistryCollectionConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = RegistryCollections.model_validate(data) + if not ((org := result.organization) and (org_entity := org.org_entity)): + raise ValueError( + f"Organization {self.organization!r} not found. Please verify the organization name is correct." + ) + + try: + conn = org_entity.artifact_collections + self.last_response = RegistryCollectionConnection.model_validate(conn) + except (LookupError, AttributeError, ValidationError) as e: + raise ValueError("Unexpected response data") from e + + def _convert(self, node: RegistryCollectionFragment) -> ArtifactCollection | None: + from wandb._pydantic import gql_typename + from wandb.apis.public import ArtifactCollection + from wandb.sdk.artifacts._generated import ArtifactSequenceTypeFields + + if not ( + # We don't _expect_ any registry collections to be + # ArtifactSequences, but defensively filter them out anyway. + node.project + and (node.typename__ != gql_typename(ArtifactSequenceTypeFields)) + ): + return None + return ArtifactCollection( + client=self.client, + entity=node.project.entity.name, + project=node.project.name, + name=node.name, + type=node.type.name, + organization=self.organization, + attrs=node, + ) + + +class Versions(RelayPaginator["ArtifactMembershipFragment", "Artifact"]): + """An lazy iterator of `Artifact` objects in a Registry.""" + + QUERY: Document # Must be set per-instance + last_response: ArtifactMembershipConnection | None + + def __init__( + self, + client: RetryingClient, + organization: str, + registry_filter: dict[str, Any] | None = None, + collection_filter: dict[str, Any] | None = None, + artifact_filter: dict[str, Any] | None = None, + per_page: PositiveInt = 100, + ): + from wandb.sdk.artifacts._generated import REGISTRY_VERSIONS_GQL + + self.QUERY = gql(REGISTRY_VERSIONS_GQL) + + self.client = client + self.organization = organization + self.registry_filter = registry_filter + self.collection_filter = collection_filter + self.artifact_filter = artifact_filter or {} + + variables = { + "registryFilter": json.dumps(f) if (f := registry_filter) else None, + "collectionFilter": json.dumps(f) if (f := collection_filter) else None, + "artifactFilter": json.dumps(f) if (f := artifact_filter) else None, + "organization": organization, + } + super().__init__(client, variables=variables, per_page=per_page) + + @override + def __next__(self): + # Implement custom next since its possible to load empty pages because of auth + self.index += 1 + while len(self.objects) <= self.index: + if not self._load_page(): + raise StopIteration + return self.objects[self.index] + + @property + def length(self) -> int | None: + if self.last_response is None: + return None + return len(self.last_response.edges) + + @override + def _update_response(self) -> None: + from wandb.sdk.artifacts._generated import RegistryVersions + from wandb.sdk.artifacts._models.pagination import ArtifactMembershipConnection + + data = self.client.execute(self.QUERY, variable_values=self.variables) + result = RegistryVersions.model_validate(data) + if not ((org := result.organization) and (org_entity := org.org_entity)): + raise ValueError( + f"Organization {self.organization!r} not found. Please verify the organization name is correct." + ) + + try: + conn = org_entity.artifact_memberships + self.last_response = ArtifactMembershipConnection.model_validate(conn) + except (LookupError, AttributeError, ValidationError) as e: + raise ValueError("Unexpected response data") from e + + def _convert(self, node: ArtifactMembershipFragment) -> Artifact | None: + from wandb.sdk.artifacts._validators import FullArtifactPath + from wandb.sdk.artifacts.artifact import Artifact + + if not ( + (collection := node.artifact_collection) + and (project := collection.project) + and node.artifact + and (version_idx := node.version_index) is not None + ): + return None + return Artifact._from_membership( + membership=node, + target=FullArtifactPath( + prefix=project.entity.name, + project=project.name, + name=f"{collection.name}:v{version_idx}", + ), + client=self.client, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/registry.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..ce9ec4da7930184e4a80d5ca5ce6816ff9ce7958 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/registries/registry.py @@ -0,0 +1,661 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import PositiveInt +from typing_extensions import Self, assert_never +from wandb_gql import gql + +import wandb +from wandb._analytics import tracked +from wandb._strutils import nameof +from wandb.apis.public.teams import Team +from wandb.apis.public.users import User +from wandb.proto import wandb_internal_pb2 as pb +from wandb.sdk.artifacts._models import RegistryData + +from ._freezable_list import AddOnlyArtifactTypesList +from ._members import ( + MemberId, + MemberKind, + MemberRole, + TeamMember, + UserMember, + parse_member_ids, +) +from ._utils import ( + Visibility, + fetch_org_entity_from_organization, + prepare_artifact_types_input, +) +from .registries_search import Collections, Versions + +if TYPE_CHECKING: + from wandb.apis.public.api import RetryingClient + from wandb.sdk.artifacts._generated import RegistryFragment + + +class Registry: + """A single registry in the Registry.""" + + _saved: RegistryData + """The saved registry data as last fetched from the W&B server.""" + + _current: RegistryData + """The local, editable registry data.""" + + def __init__( + self, + client: RetryingClient, + organization: str, + entity: str, + name: str, + attrs: RegistryFragment | None = None, + ): + self.client = client + + if attrs is None: + # FIXME: This is awkward and bypasses validation which seems shaky. + # Reconsider the init signature of `Registry` so this isn't necessary? + draft = RegistryData.model_construct( + organization=organization, entity=entity, name=name + ) + self._saved = draft + self._current = draft.model_copy(deep=True) + else: + self._update_attributes(attrs) + + def _update_attributes(self, fragment: RegistryFragment) -> None: + """Update instance attributes from a GraphQL fragment.""" + saved = RegistryData.from_fragment(fragment) + self._saved = saved + self._current = saved.model_copy(deep=True) + + @property + def id(self) -> str: + """The unique ID for this registry.""" + return self._current.id + + @property + def full_name(self) -> str: + """Full name of the registry including the `wandb-registry-` prefix.""" + return self._current.full_name + + @property + def name(self) -> str: + """Name of the registry without the `wandb-registry-` prefix.""" + return self._current.name + + @name.setter + def name(self, value: str): + self._current.name = value + + @property + def entity(self) -> str: + """Organization entity of the registry.""" + return self._current.entity + + @property + def organization(self) -> str: + """Organization name of the registry.""" + return self._current.organization + + @property + def description(self) -> str | None: + """Description of the registry.""" + return self._current.description + + @description.setter + def description(self, value: str) -> None: + """Set the description of the registry.""" + self._current.description = value + + @property + def allow_all_artifact_types(self) -> bool: + """Return whether all artifact types are allowed in the registry. + + If `True`, artifacts of any type can be added. If `False`, artifacts are + restricted to the types listed in `artifact_types`. + """ + return self._current.allow_all_artifact_types + + @allow_all_artifact_types.setter + def allow_all_artifact_types(self, value: bool) -> None: + """Set whether all artifact types are allowed in the registry.""" + self._current.allow_all_artifact_types = value + + @property + def artifact_types(self) -> AddOnlyArtifactTypesList: + """Returns the artifact types allowed in the registry. + + If `allow_all_artifact_types` is `True` then `artifact_types` reflects the + types previously saved or currently used in the registry. + If `allow_all_artifact_types` is `False` then artifacts are restricted to the + types in `artifact_types`. + + Note: + Previously saved artifact types cannot be removed. + + Example: + ```python + import wandb + + registry = wandb.Api().create_registry() + registry.artifact_types.append("model") + registry.save() # once saved, the artifact type `model` cannot be removed + registry.artifact_types.append("accidentally_added") + registry.artifact_types.remove( + "accidentally_added" + ) # Types can only be removed if it has not been saved yet + ``` + """ + return self._current.artifact_types + + @property + def created_at(self) -> str: + """Timestamp of when the registry was created.""" + return self._current.created_at + + @property + def updated_at(self) -> str: + """Timestamp of when the registry was last updated.""" + return self._current.updated_at + + @property + def path(self) -> list[str]: + return [self.entity, self.full_name] + + @property + def visibility(self) -> Literal["organization", "restricted"]: + """Visibility of the registry. + + Returns: + Literal["organization", "restricted"]: The visibility level. + - "organization": Anyone in the organization can view this registry. + You can edit their roles later from the settings in the UI. + - "restricted": Only invited members via the UI can access this registry. + Public sharing is disabled. + """ + return self._current.visibility.name + + @visibility.setter + def visibility(self, value: Literal["organization", "restricted"]): + """Set the visibility of the registry. + + Args: + value: The visibility level. Options are: + - "organization": Anyone in the organization can view this registry. + You can edit their roles later from the settings in the UI. + - "restricted": Only invited members via the UI can access this registry. + Public sharing is disabled. + """ + self._current.visibility = value + + @tracked + def collections( + self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100 + ) -> Collections: + """Returns the collections belonging to the registry.""" + return Collections( + client=self.client, + organization=self.organization, + registry_filter={"name": self.full_name}, + collection_filter=filter, + per_page=per_page, + ) + + @tracked + def versions( + self, filter: dict[str, Any] | None = None, per_page: PositiveInt = 100 + ) -> Versions: + """Returns the versions belonging to the registry.""" + return Versions( + client=self.client, + organization=self.organization, + registry_filter={"name": self.full_name}, + collection_filter=None, + artifact_filter=filter, + per_page=per_page, + ) + + @classmethod + @tracked + def create( + cls, + client: RetryingClient, + organization: str, + name: str, + visibility: Literal["organization", "restricted"], + description: str | None = None, + artifact_types: list[str] | None = None, + ) -> Self: + """Create a new registry. + + The registry name must be unique within the organization. + This function should be called using `api.create_registry()` + + Args: + client: The GraphQL client. + organization: The name of the organization. + name: The name of the registry (without the `wandb-registry-` prefix). + visibility: The visibility level ('organization' or 'restricted'). + description: An optional description for the registry. + artifact_types: An optional list of allowed artifact types. + + Returns: + Registry: The newly created Registry object. + + Raises: + ValueError: If a registry with the same name already exists in the + organization or if the creation fails. + """ + from wandb.sdk.artifacts._generated import ( + UPSERT_REGISTRY_GQL, + UpsertModelInput, + UpsertRegistry, + ) + from wandb.sdk.artifacts._validators import ( + REGISTRY_PREFIX, + validate_project_name, + ) + + failed_msg = ( + f"Failed to create registry {name!r} in organization {organization!r}." + ) + + org_entity = fetch_org_entity_from_organization(client, organization) + + gql_op = gql(UPSERT_REGISTRY_GQL) + gql_input = UpsertModelInput( + description=description, + entity_name=org_entity, + name=validate_project_name(f"{REGISTRY_PREFIX}{name}"), + access=Visibility.from_python(visibility).value, + allow_all_artifact_types_in_registry=not artifact_types, + artifact_types=prepare_artifact_types_input(artifact_types), + ) + gql_vars = {"input": gql_input.model_dump()} + try: + data = client.execute(gql_op, variable_values=gql_vars) + result = UpsertRegistry.model_validate(data).upsert_model + except Exception as e: + raise ValueError(failed_msg) from e + if not (result and result.inserted and (registry_project := result.project)): + raise ValueError(failed_msg) + + return cls( + client, + organization=organization, + entity=org_entity, + name=name, + attrs=registry_project, + ) + + @tracked + def delete(self) -> None: + """Delete the registry. This is irreversible.""" + from wandb.sdk.artifacts._generated import DELETE_REGISTRY_GQL, DeleteRegistry + + failed_msg = f"Failed to delete registry {self.name!r} in organization {self.organization!r}" + + gql_op = gql(DELETE_REGISTRY_GQL) + gql_vars = {"id": self.id} + try: + data = self.client.execute(gql_op, variable_values=gql_vars) + result = DeleteRegistry.model_validate(data).delete_model + except Exception as e: + raise ValueError(failed_msg) from e + if not (result and result.success): + raise ValueError(failed_msg) + + @tracked + def load(self) -> None: + """Load registry attributes from the backend.""" + from wandb.sdk.artifacts._generated import FETCH_REGISTRY_GQL, FetchRegistry + + failed_msg = ( + f"Failed to load registry {self.name!r} in organization" + f" {self.organization!r}." + ) + + gql_op = gql(FETCH_REGISTRY_GQL) + gql_vars = {"name": self.full_name, "entity": self.entity} + try: + data = self.client.execute(gql_op, variable_values=gql_vars) + result = FetchRegistry.model_validate(data) + except Exception as e: + raise ValueError(failed_msg) from e + + if not ((entity := result.entity) and (registry_project := entity.project)): + raise ValueError(failed_msg) + + self._update_attributes(registry_project) + + @tracked + def save(self) -> None: + """Save registry attributes to the backend.""" + from wandb.sdk.artifacts._generated import ( + RENAME_REGISTRY_GQL, + UPSERT_REGISTRY_GQL, + RenameProjectInput, + RenameRegistry, + UpsertModelInput, + UpsertRegistry, + ) + from wandb.sdk.artifacts._gqlutils import server_supports + from wandb.sdk.artifacts._validators import validate_project_name + + if not server_supports( + self.client, pb.INCLUDE_ARTIFACT_TYPES_IN_REGISTRY_CREATION + ): + raise RuntimeError( + "Saving the registry is not enabled on this wandb server version. " + "Please upgrade your server version or contact support at support@wandb.com." + ) + + # If `artifact_types.draft` has items, the user added types that are not + # yet saved. + if ( + new_artifact_types := self.artifact_types.draft + ) and self.allow_all_artifact_types: + raise ValueError( + f"Cannot update artifact types when `allows_all_artifact_types` is {True!r}. Set it to {False!r} first." + ) + + failed_msg = f"Failed to save registry {self.name!r} in organization {self.organization!r}" + + old_project_name = validate_project_name(self._saved.full_name) + new_project_name = validate_project_name(self._current.full_name) + + upsert_op = gql(UPSERT_REGISTRY_GQL) + upsert_input = UpsertModelInput( + description=self.description, + entity_name=self.entity, + name=old_project_name, + access=self._current.visibility.value, + allow_all_artifact_types_in_registry=self.allow_all_artifact_types, + artifact_types=prepare_artifact_types_input(new_artifact_types), + ) + upsert_vars = {"input": upsert_input.model_dump()} + try: + data = self.client.execute(upsert_op, variable_values=upsert_vars) + result = UpsertRegistry.model_validate(data).upsert_model + except Exception as e: + raise ValueError(failed_msg) from e + + if result and result.inserted: + # This should only trigger if `_saved_name` was modified unexpectedly. + wandb.termlog( + f"Created registry {self.name!r} in organization {self.organization!r} on save" + ) + + if not (result and (registry_project := result.project)): + raise ValueError(failed_msg) + + self._update_attributes(registry_project) + + # Update the name of the registry if it has changed + if old_project_name != new_project_name: + rename_op = gql(RENAME_REGISTRY_GQL) + rename_input = RenameProjectInput( + entity_name=self.entity, + old_project_name=old_project_name, + new_project_name=new_project_name, + ) + rename_vars = {"input": rename_input.model_dump()} + data = self.client.execute(rename_op, variable_values=rename_vars) + result = RenameRegistry.model_validate(data).rename_project + if not (result and (registry_project := result.project)): + raise ValueError(failed_msg) + + if result.inserted: + # This should only trigger if `_saved_name` was modified unexpectedly. + wandb.termlog(f"Created new registry {self.name!r} on save") + + self._update_attributes(registry_project) + + def members(self) -> list[UserMember | TeamMember]: + """Returns the current members (users and teams) of this registry.""" + return [*self.user_members(), *self.team_members()] + + def user_members(self) -> list[UserMember]: + """Returns the current member users of this registry.""" + from wandb.sdk.artifacts._generated import ( + REGISTRY_USER_MEMBERS_GQL, + RegistryUserMembers, + ) + + gql_op = gql(REGISTRY_USER_MEMBERS_GQL) + gql_vars = {"project": self.full_name, "entity": self.entity} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = RegistryUserMembers.model_validate(data) + + if not (project := result.project): + raise ValueError(f"Failed to fetch user members for registry {self.name!r}") + + return [ + UserMember( + user=User( + client=self.client, + # The `User` class requires an unstructured attribute dict. + # Exclude `.role`, which is specific to this registry membership. + attrs=m.model_dump(exclude_none=True, exclude={"role"}), + ), + role=m.role.name, + ) + for m in project.members + ] + + def team_members(self) -> list[TeamMember]: + """Returns the current member teams of this registry.""" + from wandb.sdk.artifacts._generated import ( + REGISTRY_TEAM_MEMBERS_GQL, + RegistryTeamMembers, + ) + + gql_op = gql(REGISTRY_TEAM_MEMBERS_GQL) + gql_vars = {"project": self.full_name, "entity": self.entity} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = RegistryTeamMembers.model_validate(data) + + if not (project := result.project): + raise ValueError(f"Failed to fetch team members for registry {self.name!r}") + + return [ + TeamMember( + team=Team( + client=self.client, + name=m.team.name, + # The `Team` class currently requires an unstructured attribute dict. + attrs=m.team.model_dump(exclude_none=True), + ), + role=m.role.name, + ) + for m in project.team_members + ] + + def add_members( + self, *members: User | UserMember | Team | TeamMember | str + ) -> Self: + """Adds users or teams to this registry. + + Args: + members: The users or teams to add to the registry. Accepts + `User` objects, `Team` objects, or their string IDs. + + Returns: + This registry for further method chaining, if needed. + + Raises: + TypeError: If no members are passed as arguments. + ValueError: If unable to infer or parse the user or team IDs. + + Examples: + ```python + import wandb + + api = wandb.Api() + + # Fetch an existing registry + registry = api.registry(name="my-registry", organization="my-org") + + user1 = api.user(username="some-user") + user2 = api.user(username="other-user") + registry.add_members(user1, user2) + + my_team = api.team(name="my-team") + registry.add_members(my_team) + ``` + """ + from wandb.sdk.artifacts._generated import ( + CREATE_REGISTRY_MEMBERS_GQL, + CreateProjectMembersInput, + CreateRegistryMembers, + ) + + if not members: + raise TypeError( + f"Must provide at least one member to {nameof(self.add_members)!r}." + ) + user_ids, team_ids = parse_member_ids(members) + + gql_op = gql(CREATE_REGISTRY_MEMBERS_GQL) + gql_input = CreateProjectMembersInput( + user_ids=user_ids, team_ids=team_ids, project_id=self.id + ) + gql_vars = {"input": gql_input.model_dump()} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = CreateRegistryMembers.model_validate(data).result + + if not (result and result.success): + raise ValueError(f"Failed to add members to registry {self.name!r}") + return self + + def remove_members( + self, *members: User | UserMember | Team | TeamMember | str + ) -> Self: + """Removes users or teams from this registry. + + Args: + members: The users or teams to remove from the registry. Accepts + `User` objects, `Team` objects, or their string IDs. + + Returns: + This registry for further method chaining, if needed. + + Raises: + TypeError: If no members are passed as arguments. + ValueError: If unable to infer or parse the user or team IDs. + + Examples: + ```python + import wandb + + api = wandb.Api() + + # Fetch an existing registry + registry = api.registry(name="my-registry", organization="my-org") + + user1 = api.user(username="some-user") + user2 = api.user(username="other-user") + registry.remove_members(user1, user2) + + old_team = api.team(name="old-team") + registry.remove_members(old_team) + ``` + """ + from wandb.sdk.artifacts._generated import ( + DELETE_REGISTRY_MEMBERS_GQL, + DeleteProjectMembersInput, + DeleteRegistryMembers, + ) + + if not members: + raise TypeError( + f"Must provide at least one member to {nameof(self.add_members)!r}." + ) + user_ids, team_ids = parse_member_ids(members) + + gql_op = gql(DELETE_REGISTRY_MEMBERS_GQL) + gql_input = DeleteProjectMembersInput( + user_ids=user_ids, team_ids=team_ids, project_id=self.id + ) + gql_vars = {"input": gql_input.model_dump()} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = DeleteRegistryMembers.model_validate(data).result + + if not (result and result.success): + raise ValueError(f"Failed to remove members from registry {self.name!r}") + return self + + def update_member( + self, + member: User | UserMember | Team | TeamMember | str, + role: MemberRole | str, + ) -> Self: + """Updates the role of a member (user or team) within this registry. + + Args: + member: The user or team to update the role of. + Accepts a `User` object, `Team` object, or their string ID. + role: The new role to assign to the member. May be one of: + - "admin" + - "member" + - "viewer" + - "restricted_viewer" (if supported by the W&B server) + + Returns: + This registry for further method chaining, if needed. + + Raises: + ValueError: If unable to infer the user or team ID. + + Examples: + Make all users in the registry admins: + ```python + import wandb + + api = wandb.Api() + + # Fetch an existing registry + registry = api.registry(name="my-registry", organization="my-org") + + for member in registry.user_members(): + registry.update_member(member.user, role="admin") + ``` + """ + from wandb.sdk.artifacts._generated import ( + UPDATE_TEAM_REGISTRY_ROLE_GQL, + UPDATE_USER_REGISTRY_ROLE_GQL, + UpdateProjectMemberInput, + UpdateProjectTeamMemberInput, + UpdateTeamRegistryRole, + UpdateUserRegistryRole, + ) + + id_ = MemberId.from_obj(member) + + if id_.kind is MemberKind.USER: + gql_op = gql(UPDATE_USER_REGISTRY_ROLE_GQL) + gql_input = UpdateProjectMemberInput( + user_id=id_.encode(), project_id=self.id, user_project_role=role + ) + result_cls = UpdateUserRegistryRole + elif id_.kind is MemberKind.ENTITY: + gql_op = gql(UPDATE_TEAM_REGISTRY_ROLE_GQL) + gql_input = UpdateProjectTeamMemberInput( + team_id=id_.encode(), project_id=self.id, team_project_role=role + ) + result_cls = UpdateTeamRegistryRole + else: + assert_never(id_.kind) + + gql_vars = {"input": gql_input.model_dump()} + data = self.client.execute(gql_op, variable_values=gql_vars) + result = result_cls.model_validate(data).result + + if not (result and result.success): + raise ValueError( + f"Failed to update member {member!r} role to {role!r} in registry {self.name!r}" + ) + return self diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/reports.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/reports.py new file mode 100644 index 0000000000000000000000000000000000000000..c3608c9cfaee595350c10d83295bd01ce3f15342 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/reports.py @@ -0,0 +1,618 @@ +"""W&B Public API for Report objects. + +This module provides classes for interacting with W&B reports and +managing report-related data. +""" + +from __future__ import annotations + +import ast +import json +import re +import urllib +from typing import TYPE_CHECKING, Any + +from wandb_gql import gql + +import wandb +from wandb._strutils import nameof +from wandb.apis import public +from wandb.apis.attrs import Attrs +from wandb.apis.paginator import SizedPaginator +from wandb.sdk.lib import ipython + +if TYPE_CHECKING: + from .api import RetryingClient + from .projects import Project + + +class Reports(SizedPaginator["BetaReport"]): + """Reports is a lazy iterator of `BetaReport` objects. + + Args: + client (`wandb.apis.internal.Api`): The API client instance to use. + project (`wandb.sdk.internal.Project`): The project to fetch reports from. + name (str, optional): The name of the report to filter by. If `None`, + fetches all reports. + entity (str, optional): The entity name for the project. Defaults to + the project entity. + per_page (int): Number of reports to fetch per page (default is 50). + """ + + QUERY = gql( + """ + query ProjectViews($project: String!, $entity: String!, $reportCursor: String, + $reportLimit: Int!, $viewType: String = "runs", $viewName: String) { + project(name: $project, entityName: $entity) { + allViews(viewType: $viewType, viewName: $viewName, first: + $reportLimit, after: $reportCursor) { + edges { + node { + id + name + displayName + description + user { + username + photoUrl + email + } + spec + updatedAt + createdAt + } + cursor + } + pageInfo { + endCursor + hasNextPage + } + + } + } + } + """ + ) + + def __init__( + self, + client: RetryingClient, + project: Project, + name: str | None = None, + entity: str | None = None, + per_page: int = 50, + ): + self.project = project + self.name = name + variables = { + "project": project.name, + "entity": project.entity, + "viewName": self.name, + } + super().__init__(client, variables, per_page) + + @property + def _length(self) -> int | None: + """The number of reports in the project. + + + """ + # TODO: Add the count the backend + if self.last_response: + return len(self.objects) + return None + + @property + def more(self) -> bool: + """Returns whether there are more files to fetch. + + + """ + if self.last_response: + return bool( + self.last_response["project"]["allViews"]["pageInfo"]["hasNextPage"] + ) + return True + + @property + def cursor(self) -> str | None: + """Returns the cursor position for pagination of file results. + + + """ + if self.last_response: + return self.last_response["project"]["allViews"]["edges"][-1]["cursor"] + return None + + def update_variables(self) -> None: + """Updates the GraphQL query variables for pagination.""" + self.variables.update( + {"reportCursor": self.cursor, "reportLimit": self.per_page} + ) + + def convert_objects(self) -> list[BetaReport]: + """Converts GraphQL edges to File objects.""" + if self.last_response["project"] is None: + raise ValueError( + f"Project {self.variables['project']} does not exist under entity {self.variables['entity']}" + ) + return [ + BetaReport( + self.client, + r["node"], + entity=self.project.entity, + project=self.project.name, + ) + for r in self.last_response["project"]["allViews"]["edges"] + ] + + def __repr__(self) -> str: + return f"<{nameof(type(self))} {'/'.join(self.project.path)}>" + + +class BetaReport(Attrs): + """BetaReport is a class associated with reports created in W&B. + + Provides access to report attributes (name, description, user, spec, + timestamps) and methods for retrieving associated runs, + sections, and for rendering the report as HTML. + + Attributes: + id (string): Unique identifier of the report. + display_name (string): Human-readable display name of the report. + name (string): The name of the report. Use `display_name` for a more user-friendly name. + description (string): Description of the report. + user (User): Dictionary containing user info (username, email) who + created the report. + spec (dict): The spec of the report. + url (string): The URL of the report. + updated_at (string): Timestamp of last update. + created_at (string): Timestamp when the report was created. + """ + + def __init__( + self, + client: RetryingClient, + attrs: dict, + entity: str | None = None, + project: str | None = None, + ): + self.client = client + self.project = project + self.entity = entity + self.query_generator = public.QueryGenerator() + super().__init__(dict(attrs)) + + if "spec" in self._attrs: + if isinstance(self._attrs["spec"], str): + self._attrs["spec"] = json.loads(self._attrs["spec"]) + else: + self._attrs["spec"] = {} + + @property + def spec(self) -> dict[str, Any]: + return self._attrs["spec"] + + @property + def sections(self): + """Get the panel sections (groups) from the report.""" + return self.spec["panelGroups"] + + def runs( + self, + section: dict[str, Any], + per_page: int = 50, + only_selected: bool = True, + ) -> public.Runs: + """Get runs associated with a section of the report.""" + run_set_idx = section.get("openRunSet", 0) + run_set = section["runSets"][run_set_idx] + order = self.query_generator.key_to_server_path(run_set["sort"]["key"]) + if run_set["sort"].get("ascending"): + order = "+" + order + else: + order = "-" + order + filters = self.query_generator.filter_to_mongo(run_set["filters"]) + if only_selected: + # TODO: handle this not always existing + filters["$or"][0]["$and"].append( + {"name": {"$in": run_set["selections"]["tree"]}} + ) + return public.Runs( + self.client, + self.entity, + self.project, + filters=filters, + order=order, + per_page=per_page, + ) + + @property + def id(self) -> str: + return self._attrs.get("id") + + @property + def name(self) -> str | None: + return self._attrs.get("name") + + @property + def display_name(self) -> str | None: + return self._attrs.get("displayName") + + @property + def description(self) -> str | None: + return self._attrs.get("description") + + @property + def user(self): + return self._attrs.get("user") + + @property + def updated_at(self): + return self._attrs.get("updatedAt") + + @property + def created_at(self): + return self._attrs.get("createdAt") + + @property + def url(self) -> str | None: + if ( + not self.client + or not self.entity + or not self.project + or not self.display_name + or not self.id + ): + return None + return self.client.app_url + "/".join( + [ + self.entity, + self.project, + "reports", + "--".join( + [ + # made this more closely match the url creation in the frontend (https://github.com/wandb/core/blob/76943979c8e967f7a62dae8bef0a001a2672584c/frontends/app/src/util/report/urls.ts#L19) + urllib.parse.quote( + re.sub( + r"-+", "-", re.sub(r"\W", "-", self.display_name) + ).strip("-") + ), + self.id.replace("=", ""), + ] + ), + ] + ) + + def to_html(self, height: int = 1024, hidden: bool = False) -> str: + """Generate HTML containing an iframe displaying this report.""" + url = self.url + if url is None: + return "
Report URL not available
" + url = url + "?jupyter=true" + style = f"border:none;width:100%;height:{height}px;" + prefix = "" + if hidden: + style += "display:none;" + prefix = ipython.toggle_button("report") + return prefix + f"" + + def _repr_html_(self) -> str: + return self.to_html() + + +class PythonMongoishQueryGenerator: + """Converts Python-style query expressions to MongoDB-style queries for W&B reports. + + + """ + + SPACER = "----------" + DECIMAL_SPACER = ";;;" + FRONTEND_NAME_MAPPING = { + "ID": "name", + "Name": "displayName", + "Tags": "tags", + "State": "state", + "CreatedTimestamp": "createdAt", + "Runtime": "duration", + "User": "username", + "Sweep": "sweep", + "Group": "group", + "JobType": "jobType", + "Hostname": "host", + "UsingArtifact": "inputArtifacts", + "OutputtingArtifact": "outputArtifacts", + "Step": "_step", + "Relative Time (Wall)": "_absolute_runtime", + "Relative Time (Process)": "_runtime", + "Wall Time": "_timestamp", + # "GroupedRuns": "__wb_group_by_all" + } + FRONTEND_NAME_MAPPING_REVERSED = {v: k for k, v in FRONTEND_NAME_MAPPING.items()} + AST_OPERATORS = { + ast.Lt: "$lt", + ast.LtE: "$lte", + ast.Gt: "$gt", + ast.GtE: "$gte", + ast.Eq: "=", + ast.Is: "=", + ast.NotEq: "$ne", + ast.IsNot: "$ne", + ast.In: "$in", + ast.NotIn: "$nin", + ast.And: "$and", + ast.Or: "$or", + ast.Not: "$not", + } + + AST_FIELDS = { + ast.Constant: "value", + ast.Name: "id", + ast.List: "elts", + ast.Tuple: "elts", + } + + def __init__(self, run_set): + self.run_set = run_set + self.panel_metrics_helper = PanelMetricsHelper() + + def _handle_compare(self, node): + # only left side can be a col + left = self.front_to_back(self._handle_fields(node.left)) + op = self._handle_ops(node.ops[0]) + right = self._handle_fields(node.comparators[0]) + + # Eq has no op for some reason + if op == "=": + return {left: right} + else: + return {left: {op: right}} + + def _handle_fields(self, node): + result = getattr(node, self.AST_FIELDS.get(type(node))) + if isinstance(result, list): + return [self._handle_fields(node) for node in result] + elif isinstance(result, str): + return self._unconvert(result) + return result + + def _handle_ops(self, node): + return self.AST_OPERATORS.get(type(node)) + + def _replace_numeric_dots(self, s): + numeric_dots = [] + for i, (left, mid, right) in enumerate(zip(s, s[1:], s[2:]), 1): + if mid == ".": + if ( + left.isdigit() + and right.isdigit() # 1.2 + or left.isdigit() + and right == " " # 1. + or left == " " + and right.isdigit() # .2 + ): + numeric_dots.append(i) + # Edge: Catch number ending in dot at end of string + if s[-2].isdigit() and s[-1] == ".": + numeric_dots.append(len(s) - 1) + numeric_dots = [-1] + numeric_dots + [len(s)] + + substrs = [] + for start, stop in zip(numeric_dots, numeric_dots[1:]): + substrs.append(s[start + 1 : stop]) + substrs.append(self.DECIMAL_SPACER) + substrs = substrs[:-1] + return "".join(substrs) + + def _convert(self, filterstr): + _conversion = ( + self._replace_numeric_dots(filterstr) # temporarily sub numeric dots + .replace(".", self.SPACER) # Allow dotted fields + .replace(self.DECIMAL_SPACER, ".") # add them back + ) + return "(" + _conversion + ")" + + def _unconvert(self, field_name): + return field_name.replace(self.SPACER, ".") # Allow dotted fields + + def python_to_mongo(self, filterstr): + """Convert Python expresion to MongoDB filter. + + + """ + try: + tree = ast.parse(self._convert(filterstr), mode="eval") + except SyntaxError as e: + raise ValueError( + "Invalid python comparison expression; form something like `my_col == 123`" + ) from e + + multiple_filters = hasattr(tree.body, "op") + + if multiple_filters: + op = self.AST_OPERATORS.get(type(tree.body.op)) + values = [self._handle_compare(v) for v in tree.body.values] + else: + op = "$and" + values = [self._handle_compare(tree.body)] + return {"$or": [{op: values}]} + + def front_to_back(self, name): + """Convert frontend metric names to backend field names. + + + """ + name, *rest = name.split(".") + rest = "." + ".".join(rest) if rest else "" + + if name in self.FRONTEND_NAME_MAPPING: + return self.FRONTEND_NAME_MAPPING[name] + elif name in self.FRONTEND_NAME_MAPPING_REVERSED: + return name + elif name in self.run_set._runs_config: + return f"config.{name}.value{rest}" + else: # assume summary metrics + return f"summary_metrics.{name}{rest}" + + def back_to_front(self, name): + """Convert backend field names to frontend metric names. + + + """ + if name in self.FRONTEND_NAME_MAPPING_REVERSED: + return self.FRONTEND_NAME_MAPPING_REVERSED[name] + elif name in self.FRONTEND_NAME_MAPPING: + return name + elif ( + name.startswith("config.") and ".value" in name + ): # may be brittle: originally "endswith", but that doesn't work with nested keys... + # strip is weird sometimes (??) + return name.replace("config.", "").replace(".value", "") + elif name.startswith("summary_metrics."): + return name.replace("summary_metrics.", "") + wandb.termerror(f"Unknown token: {name}") + return name + + # These are only used for ParallelCoordinatesPlot because it has weird backend names... + def pc_front_to_back(self, name): + """Convert ParallelCoordinatesPlot to backend field names. + + + """ + name, *rest = name.split(".") + rest = "." + ".".join(rest) if rest else "" + if name is None: + return None + elif name in self.panel_metrics_helper.FRONTEND_NAME_MAPPING: + return "summary:" + self.panel_metrics_helper.FRONTEND_NAME_MAPPING[name] + elif name in self.FRONTEND_NAME_MAPPING: + return self.FRONTEND_NAME_MAPPING[name] + elif name in self.FRONTEND_NAME_MAPPING_REVERSED: + return name + elif name in self.run_set._runs_config: + return f"config:{name}.value{rest}" + else: # assume summary metrics + return f"summary:{name}{rest}" + + def pc_back_to_front(self, name): + """Convert backend backend field names to ParallelCoordinatesPlot names. + + + """ + if name is None: + return None + elif "summary:" in name: + name = name.replace("summary:", "") + return self.panel_metrics_helper.FRONTEND_NAME_MAPPING_REVERSED.get( + name, name + ) + elif name in self.FRONTEND_NAME_MAPPING_REVERSED: + return self.FRONTEND_NAME_MAPPING_REVERSED[name] + elif name in self.FRONTEND_NAME_MAPPING: + return name + elif name.startswith("config:") and ".value" in name: + return name.replace("config:", "").replace(".value", "") + elif name.startswith("summary_metrics."): + return name.replace("summary_metrics.", "") + return name + + +class PanelMetricsHelper: + """Converts Python-style query expressions to MongoDB-style queries for W&B reports. + + + """ + + FRONTEND_NAME_MAPPING = { + "Step": "_step", + "Relative Time (Wall)": "_absolute_runtime", + "Relative Time (Process)": "_runtime", + "Wall Time": "_timestamp", + } + FRONTEND_NAME_MAPPING_REVERSED = {v: k for k, v in FRONTEND_NAME_MAPPING.items()} + + RUN_MAPPING = {"Created Timestamp": "createdAt", "Latest Timestamp": "heartbeatAt"} + RUN_MAPPING_REVERSED = {v: k for k, v in RUN_MAPPING.items()} + + def front_to_back(self, name): + """Convert frontend metric names to backend field names. + + + """ + if name in self.FRONTEND_NAME_MAPPING: + return self.FRONTEND_NAME_MAPPING[name] + return name + + def back_to_front(self, name): + """Convert backend field names to frontend metric names. + + + """ + if name in self.FRONTEND_NAME_MAPPING_REVERSED: + return self.FRONTEND_NAME_MAPPING_REVERSED[name] + return name + + # ScatterPlot and ParallelCoords have weird conventions + def special_front_to_back(self, name): + """Convert frontend metric names to backend field names. + + + """ + if name is None: + return name + + name, *rest = name.split(".") + rest = "." + ".".join(rest) if rest else "" + + # special case for config + if name.startswith("c::"): + name = name[3:] + return f"config:{name}.value{rest}" + + # special case for summary + if name.startswith("s::"): + name = name[3:] + rest + return f"summary:{name}" + + name = name + rest + if name in self.RUN_MAPPING: + return "run:" + self.RUN_MAPPING[name] + if name in self.FRONTEND_NAME_MAPPING: + return "summary:" + self.FRONTEND_NAME_MAPPING[name] + if name == "Index": + return name + return "summary:" + name + + def special_back_to_front(self, name): + """Convert backend field names to frontend metric names. + + + """ + if name is not None: + kind, rest = name.split(":", 1) + + if kind == "config": + pieces = rest.split(".") + if len(pieces) <= 1: + raise ValueError(f"Invalid name: {name}") + elif len(pieces) == 2: + name = pieces[0] + elif len(pieces) >= 3: + name = pieces[:1] + pieces[2:] + name = ".".join(name) + return f"c::{name}" + + elif kind == "summary": + name = rest + return f"s::{name}" + + if name is None: + return name + elif "summary:" in name: + name = name.replace("summary:", "") + return self.FRONTEND_NAME_MAPPING_REVERSED.get(name, name) + elif "run:" in name: + name = name.replace("run:", "") + return self.RUN_MAPPING_REVERSED[name] + return name diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/runs.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/runs.py new file mode 100644 index 0000000000000000000000000000000000000000..849b17da927f5e18c38c649e161ba2b155555525 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/runs.py @@ -0,0 +1,1654 @@ +"""W&B Public API for Runs. + +This module provides classes for interacting with W&B runs and their associated +data. + +Example: +```python +from wandb.apis.public import Api + +# Get runs matching filters +runs = Api().runs( + path="entity/project", filters={"state": "finished", "config.batch_size": 32} +) + +# Access run data +for run in runs: + print(f"Run: {run.name}") + print(f"Config: {run.config}") + print(f"Metrics: {run.summary}") + + # Get history with pandas + history_df = run.history(keys=["loss", "accuracy"], pandas=True) + + # Work with artifacts + for artifact in run.logged_artifacts(): + print(f"Artifact: {artifact.name}") +``` + +Note: + This module is part of the W&B Public API and provides read/write access + to run data. For logging new runs, use the wandb.init() function from + the main wandb package. +""" + +from __future__ import annotations + +import json +import os +import pathlib +import tempfile +import time +import urllib +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Collection, Iterator, Literal, Mapping + +from wandb_gql import gql + +import wandb +from wandb import env, util +from wandb._strutils import nameof +from wandb.apis import public +from wandb.apis.attrs import Attrs +from wandb.apis.internal import Api as InternalApi +from wandb.apis.normalize import normalize_exceptions +from wandb.apis.paginator import SizedPaginator +from wandb.apis.public.const import RETRY_TIMEDELTA +from wandb.proto import wandb_api_pb2 as apb +from wandb.sdk.lib import ipython, json_util, runid +from wandb.sdk.lib.paths import LogicalPath +from wandb.sdk.lib.service.service_connection import WandbApiFailedError + +if TYPE_CHECKING: + import pandas as pd + import polars as pl + from typing_extensions import Self + from wandb_graphql.language.ast import Document + + from wandb.apis.public import RetryingClient + from wandb.old.summary import HTTPSummary + +WANDB_INTERNAL_KEYS = {"_wandb", "wandb_version"} + +RUN_FRAGMENT = """fragment RunFragment on Run { + id + tags + name + displayName + sweepName + state + config + group + jobType + commit + readOnly + createdAt + heartbeatAt + description + notes + systemMetrics + summaryMetrics + historyLineCount + user { + name + username + } + historyKeys +}""" + +# Lightweight fragment for listing operations - excludes heavy fields +LIGHTWEIGHT_RUN_FRAGMENT = """fragment LightweightRunFragment on Run { + id + tags + name + displayName + sweepName + state + group + jobType + commit + readOnly + createdAt + heartbeatAt + description + notes + historyLineCount + user { + name + username + } +}""" + +# Fragment name constants to avoid string parsing +RUN_FRAGMENT_NAME = "RunFragment" +LIGHTWEIGHT_RUN_FRAGMENT_NAME = "LightweightRunFragment" + + +class IncompleteRunHistoryError(Exception): + """Raised when run history has incomplete history. + + Incomplete history occurs when there is some data + that has not been exported to parquet files yet. + Typically due to an on-going run. + """ + + +@dataclass(frozen=True) +class DownloadHistoryResult: + """Result of downloading a run's history exports. + + Attributes: + paths: The paths to the downloaded history files. + errors: A dictionary of errors that occurred while downloading the history files. + contains_live_data: Whether the run contains live data, + not yet exported to parquet files:w. + """ + + paths: list[pathlib.Path] + contains_live_data: bool + errors: dict[pathlib.Path, str] | None = None + + +def _create_runs_query( + *, lazy: bool, with_internal_id: bool, with_project_id: bool +) -> gql: + """Create GraphQL query for runs with appropriate fragment.""" + fragment = LIGHTWEIGHT_RUN_FRAGMENT if lazy else RUN_FRAGMENT + fragment_name = LIGHTWEIGHT_RUN_FRAGMENT_NAME if lazy else RUN_FRAGMENT_NAME + + return gql( + f"""#graphql + query Runs($project: String!, $entity: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{ + project(name: $project, entityName: $entity) {{ + {"internalId" if with_internal_id else ""} + runCount(filters: $filters) + readOnly + runs(filters: $filters, after: $cursor, first: $perPage, order: $order) {{ + edges {{ + node {{ + {"projectId" if with_project_id else ""} + ...{fragment_name} + }} + cursor + }} + pageInfo {{ + endCursor + hasNextPage + }} + }} + }} + }} + {fragment} + """ + ) + + +@normalize_exceptions +def _server_provides_internal_id_for_project(client: RetryingClient) -> bool: + """Returns True if the server allows us to query the internalId field for a project.""" + query_string = """ + query ProbeProjectInput { + ProjectType: __type(name:"Project") { + fields { + name + } + } + } + """ + + # Only perform the query once to avoid extra network calls + query = gql(query_string) + res = client.execute(query) + return "internalId" in [ + x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}])) + ] + + +@normalize_exceptions +def _server_provides_project_id_for_run(client: RetryingClient) -> bool: + """Returns True if the server allows us to query the projectId field for a run.""" + query_string = """ + query ProbeRunInput { + RunType: __type(name:"Run") { + fields { + name + } + } + } + """ + + # Only perform the query once to avoid extra network calls + query = gql(query_string) + res = client.execute(query) + return "projectId" in [ + x["name"] for x in (res.get("RunType", {}).get("fields", [{}])) + ] + + +@normalize_exceptions +def _convert_to_dict(value: Any) -> dict[str, Any]: + """Converts a value to a dictionary. + + If the value is already a dictionary, the value is returned unchanged. + If the value is a string, bytes, or bytearray, it is parsed as JSON. + For any other type, a TypeError is raised. + """ + if value is None: + return {} + + if isinstance(value, dict): + return value + + if isinstance(value, (str, bytes, bytearray)): + try: + return json.loads(value) + except json.decoder.JSONDecodeError: + # ignore invalid utf-8 or control characters + return json.loads(value, strict=False) + + raise TypeError(f"Unable to convert {value} to a dict") + + +class Runs(SizedPaginator["Run"]): + """A lazy iterator of `Run` objects associated with a project and optional filter. + + Runs are retrieved in pages from the W&B server as needed. + + This is generally used indirectly using the `Api.runs` namespace. + + Args: + client: (`wandb.apis.public.RetryingClient`) The API client to use + for requests. + entity: (str) The entity (username or team) that owns the project. + project: (str) The name of the project to fetch runs from. + filters: (Optional[Dict[str, Any]]) A dictionary of filters to apply + to the runs query. + order: (str) Order can be `created_at`, `heartbeat_at`, `config.*.value`, or `summary_metrics.*`. + If you prepend order with a + order is ascending (default). + If you prepend order with a - order is descending. + The default order is run.created_at from oldest to newest. + per_page: (int) The number of runs to fetch per request (default is 50). + include_sweeps: (bool) Whether to include sweep information in the + runs. Defaults to True. + """ + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + filters: dict[str, Any] | None = None, + order: str = "+created_at", + per_page: int = 50, + include_sweeps: bool = True, + lazy: bool = True, + api: public.Api | None = None, + ): + if not order: + order = "+created_at" + + self.QUERY = _create_runs_query( + lazy=lazy, + with_internal_id=_server_provides_internal_id_for_project(client), + with_project_id=_server_provides_project_id_for_run(client), + ) + + self.entity = entity + self.project = project + self._project_internal_id = None + self.filters = filters or {} + self.order = order + self._sweeps: dict[str, public.Sweep] = {} + self._include_sweeps = include_sweeps + self._lazy = lazy + self._api = api + variables = { + "project": self.project, + "entity": self.entity, + "order": self.order, + "filters": json.dumps(self.filters), + } + super().__init__(client, variables, per_page) + + @property + def _length(self) -> int: + """Returns the total number of runs. + + + """ + if not self.last_response: + self._load_page() + return self.last_response["project"]["runCount"] + + @property + def more(self) -> bool: + """Returns whether there are more runs to fetch. + + + """ + if self.last_response: + return bool( + self.last_response["project"]["runs"]["pageInfo"]["hasNextPage"] + ) + else: + return True + + @property + def cursor(self): + """Returns the cursor position for pagination of runs results. + + + """ + if self.last_response: + return self.last_response["project"]["runs"]["edges"][-1]["cursor"] + else: + return None + + def convert_objects(self) -> list[Run]: + """Converts GraphQL edges to Runs objects. + + + """ + objs = [] + if self.last_response is None or self.last_response.get("project") is None: + raise ValueError("Could not find project {}".format(self.project)) + for run_response in self.last_response["project"]["runs"]["edges"]: + run = Run( + self.client, + self.entity, + self.project, + run_response["node"]["name"], + run_response["node"], + include_sweeps=self._include_sweeps, + lazy=self._lazy, + api=self._api, + ) + objs.append(run) + + if self._include_sweeps and run.sweep_name: + if run.sweep_name in self._sweeps: + sweep = self._sweeps[run.sweep_name] + else: + sweep = public.Sweep.get( + self.client, + self.entity, + self.project, + run.sweep_name, + withRuns=False, + ) + self._sweeps[run.sweep_name] = sweep + + if sweep is None: + continue + run.sweep = sweep + + return objs + + @normalize_exceptions + def histories( + self, + samples: int = 500, + keys: list[str] | None = None, + x_axis: str = "_step", + format: Literal["default", "pandas", "polars"] = "default", + stream: Literal["default", "system"] = "default", + ) -> list[dict[str, Any]] | pd.DataFrame | pl.DataFrame: + """Return sampled history metrics for all runs that fit the filters conditions. + + Args: + samples: The number of samples to return per run + keys: Only return metrics for specific keys + x_axis: Use this metric as the xAxis defaults to _step + format: Format to return data in, options are "default", "pandas", + "polars" + stream: "default" for metrics, "system" for machine metrics + Returns: + pandas.DataFrame: If `format="pandas"`, returns a `pandas.DataFrame` + of history metrics. + polars.DataFrame: If `format="polars"`, returns a `polars.DataFrame` + of history metrics. + list of dicts: If `format="default"`, returns a list of dicts + containing history metrics with a `run_id` key. + """ + if format not in ("default", "pandas", "polars"): + raise ValueError( + f"Invalid format: {format}. Must be one of 'default', 'pandas', 'polars'" + ) + + histories = [] + + if format == "default": + for run in self: + history_data = run.history( + samples=samples, + keys=keys, + x_axis=x_axis, + pandas=False, + stream=stream, + ) + if not history_data: + continue + for entry in history_data: + entry["run_id"] = run.id + histories.extend(history_data) + + return histories + + if format == "pandas": + pd = util.get_module( + "pandas", required="Exporting pandas DataFrame requires pandas" + ) + for run in self: + history_data = run.history( + samples=samples, + keys=keys, + x_axis=x_axis, + pandas=False, + stream=stream, + ) + if not history_data: + continue + df = pd.DataFrame.from_records(history_data) + df["run_id"] = run.id + histories.append(df) + if not histories: + return pd.DataFrame() + combined_df = pd.concat(histories) + combined_df.reset_index(drop=True, inplace=True) + # sort columns for consistency + combined_df = combined_df[(sorted(combined_df.columns))] + + return combined_df + + if format == "polars": + pl = util.get_module( + "polars", required="Exporting polars DataFrame requires polars" + ) + for run in self: + history_data = run.history( + samples=samples, + keys=keys, + x_axis=x_axis, + pandas=False, + stream=stream, + ) + if not history_data: + continue + df = pl.from_records(history_data) + df = df.with_columns(pl.lit(run.id).alias("run_id")) + histories.append(df) + if not histories: + return pl.DataFrame() + combined_df = pl.concat(histories, how="vertical") + # sort columns for consistency + combined_df = combined_df.select(sorted(combined_df.columns)) + + return combined_df + + def __repr__(self) -> str: + return f"<{nameof(type(self))} {self.entity}/{self.project}>" + + def upgrade_to_full(self) -> None: + """Upgrade this Runs collection from lazy to full mode. + + This switches to fetching full run data and + upgrades any already-loaded Run objects to have full data. + Uses parallel loading for better performance when upgrading multiple runs. + """ + if not self._lazy: + return # Already in full mode + + # Switch to full mode + self._lazy = False + + # Regenerate query with full fragment + self.QUERY = _create_runs_query( + lazy=False, + with_internal_id=_server_provides_internal_id_for_project(self.client), + with_project_id=_server_provides_project_id_for_run(self.client), + ) + + # Upgrade any existing runs that have been loaded - use parallel loading for performance + lazy_runs = [run for run in self.objects if run._lazy] + if lazy_runs: + from concurrent.futures import ThreadPoolExecutor + + # Limit workers to avoid overwhelming the server + max_workers = min(len(lazy_runs), 10) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(run.load_full_data) for run in lazy_runs] + # Wait for all to complete + for future in futures: + future.result() + + +class Run(Attrs): + """A single run associated with an entity and project. + + Args: + client: The W&B API client. + entity: The entity associated with the run. + project: The project associated with the run. + run_id: The unique identifier for the run. + attrs: The attributes of the run. + include_sweeps: Whether to include sweeps in the run. + + Attributes: + tags ([str]): a list of tags associated with the run + url (str): the url of this run + id (str): unique identifier for the run (defaults to eight characters) + name (str): the name of the run + state (str): one of: running, finished, crashed, killed, preempting, preempted + config (dict): a dict of hyperparameters associated with the run + created_at (str): ISO timestamp when the run was started + system_metrics (dict): the latest system metrics recorded for the run + summary (dict): A mutable dict-like property that holds the current summary. + Calling update will persist any changes. + project (str): the project associated with the run + entity (str): the name of the entity associated with the run + project_internal_id (int): the internal id of the project + user (str): the name of the user who created the run + path (str): Unique identifier [entity]/[project]/[run_id] + notes (str): Notes about the run + read_only (boolean): Whether the run is editable + history_keys (str): History metric keys logged with `wandb.Run.log({"key": "value"})` + metadata (str): Metadata about the run from wandb-metadata.json + """ + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + run_id: str, + attrs: Mapping | None = None, + include_sweeps: bool = True, + lazy: bool = True, + api: public.Api | None = None, + ): + """Initialize a Run object. + + Run is always initialized by calling api.runs() where api is an instance of + wandb.Api. + """ + _attrs = attrs or {} + super().__init__(dict(_attrs)) + self.client = client + self._entity = entity + self.project = project + self._files = {} + self._base_dir = env.get_dir(tempfile.gettempdir()) + self.id = run_id + self.sweep = None + self._include_sweeps = include_sweeps + self._lazy = lazy + self._full_data_loaded = False # Track if we've loaded full data + self.dir = os.path.join(self._base_dir, *self.path) + try: + os.makedirs(self.dir) + except OSError: + pass + self._summary = None + self._metadata: dict[str, Any] | None = None + self._state = _attrs.get("state", "not found") + self.server_provides_internal_id_field: bool | None = None + self._server_provides_project_id_field: bool | None = None + self._is_loaded: bool = False + self._api: public.Api | None = api + + self.load(force=not _attrs) + + @property + def state(self) -> str: + """The state of the run. Can be one of: Finished, Failed, Crashed, or Running.""" + return self._state + + @property + def entity(self) -> str: + """The entity associated with the run.""" + return self._entity + + @property + def username(self) -> str: + """This API is deprecated. Use `entity` instead.""" + wandb.termwarn("Run.username is deprecated. Please use Run.entity instead.") + return self._entity + + @property + def storage_id(self) -> str: + """The unique storage identifier for the run.""" + # For compatibility with wandb.Run, which has storage IDs + # in self.storage_id and names in self.id. + + return self._attrs.get("id") + + @property + def id(self) -> str: + """The unique identifier for the run.""" + return self._attrs.get("name") + + @id.setter + def id(self, new_id: str) -> None: + """Set the unique identifier for the run.""" + self._attrs["name"] = new_id + + @property + def name(self) -> str | None: + """The name of the run.""" + return self._attrs.get("displayName") + + @name.setter + def name(self, new_name: str) -> None: + """Set the name of the run.""" + self._attrs["displayName"] = new_name + + @classmethod + def create( + cls, + api: public.Api, + run_id: str | None = None, + project: str | None = None, + entity: str | None = None, + state: Literal["running", "pending"] = "running", + ) -> Self: + """Create a run for the given project.""" + api._sentry.message("Invoking Run.create", level="info") + run_id = run_id or runid.generate_id() + project = project or api.settings.get("project") or "uncategorized" + mutation = gql( + """ + mutation UpsertBucket($project: String, $entity: String, $name: String!, $state: String) { + upsertBucket(input: {modelName: $project, entityName: $entity, name: $name, state: $state}) { + bucket { + project { + name + entity { name } + } + id + name + } + inserted + } + } + """ + ) + variables = { + "entity": entity, + "project": project, + "name": run_id, + "state": state, + } + res = api.client.execute(mutation, variable_values=variables) + res = res["upsertBucket"]["bucket"] + return cls( + api.client, + res["project"]["entity"]["name"], + res["project"]["name"], + res["name"], + { + "id": res["id"], + "config": "{}", + "systemMetrics": "{}", + "summaryMetrics": "{}", + "tags": [], + "description": None, + "notes": None, + "state": state, + }, + lazy=False, # Created runs should have full data available immediately + ) + + def _load_with_fragment( + self, fragment: str, fragment_name: str, force: bool = False + ) -> dict[str, Any]: + """Load run data using specified GraphQL fragment.""" + # Cache the server capability check to avoid repeated network calls + if self._server_provides_project_id_field is None: + self._server_provides_project_id_field = ( + _server_provides_project_id_for_run(self.client) + ) + + query = gql( + f""" + query Run($project: String!, $entity: String!, $name: String!) {{ + project(name: $project, entityName: $entity) {{ + run(name: $name) {{ + {"projectId" if self._server_provides_project_id_field else ""} + ...{fragment_name} + }} + }} + }} + {fragment} + """ + ) + + if force or not self._attrs: + response = self._exec(query) + if ( + response is None + or response.get("project") is None + or response["project"].get("run") is None + ): + raise ValueError("Could not find run {}".format(self)) + self._attrs = response["project"]["run"] + + self._state = self._attrs["state"] + if self._attrs.get("user"): + self.user = public.User(self.client, self._attrs["user"]) + + if self._include_sweeps and self.sweep_name and not self.sweep: + # There may be a lot of runs. Don't bother pulling them all + # just for the sake of this one. + self.sweep = public.Sweep.get( + self.client, + self.entity, + self.project, + self.sweep_name, + withRuns=False, + ) + + if not self._is_loaded or force: + # Always set _project_internal_id if projectId is available, regardless of fragment type + if "projectId" in self._attrs: + self._project_internal_id = int(self._attrs["projectId"]) + else: + self._project_internal_id = None + + # Always call _load_from_attrs when using the full fragment or when the fields are actually present + if fragment_name == RUN_FRAGMENT_NAME or ( + "config" in self._attrs + or "summaryMetrics" in self._attrs + or "systemMetrics" in self._attrs + ): + self._load_from_attrs() + + # Only mark as loaded for lightweight fragments, not full fragments + if fragment_name == LIGHTWEIGHT_RUN_FRAGMENT_NAME: + self._is_loaded = True + + return self._attrs + + def _load_from_attrs(self) -> dict[str, Any]: + self._state = self._attrs.get("state", None) + + # Only convert fields if they exist in _attrs + if "config" in self._attrs: + self._attrs["config"] = _convert_to_dict(self._attrs.get("config")) + if "summaryMetrics" in self._attrs: + self._attrs["summaryMetrics"] = _convert_to_dict( + self._attrs.get("summaryMetrics") + ) + if "systemMetrics" in self._attrs: + self._attrs["systemMetrics"] = _convert_to_dict( + self._attrs.get("systemMetrics") + ) + + # Only check for sweeps if sweep_name is available (not in lazy mode or if it exists) + if self._include_sweeps and self._attrs.get("sweepName") and not self.sweep: + # There may be a lot of runs. Don't bother pulling them all + self.sweep = public.Sweep.get( + self.client, + self.entity, + self.project, + self._attrs["sweepName"], + withRuns=False, + ) + + config_user, config_raw = {}, {} + if self._attrs.get("config"): + try: + # config is already converted to dict by _convert_to_dict + for key, value in self._attrs.get("config", {}).items(): + config = config_raw if key in WANDB_INTERNAL_KEYS else config_user + if isinstance(value, dict) and "value" in value: + config[key] = value["value"] + else: + config[key] = value + except (TypeError, AttributeError): + # Handle case where config is malformed or not a dict + pass + + config_raw.update(config_user) + self._attrs["config"] = config_user + self._attrs["rawconfig"] = config_raw + + return self._attrs + + def load(self, force: bool = False) -> dict[str, Any]: + """Load run data using appropriate fragment based on lazy mode.""" + if self._lazy: + return self._load_with_fragment( + LIGHTWEIGHT_RUN_FRAGMENT, LIGHTWEIGHT_RUN_FRAGMENT_NAME, force + ) + else: + return self._load_with_fragment(RUN_FRAGMENT, RUN_FRAGMENT_NAME, force) + + @normalize_exceptions + def wait_until_finished(self) -> None: + """Check the state of the run until it is finished.""" + query = gql( + """ + query RunState($project: String!, $entity: String!, $name: String!) { + project(name: $project, entityName: $entity) { + run(name: $name) { + state + } + } + } + """ + ) + while True: + res = self._exec(query) + state = res["project"]["run"]["state"] + if state in ["finished", "crashed", "failed"]: + self._attrs["state"] = state + self._state = state + return + time.sleep(5) + + @normalize_exceptions + def update(self) -> None: + """Persist changes to the run object to the wandb backend.""" + mutation = gql( + """ + mutation UpsertBucket($id: String!, $description: String, $display_name: String, $notes: String, $tags: [String!], $config: JSONString!, $groupName: String, $jobType: String) {{ + upsertBucket(input: {{id: $id, description: $description, displayName: $display_name, notes: $notes, tags: $tags, config: $config, groupName: $groupName, jobType: $jobType}}) {{ + bucket {{ + ...RunFragment + }} + }} + }} + {} + """.format(RUN_FRAGMENT) + ) + _ = self._exec( + mutation, + id=self.storage_id, + tags=self.tags, + description=self.description, + notes=self.notes, + display_name=self.display_name, + config=self.json_config, + groupName=self.group, + jobType=self.job_type, + ) + self.summary.update() + + @normalize_exceptions + def delete(self, delete_artifacts: bool = False) -> None: + """Delete the given run from the wandb backend. + + Args: + delete_artifacts (bool, optional): Whether to delete the artifacts + associated with the run. + """ + mutation = gql( + """ + mutation DeleteRun( + $id: ID!, + {} + ) {{ + deleteRun(input: {{ + id: $id, + {} + }}) {{ + clientMutationId + }} + }} + """.format( + "$deleteArtifacts: Boolean" if delete_artifacts else "", + "deleteArtifacts: $deleteArtifacts" if delete_artifacts else "", + ) + ) + + self.client.execute( + mutation, + variable_values={ + "id": self.storage_id, + "deleteArtifacts": delete_artifacts, + }, + ) + + def save(self) -> None: + """Persist changes to the run object to the W&B backend.""" + self.update() + + @normalize_exceptions + def update_state(self, state: Literal["pending"]) -> bool: + """Update the state of a run. + + Allows transitioning runs from 'failed' or 'crashed' to 'pending'. + + Args: + state: The target run state. Only `"pending"` is supported. + + Returns: + `True` if the state was successfully updated. + + Raises: + `wandb.Error`: If the requested state transition is not allowed, or the server + does not support this operation. + """ + mutation = gql( + """ + mutation UpdateRunState($input: UpdateRunStateInput!) { + updateRunState(input: $input) { + success + } + } + """ + ) + + try: + result = self.client.execute( + mutation, + variable_values={ + "input": { + "id": self.storage_id, + "state": state, + } + }, + ) + except Exception as e: + error_msg = str(e) + if "UpdateRunStateInput" in error_msg or "updateRunState" in error_msg: + raise wandb.Error( + "The server does not support the update_state operation. " + "Please ensure your W&B server is updated to a version that " + "supports run state transitions." + ) from e + if "invalid state transition" in error_msg.lower(): + raise wandb.Error( + f"Invalid state transition: cannot change run from '{self.state}' " + f"to '{state}'. Only runs in 'failed' or 'crashed' state can be " + "transitioned to 'pending'." + ) from e + raise + + if result.get("updateRunState", {}).get("success"): + self._attrs["state"] = state + self._state = state + return True + return False + + @property + def json_config(self) -> str: + """Return the run config as a JSON string. + + + """ + config = {} + if "_wandb" in self.rawconfig: + config["_wandb"] = {"value": self.rawconfig["_wandb"], "desc": None} + for k, v in self.config.items(): + config[k] = {"value": v, "desc": None} + return json.dumps(config) + + def _exec(self, query: Document, **kwargs: Any) -> dict[str, Any]: + """Execute a query against the cloud backend.""" + variables = {"entity": self.entity, "project": self.project, "name": self.id} + variables.update(kwargs) + return self.client.execute(query, variable_values=variables) + + def _sampled_history( + self, + keys: list[str], + x_axis: str = "_step", + samples: int = 500, + ) -> list[dict[str, Any]]: + spec = {"keys": [x_axis] + keys, "samples": samples} + query = gql( + """ + query RunSampledHistory($project: String!, $entity: String!, $name: String!, $specs: [JSONString!]!) { + project(name: $project, entityName: $entity) { + run(name: $name) { sampledHistory(specs: $specs) } + } + } + """ + ) + + response = self._exec(query, specs=[json.dumps(spec)]) + # sampledHistory returns one list per spec, we only send one spec + return response["project"]["run"]["sampledHistory"][0] + + def _full_history( + self, + samples: int = 500, + stream: Literal["default", "system"] = "default", + ) -> list[dict[str, Any]]: + node = "history" if stream == "default" else "events" + query = gql( + """ + query RunFullHistory($project: String!, $entity: String!, $name: String!, $samples: Int) {{ + project(name: $project, entityName: $entity) {{ + run(name: $name) {{ {}(samples: $samples) }} + }} + }} + """.format(node) + ) + + response = self._exec(query, samples=samples) + return [json.loads(line) for line in response["project"]["run"][node]] + + @normalize_exceptions + def files( + self, + names: list[str] | None = None, + pattern: str | None = None, + per_page: int = 50, + ) -> public.Files: + """Returns a `Files` object for all files in the run which match the given criteria. + + You can specify a list of exact file names to match, or a pattern to match against. + If both are provided, the pattern will be ignored. + + Args: + names (list): names of the requested files, if empty returns all files + pattern (str, optional): Pattern to match when returning files from W&B. + This pattern uses mySQL's LIKE syntax, + so matching all files that end with .json would be "%.json". + If both names and pattern are provided, a ValueError will be raised. + per_page (int): number of results per page. + + Returns: + A `Files` object, which is an iterator over `File` objects. + """ + return public.Files( + self.client, + self, + names or [], + pattern=pattern, + per_page=per_page, + ) + + @normalize_exceptions + def file(self, name: str) -> public.File: + """Return the path of a file with a given name in the artifact. + + Args: + name (str): name of requested file. + + Returns: + A `File` matching the name argument. + """ + return public.Files(self.client, self, [name])[0] + + @normalize_exceptions + def upload_file(self, path: str, root: str = ".") -> public.File: + """Upload a local file to W&B, associating it with this run. + + Args: + path (str): Path to the file to upload. Can be absolute or relative. + root (str): The root path to save the file relative to. For example, + if you want to have the file saved in the run as "my_dir/file.txt" + and you're currently in "my_dir" you would set root to "../". + Defaults to current directory ("."). + + Returns: + A `File` object representing the uploaded file. + """ + api = InternalApi( + default_settings={"entity": self.entity, "project": self.project}, + retry_timedelta=RETRY_TIMEDELTA, + ) + api.set_current_run_id(self.id) + root = os.path.abspath(root) + name = os.path.relpath(path, root) + upload_path = util.make_file_path_upload_safe(name) + with open(os.path.join(root, name), "rb") as f: + api.push({LogicalPath(upload_path): f}) + return public.Files(self.client, self, [name])[0] + + @normalize_exceptions + def history( + self, + samples: int = 500, + keys: list[str] | None = None, + x_axis: str = "_step", + pandas: bool = True, + stream: Literal["default", "system"] = "default", + ) -> list[dict[str, Any]] | pd.DataFrame: + """Return sampled history metrics for a run. + + This is simpler and faster if you are ok with the history records being sampled. + + Args: + samples : (int, optional) The number of samples to return + pandas : (bool, optional) Return a pandas dataframe + keys : (list, optional) Only return metrics for specific keys + x_axis : (str, optional) Use this metric as the xAxis defaults to _step + stream : (str, optional) "default" for metrics, "system" for machine metrics + + Returns: + pandas.DataFrame: If pandas=True returns a `pandas.DataFrame` of history + metrics. + list of dicts: If pandas=False returns a list of dicts of history metrics. + """ + if keys is not None and not isinstance(keys, list): + wandb.termerror("keys must be specified in a list") + return [] + if keys is not None and len(keys) > 0 and not isinstance(keys[0], str): + wandb.termerror("keys argument must be a list of strings") + return [] + + if keys and stream != "default": + wandb.termerror("stream must be default when specifying keys") + return [] + elif keys: + lines = self._sampled_history(keys=keys, x_axis=x_axis, samples=samples) + else: + lines = self._full_history(samples=samples, stream=stream) + if pandas: + pd = util.get_module("pandas") + if pd: + lines = pd.DataFrame.from_records(lines) + else: + wandb.termwarn("Unable to load pandas, call history with pandas=False") + return lines + + @normalize_exceptions + def scan_history( + self, + keys: list[str] | None = None, + page_size: int = 1_000, + min_step: int | None = None, + max_step: int | None = None, + ) -> Iterator[dict[str, Any]]: + """Returns an iterable collection of all history records for a run. + + Args: + keys ([str], optional): only fetch these keys, and only fetch rows that have all of keys defined. + page_size (int, optional): size of pages to fetch from the api. + min_step (int, optional): the minimum number of pages to scan at a time. + max_step (int, optional): the maximum number of pages to scan at a time. + + Returns: + An iterable collection over history records (dict). + + Example: + Export all the loss values for an example run + + ```python + run = api.run("entity/project-name/run-id") + history = run.scan_history(keys=["Loss"]) + losses = [row["Loss"] for row in history] + ``` + """ + if keys is not None and not isinstance(keys, list): + wandb.termerror("keys must be specified in a list") + return [] + if keys is not None and len(keys) > 0 and not isinstance(keys[0], str): + wandb.termerror("keys argument must be a list of strings") + return [] + + last_step = self.lastHistoryStep + # set defaults for min/max step + if min_step is None: + min_step = 0 + if max_step is None: + max_step = last_step + 1 + # if the max step is past the actual last step, clamp it down + if max_step > last_step: + max_step = last_step + 1 + if keys is None: + return public.HistoryScan( + run=self, + client=self.client, + page_size=page_size, + min_step=min_step, + max_step=max_step, + ) + else: + return public.SampledHistoryScan( + run=self, + client=self.client, + keys=keys, + page_size=page_size, + min_step=min_step, + max_step=max_step, + ) + + @normalize_exceptions + def logged_artifacts(self, per_page: int = 100) -> public.RunArtifacts: + """Fetches all artifacts logged by this run. + + Retrieves all output artifacts that were logged during the run. Returns a + paginated result that can be iterated over or collected into a single list. + + Args: + per_page: Number of artifacts to fetch per API request. + + Returns: + An iterable collection of all Artifact objects logged as outputs during this run. + + Example: + ```python + import wandb + import tempfile + + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as tmp: + tmp.write("This is a test artifact") + tmp_path = tmp.name + run = wandb.init(project="artifact-example") + artifact = wandb.Artifact("test_artifact", type="dataset") + artifact.add_file(tmp_path) + run.log_artifact(artifact) + run.finish() + + api = wandb.Api() + + finished_run = api.run(f"{run.entity}/{run.project}/{run.id}") + + for logged_artifact in finished_run.logged_artifacts(): + print(logged_artifact.name) + ``` + + """ + return public.RunArtifacts(self.client, self, mode="logged", per_page=per_page) + + @normalize_exceptions + def used_artifacts(self, per_page: int = 100) -> public.RunArtifacts: + """Fetches artifacts explicitly used by this run. + + Retrieves only the input artifacts that were explicitly declared as used + during the run, typically via `run.use_artifact()`. Returns a paginated + result that can be iterated over or collected into a single list. + + Args: + per_page: Number of artifacts to fetch per API request. + + Returns: + An iterable collection of Artifact objects explicitly used as inputs in this run. + + Example: + ```python + import wandb + + run = wandb.init(project="artifact-example") + run.use_artifact("test_artifact:latest") + run.finish() + + api = wandb.Api() + finished_run = api.run(f"{run.entity}/{run.project}/{run.id}") + for used_artifact in finished_run.used_artifacts(): + print(used_artifact.name) + test_artifact + ``` + """ + return public.RunArtifacts(self.client, self, mode="used", per_page=per_page) + + @normalize_exceptions + def use_artifact( + self, + artifact: wandb.Artifact, + use_as: str | None = None, + ) -> wandb.Artifact: + """Declare an artifact as an input to a run. + + Args: + artifact (`Artifact`): An artifact returned from + `wandb.Api().artifact(name)` + use_as (string, optional): A string identifying + how the artifact is used in the script. Used + to easily differentiate artifacts used in a + run, when using the beta wandb launch + feature's artifact swapping functionality. + + Returns: + An `Artifact` object. + """ + api = InternalApi( + default_settings={"entity": self.entity, "project": self.project}, + retry_timedelta=RETRY_TIMEDELTA, + ) + api.set_current_run_id(self.id) + + if isinstance(artifact, wandb.Artifact) and not artifact.is_draft(): + api.use_artifact( + artifact.id, + use_as=use_as or artifact.name, + artifact_entity_name=artifact.entity, + artifact_project_name=artifact.project, + ) + return artifact + elif isinstance(artifact, wandb.Artifact) and artifact.is_draft(): + raise ValueError( + "Only existing artifacts are accepted by this api. " + "Manually create one with `wandb artifact put`" + ) + else: + raise ValueError("You must pass a wandb.Api().artifact() to use_artifact") + + @normalize_exceptions + def log_artifact( + self, + artifact: wandb.Artifact, + aliases: Collection[str] | None = None, + tags: Collection[str] | None = None, + ) -> wandb.Artifact: + """Declare an artifact as output of a run. + + Args: + artifact (`Artifact`): An artifact returned from + `wandb.Api().artifact(name)`. + aliases (list, optional): Aliases to apply to this artifact. + tags: (list, optional) Tags to apply to this artifact, if any. + + Returns: + A `Artifact` object. + """ + api = InternalApi( + default_settings={"entity": self.entity, "project": self.project}, + retry_timedelta=RETRY_TIMEDELTA, + ) + api.set_current_run_id(self.id) + + if not isinstance(artifact, wandb.Artifact): + raise TypeError("You must pass a wandb.Api().artifact() to use_artifact") + if artifact.is_draft(): + raise ValueError( + "Only existing artifacts are accepted by this api. " + "Manually create one with `wandb artifact put`" + ) + if ( + self.entity != artifact.source_entity + or self.project != artifact.source_project + ): + raise ValueError("A run can't log an artifact to a different project.") + + artifact_collection_name = artifact.source_name.split(":")[0] + api.create_artifact( + artifact.type, + artifact_collection_name, + artifact.digest, + entity_name=self.entity, + project_name=self.project, + aliases=aliases, + tags=tags, + ) + return artifact + + def load_full_data(self, force: bool = False) -> dict[str, Any]: + """Load full run data including heavy fields like config, systemMetrics, summaryMetrics. + + This method is useful when you initially used lazy=True for listing runs, + but need access to the full data for specific runs. + + Args: + force: Force reload even if data is already loaded + + Returns: + The loaded run attributes + """ + if not self._lazy and not force: + # Already in full mode, no need to reload + return self._attrs + + # Load full data and mark as loaded + result = self._load_with_fragment(RUN_FRAGMENT, RUN_FRAGMENT_NAME, force=True) + self._full_data_loaded = True + return result + + @property + def config(self) -> dict[str, Any]: + """Get run config. Auto-loads full data if in lazy mode.""" + if self._lazy and not self._full_data_loaded and "config" not in self._attrs: + self.load_full_data() + + # Ensure config is always converted to dict (defensive against conversion issues) + config_value = self._attrs.get("config", {}) + # _convert_to_dict handles dict inputs (noop) and converts str/bytes/bytearray to dict + config_value = _convert_to_dict(config_value) + self._attrs["config"] = config_value + return config_value + + @property + def summary(self) -> HTTPSummary: + """Get run summary metrics. Auto-loads full data if in lazy mode.""" + if ( + self._lazy + and not self._full_data_loaded + and "summaryMetrics" not in self._attrs + ): + self.load_full_data() + if self._summary is None: + from wandb.old.summary import HTTPSummary + + # TODO: fix the outdir issue + self._summary = HTTPSummary(self, self.client, summary=self.summary_metrics) + return self._summary + + @property + def system_metrics(self) -> dict[str, Any]: + """Get run system metrics. Auto-loads full data if in lazy mode.""" + if ( + self._lazy + and not self._full_data_loaded + and "systemMetrics" not in self._attrs + ): + self.load_full_data() + + # Ensure systemMetrics is always converted to dict (defensive against conversion issues) + system_metrics_value = self._attrs.get("systemMetrics", {}) + # _convert_to_dict handles dict inputs (noop) and converts str/bytes/bytearray to dict + system_metrics_value = _convert_to_dict(system_metrics_value) + self._attrs["systemMetrics"] = system_metrics_value + return system_metrics_value + + @property + def summary_metrics(self) -> dict[str, Any]: + """Get run summary metrics. Auto-loads full data if in lazy mode.""" + if ( + self._lazy + and not self._full_data_loaded + and "summaryMetrics" not in self._attrs + ): + self.load_full_data() + + # Ensure summaryMetrics is always converted to dict (defensive against conversion issues) + summary_metrics_value = self._attrs.get("summaryMetrics", {}) + # _convert_to_dict handles dict inputs (noop) and converts str/bytes/bytearray to dict + summary_metrics_value = _convert_to_dict(summary_metrics_value) + self._attrs["summaryMetrics"] = summary_metrics_value + return summary_metrics_value + + @property + def rawconfig(self) -> dict[str, Any]: + """Get raw run config including internal keys. Auto-loads full data if in lazy mode.""" + if self._lazy and not self._full_data_loaded and "rawconfig" not in self._attrs: + self.load_full_data() + return self._attrs.get("rawconfig", {}) + + @property + def sweep_name(self) -> str | None: + """Get sweep name. Always available since sweepName is in lightweight fragment.""" + # sweepName is included in lightweight fragment, so no need to load full data + return self._attrs.get("sweepName") + + @property + def path(self) -> list[str]: + """The path of the run. The path is a list containing the entity, project, and run_id.""" + return [ + urllib.parse.quote_plus(str(self.entity)), + urllib.parse.quote_plus(str(self.project)), + urllib.parse.quote_plus(str(self.id)), + ] + + @property + def url(self) -> str: + """The URL of the run. + + The run URL is generated from the entity, project, and run_id. For + SaaS users, it takes the form of `https://wandb.ai/entity/project/run_id`. + """ + path = self.path + path.insert(2, "runs") + return self.client.app_url + "/".join(path) + + @property + def metadata(self) -> dict[str, Any] | None: + """Metadata about the run from wandb-metadata.json. + + Metadata includes the run's description, tags, start time, memory + usage and more. + """ + if self._metadata is None: + try: + f = self.file("wandb-metadata.json") + session = self.client._client.transport.session + response = session.get(f.url, timeout=5) + response.raise_for_status() + contents = response.content + self._metadata = json_util.loads(contents) + except: # noqa: E722 + # file doesn't exist, or can't be downloaded, or can't be parsed + pass + return self._metadata + + @property + def lastHistoryStep(self) -> int: # noqa: N802 + """Returns the last step logged in the run's history.""" + query = gql( + """ + query RunHistoryKeys($project: String!, $entity: String!, $name: String!) { + project(name: $project, entityName: $entity) { + run(name: $name) { historyKeys } + } + } + """ + ) + response = self._exec(query) + if ( + response is None + or response.get("project") is None + or response["project"].get("run") is None + or response["project"]["run"].get("historyKeys") is None + ): + return -1 + history_keys = response["project"]["run"]["historyKeys"] + return history_keys["lastStep"] if "lastStep" in history_keys else -1 + + def to_html(self, height: int = 420, hidden: bool = False) -> str: + """Generate HTML containing an iframe displaying this run.""" + url = self.url + "?jupyter=true" + style = f"border:none;width:100%;height:{height}px;" + prefix = "" + if hidden: + style += "display:none;" + prefix = ipython.toggle_button() + return prefix + f"" + + def _repr_html_(self) -> str: + if ipython.in_vscode_notebook(): + import html + + return html.escape(self._string_representation()) + + return self.to_html() + + def __repr__(self) -> str: + return self._string_representation() + + def _string_representation(self) -> str: + return f"<{nameof(type(self))} {'/'.join(self.path)} ({self.state})>" + + def beta_scan_history( + self, + keys: list[str] | None = None, + page_size: int = 1_000, + min_step: int = 0, + max_step: int | None = None, + use_cache: bool = True, + ) -> public.BetaHistoryScan: + """Returns an iterable collection of all history records for a run. + + This function is still in development and may not work as expected. + It uses wandb-core to read history from a run's exported + parquet history locally. + + Args: + keys: list of metrics to read from the run's history. + if no keys are provided then all metrics will be returned. + page_size: the number of history records to read at a time. + min_step: The minimum step to start reading history from (inclusive). + max_step: The maximum step to read history up to (exclusive). + use_cache: When set to True, checks the WANDB_CACHE_DIR for a run history. + If the run history is not found in the cache, it will be downloaded from the server. + If set to False, the run history will be downloaded every time. + + Returns: + A BetaHistoryScan object, + which can be iterator over to get history records. + """ + if self._api is None: + self._api = public.Api() + + beta_history_scan = public.BetaHistoryScan( + api=self._api, + run=self, + min_step=min_step, + max_step=max_step or self.lastHistoryStep + 1, + keys=keys, + page_size=page_size, + use_cache=use_cache, + ) + return beta_history_scan + + def download_history_exports( + self, + download_dir: pathlib.Path | str, + require_complete_history: bool = True, + ) -> DownloadHistoryResult: + """Download any parquet history files for the run to the provided directory. + + Args: + download_dir: The directory to download the history files to. + require_complete_history: Whether to require the complete history to be downloaded. + If true, and the run contains data that has not been exported to parquet files yet, + an IncompleteRunHistoryError will be raised. + + Returns: + A DownloadHistoryResult. + """ + if self._api is None: + self._api = public.Api() + + api_request = apb.ApiRequest( + read_run_history_request=apb.ReadRunHistoryRequest( + download_run_history=apb.DownloadRunHistory( + entity=self.entity, + project=self.project, + run_id=self.id, + download_dir=str(download_dir), + require_complete_history=require_complete_history, + ) + ) + ) + + response: apb.ApiResponse | None = None + try: + response = self._api._send_api_request(api_request) + except WandbApiFailedError as e: + if ( + e.response is not None + and e.response.error_type is not None + and e.response.error_type == apb.ErrorType.INCOMPLETE_RUN_HISTORY_ERROR + ): + raise IncompleteRunHistoryError() from None + + if response is None: + raise wandb.Error( + "Failed to download run history exports, no response from server" + ) + + contains_live_data: bool = ( + response.download_run_history_response.contains_live_data + ) + file_names: list[pathlib.Path] = [] + for file_name in response.download_run_history_response.file_names: + file_names.append(pathlib.Path(download_dir, file_name)) + + return DownloadHistoryResult( + paths=file_names, + contains_live_data=contains_live_data, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/sweeps.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/sweeps.py new file mode 100644 index 0000000000000000000000000000000000000000..eb720566e64bff53797428c645a6d3e110908161 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/sweeps.py @@ -0,0 +1,380 @@ +"""W&B Public API for Sweeps. + +This module provides classes for interacting with W&B hyperparameter +optimization sweeps. + +Example: +```python +from wandb.apis.public import Api + +# Get a specific sweep +sweep = Api().sweep("entity/project/sweep_id") + +# Access sweep properties +print(f"Sweep: {sweep.name}") +print(f"State: {sweep.state}") +print(f"Best Loss: {sweep.best_loss}") + +# Get best performing run +best_run = sweep.best_run() +print(f"Best Run: {best_run.name}") +print(f"Metrics: {best_run.summary}") +``` + +Note: + This module is part of the W&B Public API and provides read-only access + to sweep data. For creating and controlling sweeps, use the wandb.sweep() + and wandb.agent() functions from the main wandb package. +""" + +from __future__ import annotations + +import urllib +from typing import TYPE_CHECKING, Any, ClassVar, Mapping + +from typing_extensions import override +from wandb_gql import gql +from wandb_graphql.language.ast import Document + +import wandb +from wandb import util +from wandb.apis import public +from wandb.apis.attrs import Attrs +from wandb.apis.paginator import SizedPaginator +from wandb.sdk.lib import ipython + +if TYPE_CHECKING: + from wandb.apis._generated import GetSweeps + from wandb.apis.public.api import RetryingClient + + +class Sweeps(SizedPaginator["Sweep"]): + """A lazy iterator over a collection of `Sweep` objects. + + Examples: + ```python + from wandb.apis.public import Api + + sweeps = Api().project(name="project_name", entity="entity").sweeps() + + # Iterate over sweeps and print details + for sweep in sweeps: + print(f"Sweep name: {sweep.name}") + print(f"Sweep ID: {sweep.id}") + print(f"Sweep URL: {sweep.url}") + print("----------") + ``` + """ + + QUERY: ClassVar[Document | None] = None + last_response: GetSweeps | None + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + per_page: int = 50, + ) -> Sweeps: + """An iterable collection of `Sweep` objects. + + Args: + client: The API client used to query W&B. + entity: The entity which owns the sweeps. + project: The project which contains the sweeps. + per_page: The number of sweeps to fetch per request to the API. + """ + if self.QUERY is None: + from wandb.apis._generated import GET_SWEEPS_GQL + + type(self).QUERY = gql(GET_SWEEPS_GQL) + + self.entity = entity + self.project = project + variables = {"project": self.project, "entity": self.entity} + super().__init__(client, variables, per_page) + + @override + def _update_response(self) -> None: + """Fetch and validate the response data for the current page.""" + from wandb.apis._generated import GetSweeps + + data = self.client.execute(self.QUERY, variable_values=self.variables) + self.last_response = GetSweeps.model_validate(data) + + @property + @override + def _length(self) -> int: + """The total number of sweeps in the project. + + + """ + if self.last_response is None: + self._load_page() + return ( + total + if (total := self.last_response.project.total_sweeps) is not None + else 0 + ) + + @property + @override + def more(self) -> bool: + """Returns whether there are more sweeps to fetch. + + + """ + if self.last_response: + return self.last_response.project.sweeps.page_info.has_next_page + return True + + @property + @override + def cursor(self) -> str | None: + """Returns the cursor for the next page of sweeps. + + + """ + if self.last_response: + return self.last_response.project.sweeps.page_info.end_cursor + return None + + @override + def convert_objects(self) -> list[Sweep]: + """Converts the last GraphQL response into a list of `Sweep` objects. + + + """ + from wandb._pydantic import Connection + from wandb.apis._generated import SweepFragment + + if (rsp := self.last_response) is None or (project := rsp.project) is None: + msg = f"Could not find project {self.project!r}" + raise ValueError(msg) + + if project.total_sweeps < 1: + return [] + return [ + Sweep( + self.client, + self.entity, + self.project, + node.name, + ) + for node in Connection[SweepFragment].model_validate(project.sweeps).nodes() + ] + + def __repr__(self): + return f"" + + +class Sweep(Attrs): + """The set of runs associated with the sweep. + + Attributes: + runs (Runs): List of runs + id (str): Sweep ID + project (str): The name of the project the sweep belongs to + config (dict): Dictionary containing the sweep configuration + state (str): The state of the sweep. Can be "Finished", "Failed", + "Crashed", or "Running". + expected_run_count (int): The number of expected runs for the sweep + """ + + def __init__( + self, + client: RetryingClient, + entity: str, + project: str, + sweep_id: str, + attrs: Mapping[str, Any] | None = None, + ): + # TODO: Add agents / flesh this out. + super().__init__(dict(attrs or {})) + self.client = client + self._entity = entity + self.project = project + self.id = sweep_id + self.runs = [] + + self.load(force=not attrs) + + @property + def entity(self) -> str: + """The entity associated with the sweep.""" + return self._entity + + @property + def username(self) -> str: + """Deprecated. Use `Sweep.entity` instead.""" + wandb.termwarn("Sweep.username is deprecated. please use Sweep.entity instead.") + return self._entity + + @property + def config(self): + """The sweep configuration used for the sweep.""" + return util.load_yaml(self._attrs["config"]) + + def load(self, force: bool = False): + """Fetch and update sweep data logged to the run from GraphQL database. + + + """ + if force or not self._attrs: + if not (sweep := self.get(self.client, self.entity, self.project, self.id)): + raise ValueError(f"Could not find sweep {self!r}") + self._attrs = sweep._attrs + self.runs = sweep.runs + + return self._attrs + + @property + def order(self): + """Return the order key for the sweep.""" + if self._attrs.get("config") and self.config.get("metric"): + sort_order = self.config["metric"].get("goal", "minimize") + prefix = "+" if sort_order == "minimize" else "-" + return public.QueryGenerator.format_order_key( + prefix + self.config["metric"]["name"] + ) + + def best_run(self, order=None): + """Return the best run sorted by the metric defined in config or the order passed in.""" + if order is None: + order = self.order + else: + order = public.QueryGenerator.format_order_key(order) + if order is None: + wandb.termwarn( + "No order specified and couldn't find metric in sweep config, returning most recent run" + ) + else: + wandb.termlog("Sorting runs by {}".format(order)) + filters = {"$and": [{"sweep": self.id}]} + try: + return public.Runs( + self.client, + self.entity, + self.project, + order=order, + filters=filters, + per_page=1, + )[0] + except IndexError: + return None + + @property + def expected_run_count(self) -> int | None: + """Return the number of expected runs in the sweep or None for infinite runs.""" + return self._attrs.get("runCountExpected") + + @property + def path(self): + """Returns the path of the project. + + The path is a list containing the entity, project name, and sweep ID.""" + return [ + urllib.parse.quote_plus(self.entity), + urllib.parse.quote_plus(self.project), + urllib.parse.quote_plus(self.id), + ] + + @property + def url(self): + """The URL of the sweep. + + The sweep URL is generated from the entity, project, the term + "sweeps", and the sweep ID.run_id. For + SaaS users, it takes the form + of `https://wandb.ai/entity/project/sweeps/sweeps_ID`. + """ + path = self.path + path.insert(2, "sweeps") + return self.client.app_url + "/".join(path) + + @property + def name(self): + """The name of the sweep. + + Returns the first name that exists in the following priority order: + + 1. User-edited display name + 2. Name configured at creation time + 3. Sweep ID + """ + return self._attrs.get("displayName") or self.config.get("name") or self.id + + @classmethod + def get( + cls, + client: RetryingClient, + entity: str | None = None, + project: str | None = None, + sid: str | None = None, + order: str | None = None, + query: Document | None = None, + **kwargs, + ): + """Execute a query against the cloud backend. + + Args: + client: The client to use to execute the query. + entity: The entity (username or team) that owns the project. + project: The name of the project to fetch sweep from. + sid: The sweep ID to query. + order: The order in which the sweep's runs are returned. + query: The query to use to execute the query. + **kwargs: Additional keyword arguments to pass to the query. + """ + from wandb.apis._generated import GET_SWEEP_GQL, GET_SWEEP_LEGACY_GQL + + if not order: + order = "+created_at" + + variables = {"entity": entity, "project": project, "name": sid, **kwargs} + if query is None: + query = gql(GET_SWEEP_GQL) + try: + data = client.execute(query, variable_values=variables) + except Exception: + # Don't handle exception, rely on legacy query + # TODO(gst): Implement updated introspection workaround + query = gql(GET_SWEEP_LEGACY_GQL) + data = client.execute(query, variable_values=variables) + + # FIXME: looks like this method allows passing arbitrary GQL queries, so for now + # we'll have to skip trying to validate the result with a generated pydantic model. + if not ( + data + and (proj_dict := data.get("project")) + and (sweep_dict := proj_dict.get("sweep")) + ): + return None + sweep = cls(client, entity, project, sid, attrs=sweep_dict) + sweep.runs = public.Runs( + client, + entity, + project, + order=order, + per_page=10, + filters={"$and": [{"sweep": sweep.id}]}, + ) + return sweep + + def to_html(self, height: int = 420, hidden: bool = False) -> str: + """Generate HTML containing an iframe displaying this sweep.""" + url = self.url + "?jupyter=true" + style = f"border:none;width:100%;height:{height}px;" + prefix = "" + if hidden: + style += "display:none;" + prefix = ipython.toggle_button("sweep") + return prefix + f"" + + def _repr_html_(self) -> str: + return self.to_html() + + def __repr__(self) -> str: + pathstr = "/".join(self.path) + state = self._attrs.get("state", "Unknown State") + return f"" diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/teams.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/teams.py new file mode 100644 index 0000000000000000000000000000000000000000..4784006367490b0c153d61e149375436ffcd80ce --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/teams.py @@ -0,0 +1,182 @@ +"""W&B Public API for managing teams and team members. + +This module provides classes for managing W&B teams and their members. + +Note: + This module is part of the W&B Public API and provides methods to manage + teams and their members. Team management operations require appropriate + permissions. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping + +from typing_extensions import Self +from wandb_gql import gql + +from wandb.apis.attrs import Attrs + +if TYPE_CHECKING: + from .api import Api, RetryingClient + + +class Member(Attrs): + """A member of a team. + + Args: + client (`wandb.apis.internal.Api`): The client instance to use + team (str): The name of the team this member belongs to + attrs (dict): The member attributes + """ + + def __init__(self, client: RetryingClient, team: str, attrs: Mapping[str, Any]): + super().__init__(attrs) + self._client = client + self.team = team + + def delete(self): + """Remove a member from a team. + + Returns: + Boolean indicating success + """ + from requests import HTTPError + + from wandb.apis._generated import DELETE_INVITE_GQL, DeleteInvite + + try: + data = self._client.execute( + gql(DELETE_INVITE_GQL), {"id": self.id, "entity": self.team} + ) + except HTTPError: + return False + else: + result = DeleteInvite.model_validate(data).result + return (result is not None) and result.success + + def __repr__(self): + return f"" + + +class Team(Attrs): + """A class that represents a W&B team. + + This class provides methods to manage W&B teams, including creating teams, + inviting members, and managing service accounts. It inherits from Attrs + to handle team attributes. + + Args: + client (`wandb.apis.public.Api`): The api instance to use + name (str): The name of the team + attrs (dict): Optional dictionary of team attributes + + Note: + Team management requires appropriate permissions. + """ + + def __init__( + self, + client: RetryingClient, + name: str, + attrs: Mapping[str, Any] | None = None, + ): + super().__init__(attrs or {}) + self._client = client + self.name = name + self.load() + + @classmethod + def create(cls, api: Api, team: str, admin_username: str | None = None) -> Self: + """Create a new team. + + Args: + api: (`Api`) The api instance to use + team: (str) The name of the team + admin_username: (str) optional username of the admin user of the team, defaults to the current user. + + Returns: + A `Team` object + """ + from requests import HTTPError + + from wandb.apis._generated import CREATE_TEAM_GQL + + try: + api.client.execute( + gql(CREATE_TEAM_GQL), + {"teamName": team, "teamAdminUserName": admin_username}, + ) + except HTTPError: + pass + return cls(api.client, team) + + def invite(self, username_or_email: str, admin: bool = False) -> bool: + """Invite a user to a team. + + Args: + username_or_email: (str) The username or email address of the user + you want to invite. + admin: (bool) Whether to make this user a team admin. + Defaults to `False`. + + Returns: + `True` on success, `False` if user was already invited or didn't exist. + """ + from requests import HTTPError + + from wandb.apis._generated import CREATE_INVITE_GQL + + variables = { + "entity": self.name, + "admin": admin, + ("email" if ("@" in username_or_email) else "username"): username_or_email, + } + try: + self._client.execute(gql(CREATE_INVITE_GQL), variables) + except HTTPError: + return False + return True + + def create_service_account(self, description: str) -> Member | None: + """Create a service account for the team. + + Args: + description: (str) A description for this service account + + Returns: + The service account `Member` object, or None on failure + """ + from requests import HTTPError + + from wandb.apis._generated import CREATE_SERVICE_ACCOUNT_GQL + + try: + self._client.execute( + gql(CREATE_SERVICE_ACCOUNT_GQL), + {"entity": self.name, "description": description}, + ) + self.load(True) + return self.members[-1] + except HTTPError: + return None + + def load(self, force: bool = False) -> dict[str, Any]: + """Return members that belong to a team. + + + """ + from wandb.apis._generated import GET_TEAM_ENTITY_GQL, GetTeamEntity + + if force or not self._attrs: + data = self._client.execute(gql(GET_TEAM_ENTITY_GQL), {"name": self.name}) + result = GetTeamEntity.model_validate(data) + self._attrs = entity.model_dump() if (entity := result.entity) else {} + self._attrs["members"] = [ + Member(self._client, self.name, member) + for member in self._attrs["members"] + ] + return self._attrs + + def __repr__(self) -> str: + return f"" diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/users.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/users.py new file mode 100644 index 0000000000000000000000000000000000000000..665d6ee3ccb8504feb199f0d83cb4e734d16036a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/users.py @@ -0,0 +1,156 @@ +"""W&B Public API for managing users and API keys. + +This module provides classes for managing W&B users and their API keys. + +Note: + This module is part of the W&B Public API and provides methods to manage + users and their authentication. Some operations require admin privileges. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, MutableMapping + +from typing_extensions import Self +from wandb_gql import gql + +import wandb +from wandb.apis.attrs import Attrs + +if TYPE_CHECKING: + from .api import Api, RetryingClient + + +class User(Attrs): + """A class representing a W&B user with authentication and management capabilities. + + This class provides methods to manage W&B users, including creating users, + managing API keys, and accessing team memberships. It inherits from Attrs + to handle user attributes. + + Args: + client: (`wandb.apis.internal.Api`) The client instance to use + attrs: (dict) The user attributes + + Note: + Some operations require admin privileges + """ + + def __init__(self, client: RetryingClient, attrs: MutableMapping[str, Any]): + super().__init__(attrs) + self._client = client + self._user_api: Api | None = None + + @property + def user_api(self) -> Api | None: + """An instance of the api using credentials from the user.""" + if self._user_api is None and self.api_keys: + self._user_api = wandb.Api(api_key=self.api_keys[0]) + return self._user_api + + @classmethod + def create(cls, api: Api, email: str, admin: bool = False) -> Self: + """Create a new user. + + Args: + api (`Api`): The api instance to use + email (str): The name of the team + admin (bool): Whether this user should be a global instance admin + + Returns: + A `User` object + """ + from wandb.apis._generated import ( + CREATE_USER_FROM_ADMIN_GQL, + CreateUserFromAdmin, + ) + + gql_op = gql(CREATE_USER_FROM_ADMIN_GQL) + data = api.client.execute(gql_op, {"email": email, "admin": admin}) + user = CreateUserFromAdmin.model_validate(data).result.user + return cls(api.client, user.model_dump()) + + @property + def api_keys(self) -> list[str]: + """List of API key names associated with the user. + + Returns: + Names of API keys associated with the user. Empty list if user + has no API keys or if API key data hasn't been loaded. + """ + if self._attrs.get("apiKeys") is None: + return [] + return [k["node"]["name"] for k in self._attrs["apiKeys"]["edges"]] + + @property + def teams(self) -> list[str]: + """List of team names that the user is a member of. + + Returns: + Names of teams the user belongs to. Empty list if user has no + team memberships or if teams data hasn't been loaded. + """ + if self._attrs.get("teams") is None: + return [] + return [k["node"]["name"] for k in self._attrs["teams"]["edges"]] + + def delete_api_key(self, api_key: str) -> bool: + """Delete a user's api key. + + Args: + api_key (str): The name of the API key to delete. This should be + one of the names returned by the `api_keys` property. + + Returns: + Boolean indicating success + + Raises: + ValueError if the api_key couldn't be found + """ + from requests import HTTPError + + from wandb.apis._generated import DELETE_API_KEY_GQL + + idx = self.api_keys.index(api_key) + api_key_id = self._attrs["apiKeys"]["edges"][idx]["node"]["id"] + try: + self._client.execute(gql(DELETE_API_KEY_GQL), {"id": api_key_id}) + except HTTPError: + return False + return True + + def generate_api_key(self, description: str | None = None) -> str | None: + """Generate a new api key. + + Args: + description (str, optional): A description for the new API key. This can be + used to identify the purpose of the API key. + + Returns: + The new api key, or None on failure + """ + from requests import HTTPError + + from wandb.apis._generated import GENERATE_API_KEY_GQL, GenerateApiKey + + try: + # We must make this call using credentials from the original user + gql_op = gql(GENERATE_API_KEY_GQL) + data = self.user_api.client.execute(gql_op, {"description": description}) + key_fragment = GenerateApiKey.model_validate(data).result.api_key + self._attrs["apiKeys"]["edges"].append({"node": key_fragment.model_dump()}) + except (HTTPError, AttributeError): + return None + else: + return key_fragment.name + + def __repr__(self) -> str: + if email := self._attrs.get("email"): + return f"" + if username := self._attrs.get("username"): + return f"" + if id_ := self._attrs.get("id"): + return f"" + if name := self._attrs.get("name"): + return f"" + return "" diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/public/utils.py b/.venv/lib/python3.13/site-packages/wandb/apis/public/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57e9750d26ce86e28d1613bc3e7f179276f4f276 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/public/utils.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import re +from enum import Enum +from typing import Any, Iterable, Mapping +from urllib.parse import urlparse + +from wandb_gql import gql +from wandb_graphql import TypeInfo +from wandb_graphql.language import ast, visitor +from wandb_graphql.validation.validation import ValidationContext + +from wandb._iterutils import one +from wandb.sdk.internal.internal_api import Api as InternalApi + + +def parse_s3_url_to_s3_uri(url) -> str: + """Convert an S3 HTTP(S) URL to an S3 URI. + + Arguments: + url (str): The S3 URL to convert, in the format + 'http(s)://.s3..amazonaws.com/'. + or 'http(s)://.s3.amazonaws.com/' + + Returns: + str: The corresponding S3 URI in the format 's3:///'. + + Raises: + ValueError: If the provided URL is not a valid S3 URL. + """ + # Regular expression to match S3 URL pattern + s3_pattern = r"^https?://.*s3.*amazonaws\.com.*" + parsed_url = urlparse(url) + + # Check if it's an S3 URL + match = re.match(s3_pattern, parsed_url.geturl()) + if not match: + raise ValueError("Invalid S3 URL") + + # Extract bucket name and key + bucket_name, *_ = parsed_url.netloc.split(".") + key = parsed_url.path.lstrip("/") + + # Construct the S3 URI + s3_uri = f"s3://{bucket_name}/{key}" + + return s3_uri + + +class PathType(Enum): + """We have lots of different paths users pass in to fetch artifacts, projects, etc. + + This enum is used for specifying what format the path is in given a string path. + """ + + PROJECT = "PROJECT" + ARTIFACT = "ARTIFACT" + + +def parse_org_from_registry_path(path: str, path_type: PathType) -> str: + """Parse the org from a registry path. + + Essentially fetching the "entity" from the path but for Registries the entity is actually the org. + + Args: + path (str): The path to parse. Can be a project path / or or an + artifact path like // or / or + path_type (PathType): The type of path to parse. + """ + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + parts = path.split("/") + expected_parts = 3 if path_type == PathType.ARTIFACT else 2 + + if len(parts) >= expected_parts: + org, project = parts[:2] + if is_artifact_registry_project(project): + return org + return "" + + +def fetch_org_from_settings_or_entity( + settings: dict, default_entity: str | None = None +) -> str: + """Fetch the org from either the settings or deriving it from the entity. + + Returns the org from the settings if available. If no org is passed in or set, the entity is used to fetch the org. + + Args: + organization (str | None): The organization to fetch the org for. + settings (dict): The settings to fetch the org for. + default_entity (str | None): The default entity to fetch the org for. + """ + if (organization := settings.get("organization")) is None: + # Fetch the org via the Entity. Won't work if default entity is a personal entity and belongs to multiple orgs + entity = settings.get("entity") or default_entity + if entity is None: + raise ValueError( + "No entity specified and can't fetch organization from the entity" + ) + entity_orgs = InternalApi()._fetch_orgs_and_org_entities_from_entity(entity) + entity_org = one( + entity_orgs, + too_short=ValueError( + "No organizations found for entity. Please specify an organization in the settings." + ), + too_long=ValueError( + "Multiple organizations found for entity. Please specify an organization in the settings." + ), + ) + organization = entity_org.display_name + return organization + + +class _GQLCompatRewriter(visitor.Visitor): + """GraphQL AST visitor to rewrite queries/mutations to be compatible with older server versions.""" + + def __init__( + self, + omit_variables: Iterable[str] | None = None, + omit_fragments: Iterable[str] | None = None, + omit_fields: Iterable[str] | None = None, + rename_fields: Mapping[str, str] | None = None, + ): + self.omit_variables = set(omit_variables or ()) + self.omit_fragments = set(omit_fragments or ()) + self.omit_fields = set(omit_fields or ()) + self.rename_fields = dict(rename_fields or {}) + + def leave_Document(self, node: ast.Document, *_, **__) -> Any: # noqa: N802 + # After rewriting the GQL document, prune "orphan" (unused) fragment definitions. + # Note: The ValidationContext doesn't require a schema here, as we only use it to check for reachable fragments. + ctx = ValidationContext(schema=None, ast=node, type_info=TypeInfo(schema=None)) + operation_defns = { + dfn for dfn in node.definitions if isinstance(dfn, ast.OperationDefinition) + } + used_fragment_defns = { + frag + for op in operation_defns + for frag in ctx.get_recursively_referenced_fragments(op) + } + # Preserve original defintion order + allowed_defns = operation_defns | used_fragment_defns + node.definitions = [dfn for dfn in node.definitions if (dfn in allowed_defns)] + + def enter_Variable(self, node: ast.Variable, *_, **__) -> Any: # noqa: N802 + if node.name.value in self.omit_variables: + return visitor.REMOVE + + def leave_VariableDefinition(self, node: ast.VariableDefinition, *_, **__) -> Any: # noqa: N802 + # For context, consider the `$varName: String` variable definition below: + # (..., $varName: String, ...) + # + # On ENTERING, the AST looks like: + # VariableDefinition(variable=Variable(name=Name(value='varName')), ...) + # + # On LEAVING, if `$varName` was removed, the AST looks like: + # VariableDefinition(variable=REMOVE, ...) + if node.variable is visitor.REMOVE: + return visitor.REMOVE + + def leave_ObjectField(self, node: ast.ObjectField, *_, **__) -> Any: # noqa: N802 + # For context, consider `argName: $varName` in the input args below: + # input: {..., argName: $varName, ...} + # + # On ENTERING, the AST for `argName: $varName` looks like: + # ObjectField( + # name=Name(value='argName'), value=Variable(name=Name(value='varName')), + # ) + # + # On LEAVING, if `$varName` was removed, the AST looks like: + # ObjectField( + # name=Name(value='argName'), value=REMOVE, + # ) + if node.value is visitor.REMOVE: + return visitor.REMOVE + + def enter_Argument(self, node: ast.Argument, *_, **__) -> Any: # noqa: N802 + if node.name.value in self.omit_variables: + return visitor.REMOVE + + def enter_FragmentDefinition(self, node: ast.FragmentDefinition, *_, **__) -> Any: # noqa: N802 + if node.name.value in self.omit_fragments: + return visitor.REMOVE + + def enter_FragmentSpread(self, node: ast.FragmentSpread, *_, **__) -> Any: # noqa: N802 + if node.name.value in self.omit_fragments: + return visitor.REMOVE + + def enter_Field(self, node: ast.Field, *_, **__) -> Any: # noqa: N802 + if node.name.value in self.omit_fields: + return visitor.REMOVE + if new_name := self.rename_fields.get(node.name.value): + node.name.value = new_name + + def leave_Field(self, node: ast.Field, *_, **__) -> Any: # noqa: N802 + # If the field had a selection set, but now it's empty, remove the field entirely + if (node.selection_set is not None) and (not node.selection_set.selections): + return visitor.REMOVE + + +def gql_compat( + request_string: str, + omit_variables: Iterable[str] | None = None, + omit_fragments: Iterable[str] | None = None, + omit_fields: Iterable[str] | None = None, + rename_fields: Mapping[str, str] | None = None, +) -> ast.Document: + """Rewrite a GraphQL request string to ensure compatibility with older server versions. + + Args: + request_string (str): The GraphQL request string to rewrite. + omit_variables (Iterable[str] | None): Names of variables to remove from the request string. + omit_fragments (Iterable[str] | None): Names of fragments to remove from the request string. + omit_fields (Iterable[str] | None): Names of fields to remove from the request string. + rename_fields (Mapping[str, str] | None): + A mapping of fields to rename in the request string, given as `{old_name -> new_name}`. + + Returns: + str: Modified GraphQL request string with fragments on omitted types removed. + """ + # Parse the request into a GraphQL AST + doc = gql(request_string) + + if not (omit_variables or omit_fragments or omit_fields or rename_fields): + return doc + + # Visit the AST with our visitor to filter out unwanted fragments + rewriter = _GQLCompatRewriter( + omit_variables=omit_variables, + omit_fragments=omit_fragments, + omit_fields=omit_fields, + rename_fields=rename_fields, + ) + return visitor.visit(doc, rewriter) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/reports/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/reports/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2b9ad9385524f2b868e498d3030de028fac546 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/reports/__init__.py @@ -0,0 +1 @@ +from .v2 import * # noqa: F403 diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/reports/v1/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/reports/v1/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd41baac9f44dd1f7dc5145779d33fe95d10763 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/reports/v1/__init__.py @@ -0,0 +1,8 @@ +import wandb + +try: + from wandb_workspaces.reports.v1 import * # noqa: F403 +except ImportError: + wandb.termerror( + "Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`." + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/reports/v2/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/reports/v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..532129620940c7deb87ae483a08059cb792d15f4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/reports/v2/__init__.py @@ -0,0 +1,8 @@ +import wandb + +try: + from wandb_workspaces.reports.v2 import * # noqa: F403 +except ImportError: + wandb.termerror( + "Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`." + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/apis/workspaces/__init__.py b/.venv/lib/python3.13/site-packages/wandb/apis/workspaces/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..269a35473d08b01200c593d2c8958c85349a0749 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/apis/workspaces/__init__.py @@ -0,0 +1,8 @@ +import wandb + +try: + from wandb_workspaces.workspaces import * # noqa: F403 +except ImportError: + wandb.termerror( + "Failed to import wandb_workspaces. To edit workspaces programmatically, please install it using `pip install wandb[workspaces]`." + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/__init__.py b/.venv/lib/python3.13/site-packages/wandb/automations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b9d8738c6777c9205d8b6be6781ff9feb58e7bf2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/__init__.py @@ -0,0 +1,79 @@ +import wandb +from wandb._pydantic import IS_PYDANTIC_V2 + +from .actions import ActionType, DoNothing, SendNotification, SendWebhook +from .automations import Automation, NewAutomation +from .events import ( + ArtifactEvent, + EventType, + MetricChangeFilter, + MetricThresholdFilter, + MetricZScoreFilter, + OnAddArtifactAlias, + OnCreateArtifact, + OnLinkArtifact, + OnRunMetric, + OnRunState, + RunEvent, + RunStateFilter, +) +from .integrations import Integration, SlackIntegration, WebhookIntegration +from .scopes import ArtifactCollectionScope, ProjectScope, ScopeType + +# ---------------------------------------------------------------------------- +# WARNINGS on import +if not IS_PYDANTIC_V2: + # Raises an error in Pydantic v1 environments, where the Automations API + # has not been tested and is unlikely to work as expected. + # + # Remove this when we either: + # - Drop support for Pydantic v1 + # - Are able to implement (limited) Pydantic v1 support + raise ImportError( + "The W&B Automations API requires Pydantic v2. " + "We recommend upgrading `pydantic` to use this feature." + ) + +else: + # If Pydantic v2 is available, we can use the full Automations API + # but communicate to users that the API is still experimental and + # may change rapidly. + wandb.termwarn( + "The W&B Automations API is experimental and the implementation is subject to change." + "Review the release notes before upgrading. We recommend pinning your " + f"package version to `{wandb.__package__}=={wandb.__version__}` to reduce the risk of disruption.", + repeat=False, + ) +# ---------------------------------------------------------------------------- + +__all__ = [ + # Scopes + "ScopeType", # doc:exclude + "ArtifactCollectionScope", # doc:exclude + "ProjectScope", # doc:exclude + # Events + "EventType", # doc:exclude + "OnAddArtifactAlias", + "OnCreateArtifact", + "OnLinkArtifact", + "OnRunMetric", + "OnRunState", + "ArtifactEvent", # doc:exclude + "RunEvent", # doc:exclude + "MetricThresholdFilter", + "MetricChangeFilter", + "RunStateFilter", + "MetricZScoreFilter", + # Actions + "ActionType", # doc:exclude + "SendNotification", + "SendWebhook", + "DoNothing", + # Automations + "Automation", + "NewAutomation", + # Integrations + "Integration", # doc:exclude + "SlackIntegration", # doc:exclude + "WebhookIntegration", # doc:exclude +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_filters/__init__.py b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56d551d4ed507da432151e8eb90167495fdb5acf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/__init__.py @@ -0,0 +1,40 @@ +from .expressions import FilterExpr, MongoLikeFilter +from .operators import ( + And, + Contains, + Eq, + Exists, + Gt, + Gte, + In, + Lt, + Lte, + Ne, + Nor, + Not, + NotIn, + Op, + Or, + Regex, +) + +__all__ = [ + "And", + "Or", + "Nor", + "Not", + "Op", + "Gt", + "Lt", + "Gte", + "Lte", + "Eq", + "Ne", + "In", + "NotIn", + "Contains", + "Exists", + "Regex", + "FilterExpr", + "MongoLikeFilter", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_filters/expressions.py b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..94a40ad843f66f4faee99b6a4d7233397449ae3e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/expressions.py @@ -0,0 +1,180 @@ +"""Pydantic-compatible representations of MongoDB expressions.""" + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Dict, Union + +from pydantic import ConfigDict, model_serializer +from typing_extensions import Self, TypeAlias + +from wandb._pydantic import CompatBaseModel, model_validator +from wandb._strutils import nameof + +from .operators import ( + And, + Contains, + Eq, + Exists, + Gt, + Gte, + In, + Lt, + Lte, + Ne, + Nor, + Not, + NotIn, + Op, + Or, + Regex, + RichReprResult, + Scalar, + SupportsBitwiseLogicalOps, +) + + +class FilterableField: + """A descriptor that can be used to define a "filterable" field on a class. + + Internal helper to support syntactic sugar for defining event filters. + """ + + _python_name: str #: The name of the field this descriptor was assigned to in the Python class. + _server_name: str | None #: If set, the actual server-side field name to filter on. + + def __init__(self, server_name: str | None = None): + self._server_name = server_name + + def __set_name__(self, owner: type, name: str) -> None: + self._python_name = name + + def __get__(self, obj: Any, objtype: type) -> Self: + # By default, if we didn't explicitly provide a backend name for + # filtering, assume the field has the same name in the backend as + # the python attribute. + return self + + @property + def _name(self) -> str: + return self._server_name or self._python_name + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return f"{nameof(type(self))}({self._name!r})" + + # Methods to define filter expressions through chaining + def matches_regex(self, pattern: str, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Regex(val=pattern)) + + def contains(self, text: str, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Contains(val=text)) + + def exists(self, exists: bool = True, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Exists(val=exists)) + + def lt(self, value: Scalar, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Lt(val=value)) + + def gt(self, value: Scalar, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Gt(val=value)) + + def lte(self, value: Scalar, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Lte(val=value)) + + def gte(self, value: Scalar, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Gte(val=value)) + + def ne(self, value: Scalar, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Ne(val=value)) + + def eq(self, value: Scalar, /) -> FilterExpr: + return FilterExpr(field=self._name, op=Eq(val=value)) + + def in_(self, values: Iterable[Scalar], /) -> FilterExpr: + return FilterExpr(field=self._name, op=In(val=values)) + + def not_in(self, values: Iterable[Scalar], /) -> FilterExpr: + return FilterExpr(field=self._name, op=NotIn(val=values)) + + # Deliberately override the default behavior of comparison operator symbols, + # (`<`, `>`, `<=`, `>=`, `==`, `!=`), to allow defining filter expressions + # idiomatically, e.g. `field == "value"`. + # + # See similar overrides of built-in dunder methods in common libraries like + # `sqlalchemy`, `polars`, `pandas`, `numpy`, etc. + # + # As an illustrative example from `sqlalchemy`, see: + # https://github.com/sqlalchemy/sqlalchemy/blob/f21ae633486380a26dc0b67b70ae1c0efc6b4dc4/lib/sqlalchemy/orm/descriptor_props.py#L808-L812 + def __lt__(self, other: Any) -> FilterExpr: + return self.lt(other) + + def __gt__(self, other: Any) -> FilterExpr: + return self.gt(other) + + def __le__(self, other: Any) -> FilterExpr: + return self.lte(other) + + def __ge__(self, other: Any) -> FilterExpr: + return self.gte(other) + + def __eq__(self, other: Any) -> FilterExpr: + return self.eq(other) + + def __ne__(self, other: Any) -> FilterExpr: + return self.ne(other) + + +# ------------------------------------------------------------------------------ +class FilterExpr(CompatBaseModel, SupportsBitwiseLogicalOps): + """A MongoDB filter expression on a specific field.""" + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + field: str + op: Union[Op, Dict[str, Any]] + + def __repr__(self) -> str: + return f"{nameof(type(self))}({self.field!s}: {self.op!r})" + + def __rich_repr__(self) -> RichReprResult: + # https://rich.readthedocs.io/en/stable/pretty.html + yield self.field, self.op + + @model_validator(mode="before") + @classmethod + def _validate(cls, data: Any) -> Any: + """Parse a MongoDB dict representation of the filter expression.""" + if ( + isinstance(data, dict) + and len(data) == 1 + and not any(key.startswith("$") for key in data) + ): + # This looks like a MongoDB filter expression on a single field. E.g.: + # - in: `{"display_name": {"$contains": "my-run"}}` + # - out: `FilterExpr(field="display_name", op=Contains(val="my-run"))` + ((field, op),) = data.items() + return {"field": field, "op": op} + return data + + @model_serializer(mode="plain") + def _to_mongo_dict(self) -> dict[str, Any]: + """Return a MongoDB dict representation of the expression.""" + from pydantic_core import to_jsonable_python # Only valid in pydantic v2 + + return {self.field: to_jsonable_python(self.op, by_alias=True, round_trip=True)} + + +# Some of the MongoDB op types need to be rebuilt after defining FilterExpr, +# due to forward references. +And.model_rebuild() +Or.model_rebuild() +Nor.model_rebuild() +Not.model_rebuild() + +# for type annotations +MongoLikeFilter: TypeAlias = Union[Op, FilterExpr, Dict[str, Any]] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_filters/filterutils.py b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/filterutils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbef82ec8a33d0974a28035544c64cbb83266d5e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/filterutils.py @@ -0,0 +1,91 @@ +"""Helpers for parsing and transforming MongoDB expressions. + +If a function is defined here, it's an internal helper that we deliberately +don't expose as instnace methods on filter types for now. +""" + +from __future__ import annotations + +from functools import singledispatch +from typing import Iterator + +from .expressions import FilterExpr, MongoLikeFilter +from .operators import ( + BaseVariadicLogicalOp, + Eq, + Exists, + Gt, + Gte, + In, + Lt, + Lte, + Ne, + Nor, + Not, + NotIn, + Op, + Or, +) + + +@singledispatch +def simplify_expr(expr: MongoLikeFilter) -> MongoLikeFilter: + """Simplify a MongoDB filter by removing and unnesting redundant operators.""" + return expr # default implementation is a no-op + + +@simplify_expr.register +def _(op: BaseVariadicLogicalOp) -> MongoLikeFilter: + """Simplify an `And/Or/Nor` operator by removing and unnesting redundant expressions. + + This will flatten the operator's inner expressions and simplify them recursively, + e.g.: + - `And(op1, And(op2, ...)) -> And(op1, op2, ...)` + - `Or(op1, Or(op2, ...)) -> Or(op1, op2, ...)` + + Note that unnested empty operators are preserved, e.g. + - `And() -> And()` + - `Or() -> Or()` + + However, nested empty operators are flattened, e.g.: + - `And(And(), And()) -> And()` + - `Or(Or(), Or()) -> Or()` + + Single inner expressions are unnested, e.g.: + - `And(a) -> a` + - `Or(a) -> a` + """ + cls = type(op) + # Flatten and simplify the operator's inner expressions. + if len(exprs := [simplify_expr(x) for x in flatten_inner(op, cls)]) == 1: + return exprs[0] # Unnest single inner expressions. + return cls(exprs=exprs) + + +@simplify_expr.register +def _(op: Not) -> MongoLikeFilter: + """Simplify a `Not` operator by removing and unnesting redundant expressions. + + This will invert the inner expression if possible and otherwise remove nested + `Not` operators, e.g.: + - `Not(Not(a)) -> a` + - `Not(Or(a, b)) -> Nor(a, b)` + - `Not(Nor(a, b)) -> Or(a, b)` + - `Not(In(a, b)) -> NotIn(a, b)` + - `Not(NotIn(a, b)) -> In(a, b)` + """ + # TODO: Find a more efficient way to apply custom __invert__ impls + if isinstance( + expr := op.expr, (Not, Or, Nor, In, NotIn, Eq, Ne, Lt, Lte, Gt, Gte, Exists) + ): + return simplify_expr(~expr) + return Not(expr=simplify_expr(expr)) + + +def flatten_inner( + op: BaseVariadicLogicalOp, + parent_cls: type[BaseVariadicLogicalOp], +) -> Iterator[FilterExpr | Op]: + """Iterates over an `And/Or/Nor` operator's flattened inner expressions.""" + for x in op.exprs: + yield from (flatten_inner(x, parent_cls) if isinstance(x, parent_cls) else (x,)) diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_filters/operators.py b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/operators.py new file mode 100644 index 0000000000000000000000000000000000000000..dfade7b6ab8412338aa548e3139c621537e16d9f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/operators.py @@ -0,0 +1,277 @@ +"""Types that represent operators in MongoDB filter expressions.""" + +from __future__ import annotations + +from abc import ABC +from typing import TYPE_CHECKING, Any, Iterable, Tuple, TypeVar, Union + +from pydantic import ConfigDict, Field, StrictBool, StrictFloat, StrictInt, StrictStr +from typing_extensions import Self, TypeAlias, get_args, override + +from wandb._pydantic import GQLBase +from wandb._strutils import nameof + +if TYPE_CHECKING: + from .expressions import FilterExpr + +# for type annotations +Scalar = Union[StrictStr, StrictInt, StrictFloat, StrictBool] +# for runtime type checks +ScalarTypes: tuple[type, ...] = tuple(t.__origin__ for t in get_args(Scalar)) + +# See: https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol +RichReprResult: TypeAlias = Iterable[ + Union[ + Any, + Tuple[Any], + Tuple[str, Any], + Tuple[str, Any, Any], + ] +] + +T = TypeVar("T") +TupleOf: TypeAlias = Tuple[T, ...] + + +# NOTE: Wherever class descriptions that are not docstrings, this is deliberate. +# This is done to ensure the descriptions are omitted from generated API docs. + + +# Mixin class to support building MongoDB expressions idiomatically +# with bitwise logical operators, e.g.: +# `a | b` -> `{"$or": [a, b]}` +# `~a` -> `{"$not": a}` +class SupportsBitwiseLogicalOps: + def __or__(self, other: Any) -> Or: + """Implements default `|` behavior: `a | b -> Or(a, b)`.""" + return Or(exprs=(self, other)) + + def __and__(self, other: Any) -> And: + """Implements default `&` behavior: `a & b -> And(a, b)`.""" + from .expressions import FilterExpr + + if isinstance(other, (BaseOp, FilterExpr)): + return And(exprs=(self, other)) + return NotImplemented + + def __invert__(self) -> Not: + """Implements default `~` behavior: `~a -> Not(a)`.""" + return Not(expr=self) + + +# Base type for parsing MongoDB filter operators, e.g. from dicts like +# `{"$and": [...]}`, `{"$or": [...]}`, `{"$gt": 1.0}`, etc. +# Instances are frozen for easier comparison and more predictable behavior. +class BaseOp(GQLBase, SupportsBitwiseLogicalOps, ABC): + model_config = ConfigDict( + extra="forbid", + frozen=True, + ) + + def __repr__(self) -> str: + """Returns the operator's repr string, with operand(s) as positional args. + + Note that BaseModels implement `__iter__()`: + https://docs.pydantic.dev/latest/concepts/serialization/#iterating-over-models + """ + return f"{nameof(type(self))}({', '.join(repr(v) for _, v in self)})" + + def __rich_repr__(self) -> RichReprResult: + """Returns the operator's rich repr, if pretty-printing via `rich`. + + See: https://rich.readthedocs.io/en/stable/pretty.html + """ + # Display field values as positional args: + yield from ((None, v) for _, v in self) + + +# Base type for logical operators that take a variable number of expressions. +class BaseVariadicLogicalOp(BaseOp, ABC): + exprs: TupleOf[Union[FilterExpr, Op]] + + @classmethod + def wrap(cls, expr: Any) -> Self: + return expr if isinstance(expr, cls) else cls(exprs=(expr,)) + + +# Logical operator(s) +# https://www.mongodb.com/docs/manual/reference/operator/query/and/ +# https://www.mongodb.com/docs/manual/reference/operator/query/or/ +# https://www.mongodb.com/docs/manual/reference/operator/query/nor/ +# https://www.mongodb.com/docs/manual/reference/operator/query/not/ +class And(BaseVariadicLogicalOp): + exprs: TupleOf[Union[FilterExpr, Op]] = Field(default=(), alias="$and") + + +class Or(BaseVariadicLogicalOp): + exprs: TupleOf[Union[FilterExpr, Op]] = Field(default=(), alias="$or") + + @override + def __invert__(self) -> Nor: + """Implements `~Or(a, b) -> Nor(a, b)`.""" + return Nor(exprs=self.exprs) + + +class Nor(BaseVariadicLogicalOp): + exprs: TupleOf[Union[FilterExpr, Op]] = Field(default=(), alias="$nor") + + @override + def __invert__(self) -> Or: + """Implements `~Nor(a, b) -> Or(a, b)`.""" + return Or(exprs=self.exprs) + + +class Not(BaseOp): + expr: Union[FilterExpr, Op] = Field(alias="$not") + + @override + def __invert__(self) -> Union[FilterExpr, Op]: + """Implements `~Not(a) -> a`.""" + return self.expr + + +# Comparison operator(s) +# https://www.mongodb.com/docs/manual/reference/operator/query/lt/ +# https://www.mongodb.com/docs/manual/reference/operator/query/gt/ +# https://www.mongodb.com/docs/manual/reference/operator/query/lte/ +# https://www.mongodb.com/docs/manual/reference/operator/query/gte/ +# https://www.mongodb.com/docs/manual/reference/operator/query/eq/ +# https://www.mongodb.com/docs/manual/reference/operator/query/ne/ +# https://www.mongodb.com/docs/manual/reference/operator/query/in/ +# https://www.mongodb.com/docs/manual/reference/operator/query/nin/ +class Lt(BaseOp): + val: Scalar = Field(alias="$lt") + + @override + def __invert__(self) -> Gte: + """Implements `~Lt(a) -> Gte(a)`.""" + return Gte(val=self.val) + + +class Gt(BaseOp): + val: Scalar = Field(alias="$gt") + + @override + def __invert__(self) -> Lte: + """Implements `~Gt(a) -> Lte(a)`.""" + return Lte(val=self.val) + + +class Lte(BaseOp): + val: Scalar = Field(alias="$lte") + + @override + def __invert__(self) -> Gt: + """Implements `~Lte(a) -> Gt(a)`.""" + return Gt(val=self.val) + + +class Gte(BaseOp): + val: Scalar = Field(alias="$gte") + + @override + def __invert__(self) -> Lt: + """Implements `~Gte(a) -> Lt(a)`.""" + return Lt(val=self.val) + + +class Eq(BaseOp): + val: Scalar = Field(alias="$eq") + + @override + def __invert__(self) -> Ne: + """Implements `~Eq(a) -> Ne(a)`.""" + return Ne(val=self.val) + + +class Ne(BaseOp): + val: Scalar = Field(alias="$ne") + + @override + def __invert__(self) -> Eq: + """Implements `~Ne(a) -> Eq(a)`.""" + return Eq(val=self.val) + + +class In(BaseOp): + val: TupleOf[Scalar] = Field(default=(), alias="$in") + + @override + def __invert__(self) -> NotIn: + """Implements `~In(a) -> NotIn(a)`.""" + return NotIn(val=self.val) + + +class NotIn(BaseOp): + val: TupleOf[Scalar] = Field(default=(), alias="$nin") + + @override + def __invert__(self) -> In: + """Implements `~NotIn(a) -> In(a)`.""" + return In(val=self.val) + + +# Element operator(s) +# https://www.mongodb.com/docs/manual/reference/operator/query/exists/ +class Exists(BaseOp): + val: bool = Field(alias="$exists") + + @override + def __invert__(self) -> Exists: + """Implements `~Exists(True) -> Exists(False)` and vice versa.""" + return Exists(val=not self.val) + + +# Evaluation operator(s) +# https://www.mongodb.com/docs/manual/reference/operator/query/regex/ +# +# Note: `$contains` is NOT a formal MongoDB operator, but the W&B backend +# recognizes and executes it as a substring-match filter. +class Regex(BaseOp): + val: str = Field(alias="$regex") #: The regex expression to match against. + + +class Contains(BaseOp): + val: str = Field(alias="$contains") #: The substring to match against. + + +# ------------------------------------------------------------------------------ +# Convenience helpers, constants, and utils for supported MongoDB operators +# ------------------------------------------------------------------------------ +KEY_TO_OP: dict[str, type[BaseOp]] = { + "$and": And, + "$or": Or, + "$nor": Nor, + "$not": Not, + "$lt": Lt, + "$gt": Gt, + "$lte": Lte, + "$gte": Gte, + "$eq": Eq, + "$ne": Ne, + "$in": In, + "$nin": NotIn, + "$exists": Exists, + "$regex": Regex, + "$contains": Contains, +} + + +# Known, implemented MongoDB operators for type annotations. +Op = Union[ + And, + Or, + Nor, + Not, + Lt, + Gt, + Lte, + Gte, + Eq, + Ne, + In, + NotIn, + Exists, + Regex, + Contains, +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_filters/run_metrics.py b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/run_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..8886a960afa29a0a56b980429eafe9433ab12514 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/run_metrics.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Final, Literal, Optional, Union, overload + +from pydantic import ( + Field, + PositiveFloat, + PositiveInt, + StrictFloat, + StrictInt, + field_validator, +) +from typing_extensions import Annotated, TypeAlias, override + +from wandb._pydantic import GQLBase +from wandb.automations._validators import LenientStrEnum + +from .expressions import FilterExpr +from .operators import BaseOp, RichReprResult + +if TYPE_CHECKING: + from wandb.automations.events import RunMetricFilter + +# Maps MongoDB comparison operators -> Python literal (str) representations +MONGO2PY_OPS: Final[dict[str, str]] = { + "$eq": "==", + "$ne": "!=", + "$gt": ">", + "$lt": "<", + "$gte": ">=", + "$lte": "<=", +} +# Reverse mapping from Python literal (str) -> MongoDB operator key +PY2MONGO_OPS: Final[dict[str, str]] = {v: k for k, v in MONGO2PY_OPS.items()} + +# Type hint for positive numbers (int or float) +PosNum: TypeAlias = Union[PositiveInt, PositiveFloat] + + +class Agg(LenientStrEnum): # from: Aggregation + """Supported run metric aggregation operations.""" + + MAX = "MAX" + MIN = "MIN" + AVERAGE = "AVERAGE" + + # Shorter aliases for convenience + AVG = AVERAGE + + +class ChangeType(LenientStrEnum): # from: RunMetricChangeType + """Describes the type of metric change as absolute or relative. + + ABSOLUTE: The arithmetic difference between the current vs. prior values. + RELATIVE: The percentage change between the current vs. prior values. + """ + + ABSOLUTE = "ABSOLUTE" + RELATIVE = "RELATIVE" + + # Shorter aliases for convenience + ABS = ABSOLUTE + REL = RELATIVE + + +class ChangeDir(LenientStrEnum): # from: RunMetricChangeDirection + """Describes the direction of the metric change.""" + + INCREASE = "INCREASE" + DECREASE = "DECREASE" + ANY = "ANY" + + # Shorter aliases for convenience + INC = INCREASE + DEC = DECREASE + + +class BaseMetricFilter(GQLBase, ABC, extra="forbid"): + name: str + """Name of the observed metric.""" + + agg: Optional[Agg] + """Aggregate operation, if any, to apply over the window size.""" + + window: PositiveInt + """Size of the metric aggregation window (ignored if `agg` is ``None``).""" + + # ------------------------------------------------------------------------------ + cmp: Optional[str] + """Comparison operator between the metric expression (left) vs. the threshold or target (right).""" # noqa: W505 + + # ------------------------------------------------------------------------------ + threshold: Union[StrictInt, StrictFloat] + """Threshold value to compare against.""" + + def __and__(self, other: Any) -> RunMetricFilter: + """Returns `(metric_filter & run_filter)` as a `RunMetricFilter`.""" + from wandb.automations.events import RunMetricFilter + + if isinstance(run_filter := other, (BaseOp, FilterExpr)): + # Treat `other` as a run filter and build a RunMetricEvent. Let the + # metric filter validators wrap or nest as appropriate. + return RunMetricFilter(run=run_filter, metric=self) + return NotImplemented + + def __rand__(self, other: BaseOp | FilterExpr) -> RunMetricFilter: + """Ensures `&` is commutative for run and metric filters. + + I.e. `(run_filter & metric_filter) == (metric_filter & run_filter)`. + """ + return self.__and__(other) + + @abstractmethod + def __repr__(self) -> str: + """Returns the text representation of the metric filter.""" + raise NotImplementedError + + @override + def __rich_repr__(self) -> RichReprResult: + """Returns the `rich` pretty-print representation of the metric filter.""" + # See: https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol + yield None, repr(self) + + +class MetricThresholdFilter(BaseMetricFilter): # from: RunMetricThresholdFilter + """Filter that compares an **absolute** metric value against a user-defined threshold. + + The value may be a single value or an aggregated result over a window of + multiple values. + """ + + name: str + agg: Annotated[Optional[Agg], Field(alias="agg_op")] = None + window: Annotated[PositiveInt, Field(alias="window_size")] = 1 + + cmp: Annotated[Literal["$gte", "$gt", "$lt", "$lte"], Field(alias="cmp_op")] + """Comparison operator between the metric value (left) vs. the threshold (right).""" + + threshold: Union[StrictInt, StrictFloat] + + @field_validator("cmp", mode="before") + def _validate_cmp(cls, v: Any) -> Any: + # Be helpful: e.g. ">" -> "$gt" + return PY2MONGO_OPS.get(v.strip(), v) if isinstance(v, str) else v + + def __repr__(self) -> str: + metric = f"{self.agg.value}({self.name})" if self.agg else self.name + op = MONGO2PY_OPS.get(self.cmp, self.cmp) + return repr(rf"{metric} {op} {self.threshold}") + + +class MetricChangeFilter(BaseMetricFilter): # from: RunMetricChangeFilter + """Filter that compares a **change** in a metric value to a user-defined threshold. + + The change is calculated over "tumbling" windows, i.e. the difference + between the current window and the non-overlapping prior window. + """ + + name: str + agg: Annotated[Optional[Agg], Field(alias="agg_op")] = None + window: Annotated[PositiveInt, Field(alias="current_window_size")] = 1 + + # `prior_window` is only for `RUN_METRIC_CHANGE` events + prior_window: Annotated[ + PositiveInt, + # By default, set `window -> prior_window` if the latter wasn't provided. + Field(alias="prior_window_size", default_factory=lambda data: data["window"]), + ] + """Size of the "prior" metric aggregation window (ignored if `agg` is ``None``). + + If omitted, defaults to the size of the current window. + """ + + # ------------------------------------------------------------------------------ + # NOTE: + # - The "comparison" operator isn't actually part of the backend schema, + # but it's defined here for consistency -- and ignored otherwise. + # - In the backend, it's effectively "$gte" or "$lte", depending on the sign + # (change_dir), though again, this is not explicit in the schema. + cmp: Annotated[None, Field(frozen=True, exclude=True, repr=False)] = None + """Ignored.""" + + # ------------------------------------------------------------------------------ + change_type: ChangeType + change_dir: ChangeDir + threshold: Annotated[PosNum, Field(alias="change_amount")] + + def __repr__(self) -> str: + metric = f"{self.agg.value}({self.name})" if self.agg else self.name + verb = ( + "changes" + if (self.change_dir is ChangeDir.ANY) + else f"{self.change_dir.value.lower()}s" + ) + + fmt_spec = ".2%" if (self.change_type is ChangeType.REL) else "" + amt = f"{self.threshold:{fmt_spec}}" + return repr(rf"{metric} {verb} {amt}") + + +class BaseMetricOperand(GQLBase, ABC, extra="forbid"): + def gt(self, value: int | float, /) -> MetricThresholdFilter: + """Returns a filter that watches for `metric_expr > threshold`.""" + return self > value + + def lt(self, value: int | float, /) -> MetricThresholdFilter: + """Returns a filter that watches for `metric_expr < threshold`.""" + return self < value + + def gte(self, value: int | float, /) -> MetricThresholdFilter: + """Returns a filter that watches for `metric_expr >= threshold`.""" + return self >= value + + def lte(self, value: int | float, /) -> MetricThresholdFilter: + """Returns a filter that watches for `metric_expr <= threshold`.""" + return self <= value + + # Overloads to implement: + # - `(metric_operand > threshold) -> MetricThresholdFilter` + # - `(metric_operand < threshold) -> MetricThresholdFilter` + # - `(metric_operand >= threshold) -> MetricThresholdFilter` + # - `(metric_operand <= threshold) -> MetricThresholdFilter` + def __gt__(self, other: Any) -> MetricThresholdFilter: + if isinstance(other, (int, float)): + return MetricThresholdFilter(**dict(self), cmp="$gt", threshold=other) + return NotImplemented + + def __lt__(self, other: Any) -> MetricThresholdFilter: + if isinstance(other, (int, float)): + return MetricThresholdFilter(**dict(self), cmp="$lt", threshold=other) + return NotImplemented + + def __ge__(self, other: Any) -> MetricThresholdFilter: + if isinstance(other, (int, float)): + return MetricThresholdFilter(**dict(self), cmp="$gte", threshold=other) + return NotImplemented + + def __le__(self, other: Any) -> MetricThresholdFilter: + if isinstance(other, (int, float)): + return MetricThresholdFilter(**dict(self), cmp="$lte", threshold=other) + return NotImplemented + + @overload + def changes_by(self, *, diff: PosNum, frac: None) -> MetricChangeFilter: ... + + @overload + def changes_by(self, *, diff: None, frac: PosNum) -> MetricChangeFilter: ... + + @overload # NOTE: This overload is for internal use only. + def changes_by( + self, *, diff: PosNum | None, frac: PosNum | None, _dir: ChangeDir + ) -> MetricChangeFilter: ... + + def changes_by( + self, + *, + diff: PosNum | None = None, + frac: PosNum | None = None, + _dir: ChangeDir = ChangeDir.ANY, + ) -> MetricChangeFilter: + """Returns a filter that watches for a numerical increase OR decrease in a metric. + + Exactly one of `frac` or `diff` must be provided. + + Args: + diff: If given, arithmetic difference that must be observed in the metric. + Must be positive. + frac: If given, fractional (relative) change that must be observed in the + metric. Must be positive. For example, `frac=0.1` denotes a 10% relative + increase or decrease. + """ + if ( + # Enforce mutually exclusive keyword args + ((frac is None) and (diff is None)) + or ((frac is not None) and (diff is not None)) + ): + raise ValueError("Must provide exactly one of `frac` or `diff`") + + # Enforce positive values + if (frac is not None) and (frac <= 0): + raise ValueError(f"Expected positive threshold, got: {frac=}") + if (diff is not None) and (diff <= 0): + raise ValueError(f"Expected positive threshold, got: {diff=}") + + if diff is None: + kws = dict(change_dir=_dir, change_type=ChangeType.REL, threshold=frac) + else: + kws = dict(change_dir=_dir, change_type=ChangeType.ABS, threshold=diff) + return MetricChangeFilter(**dict(self), **kws) + + @overload + def increases_by(self, *, diff: PosNum, frac: None) -> MetricChangeFilter: ... + + @overload + def increases_by(self, *, diff: None, frac: PosNum) -> MetricChangeFilter: ... + + def increases_by( + self, *, diff: PosNum | None = None, frac: PosNum | None = None + ) -> MetricChangeFilter: + """Returns a filter that watches for a numerical increase in a metric. + + Arguments mirror those of `.changes_by()`. + """ + return self.changes_by(diff=diff, frac=frac, _dir=ChangeDir.INC) + + @overload + def decreases_by(self, *, diff: PosNum, frac: None) -> MetricChangeFilter: ... + + @overload + def decreases_by(self, *, diff: None, frac: PosNum) -> MetricChangeFilter: ... + + def decreases_by( + self, *, diff: PosNum | None = None, frac: PosNum | None = None + ) -> MetricChangeFilter: + """Returns a filter that watches for a numerical decrease in a metric. + + Arguments mirror those of `.changes_by()`. + """ + return self.changes_by(diff=diff, frac=frac, _dir=ChangeDir.DEC) + + +class MetricVal(BaseMetricOperand): + """Represents a single metric value when defining metric event filters.""" + + name: str + + # Allow conversion of a single-value metric into an aggregated expression. + def max(self, window: int) -> MetricAgg: + return MetricAgg(name=self.name, agg=Agg.MAX, window=window) + + def min(self, window: int) -> MetricAgg: + return MetricAgg(name=self.name, agg=Agg.MIN, window=window) + + def avg(self, window: int) -> MetricAgg: + return MetricAgg(name=self.name, agg=Agg.AVG, window=window) + + # Aliased method for users familiar with e.g. torch/tf/numpy/pandas/polars/etc. + def mean(self, window: int) -> MetricAgg: + return self.avg(window=window) + + def zscore(self, window: int) -> ZScoreMetricOperand: + """Returns a z-score metric builder for fluent filter construction. + + Use with comparison operators to create z-score filters: + - `metric.zscore(30) > 3` - detects z-score increases above 3 std devs + - `metric.zscore(30) < -3` - detects z-score decreases below -3 std devs + - `metric.zscore(30).abs() > 3` - detects abs z-score deviations above 3 std devs + + Note: + - The `>=` operator behaves the same as `>`, and `<=` behaves the same as `<`. + """ + return ZScoreMetricOperand(name=self.name, window=window) + + +class MetricAgg(BaseMetricOperand): + """Represents an aggregated metric value when defining metric event filters.""" + + name: str + agg: Annotated[Agg, Field(alias="agg_op")] + window: Annotated[PositiveInt, Field(alias="window_size")] + + +class ZScoreMetricOperand(GQLBase, extra="forbid"): + """Helper class to build z-score metric filters with comparison operators. + + This class enables fluent construction of z-score filters using Python + comparison operators (>, <, >=, <=) and the builtin abs() function. + + Note: When defining a z-score threshold, the `>` and `>=` operators are + interchangeable, as are the `<=` and `<` operators, since the z-score defines + a threshold on a continuous value. At runtime, the filter is evaluated + using the inclusive operators (`>=` or `<=`). + """ + + name: str + """Name of the metric to monitor.""" + + window: PositiveInt + """Size of the window to calculate the metric mean and standard deviation over.""" + + is_absolute: bool = Field(default=False, repr=False) + """Whether to check the absolute value of the z-score (ignoring direction).""" + + def lt(self, value: int | float, /) -> MetricZScoreFilter: + """Returns a filter that watches for `zscore(metric) < -threshold`. + + Args: + value: The z-score threshold value to compare against. + The absolute value is used as the threshold. + """ + if self.is_absolute: + raise ValueError("Cannot use absolute z-score with < operator") + + if value >= 0: + raise ValueError("Negative z-score threshold required") + + return MetricZScoreFilter( + name=self.name, + window=self.window, + change_dir=ChangeDir.DECREASE, + threshold=abs(value), + ) + + def __lt__(self, value: int | float, /) -> MetricZScoreFilter: + return self.lt(value) + + def __le__(self, value: int | float, /) -> MetricZScoreFilter: + """Alias for `<` operator - behaves identically to `__lt__`. + + Returns a filter that watches for `zscore(metric) < -threshold`. + Note: `<=` and `<` are treated as equivalent for z-score filters. + """ + return self.lt(value) + + def gt(self, value: int | float, /) -> MetricZScoreFilter: + """Returns a filter that watches for `zscore(metric) > threshold`. + + If `is_absolute` is True, watches for `abs(zscore(metric)) > threshold`. + + Args: + value: The z-score threshold value to compare against. + The absolute value is used as the threshold. + """ + if value <= 0: + raise ValueError(f"Expected positive threshold, got: {value=}") + + return MetricZScoreFilter( + name=self.name, + window=self.window, + change_dir=ChangeDir.ANY if self.is_absolute else ChangeDir.INCREASE, + threshold=abs(value), + ) + + def __gt__(self, value: int | float, /) -> MetricZScoreFilter: + return self.gt(value) + + def __ge__(self, value: int | float, /) -> MetricZScoreFilter: + """Alias for `>` operator - behaves identically to `__gt__`. + + Returns a filter that watches for `zscore(metric) > threshold`. + If `is_absolute` is True, watches for `abs(zscore(metric)) > threshold`. + Note: `>=` and `>` are treated as equivalent for z-score filters. + """ + return self.gt(value) + + def __abs__(self) -> ZScoreMetricOperand: + """Returns a z-score filter that checks the absolute value. + + This allows watching for z-score deviations in either direction. + Use with comparison operators: `abs(metric.zscore(window)) > threshold`. + """ + return self.model_copy(update={"is_absolute": True}) + + def abs(self) -> ZScoreMetricOperand: + """Returns a z-score filter that checks the absolute value. + + Alias for `__abs__()` that can be called as a method. + Allows using either `abs(zscore)` or `zscore.abs()`. + """ + return self.__abs__() + + +class MetricZScoreFilter(GQLBase, extra="forbid"): + """Filter that compares a metric's z-score against a user-defined threshold.""" + + name: str + """Name of the observed metric.""" + + window: Annotated[PositiveInt, Field(alias="window_size")] = 30 + """Size of the window to calculate the metric mean and standard deviation over.""" + + threshold: PosNum = 3.0 + """Threshold for the z-score.""" + + change_dir: ChangeDir = ChangeDir.ANY + """Direction of the z-score change to watch for.""" + + def __and__(self, other: Any) -> RunMetricFilter: + """Returns `(metric_filter & run_filter)` as a `RunMetricFilter`.""" + from wandb.automations.events import RunMetricFilter + + if isinstance(run_filter := other, (BaseOp, FilterExpr)): + # Treat `other` as a run filter and build a RunMetricEvent. Let the + # metric filter validators wrap or nest as appropriate. + return RunMetricFilter(run=run_filter, metric=self) + return NotImplemented + + def __rand__(self, other: BaseOp | FilterExpr) -> RunMetricFilter: + """Ensures `&` is commutative for run and metric filters. + + I.e. `(run_filter & metric_filter) == (metric_filter & run_filter)`. + """ + return self.__and__(other) + + def __repr__(self) -> str: + if self.change_dir is ChangeDir.ANY: + return repr(rf"abs(zscore({self.name!r})) > {self.threshold}") + elif self.change_dir is ChangeDir.DECREASE: + return repr(rf"zscore({self.name!r}) < -{self.threshold}") + else: # ChangeDir.INCREASE + return repr(rf"zscore({self.name!r}) > +{self.threshold}") + + @override + def __rich_repr__(self) -> RichReprResult: + """Returns the `rich` pretty-print representation of the metric filter.""" + # See: https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol + yield None, repr(self) diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_filters/run_states.py b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/run_states.py new file mode 100644 index 0000000000000000000000000000000000000000..190c9f7cd869b4fc400ce270447bbc54336ca8a1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_filters/run_states.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable + +from pydantic import BeforeValidator +from typing_extensions import Annotated + +from wandb._iterutils import always_list +from wandb._pydantic import GQLBase, field_validator +from wandb.automations._validators import LenientStrEnum + +from .expressions import FilterExpr +from .operators import BaseOp + +if TYPE_CHECKING: + from wandb.automations.events import EventType, RunStateFilter + + +class ReportedRunState(LenientStrEnum): # from: StateToReport + RUNNING = "RUNNING" + FINISHED = "FINISHED" + FAILED = "FAILED" + + # Convenience aliases that are equivalent when *creating* or *editing* + # the triggering event for a run state automation. + # NOTE: These may still be reported as distinct values from an *executed* automation. + CRASHED = FAILED + + +class StateFilter(GQLBase): # from: RunStateFilter + states: Annotated[ + list[ReportedRunState], + BeforeValidator(always_list), # Coerce x -> [x] if passed a single value + ] + + @property + def event_type(self) -> EventType: + return EventType.RUN_STATE + + @field_validator("states", mode="after") + @classmethod + def _dedup_and_order(cls, v: list[ReportedRunState]) -> list[ReportedRunState]: + """Ensure states are deduplicated and predictably ordered.""" + return sorted(set(v)) + + def __and__(self, other: Any) -> RunStateFilter: + """Returns `(state_filter & run_filter)` as a `RunStateFilter`.""" + from wandb.automations.events import RunStateFilter + + if isinstance(run_filter := other, (BaseOp, FilterExpr)): + # Treat `other` as a run filter and build a RunStateFilter. Let the + # metric filter validators wrap or nest as appropriate. + return RunStateFilter(run=run_filter, state=self) + return NotImplemented + + def __rand__(self, other: BaseOp | FilterExpr) -> RunStateFilter: + """Ensures `&` is commutative for run and state filters. + + I.e. `(run_filter & state_filter) == (state_filter & run_filter)`. + """ + return self.__and__(other) + + +class StateOperand(GQLBase): + """Descriptor type, returned on accessing `RunEvent.state`. + + Necessary in order to handle constructing the custom structure for run state filters. + """ + + def __get__(self, obj: Any, objtype: type) -> StateOperand: + return self + + def eq(self, state: str | ReportedRunState, /) -> StateFilter: + """Returns a filter that watches for `run_state == state`.""" + return StateFilter(states=[state]) + + def in_(self, states: Iterable[str | ReportedRunState], /) -> StateFilter: + """Returns a filter that watches for `run_state in states`.""" + return StateFilter(states=states) + + def __eq__(self, other: Any) -> StateFilter: + if isinstance(other, (str, ReportedRunState)): + return self.eq(other) + raise TypeError(f"Invalid operand type in run state filter: {type(other)!r}") diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/__init__.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0342618e569afb4235b5d6df6b190b72a53752a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/__init__.py @@ -0,0 +1,90 @@ +# Generated by ariadne-codegen + +__all__ = [ + "CREATE_AUTOMATION_GQL", + "CREATE_GENERIC_WEBHOOK_INTEGRATION_GQL", + "DELETE_AUTOMATION_GQL", + "GET_AUTOMATIONS_BY_ENTITY_GQL", + "GET_AUTOMATIONS_GQL", + "INTEGRATIONS_BY_ENTITY_GQL", + "UPDATE_AUTOMATION_GQL", + "GetAutomations", + "GetAutomationsByEntity", + "CreateAutomation", + "UpdateAutomation", + "DeleteAutomation", + "IntegrationsByEntity", + "CreateGenericWebhookIntegration", + "CreateFilterTriggerInput", + "CreateGenericWebhookIntegrationInput", + "GenericWebhookActionInput", + "NoOpTriggeredActionInput", + "NotificationActionInput", + "QueueJobActionInput", + "TriggeredActionConfig", + "UpdateFilterTriggerInput", + "ArtifactPortfolioScopeFields", + "ArtifactSequenceScopeFields", + "FilterEventFields", + "GenericWebhookActionFields", + "NoOpActionFields", + "NotificationActionFields", + "PageInfoFields", + "ProjectScopeFields", + "ProjectTriggersFields", + "QueueJobActionFields", + "SlackIntegrationFields", + "TriggerFields", + "WebhookIntegrationFields", + "AlertSeverity", + "EventTriggeringConditionType", + "TriggerScopeType", + "TriggeredActionType", +] +from .create_automation import CreateAutomation +from .create_generic_webhook_integration import CreateGenericWebhookIntegration +from .delete_automation import DeleteAutomation +from .enums import ( + AlertSeverity, + EventTriggeringConditionType, + TriggeredActionType, + TriggerScopeType, +) +from .fragments import ( + ArtifactPortfolioScopeFields, + ArtifactSequenceScopeFields, + FilterEventFields, + GenericWebhookActionFields, + NoOpActionFields, + NotificationActionFields, + PageInfoFields, + ProjectScopeFields, + ProjectTriggersFields, + QueueJobActionFields, + SlackIntegrationFields, + TriggerFields, + WebhookIntegrationFields, +) +from .get_automations import GetAutomations +from .get_automations_by_entity import GetAutomationsByEntity +from .input_types import ( + CreateFilterTriggerInput, + CreateGenericWebhookIntegrationInput, + GenericWebhookActionInput, + NoOpTriggeredActionInput, + NotificationActionInput, + QueueJobActionInput, + TriggeredActionConfig, + UpdateFilterTriggerInput, +) +from .integrations_by_entity import IntegrationsByEntity +from .operations import ( + CREATE_AUTOMATION_GQL, + CREATE_GENERIC_WEBHOOK_INTEGRATION_GQL, + DELETE_AUTOMATION_GQL, + GET_AUTOMATIONS_BY_ENTITY_GQL, + GET_AUTOMATIONS_GQL, + INTEGRATIONS_BY_ENTITY_GQL, + UPDATE_AUTOMATION_GQL, +) +from .update_automation import UpdateAutomation diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/create_automation.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/create_automation.py new file mode 100644 index 0000000000000000000000000000000000000000..b755ab86cf0e5eaf247a0c4e4534106c0120bcf9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/create_automation.py @@ -0,0 +1,22 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import TriggerFields + + +class CreateAutomation(GQLResult): + result: Optional[CreateAutomationResult] + + +class CreateAutomationResult(GQLResult): + trigger: Optional[TriggerFields] + + +CreateAutomation.model_rebuild() +CreateAutomationResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/create_generic_webhook_integration.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/create_generic_webhook_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..89fc048a53653ec3e3d9b438bf106eaf63137797 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/create_generic_webhook_integration.py @@ -0,0 +1,38 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from typing import Optional, Union + +from pydantic import Field +from typing_extensions import Literal + +from wandb._pydantic import GQLResult, Typename + +from .fragments import WebhookIntegrationFields + + +class CreateGenericWebhookIntegration(GQLResult): + create_generic_webhook_integration: Optional[ + CreateGenericWebhookIntegrationCreateGenericWebhookIntegration + ] = Field(alias="createGenericWebhookIntegration") + + +class CreateGenericWebhookIntegrationCreateGenericWebhookIntegration(GQLResult): + integration: Union[ + CreateGenericWebhookIntegrationCreateGenericWebhookIntegrationIntegrationIntegration, + WebhookIntegrationFields, + ] = Field(discriminator="typename__") + + +class CreateGenericWebhookIntegrationCreateGenericWebhookIntegrationIntegrationIntegration( + GQLResult +): + typename__: Typename[ + Literal["GitHubOAuthIntegration", "Integration", "SlackIntegration"] + ] + + +CreateGenericWebhookIntegration.model_rebuild() +CreateGenericWebhookIntegrationCreateGenericWebhookIntegration.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/delete_automation.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/delete_automation.py new file mode 100644 index 0000000000000000000000000000000000000000..d8d9dc39ac66107759b19c03a4b19d276c650b63 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/delete_automation.py @@ -0,0 +1,17 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from wandb._pydantic import GQLResult + + +class DeleteAutomation(GQLResult): + result: DeleteAutomationResult + + +class DeleteAutomationResult(GQLResult): + success: bool + + +DeleteAutomation.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/enums.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..e28838f2743265fd5e32160b929556343de170e1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/enums.py @@ -0,0 +1,36 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations + +from enum import Enum + + +class AlertSeverity(str, Enum): + INFO = "INFO" + WARN = "WARN" + ERROR = "ERROR" + + +class TriggerScopeType(str, Enum): + PROJECT = "PROJECT" + ARTIFACT_COLLECTION = "ARTIFACT_COLLECTION" + + +class EventTriggeringConditionType(str, Enum): + CREATE_ARTIFACT = "CREATE_ARTIFACT" + UPDATE_ARTIFACT_ALIAS = "UPDATE_ARTIFACT_ALIAS" + ADD_ARTIFACT_ALIAS = "ADD_ARTIFACT_ALIAS" + ADD_ARTIFACT_TAG = "ADD_ARTIFACT_TAG" + LINK_MODEL = "LINK_MODEL" + RUN_METRIC = "RUN_METRIC" + RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE" + RUN_STATE = "RUN_STATE" + RUN_METRIC_ZSCORE = "RUN_METRIC_ZSCORE" + + +class TriggeredActionType(str, Enum): + QUEUE_JOB = "QUEUE_JOB" + NOTIFICATION = "NOTIFICATION" + GENERIC_WEBHOOK = "GENERIC_WEBHOOK" + NO_OP = "NO_OP" diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/fragments.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/fragments.py new file mode 100644 index 0000000000000000000000000000000000000000..51477586a0e0114087c7935d9080d20794815560 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/fragments.py @@ -0,0 +1,165 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from datetime import datetime +from typing import List, Optional, Union + +from pydantic import Field +from typing_extensions import Literal + +from wandb._pydantic import GQLId, GQLResult, Typename + +from .enums import AlertSeverity, EventTriggeringConditionType + + +class ArtifactPortfolioScopeFields(GQLResult): + typename__: Typename[Literal["ArtifactPortfolio"]] = "ArtifactPortfolio" + id: GQLId + name: str + + +class ArtifactSequenceScopeFields(GQLResult): + typename__: Typename[Literal["ArtifactSequence"]] = "ArtifactSequence" + id: GQLId + name: str + + +class FilterEventFields(GQLResult): + typename__: Typename[Literal["FilterEventTriggeringCondition"]] = ( + "FilterEventTriggeringCondition" + ) + event_type: EventTriggeringConditionType = Field(alias="eventType") + filter: str + + +class WebhookIntegrationFields(GQLResult): + typename__: Typename[Literal["GenericWebhookIntegration"]] = ( + "GenericWebhookIntegration" + ) + id: GQLId + name: str + url_endpoint: str = Field(alias="urlEndpoint") + + +class GenericWebhookActionFields(GQLResult): + typename__: Typename[Literal["GenericWebhookTriggeredAction"]] = ( + "GenericWebhookTriggeredAction" + ) + integration: Union[ + GenericWebhookActionFieldsIntegrationIntegration, WebhookIntegrationFields + ] = Field(discriminator="typename__") + request_payload: Optional[str] = Field(alias="requestPayload") + + +class GenericWebhookActionFieldsIntegrationIntegration(GQLResult): + typename__: Typename[ + Literal["GitHubOAuthIntegration", "Integration", "SlackIntegration"] + ] + + +class NoOpActionFields(GQLResult): + typename__: Typename[Literal["NoOpTriggeredAction"]] = "NoOpTriggeredAction" + no_op: Optional[bool] = Field(alias="noOp") + + +class SlackIntegrationFields(GQLResult): + typename__: Typename[Literal["SlackIntegration"]] = "SlackIntegration" + id: GQLId + team_name: str = Field(alias="teamName") + channel_name: str = Field(alias="channelName") + + +class NotificationActionFields(GQLResult): + typename__: Typename[Literal["NotificationTriggeredAction"]] = ( + "NotificationTriggeredAction" + ) + integration: Union[ + NotificationActionFieldsIntegrationIntegration, SlackIntegrationFields + ] = Field(discriminator="typename__") + title: Optional[str] + message: Optional[str] + severity: Optional[AlertSeverity] + + +class NotificationActionFieldsIntegrationIntegration(GQLResult): + typename__: Typename[ + Literal["GenericWebhookIntegration", "GitHubOAuthIntegration", "Integration"] + ] + + +class PageInfoFields(GQLResult): + end_cursor: Optional[str] = Field(alias="endCursor") + has_next_page: bool = Field(alias="hasNextPage") + + +class ProjectScopeFields(GQLResult): + typename__: Typename[Literal["Project"]] = "Project" + id: GQLId + name: str + + +class QueueJobActionFields(GQLResult): + typename__: Typename[Literal["QueueJobTriggeredAction"]] = "QueueJobTriggeredAction" + queue: Optional[QueueJobActionFieldsQueue] + template: str + + +class QueueJobActionFieldsQueue(GQLResult): + id: GQLId + name: str + + +class TriggerFields(GQLResult): + typename__: Typename[Literal["Trigger"]] = "Trigger" + id: GQLId + created_at: datetime = Field(alias="createdAt") + updated_at: Optional[datetime] = Field(alias="updatedAt") + name: str + description: Optional[str] + enabled: bool + scope: Union[ + ProjectScopeFields, ArtifactSequenceScopeFields, ArtifactPortfolioScopeFields + ] = Field(discriminator="typename__") + event: FilterEventFields + action: Union[ + QueueJobActionFields, + NotificationActionFields, + GenericWebhookActionFields, + NoOpActionFields, + ] = Field(discriminator="typename__") + + +class ProjectTriggersFields(GQLResult): + typename__: Typename[Literal["Project"]] = "Project" + triggers: List[TriggerFields] + + +ArtifactPortfolioScopeFields.model_rebuild() +ArtifactSequenceScopeFields.model_rebuild() +FilterEventFields.model_rebuild() +WebhookIntegrationFields.model_rebuild() +GenericWebhookActionFields.model_rebuild() +GenericWebhookActionFieldsIntegrationIntegration.model_rebuild() +WebhookIntegrationFields.model_rebuild() +NoOpActionFields.model_rebuild() +SlackIntegrationFields.model_rebuild() +NotificationActionFields.model_rebuild() +NotificationActionFieldsIntegrationIntegration.model_rebuild() +SlackIntegrationFields.model_rebuild() +PageInfoFields.model_rebuild() +ProjectScopeFields.model_rebuild() +QueueJobActionFields.model_rebuild() +QueueJobActionFieldsQueue.model_rebuild() +TriggerFields.model_rebuild() +ProjectScopeFields.model_rebuild() +ArtifactSequenceScopeFields.model_rebuild() +ArtifactPortfolioScopeFields.model_rebuild() +FilterEventFields.model_rebuild() +QueueJobActionFields.model_rebuild() +NotificationActionFields.model_rebuild() +GenericWebhookActionFields.model_rebuild() +NoOpActionFields.model_rebuild() +ProjectTriggersFields.model_rebuild() +TriggerFields.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/get_automations.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/get_automations.py new file mode 100644 index 0000000000000000000000000000000000000000..cf799f1ba0a368afc8c3801b69a8b0348c88918b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/get_automations.py @@ -0,0 +1,35 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import PageInfoFields, ProjectTriggersFields + + +class GetAutomations(GQLResult): + scope: Optional[GetAutomationsScope] + + +class GetAutomationsScope(GQLResult): + projects: Optional[GetAutomationsScopeProjects] + + +class GetAutomationsScopeProjects(GQLResult): + page_info: PageInfoFields = Field(alias="pageInfo") + edges: List[GetAutomationsScopeProjectsEdges] + + +class GetAutomationsScopeProjectsEdges(GQLResult): + node: Optional[ProjectTriggersFields] + + +GetAutomations.model_rebuild() +GetAutomationsScope.model_rebuild() +GetAutomationsScopeProjects.model_rebuild() +GetAutomationsScopeProjectsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/get_automations_by_entity.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/get_automations_by_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..2faa3a8a123d0f3b57fc411f1a1152b491ed93cc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/get_automations_by_entity.py @@ -0,0 +1,35 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import PageInfoFields, ProjectTriggersFields + + +class GetAutomationsByEntity(GQLResult): + scope: Optional[GetAutomationsByEntityScope] + + +class GetAutomationsByEntityScope(GQLResult): + projects: Optional[GetAutomationsByEntityScopeProjects] + + +class GetAutomationsByEntityScopeProjects(GQLResult): + page_info: PageInfoFields = Field(alias="pageInfo") + edges: List[GetAutomationsByEntityScopeProjectsEdges] + + +class GetAutomationsByEntityScopeProjectsEdges(GQLResult): + node: Optional[ProjectTriggersFields] + + +GetAutomationsByEntity.model_rebuild() +GetAutomationsByEntityScope.model_rebuild() +GetAutomationsByEntityScopeProjects.model_rebuild() +GetAutomationsByEntityScopeProjectsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/input_types.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/input_types.py new file mode 100644 index 0000000000000000000000000000000000000000..6864ac97c08ddec67d13a8f23e16561a1176db05 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/input_types.py @@ -0,0 +1,104 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLId, GQLInput + +from .enums import ( + AlertSeverity, + EventTriggeringConditionType, + TriggeredActionType, + TriggerScopeType, +) + + +class CreateGenericWebhookIntegrationInput(GQLInput): + entity_name: str = Field(alias="entityName") + url_endpoint: str = Field(alias="urlEndpoint") + name: str = Field(max_length=64, pattern="^[-\\w]+([ ]+[-\\w]+)*$") + secret_ref: Optional[str] = Field(alias="secretRef", default=None) + access_token_ref: Optional[str] = Field(alias="accessTokenRef", default=None) + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class QueueJobActionInput(GQLInput): + queue_id: GQLId = Field(alias="queueID") + template: str + + +class NotificationActionInput(GQLInput): + integration_id: GQLId = Field(alias="integrationID") + title: Optional[str] = None + message: Optional[str] = None + severity: Optional[AlertSeverity] = None + + +class GenericWebhookActionInput(GQLInput): + integration_id: GQLId = Field(alias="integrationID") + request_payload: Optional[str] = Field(alias="requestPayload", default=None) + + +class NoOpTriggeredActionInput(GQLInput): + no_op: Optional[bool] = Field(alias="noOp", default=None) + + +class TriggeredActionConfig(GQLInput): + queue_job_action_input: Optional[QueueJobActionInput] = Field( + alias="queueJobActionInput", default=None + ) + notification_action_input: Optional[NotificationActionInput] = Field( + alias="notificationActionInput", default=None + ) + generic_webhook_action_input: Optional[GenericWebhookActionInput] = Field( + alias="genericWebhookActionInput", default=None + ) + no_op_action_input: Optional[NoOpTriggeredActionInput] = Field( + alias="noOpActionInput", default=None + ) + + +class CreateFilterTriggerInput(GQLInput): + name: str = Field(max_length=255) + description: Optional[str] = None + triggering_event_type: EventTriggeringConditionType = Field( + alias="triggeringEventType" + ) + scope_type: TriggerScopeType = Field(alias="scopeType") + scope_id: GQLId = Field(alias="scopeID") + event_filter: str = Field(alias="eventFilter") + triggered_action_type: TriggeredActionType = Field(alias="triggeredActionType") + triggered_action_config: TriggeredActionConfig = Field( + alias="triggeredActionConfig" + ) + enabled: bool + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class UpdateFilterTriggerInput(GQLInput): + id: GQLId + name: Optional[str] = Field(default=None, max_length=255) + description: Optional[str] = None + triggering_event_type: Optional[EventTriggeringConditionType] = Field( + alias="triggeringEventType", default=None + ) + scope_type: Optional[TriggerScopeType] = Field(alias="scopeType", default=None) + scope_id: Optional[GQLId] = Field(alias="scopeID", default=None) + event_filter: Optional[str] = Field(alias="eventFilter", default=None) + triggered_action_type: Optional[TriggeredActionType] = Field( + alias="triggeredActionType", default=None + ) + triggered_action_config: Optional[TriggeredActionConfig] = Field( + alias="triggeredActionConfig", default=None + ) + enabled: Optional[bool] = None + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +TriggeredActionConfig.model_rebuild() +CreateFilterTriggerInput.model_rebuild() +UpdateFilterTriggerInput.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/integrations_by_entity.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/integrations_by_entity.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4842ffda6de4edd90a11b08c0474449747c47b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/integrations_by_entity.py @@ -0,0 +1,49 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from typing import List, Optional, Union + +from pydantic import Field +from typing_extensions import Annotated, Literal + +from wandb._pydantic import GQLResult, Typename + +from .fragments import PageInfoFields, SlackIntegrationFields, WebhookIntegrationFields + + +class IntegrationsByEntity(GQLResult): + entity: Optional[IntegrationsByEntityEntity] + + +class IntegrationsByEntityEntity(GQLResult): + integrations: Optional[IntegrationsByEntityEntityIntegrations] + + +class IntegrationsByEntityEntityIntegrations(GQLResult): + page_info: PageInfoFields = Field(alias="pageInfo") + edges: List[IntegrationsByEntityEntityIntegrationsEdges] + + +class IntegrationsByEntityEntityIntegrationsEdges(GQLResult): + node: Optional[ + Annotated[ + Union[ + IntegrationsByEntityEntityIntegrationsEdgesNodeIntegration, + WebhookIntegrationFields, + SlackIntegrationFields, + ], + Field(discriminator="typename__"), + ] + ] + + +class IntegrationsByEntityEntityIntegrationsEdgesNodeIntegration(GQLResult): + typename__: Typename[Literal["GitHubOAuthIntegration", "Integration"]] + + +IntegrationsByEntity.model_rebuild() +IntegrationsByEntityEntity.model_rebuild() +IntegrationsByEntityEntityIntegrations.model_rebuild() +IntegrationsByEntityEntityIntegrationsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/operations.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..48c13dc14e02ac088dc92741805a406338aa3112 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/operations.py @@ -0,0 +1,530 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +__all__ = [ + "CREATE_AUTOMATION_GQL", + "CREATE_GENERIC_WEBHOOK_INTEGRATION_GQL", + "DELETE_AUTOMATION_GQL", + "GET_AUTOMATIONS_BY_ENTITY_GQL", + "GET_AUTOMATIONS_GQL", + "INTEGRATIONS_BY_ENTITY_GQL", + "UPDATE_AUTOMATION_GQL", +] + +GET_AUTOMATIONS_GQL = """ +query GetAutomations($cursor: String, $perPage: Int) { + scope: viewer { + projects(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFields + } + edges { + node { + ...ProjectTriggersFields + } + } + } + } +} + +fragment ArtifactPortfolioScopeFields on ArtifactPortfolio { + __typename + id + name +} + +fragment ArtifactSequenceScopeFields on ArtifactSequence { + __typename + id + name +} + +fragment FilterEventFields on FilterEventTriggeringCondition { + __typename + eventType + filter +} + +fragment GenericWebhookActionFields on GenericWebhookTriggeredAction { + __typename + integration { + ...WebhookIntegrationFields + } + requestPayload +} + +fragment NoOpActionFields on NoOpTriggeredAction { + __typename + noOp +} + +fragment NotificationActionFields on NotificationTriggeredAction { + __typename + integration { + ...SlackIntegrationFields + } + title + message + severity +} + +fragment PageInfoFields on PageInfo { + endCursor + hasNextPage +} + +fragment ProjectScopeFields on Project { + __typename + id + name +} + +fragment ProjectTriggersFields on Project { + __typename + triggers { + ...TriggerFields + } +} + +fragment QueueJobActionFields on QueueJobTriggeredAction { + __typename + queue { + id + name + } + template +} + +fragment SlackIntegrationFields on SlackIntegration { + __typename + id + teamName + channelName +} + +fragment TriggerFields on Trigger { + __typename + id + createdAt + updatedAt + name + description + enabled + scope { + ...ProjectScopeFields + ...ArtifactPortfolioScopeFields + ...ArtifactSequenceScopeFields + } + event: triggeringCondition { + ...FilterEventFields + } + action: triggeredAction { + ...QueueJobActionFields + ...NotificationActionFields + ...GenericWebhookActionFields + ...NoOpActionFields + } +} + +fragment WebhookIntegrationFields on GenericWebhookIntegration { + __typename + id + name + urlEndpoint +} +""" + +GET_AUTOMATIONS_BY_ENTITY_GQL = """ +query GetAutomationsByEntity($entity: String!, $cursor: String, $perPage: Int) { + scope: entity(name: $entity) { + projects(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFields + } + edges { + node { + ...ProjectTriggersFields + } + } + } + } +} + +fragment ArtifactPortfolioScopeFields on ArtifactPortfolio { + __typename + id + name +} + +fragment ArtifactSequenceScopeFields on ArtifactSequence { + __typename + id + name +} + +fragment FilterEventFields on FilterEventTriggeringCondition { + __typename + eventType + filter +} + +fragment GenericWebhookActionFields on GenericWebhookTriggeredAction { + __typename + integration { + ...WebhookIntegrationFields + } + requestPayload +} + +fragment NoOpActionFields on NoOpTriggeredAction { + __typename + noOp +} + +fragment NotificationActionFields on NotificationTriggeredAction { + __typename + integration { + ...SlackIntegrationFields + } + title + message + severity +} + +fragment PageInfoFields on PageInfo { + endCursor + hasNextPage +} + +fragment ProjectScopeFields on Project { + __typename + id + name +} + +fragment ProjectTriggersFields on Project { + __typename + triggers { + ...TriggerFields + } +} + +fragment QueueJobActionFields on QueueJobTriggeredAction { + __typename + queue { + id + name + } + template +} + +fragment SlackIntegrationFields on SlackIntegration { + __typename + id + teamName + channelName +} + +fragment TriggerFields on Trigger { + __typename + id + createdAt + updatedAt + name + description + enabled + scope { + ...ProjectScopeFields + ...ArtifactPortfolioScopeFields + ...ArtifactSequenceScopeFields + } + event: triggeringCondition { + ...FilterEventFields + } + action: triggeredAction { + ...QueueJobActionFields + ...NotificationActionFields + ...GenericWebhookActionFields + ...NoOpActionFields + } +} + +fragment WebhookIntegrationFields on GenericWebhookIntegration { + __typename + id + name + urlEndpoint +} +""" + +CREATE_AUTOMATION_GQL = """ +mutation CreateAutomation($input: CreateFilterTriggerInput!) { + result: createFilterTrigger(input: $input) { + trigger { + ...TriggerFields + } + } +} + +fragment ArtifactPortfolioScopeFields on ArtifactPortfolio { + __typename + id + name +} + +fragment ArtifactSequenceScopeFields on ArtifactSequence { + __typename + id + name +} + +fragment FilterEventFields on FilterEventTriggeringCondition { + __typename + eventType + filter +} + +fragment GenericWebhookActionFields on GenericWebhookTriggeredAction { + __typename + integration { + ...WebhookIntegrationFields + } + requestPayload +} + +fragment NoOpActionFields on NoOpTriggeredAction { + __typename + noOp +} + +fragment NotificationActionFields on NotificationTriggeredAction { + __typename + integration { + ...SlackIntegrationFields + } + title + message + severity +} + +fragment ProjectScopeFields on Project { + __typename + id + name +} + +fragment QueueJobActionFields on QueueJobTriggeredAction { + __typename + queue { + id + name + } + template +} + +fragment SlackIntegrationFields on SlackIntegration { + __typename + id + teamName + channelName +} + +fragment TriggerFields on Trigger { + __typename + id + createdAt + updatedAt + name + description + enabled + scope { + ...ProjectScopeFields + ...ArtifactPortfolioScopeFields + ...ArtifactSequenceScopeFields + } + event: triggeringCondition { + ...FilterEventFields + } + action: triggeredAction { + ...QueueJobActionFields + ...NotificationActionFields + ...GenericWebhookActionFields + ...NoOpActionFields + } +} + +fragment WebhookIntegrationFields on GenericWebhookIntegration { + __typename + id + name + urlEndpoint +} +""" + +UPDATE_AUTOMATION_GQL = """ +mutation UpdateAutomation($input: UpdateFilterTriggerInput!) { + result: updateFilterTrigger(input: $input) { + trigger { + ...TriggerFields + } + } +} + +fragment ArtifactPortfolioScopeFields on ArtifactPortfolio { + __typename + id + name +} + +fragment ArtifactSequenceScopeFields on ArtifactSequence { + __typename + id + name +} + +fragment FilterEventFields on FilterEventTriggeringCondition { + __typename + eventType + filter +} + +fragment GenericWebhookActionFields on GenericWebhookTriggeredAction { + __typename + integration { + ...WebhookIntegrationFields + } + requestPayload +} + +fragment NoOpActionFields on NoOpTriggeredAction { + __typename + noOp +} + +fragment NotificationActionFields on NotificationTriggeredAction { + __typename + integration { + ...SlackIntegrationFields + } + title + message + severity +} + +fragment ProjectScopeFields on Project { + __typename + id + name +} + +fragment QueueJobActionFields on QueueJobTriggeredAction { + __typename + queue { + id + name + } + template +} + +fragment SlackIntegrationFields on SlackIntegration { + __typename + id + teamName + channelName +} + +fragment TriggerFields on Trigger { + __typename + id + createdAt + updatedAt + name + description + enabled + scope { + ...ProjectScopeFields + ...ArtifactPortfolioScopeFields + ...ArtifactSequenceScopeFields + } + event: triggeringCondition { + ...FilterEventFields + } + action: triggeredAction { + ...QueueJobActionFields + ...NotificationActionFields + ...GenericWebhookActionFields + ...NoOpActionFields + } +} + +fragment WebhookIntegrationFields on GenericWebhookIntegration { + __typename + id + name + urlEndpoint +} +""" + +DELETE_AUTOMATION_GQL = """ +mutation DeleteAutomation($id: ID!) { + result: deleteTrigger(input: {triggerID: $id}) { + success + } +} +""" + +INTEGRATIONS_BY_ENTITY_GQL = """ +query IntegrationsByEntity($entity: String!, $cursor: String, $perPage: Int) { + entity(name: $entity) { + integrations(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFields + } + edges { + node { + __typename + ...SlackIntegrationFields + ...WebhookIntegrationFields + } + } + } + } +} + +fragment PageInfoFields on PageInfo { + endCursor + hasNextPage +} + +fragment SlackIntegrationFields on SlackIntegration { + __typename + id + teamName + channelName +} + +fragment WebhookIntegrationFields on GenericWebhookIntegration { + __typename + id + name + urlEndpoint +} +""" + +CREATE_GENERIC_WEBHOOK_INTEGRATION_GQL = """ +mutation CreateGenericWebhookIntegration($input: CreateGenericWebhookIntegrationInput!) { + createGenericWebhookIntegration(input: $input) { + integration { + __typename + ...WebhookIntegrationFields + } + } +} + +fragment WebhookIntegrationFields on GenericWebhookIntegration { + __typename + id + name + urlEndpoint +} +""" diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_generated/update_automation.py b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/update_automation.py new file mode 100644 index 0000000000000000000000000000000000000000..37415b55228207587c6f9d7ff757fcb0d852365a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_generated/update_automation.py @@ -0,0 +1,22 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/automations/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + +from .fragments import TriggerFields + + +class UpdateAutomation(GQLResult): + result: Optional[UpdateAutomationResult] + + +class UpdateAutomationResult(GQLResult): + trigger: Optional[TriggerFields] + + +UpdateAutomation.model_rebuild() +UpdateAutomationResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_utils.py b/.venv/lib/python3.13/site-packages/wandb/automations/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1b37601656f1558fb08d21389d800e20a487e675 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_utils.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +from typing import Any, Collection, Final, Optional, Protocol, TypedDict + +from pydantic import Field +from typing_extensions import Annotated, Self, Unpack + +from wandb._pydantic import GQLId, GQLInput, computed_field, model_validator, to_json + +from ._filters import MongoLikeFilter +from ._generated import ( + CreateFilterTriggerInput, + QueueJobActionInput, + TriggeredActionConfig, + UpdateFilterTriggerInput, +) +from ._validators import parse_input_action +from .actions import ( + ActionType, + DoNothing, + InputAction, + SavedAction, + SendNotification, + SendWebhook, +) +from .automations import Automation, NewAutomation +from .events import EventType, InputEvent, RunMetricFilter, _WrappedSavedEventFilter +from .scopes import AutomationScope, ScopeType + +INVALID_INPUT_EVENTS: Final[Collection[EventType]] = (EventType.UPDATE_ARTIFACT_ALIAS,) +"""Event types that should NOT be allowed as new values on new or edited automations. + +While we forbid new/edited automations from assigning these event types, +they're defined so that we can still parse existing automations that may use them. +""" + +INVALID_INPUT_ACTIONS: Final[Collection[ActionType]] = (ActionType.QUEUE_JOB,) +"""Action types that should NOT be allowed as new values on new or edited automations. + +While we forbid new/edited automations from assigning these action types, +they're defined so that we can still parse existing automations that may use them. +""" + +ALWAYS_SUPPORTED_EVENTS: Final[Collection[EventType]] = frozenset( + { + EventType.CREATE_ARTIFACT, + EventType.LINK_ARTIFACT, + EventType.ADD_ARTIFACT_ALIAS, + } +) +"""Event types that should be supported by all current, non-EOL server versions.""" + +ALWAYS_SUPPORTED_ACTIONS: Final[Collection[ActionType]] = frozenset( + { + ActionType.NOTIFICATION, + ActionType.GENERIC_WEBHOOK, + } +) +"""Action types that should be supported by all current, non-EOL server versions.""" + + +class HasId(Protocol): + id: str + + +def extract_id(obj: HasId | str) -> str: + return obj.id if hasattr(obj, "id") else obj + + +# --------------------------------------------------------------------------- +ACTION_CONFIG_KEYS: dict[ActionType, str] = { + ActionType.NOTIFICATION: "notification_action_input", + ActionType.GENERIC_WEBHOOK: "generic_webhook_action_input", + ActionType.NO_OP: "no_op_action_input", + ActionType.QUEUE_JOB: "queue_job_action_input", +} + + +class InputActionConfig(TriggeredActionConfig): + """Prepares action configuration data for saving an automation.""" + + # NOTE: `QueueJobActionInput` for defining a Launch job is deprecated, + # so while it's allowed here to update EXISTING mutations, we don't + # currently expose it through the public API to create NEW automations. + queue_job_action_input: Optional[QueueJobActionInput] = None + + notification_action_input: Optional[SendNotification] = None + generic_webhook_action_input: Optional[SendWebhook] = None + no_op_action_input: Optional[DoNothing] = None + + +def prepare_action_config_input(obj: SavedAction | InputAction) -> dict[str, Any]: + """Nests the action input under the correct key for `TriggeredActionConfig`. + + This is necessary to conform to the schemas for: + - `CreateFilterTriggerInput` + - `UpdateFilterTriggerInput` + """ + # Delegate to inner validators to convert SavedAction -> InputAction types, if needed. + obj = parse_input_action(obj) + return InputActionConfig(**{ACTION_CONFIG_KEYS[obj.action_type]: obj}).model_dump() + + +def prepare_event_filter_input( + obj: _WrappedSavedEventFilter | MongoLikeFilter | RunMetricFilter, +) -> str: + """Unnests (if needed) and serializes an `EventFilter` input to JSON. + + This is necessary to conform to the schemas for: + - `CreateFilterTriggerInput` + - `UpdateFilterTriggerInput` + """ + # Input event filters are nested one level deeper than saved event filters. + # Note that this is NOT the case for run/run metric filters. + # + # Yes, this is confusing. It's also necessary to conform to under-the-hood + # schemas and logic in the backend. + if isinstance(obj, _WrappedSavedEventFilter): + return to_json(obj.filter) + return to_json(obj) + + +class WriteAutomationsKwargs(TypedDict, total=False): + """Keyword arguments that can be passed to create or update an automation.""" + + name: str + description: str + enabled: bool + scope: AutomationScope + event: InputEvent + action: InputAction + + +class ValidatedCreateInput(GQLInput, extra="forbid", frozen=True): + """Validated automation parameters, prepared for creating a new automation. + + Note: Users should never need to instantiate this class directly. + """ + + name: str + description: Optional[str] = None + enabled: bool = True + + # ------------------------------------------------------------------------------ + # Set on instantiation, but used to derive other fields and deliberately + # EXCLUDED from the final GraphQL request vars + event: Annotated[InputEvent, Field(exclude=True)] + action: Annotated[InputAction, Field(exclude=True)] + + # ------------------------------------------------------------------------------ + # Derived fields to match the input schemas + @computed_field + def scope_type(self) -> ScopeType: + return self.event.scope.scope_type + + @computed_field + def scope_id(self) -> GQLId: + return self.event.scope.id + + @computed_field + def triggering_event_type(self) -> EventType: + return self.event.event_type + + @computed_field + def event_filter(self) -> str: + return prepare_event_filter_input(self.event.filter) + + @computed_field + def triggered_action_type(self) -> ActionType: + return self.action.action_type + + @computed_field + def triggered_action_config(self) -> dict[str, Any]: + return prepare_action_config_input(self.action) + + # ------------------------------------------------------------------------------ + # Custom validation + @model_validator(mode="after") + def _forbid_legacy_event_types(self) -> Self: + if (type_ := self.event.event_type) in INVALID_INPUT_EVENTS: + raise ValueError(f"{type_!r} events cannot be assigned to automations.") + return self + + @model_validator(mode="after") + def _forbid_legacy_action_types(self) -> Self: + if (type_ := self.action.action_type) in INVALID_INPUT_ACTIONS: + raise ValueError(f"{type_!r} actions cannot be assigned to automations.") + return self + + +def prepare_to_create( + obj: NewAutomation | None = None, + /, + **kwargs: Unpack[WriteAutomationsKwargs], +) -> CreateFilterTriggerInput: + """Prepares the payload to create an automation in a GraphQL request.""" + # Validate all input variables, and prepare as expected by the GraphQL request. + # - if an object is provided, override its fields with any keyword args + # - otherwise, instantiate from the keyword args + obj_dict = {**obj.model_dump(), **kwargs} if obj else kwargs + vobj = ValidatedCreateInput(**obj_dict) + return CreateFilterTriggerInput.model_validate(vobj) + + +def prepare_to_update( + obj: Automation | None = None, + /, + **kwargs: Unpack[WriteAutomationsKwargs], +) -> UpdateFilterTriggerInput: + """Prepares the payload to update an automation in a GraphQL request.""" + # Validate all values: + # - if an object is provided, override its fields with any keyword args + # - otherwise, instantiate from the keyword args + vobj = Automation(**{**dict(obj or {}), **kwargs}) + return UpdateFilterTriggerInput( + id=vobj.id, + name=vobj.name, + description=vobj.description, + enabled=vobj.enabled, + scope_type=vobj.scope.scope_type, + scope_id=vobj.scope.id, + triggering_event_type=vobj.event.event_type, + event_filter=prepare_event_filter_input(vobj.event.filter), + triggered_action_type=vobj.action.action_type, + triggered_action_config=prepare_action_config_input(vobj.action), + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/_validators.py b/.venv/lib/python3.13/site-packages/wandb/automations/_validators.py new file mode 100644 index 0000000000000000000000000000000000000000..daaaa4d7478340fd426beaceaf7b65836b514e69 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/_validators.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, TypeVar + +from pydantic import BeforeValidator, Json, PlainSerializer +from pydantic_core import PydanticUseDefault +from typing_extensions import Annotated + +from wandb._pydantic import to_json + +from ._filters import And, MongoLikeFilter, Or +from ._filters.filterutils import simplify_expr + +T = TypeVar("T") + + +def ensure_json(v: Any) -> Any: + """In case the incoming value isn't serialized JSON, reserialize it. + + This lets us use `Json[...]` fields with values that are already deserialized. + """ + # NOTE: Assumes that the deserialized type is not itself a string. + # Revisit this if we need to support deserialized types that are str/bytes. + return v if isinstance(v, (str, bytes)) else to_json(v) + + +JsonEncoded = Annotated[Json[T], BeforeValidator(ensure_json), PlainSerializer(to_json)] +"""A Pydantic type that's always serialized to a JSON string. + +Unlike `pydantic.Json[T]`, this is more lenient on validation and instantiation. +It doesn't strictly require the incoming value to be an encoded JSON string, and +accepts values that may _already_ be deserialized from JSON (e.g. a dict). +""" + + +class LenientStrEnum(str, Enum): + """A string enum allowing for case-insensitive lookups by value. + + May include other internal customizations if needed. + + Note: This is a bespoke, internal implementation and NOT intended as a + backport of `enum.StrEnum` from Python 3.11+. + """ + + def __repr__(self) -> str: + return self.name + + @classmethod + def _missing_(cls, value: object) -> Any: + # Accept case-insensitive enum values + if isinstance(value, str): + v = value.lower() + return next((e for e in cls if e.value.lower() == v), None) + return None + + +def default_if_none(v: Any) -> Any: + """A "before"-mode field validator that coerces `None` to the field default. + + See: https://docs.pydantic.dev/2.11/api/pydantic_core/#pydantic_core.PydanticUseDefault + """ + if v is None: + raise PydanticUseDefault + return v + + +def upper_if_str(v: Any) -> Any: + return v.strip().upper() if isinstance(v, str) else v + + +# ---------------------------------------------------------------------------- +def parse_scope(v: Any) -> Any: + """Convert eligible objects (including wandb types) to an automation scope.""" + from wandb.apis.public import ArtifactCollection, Project + + from .scopes import ProjectScope, _ArtifactPortfolioScope, _ArtifactSequenceScope + + if isinstance(v, Project): + return ProjectScope.model_validate(v) + if isinstance(v, ArtifactCollection): + typ = _ArtifactSequenceScope if v.is_sequence() else _ArtifactPortfolioScope + return typ.model_validate(v) + return v + + +def parse_saved_action(v: Any) -> Any: + """If necessary (and possible), convert the object to a saved action.""" + from .actions import ( + DoNothing, + SavedNoOpAction, + SavedNotificationAction, + SavedWebhookAction, + SendNotification, + SendWebhook, + ) + + if isinstance(v, SendNotification): + return SavedNotificationAction( + integration={"id": v.integration_id}, **v.model_dump() + ) + if isinstance(v, SendWebhook): + return SavedWebhookAction( + integration={"id": v.integration_id}, **v.model_dump() + ) + if isinstance(v, DoNothing): + return SavedNoOpAction(**v.model_dump()) + return v + + +def parse_input_action(v: Any) -> Any: + """If necessary (and possible), convert the object to an input action.""" + from .actions import ( + DoNothing, + SavedNoOpAction, + SavedNotificationAction, + SavedWebhookAction, + SendNotification, + SendWebhook, + ) + + if isinstance(v, SavedNotificationAction): + return SendNotification(integration_id=v.integration.id, **v.model_dump()) + if isinstance(v, SavedWebhookAction): + return SendWebhook(integration_id=v.integration.id, **v.model_dump()) + if isinstance(v, SavedNoOpAction): + return DoNothing(**v.model_dump()) + return v + + +# ---------------------------------------------------------------------------- +def wrap_run_event_run_filter(f: MongoLikeFilter) -> MongoLikeFilter: + """Wrap a run filter in an `And` operator if it's not already. + + This is a necessary constraint imposed elsewhere by backend/frontend code. + """ + return And.wrap(simplify_expr(f)) # simplify/flatten first if needed + + +def wrap_mutation_event_filter(f: MongoLikeFilter) -> MongoLikeFilter: + """Wrap filters as `{"$or": [{"$and": []}]}`. + + This awkward format is necessary because the frontend expects it. + """ + return Or.wrap(And.wrap(simplify_expr(f))) # simplify/flatten first if needed diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/actions.py b/.venv/lib/python3.13/site-packages/wandb/automations/actions.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf48bbfb298646a174122517796e5651673acf7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/actions.py @@ -0,0 +1,228 @@ +"""Actions that are triggered by W&B Automations.""" + +from __future__ import annotations + +from typing import Any, Literal, Optional, Union + +from pydantic import BeforeValidator, Field +from typing_extensions import Annotated, Self, TypeVar, get_args + +from wandb._pydantic import GQLBase, GQLId +from wandb._strutils import nameof + +from ._generated import ( + AlertSeverity, + GenericWebhookActionFields, + GenericWebhookActionInput, + NoOpActionFields, + NoOpTriggeredActionInput, + NotificationActionFields, + NotificationActionInput, + QueueJobActionFields, +) +from ._validators import ( + JsonEncoded, + LenientStrEnum, + default_if_none, + parse_input_action, + parse_saved_action, + upper_if_str, +) +from .integrations import SlackIntegration, WebhookIntegration + +T = TypeVar("T") + + +# NOTE: Name shortened for readability and defined publicly for easier access +class ActionType(LenientStrEnum): + """The type of action triggered by an automation.""" + + QUEUE_JOB = "QUEUE_JOB" # NOTE: Deprecated for creation + NOTIFICATION = "NOTIFICATION" + GENERIC_WEBHOOK = "GENERIC_WEBHOOK" + NO_OP = "NO_OP" + + +# ------------------------------------------------------------------------------ +# Saved types: for parsing response data from saved automations + + +# NOTE: `QueueJobActionInput` for defining a Launch job is deprecated, +# so while we allow parsing it from previously saved Automations, we deliberately +# don't currently expose it in the API for creating automations. +class SavedLaunchJobAction(QueueJobActionFields): + action_type: Literal[ActionType.QUEUE_JOB] = ActionType.QUEUE_JOB + + +# FIXME: Find a better place to put these OR a better way to handle the +# conversion from `InputAction` -> `SavedAction`. +# +# Necessary placeholder class defs for converting: +# - `SendNotification -> SavedNotificationAction` +# - `SendWebhook -> SavedWebhookAction` +# +# The "input" types (`Send{Notification,Webhook}`) will only have an `integration_id`, +# and we don't want/need to fetch the other `{Slack,Webhook}Integration` fields if +# we can avoid it. +class _SlackIntegrationStub(GQLBase): + typename__: Annotated[ + Literal["SlackIntegration"], + Field(alias="__typename", frozen=True, repr=False), + ] = "SlackIntegration" + id: GQLId + + +class _WebhookIntegrationStub(GQLBase): + typename__: Annotated[ + Literal["GenericWebhookIntegration"], + Field(alias="__typename", frozen=True, repr=False), + ] = "GenericWebhookIntegration" + id: GQLId + + +class SavedNotificationAction(NotificationActionFields, frozen=False): + action_type: Literal[ActionType.NOTIFICATION] = ActionType.NOTIFICATION + integration: _SlackIntegrationStub + + title: Optional[str] + message: Optional[str] + severity: Optional[AlertSeverity] + + +class SavedWebhookAction(GenericWebhookActionFields, frozen=False): + action_type: Literal[ActionType.GENERIC_WEBHOOK] = ActionType.GENERIC_WEBHOOK + integration: _WebhookIntegrationStub + + # We override the type of the `requestPayload` field since the original GraphQL + # schema (and generated class) effectively defines it as a string, when we know + # and need to anticipate the expected structure of the JSON-serialized data. + request_payload: Optional[JsonEncoded[dict[str, Any]]] = None # type: ignore[assignment] + + +class SavedNoOpAction(NoOpActionFields, frozen=False): + action_type: Literal[ActionType.NO_OP] = ActionType.NO_OP + + no_op: Annotated[ + bool, + BeforeValidator(default_if_none), + Field(repr=False, frozen=True), + ] = True + """Placeholder field, only needed to conform to schema requirements. + + There should never be a need to set this field explicitly, as its value is ignored. + """ + + +# for type annotations +SavedAction = Annotated[ + Union[ + SavedLaunchJobAction, + SavedNotificationAction, + SavedWebhookAction, + SavedNoOpAction, + ], + BeforeValidator(parse_saved_action), + Field(discriminator="typename__"), +] +# for runtime type checks +SavedActionTypes: tuple[type, ...] = get_args(SavedAction.__origin__) # type: ignore[attr-defined] + + +# ------------------------------------------------------------------------------ +# Input types: for creating or updating automations +class _BaseActionInput(GQLBase): + action_type: Annotated[ActionType, Field(frozen=True)] + """The kind of action to be triggered.""" + + +class SendNotification(_BaseActionInput, NotificationActionInput): + """Defines an automation action that sends a (Slack) notification.""" + + action_type: Literal[ActionType.NOTIFICATION] = ActionType.NOTIFICATION + + integration_id: GQLId + """The ID of the Slack integration that will be used to send the notification.""" + + # Note: Validation aliases preserve continuity with the prior `wandb.alert()` API. + title: str = "" + """The title of the sent notification.""" + + message: Annotated[str, Field(validation_alias="text")] = "" + """The message body of the sent notification.""" + + severity: Annotated[ + AlertSeverity, + BeforeValidator(upper_if_str), # Be helpful by ensuring uppercase strings + Field(validation_alias="level"), + ] = AlertSeverity.INFO + """The severity (`INFO`, `WARN`, `ERROR`) of the sent notification.""" + + @classmethod + def from_integration( + cls, + integration: SlackIntegration, + *, + title: str = "", + text: str = "", + level: AlertSeverity = AlertSeverity.INFO, + ) -> Self: + """Define a notification action that sends to the given (Slack) integration.""" + return cls( + integration_id=integration.id, title=title, message=text, severity=level + ) + + +class SendWebhook(_BaseActionInput, GenericWebhookActionInput): + """Defines an automation action that sends a webhook request.""" + + action_type: Literal[ActionType.GENERIC_WEBHOOK] = ActionType.GENERIC_WEBHOOK + + integration_id: GQLId + """The ID of the webhook integration that will be used to send the request.""" + + # overrides the generated field type to parse/serialize JSON strings + request_payload: Optional[JsonEncoded[dict[str, Any]]] = Field( # type: ignore[assignment] + default=None, alias="requestPayload" + ) + """The payload, possibly with template variables, to send in the webhook request.""" + + @classmethod + def from_integration( + cls, + integration: WebhookIntegration, + *, + payload: Optional[JsonEncoded[dict[str, Any]]] = None, + ) -> Self: + """Define a webhook action that sends to the given (webhook) integration.""" + return cls(integration_id=integration.id, request_payload=payload) + + +class DoNothing(_BaseActionInput, NoOpTriggeredActionInput, frozen=True): + """Defines an automation action that intentionally does nothing.""" + + action_type: Literal[ActionType.NO_OP] = ActionType.NO_OP + + no_op: Annotated[bool, BeforeValidator(default_if_none)] = True + """Placeholder field which exists only to satisfy backend schema requirements. + + There should never be a need to set this field explicitly, as its value is ignored. + """ + + +# for type annotations +InputAction = Annotated[ + Union[ + SendNotification, + SendWebhook, + DoNothing, + ], + BeforeValidator(parse_input_action), + Field(discriminator="action_type"), +] +# for runtime type checks +InputActionTypes: tuple[type, ...] = get_args(InputAction.__origin__) # type: ignore[attr-defined] + +__all__ = [ + "ActionType", + *(nameof(cls) for cls in InputActionTypes), +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/automations.py b/.venv/lib/python3.13/site-packages/wandb/automations/automations.py new file mode 100644 index 0000000000000000000000000000000000000000..faba6ad85b6db820ca490374ae3ab51b280221b8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/automations.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Optional + +from pydantic import Field +from typing_extensions import Annotated + +from wandb._pydantic import GQLId, GQLInput + +from ._generated import TriggerFields +from .actions import InputAction, SavedAction +from .events import InputEvent, SavedEvent +from .scopes import AutomationScope + + +# ------------------------------------------------------------------------------ +# Saved types: for parsing response data from saved automations while allowing +# local editing. +class Automation(TriggerFields, frozen=False): + """A local instance of a saved W&B automation that supports editing.""" + + id: GQLId + + created_at: Annotated[datetime, Field(repr=False, frozen=True, alias="createdAt")] + """The date and time when this automation was created.""" + + updated_at: Annotated[ + Optional[datetime], Field(repr=False, frozen=True, alias="updatedAt") + ] = None + """The date and time when this automation was last updated, if applicable.""" + + name: str + """The name of this automation.""" + + description: Optional[str] + """An optional description of this automation.""" + + enabled: bool + """Whether this automation is enabled. Only enabled automations will trigger.""" + + event: SavedEvent + """The event that will trigger this automation.""" + + scope: AutomationScope + """The scope in which the triggering event must occur.""" + + action: SavedAction + """The action that will execute when this automation is triggered.""" + + +class NewAutomation(GQLInput, extra="forbid", validate_default=False): + """A new automation to be created.""" + + name: Optional[str] = None + """The name of this automation.""" + + description: Optional[str] = None + """An optional description of this automation.""" + + enabled: Optional[bool] = None + """Whether this automation is enabled. Only enabled automations will trigger.""" + + event: Optional[InputEvent] = None + """The event that will trigger this automation.""" + + # Ensure that the event and its scope are always consistent, if the event is set. + @property + def scope(self) -> Optional[AutomationScope]: + """The scope in which the triggering event must occur.""" + return self.event.scope if self.event else None + + @scope.setter + def scope(self, value: AutomationScope) -> None: + if self.event is None: + raise ValueError("Cannot set `scope` for an automation with no `event`") + self.event.scope = value + + action: Optional[InputAction] = None + """The action that will execute when this automation is triggered.""" + + +__all__ = [ + "Automation", + "NewAutomation", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/events.py b/.venv/lib/python3.13/site-packages/wandb/automations/events.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7873bb852f3b1dba6b663e2926ee12ee70998f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/events.py @@ -0,0 +1,438 @@ +"""Events that trigger W&B Automations.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from pydantic import AfterValidator, Field +from typing_extensions import Annotated, get_args + +from wandb._pydantic import GQLBase, model_validator, pydantic_isinstance +from wandb._strutils import nameof + +from ._filters import And, MongoLikeFilter +from ._filters.expressions import FilterableField +from ._filters.run_metrics import ( + MetricChangeFilter, + MetricThresholdFilter, + MetricVal, + MetricZScoreFilter, +) +from ._filters.run_states import StateFilter, StateOperand +from ._generated import FilterEventFields +from ._validators import ( + JsonEncoded, + LenientStrEnum, + ensure_json, + wrap_mutation_event_filter, + wrap_run_event_run_filter, +) +from .actions import InputAction, InputActionTypes, SavedActionTypes +from .scopes import ArtifactCollectionScope, AutomationScope, ProjectScope + +if TYPE_CHECKING: + from .automations import NewAutomation + + +# NOTE: Re-defined publicly with a more readable name for easier access +class EventType(LenientStrEnum): + """The type of event that triggers an automation.""" + + # --------------------------------------------------------------------------- + # Events triggered by GraphQL mutations + UPDATE_ARTIFACT_ALIAS = "UPDATE_ARTIFACT_ALIAS" # NOTE: Avoid in new automations + + CREATE_ARTIFACT = "CREATE_ARTIFACT" + ADD_ARTIFACT_ALIAS = "ADD_ARTIFACT_ALIAS" + LINK_ARTIFACT = "LINK_MODEL" + # Note: "LINK_MODEL" is the (legacy) value expected by the backend, but we + # name it "LINK_ARTIFACT" here in the public API for clarity and consistency. + + # --------------------------------------------------------------------------- + # Events triggered by Run conditions + RUN_METRIC_THRESHOLD = "RUN_METRIC" + RUN_METRIC_CHANGE = "RUN_METRIC_CHANGE" + RUN_STATE = "RUN_STATE" + RUN_METRIC_ZSCORE = "RUN_METRIC_ZSCORE" + + +# ------------------------------------------------------------------------------ +# Saved types: for parsing response data from saved automations + + +# Note: In GQL responses containing saved automation data, the filter is wrapped +# in an extra `filter` key. +class _WrappedSavedEventFilter(GQLBase): # from: TriggeringFilterEvent + filter: JsonEncoded[MongoLikeFilter] = And() + + +class _WrappedMetricThresholdFilter(GQLBase): # from: RunMetricFilter + event_type: Annotated[ + Literal[EventType.RUN_METRIC_THRESHOLD], + Field(exclude=True, repr=False), + ] = EventType.RUN_METRIC_THRESHOLD + + threshold_filter: MetricThresholdFilter + + @model_validator(mode="before") + @classmethod + def _nest_inner_filter(cls, v: Any) -> Any: + # Yeah, we've got a lot of nesting due to backend schema constraints. + if pydantic_isinstance(v, MetricThresholdFilter): + return cls(threshold_filter=v) + return v + + +class _WrappedMetricChangeFilter(GQLBase): # from: RunMetricFilter + event_type: Annotated[ + Literal[EventType.RUN_METRIC_CHANGE], + Field(exclude=True, repr=False), + ] = EventType.RUN_METRIC_CHANGE + + change_filter: MetricChangeFilter + + @model_validator(mode="before") + @classmethod + def _nest_inner_filter(cls, v: Any) -> Any: + # Yeah, we've got a lot of nesting due to backend schema constraints. + if pydantic_isinstance(v, MetricChangeFilter): + return cls(change_filter=v) + return v + + +class _WrappedMetricZScoreFilter(GQLBase): # from: RunMetricFilter + event_type: Annotated[ + Literal[EventType.RUN_METRIC_ZSCORE], + Field(exclude=True, repr=False), + ] = EventType.RUN_METRIC_ZSCORE + + zscore_filter: MetricZScoreFilter + + @model_validator(mode="before") + @classmethod + def _nest_inner_filter(cls, v: Any) -> Any: + if pydantic_isinstance(v, MetricZScoreFilter): + return cls(zscore_filter=v) + return v + + +class RunMetricFilter(GQLBase): # from: TriggeringRunMetricEvent + run: Annotated[ + JsonEncoded[MongoLikeFilter], + AfterValidator(wrap_run_event_run_filter), + Field(alias="run_filter"), + ] = And() + """Filters that must match any runs that will trigger this event.""" + + metric: Annotated[ + Union[ + _WrappedMetricThresholdFilter, + _WrappedMetricChangeFilter, + _WrappedMetricZScoreFilter, + ], + Field(alias="run_metric_filter"), + ] + """Metric condition(s) that must be satisfied for this event to trigger.""" + + # ------------------------------------------------------------------------------ + legacy_metric_filter: Annotated[ + Optional[JsonEncoded[MetricThresholdFilter]], + Field(alias="metric_filter", deprecated=True), + ] = None + """Deprecated legacy field for defining run metric threshold events. + + For new automations, use the `metric` field (JSON alias `run_metric_filter`). + """ + + @model_validator(mode="before") + @classmethod + def _nest_metric_filter(cls, v: Any) -> Any: + # If no run filter is given, automatically nest the metric filter and + # let inner validators reshape further as needed. + if pydantic_isinstance( + v, (MetricThresholdFilter, MetricChangeFilter, MetricZScoreFilter) + ): + return cls(metric=v) + return v + + +class RunStateFilter(GQLBase): # from: TriggeringRunStateEvent + """Represents a filter for triggering events based on changes in run states.""" + + run: Annotated[ + JsonEncoded[MongoLikeFilter], + AfterValidator(wrap_run_event_run_filter), + Field(alias="run_filter"), + ] = And() + """Filters that must match any runs that will trigger this event.""" + + state: Annotated[StateFilter, Field(alias="run_state_filter")] + """Run state condition(s) that must be satisfied for this event to trigger.""" + + @model_validator(mode="before") + @classmethod + def _nest_state_filter(cls, v: Any) -> Any: + # If no run filter is given, automatically nest the state filter and + # let inner validators reshape further as needed. + if pydantic_isinstance(v, StateFilter): + return cls(state=v) + return v + + +class SavedEvent(FilterEventFields): # from: FilterEventTriggeringCondition + """A triggering event from a saved automation.""" + + event_type: Annotated[EventType, Field(frozen=True)] # type: ignore[assignment] + + # We override the type of the `filter` field in order to enforce the expected + # structure for the JSON data when validating and serializing. + filter: JsonEncoded[ + Union[_WrappedSavedEventFilter, RunMetricFilter, RunStateFilter] + ] + """The condition(s) under which this event triggers an automation.""" + + +# ------------------------------------------------------------------------------ +# Input types: for creating or updating automations + + +# Note: The GQL input for `eventFilter` does NOT wrap the filter in an extra +# `filter` key, unlike the `eventFilter` in GQL responses for saved automations. +class _BaseEventInput(GQLBase): + event_type: EventType + + scope: AutomationScope + """The scope of the event.""" + + filter: JsonEncoded[Any] + + def then(self, action: InputAction) -> NewAutomation: + """Define a new Automation in which this event triggers the given action.""" + from .automations import NewAutomation + + if isinstance(action, (InputActionTypes, SavedActionTypes)): + return NewAutomation(event=self, action=action) + + raise TypeError(f"Expected a valid action, got: {nameof(type(action))!r}") + + def __rshift__(self, other: InputAction) -> NewAutomation: + """Implement `event >> action` to define an automation.""" + return self.then(other) + + +# ------------------------------------------------------------------------------ +# Events that trigger on specific mutations in the backend +class _BaseMutationEventInput(_BaseEventInput): + filter: Annotated[ + JsonEncoded[MongoLikeFilter], + AfterValidator(wrap_mutation_event_filter), + ] = And() + """Additional conditions(s), if any, that are required for this event to trigger.""" + + +class OnLinkArtifact(_BaseMutationEventInput): + """A new artifact is linked to a collection. + + Examples: + Define an event that triggers when an artifact is linked to the + collection "my-collection" with the alias "prod": + + ```python + from wandb import Api + from wandb.automations import OnLinkArtifact, ArtifactEvent + + api = Api() + collection = api.artifact_collection(name="my-collection", type_name="model") + + event = OnLinkArtifact( + scope=collection, + filter=ArtifactEvent.alias.eq("prod"), + ) + ``` + """ + + event_type: Literal[EventType.LINK_ARTIFACT] = EventType.LINK_ARTIFACT + + +class OnAddArtifactAlias(_BaseMutationEventInput): + """A new alias is assigned to an artifact. + + Examples: + Define an event that triggers whenever the alias "prod" is assigned to + any artifact in the collection "my-collection": + + ```python + from wandb import Api + from wandb.automations import OnAddArtifactAlias, ArtifactEvent + + api = Api() + collection = api.artifact_collection(name="my-collection", type_name="model") + + event = OnAddArtifactAlias( + scope=collection, + filter=ArtifactEvent.alias.eq("prod"), + ) + ``` + """ + + event_type: Literal[EventType.ADD_ARTIFACT_ALIAS] = EventType.ADD_ARTIFACT_ALIAS + + +class OnCreateArtifact(_BaseMutationEventInput): + """A new artifact is created. + + Examples: + Define an event that triggers when a new artifact is created in the + collection "my-collection": + + ```python + from wandb import Api + from wandb.automations import OnCreateArtifact + + api = Api() + collection = api.artifact_collection(name="my-collection", type_name="model") + + event = OnCreateArtifact(scope=collection) + ``` + """ + + event_type: Literal[EventType.CREATE_ARTIFACT] = EventType.CREATE_ARTIFACT + + scope: ArtifactCollectionScope + """The scope of the event: must be an artifact collection.""" + + +# ------------------------------------------------------------------------------ +# Events that trigger on run conditions +class _BaseRunEventInput(_BaseEventInput): + scope: ProjectScope + """The scope of the event: must be a project.""" + + +class OnRunMetric(_BaseRunEventInput): + """A run metric satisfies a user-defined condition. + + Examples: + Define an event that triggers for any run in project "my-project" when + the average of the last 5 values of metric "my-metric" exceeds 123.45: + + ```python + from wandb import Api + from wandb.automations import OnRunMetric, RunEvent + + api = Api() + project = api.project(name="my-project") + + event = OnRunMetric( + scope=project, + filter=RunEvent.metric("my-metric").avg(5).gt(123.45), + ) + ``` + """ + + event_type: Literal[ + EventType.RUN_METRIC_THRESHOLD, + EventType.RUN_METRIC_CHANGE, + EventType.RUN_METRIC_ZSCORE, + ] + + filter: JsonEncoded[RunMetricFilter] + """Run and/or metric condition(s) that must be satisfied for this event to trigger.""" + + @model_validator(mode="before") + @classmethod + def _infer_event_type(cls, data: Any) -> Any: + """Infer the event type from the inner filter during validation. + + This supports both "threshold" and "change" metric filters, which can + only be determined after parsing and validating the inner JSON data. + """ + if isinstance(data, dict) and (raw_filter := data.get("filter")): + # At this point, `raw_filter` may or may not be JSON-serialized + parsed_filter = RunMetricFilter.model_validate_json(ensure_json(raw_filter)) + return {**data, "event_type": parsed_filter.metric.event_type} + + return data + + +class OnRunState(_BaseRunEventInput): + """A run state changes. + + Examples: + Define an event that triggers for any run in project "my-project" when + its state changes to "finished" (i.e. succeeded) or "failed": + + ```python + from wandb import Api + from wandb.automations import OnRunState + + api = Api() + project = api.project(name="my-project") + + event = OnRunState( + scope=project, + filter=RunEvent.state.in_(["finished", "failed"]), + ) + ``` + """ + + event_type: Literal[EventType.RUN_STATE] = EventType.RUN_STATE + + filter: JsonEncoded[RunStateFilter] + """Run state condition(s) that must be satisfied for this event to trigger.""" + + +# for type annotations +InputEvent = Annotated[ + Union[ + OnLinkArtifact, + OnAddArtifactAlias, + OnCreateArtifact, + OnRunMetric, + OnRunState, + ], + Field(discriminator="event_type"), +] +# for runtime type checks +InputEventTypes: tuple[type, ...] = get_args(InputEvent.__origin__) # type: ignore[attr-defined] + + +# ---------------------------------------------------------------------------- + + +class RunEvent: + name = FilterableField(server_name="display_name") + # `Run.name` is actually filtered on `Run.display_name` in the backend. + # We can't reasonably expect users to know this a priori, so + # automatically fix it here. + + state = StateOperand() + + @staticmethod + def metric(name: str) -> MetricVal: + """Define a metric filter condition.""" + return MetricVal(name=name) + + +class ArtifactEvent: + alias = FilterableField() + + +MetricThresholdFilter.model_rebuild() +RunMetricFilter.model_rebuild() +_WrappedSavedEventFilter.model_rebuild() + +OnLinkArtifact.model_rebuild() +OnAddArtifactAlias.model_rebuild() +OnCreateArtifact.model_rebuild() +OnRunMetric.model_rebuild() + +__all__ = [ + "EventType", + *(nameof(cls) for cls in InputEventTypes), + "RunEvent", + "ArtifactEvent", + "MetricThresholdFilter", + "MetricChangeFilter", + "MetricZScoreFilter", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/integrations.py b/.venv/lib/python3.13/site-packages/wandb/automations/integrations.py new file mode 100644 index 0000000000000000000000000000000000000000..025fdeb5eb1c13d04aa76e6eab3ccf3832dc30fa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/integrations.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Union + +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated + +from ._generated import SlackIntegrationFields, WebhookIntegrationFields + + +class SlackIntegration(SlackIntegrationFields): + team_name: str + """Slack workspace (not W&B team) where this integration will post messages.""" + + channel_name: str + """Slack channel where this integration will post messages.""" + + +class WebhookIntegration(WebhookIntegrationFields): + name: str + """The name of this webhook integration.""" + + url_endpoint: str + """The URL that this webhook will POST events to.""" + + +Integration = Annotated[ + Union[SlackIntegration, WebhookIntegration], + Field(discriminator="typename__"), +] + +# INTERNAL USE ONLY: For parsing integrations from paginated responses +IntegrationAdapter: TypeAdapter[Integration] = TypeAdapter(Integration) + + +__all__ = [ + "Integration", + "SlackIntegration", + "WebhookIntegration", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/automations/scopes.py b/.venv/lib/python3.13/site-packages/wandb/automations/scopes.py new file mode 100644 index 0000000000000000000000000000000000000000..aaf6440129f393549a982ee2d6dde523863b947d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/automations/scopes.py @@ -0,0 +1,77 @@ +"""Scopes in which a W&B Automation can be triggered.""" + +from __future__ import annotations + +from typing import Literal, Union + +from pydantic import BeforeValidator, Field +from typing_extensions import Annotated, TypeAlias, get_args + +from wandb._pydantic import GQLBase + +from ._generated import ( + ArtifactPortfolioScopeFields, + ArtifactSequenceScopeFields, + ProjectScopeFields, +) +from ._validators import LenientStrEnum, parse_scope + + +# NOTE: Re-defined publicly with a more readable name for easier access +class ScopeType(LenientStrEnum): + """The kind of scope that triggers an automation.""" + + PROJECT = "PROJECT" + ARTIFACT_COLLECTION = "ARTIFACT_COLLECTION" + + +class _BaseScope(GQLBase): + scope_type: Annotated[ScopeType, Field(frozen=True)] + + +class _ArtifactSequenceScope(_BaseScope, ArtifactSequenceScopeFields): + """An automation scope defined by a specific `ArtifactSequence`.""" + + scope_type: Literal[ScopeType.ARTIFACT_COLLECTION] = ScopeType.ARTIFACT_COLLECTION + + +class _ArtifactPortfolioScope(_BaseScope, ArtifactPortfolioScopeFields): + """Automation scope defined by an `ArtifactPortfolio` (e.g. a registry collection).""" + + scope_type: Literal[ScopeType.ARTIFACT_COLLECTION] = ScopeType.ARTIFACT_COLLECTION + + +# for type annotations +ArtifactCollectionScope = Annotated[ + Union[_ArtifactSequenceScope, _ArtifactPortfolioScope], + BeforeValidator(parse_scope), + Field(discriminator="typename__"), +] +"""An automation scope defined by a specific `ArtifactCollection`.""" + +# for runtime type checks +ArtifactCollectionScopeTypes: tuple[type, ...] = get_args( + ArtifactCollectionScope.__origin__ # type: ignore[attr-defined] +) + + +class ProjectScope(_BaseScope, ProjectScopeFields): + """An automation scope defined by a specific `Project`.""" + + scope_type: Literal[ScopeType.PROJECT] = ScopeType.PROJECT + + +# for type annotations +AutomationScope: TypeAlias = Annotated[ + Union[_ArtifactSequenceScope, _ArtifactPortfolioScope, ProjectScope], + BeforeValidator(parse_scope), + Field(discriminator="typename__"), +] +# for runtime type checks +AutomationScopeTypes: tuple[type, ...] = get_args(AutomationScope.__origin__) # type: ignore[attr-defined] + +__all__ = [ + "ScopeType", + "ArtifactCollectionScope", + "ProjectScope", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/__init__.py b/.venv/lib/python3.13/site-packages/wandb/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/cli/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1337c8fa16ab36f8ae4eac469d97ec4cd7152ce9 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/cli/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/__pycache__/beta.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/cli/__pycache__/beta.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad37b506812fd592d4b26121fa0b40f73dcf7f30 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/cli/__pycache__/beta.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/beta.py b/.venv/lib/python3.13/site-packages/wandb/cli/beta.py new file mode 100644 index 0000000000000000000000000000000000000000..40bd9fcacc7015d458d43638abbd6715a884d494 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/cli/beta.py @@ -0,0 +1,172 @@ +"""Beta versions of wandb CLI commands. + +These commands are experimental and may change or be removed in future versions. +""" + +from __future__ import annotations + +import pathlib + +import click + +from wandb.analytics import get_sentry +from wandb.errors import WandbCoreNotAvailableError +from wandb.util import get_core_path + + +@click.group() +def beta(): + """Beta versions of wandb CLI commands. + + These commands may change or even completely break in any release of wandb. + """ + get_sentry().configure_scope(process_context="wandb_beta") + + try: + get_core_path() + except WandbCoreNotAvailableError as e: + get_sentry().exception(f"using `wandb beta`. failed with {e}") + click.secho( + (e), + fg="red", + err=True, + ) + + +@beta.command() +@click.argument("path", nargs=1, type=click.Path(exists=True), required=False) +@click.option( + "--pprof", + default="", + hidden=True, + help="""Run with pprof enabled at a specified address, e.g. --pprof=127.0.0.1:6060. + + If set, serves /debug/pprof/* on this address, e.g. 127.0.0.1:6060/debug/pprof. + """, +) +def leet( + path: str | None = None, + pprof: str = "", +) -> None: + """Launch W&B LEET: the Lightweight Experiment Exploration Tool. + + LEET is a terminal UI for viewing a W&B run specified by an optional PATH. + + PATH can include a .wandb file or a run directory containing a .wandb file. + If PATH is not provided, the command will look for the latest run. + """ + from . import beta_leet + + beta_leet.launch(path, pprof) + + +@beta.command() +@click.argument("paths", type=click.Path(exists=True), nargs=-1) +@click.option( + "--live", + is_flag=True, + default=False, + help="""Sync a run while it's still being logged. + + This may hang if the process generating the run crashes uncleanly. + """, +) +@click.option( + "-e", + "--entity", + default="", + help="An entity override to use for all runs being synced.", +) +@click.option( + "-p", + "--project", + default="", + help="A project override to use for all runs being synced.", +) +@click.option( + "--id", + "run_id", + default="", + help="""A run ID override to use for all runs being synced. + + If setting this and syncing multiple files (with the same entity + and project), the files will be synced in order of start time. + This is intended to work with syncing multiple resumed fragments + of the same run. + """, +) +@click.option( + "--skip-synced/--no-skip-synced", + is_flag=True, + default=True, + help="Skip runs that have already been synced with this command.", +) +@click.option( + "--dry-run", + is_flag=True, + default=False, + help="Print what would happen without uploading anything.", +) +@click.option( + "-v", + "--verbose", + is_flag=True, + default=False, + help="Print more information.", +) +@click.option( + "-n", + default=5, + help="""Max number of runs to sync at a time. + + When syncing multiple files that are part of the same run, + the files are synced sequentially in order of start time + regardless of this setting. This happens for resumed runs + or when using the --id parameter. + """, +) +def sync( + paths: tuple[str, ...], + live: bool, + entity: str, + project: str, + run_id: str, + skip_synced: bool, + dry_run: bool, + verbose: bool, + n: int, +) -> None: + """Upload .wandb files specified by PATHS. + + This is a beta re-implementation of `wandb sync`. + It is not feature complete, not guaranteed to work, and may change + in backward-incompatible ways in any release of wandb. + + PATHS can include .wandb files, run directories containing .wandb files, + and "wandb" directories containing run directories. + + For example, to sync all runs in a directory: + + wandb beta sync ./wandb + + To sync a specific run: + + wandb beta sync ./wandb/run-20250813_124246-n67z9ude + + Or equivalently: + + wandb beta sync ./wandb/run-20250813_124246-n67z9ude/run-n67z9ude.wandb + """ + from . import beta_sync + + beta_sync.sync( + [pathlib.Path(path) for path in paths], + live=live, + entity=entity, + project=project, + run_id=run_id, + dry_run=dry_run, + skip_synced=skip_synced, + verbose=verbose, + parallelism=n, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/beta_leet.py b/.venv/lib/python3.13/site-packages/wandb/cli/beta_leet.py new file mode 100644 index 0000000000000000000000000000000000000000..8569f293d908e93859d022b7304e0e24e0692486 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/cli/beta_leet.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import os +import pathlib +import subprocess +import sys + +import click +from typing_extensions import Never + +from wandb.analytics import get_sentry +from wandb.env import error_reporting_enabled, is_debug +from wandb.sdk import wandb_setup +from wandb.util import get_core_path + +from .beta_sync import _find_wandb_files + + +def _fatal(message: str) -> Never: + """Print an error message and exit with code 1.""" + click.echo(f"Error: {message}", err=True) + sys.exit(1) + + +def _wandb_file_path(path: str | None) -> str: + """Returns absolute path to the .wandb file to display with LEET. + + If `path` is not provided, looks for the latest W&B run. + + Prints an error and exits if a valid path is not found. + """ + if not path: + wandb_dir = wandb_setup.singleton().settings.wandb_dir + + wandb_run_path = (pathlib.Path(wandb_dir) / "latest-run").resolve() + else: + wandb_run_path = pathlib.Path(path).resolve() + + wandb_files = list(_find_wandb_files(wandb_run_path, skip_synced=False)) + + if len(wandb_files) == 0: + _fatal(f"Could not find a .wandb file in {wandb_run_path}.") + elif len(wandb_files) > 1: + _fatal(f"Found multiple .wandb files in {wandb_run_path}.") + + return wandb_files[0] + + +def launch( + path: str | None, + pprof: str, +) -> Never: + get_sentry().configure_scope(process_context="leet") + + wandb_file = _wandb_file_path(path) + + try: + core_path = get_core_path() + + args = [core_path, "leet"] + + if not error_reporting_enabled(): + args.append("--no-observability") + + if is_debug(default="False"): + args.extend(["--log-level", "-4"]) + + if pprof: + args.extend(["--pprof", pprof]) + + args.append(wandb_file) + + result = subprocess.run( + args, + env=os.environ, + close_fds=True, + ) + sys.exit(result.returncode) + + except Exception as e: + get_sentry().reraise(e) diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/beta_sync.py b/.venv/lib/python3.13/site-packages/wandb/cli/beta_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..f9972e6def3304e80e2ac425acf8ebdc52446e03 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/cli/beta_sync.py @@ -0,0 +1,315 @@ +"""Implements `wandb sync` using wandb-core.""" + +from __future__ import annotations + +import asyncio +import contextlib +import pathlib +import time +from itertools import filterfalse +from typing import Iterable, Iterator + +import wandb +from wandb.errors import term +from wandb.proto.wandb_sync_pb2 import ServerSyncResponse +from wandb.sdk import wandb_setup +from wandb.sdk.lib import asyncio_compat, wbauth +from wandb.sdk.lib.printer import Printer, new_printer +from wandb.sdk.lib.progress import progress_printer +from wandb.sdk.lib.service.service_connection import ServiceConnection +from wandb.sdk.mailbox.mailbox_handle import MailboxHandle + +_MAX_LIST_LINES = 20 +_POLL_WAIT_SECONDS = 0.1 +_SLEEP = asyncio.sleep # patched in tests + + +def sync( + paths: list[pathlib.Path], + *, + live: bool, + entity: str, + project: str, + run_id: str, + dry_run: bool, + skip_synced: bool, + verbose: bool, + parallelism: int, +) -> None: + """Replay one or more .wandb files. + + Args: + live: Whether to enable 'live' mode, which indefinitely retries reading + incomplete transaction logs. + entity: The entity override for all paths, or an empty string. + project: The project override for all paths, or an empty string. + run_id: The run ID override for all paths, or an empty string. + paths: One or more .wandb files, run directories containing + .wandb files, and wandb directories containing run directories. + dry_run: If true, just prints what it would do and exits. + skip_synced: If true, skips files that have already been synced + as indicated by a .wandb.synced marker file in the same directory. + verbose: Verbose mode for printing more info. + parallelism: Max number of runs to sync at a time. + """ + singleton = wandb_setup.singleton() + + try: + cwd = pathlib.Path.cwd() + except OSError: + cwd = None + + ask_for_confirmation = False + if not paths: + paths = [pathlib.Path(singleton.settings.wandb_dir)] + ask_for_confirmation = True + + wandb_files = _to_unique_files( + ( + wandb_file + for path in paths + for wandb_file in _find_wandb_files(path, skip_synced=skip_synced) + ), + verbose=verbose, + ) + + if not wandb_files: + term.termlog("No runs to sync.") + return + + if dry_run: + term.termlog(f"Would sync {len(wandb_files)} run(s):") + _print_sorted_paths(wandb_files, verbose=verbose, root=cwd) + return + + term.termlog(f"Syncing {len(wandb_files)} run(s):") + _print_sorted_paths(wandb_files, verbose=verbose, root=cwd) + + if ask_for_confirmation and not term.confirm("Sync the listed runs?"): + return + + # Authenticate the session. This updates the singleton settings credentials. + if not wbauth.authenticate_session( + host=singleton.settings.base_url, + source="wandb sync", + no_offline=True, + ): + term.termlog("Not authenticated.") + return + + service = singleton.ensure_service() + printer = new_printer() + singleton.asyncer.run( + lambda: _do_sync( + wandb_files, + cwd=cwd, + live=live, + service=service, + entity=entity, + project=project, + run_id=run_id, + settings=singleton.settings, + printer=printer, + parallelism=parallelism, + ) + ) + + +def _to_unique_files( + paths: Iterator[pathlib.Path], + *, + verbose: bool, +) -> set[pathlib.Path]: + """Returns paths with duplicates removed. + + Determines file equality the same way as os.path.samefile(). + """ + id_to_path: dict[tuple[int, int], pathlib.Path] = dict() + + # Sort in reverse so that the last path written to the map is + # alphabetically earliest. + for path in sorted(paths, reverse=True): + try: + stat = path.stat() + except OSError as e: + term.termerror(f"Failed to stat {path}: {e}") + continue + + id = (stat.st_ino, stat.st_dev) + + if verbose and (other_path := id_to_path.get(id)): + term.termlog(f"{path} is the same as {other_path}") + + id_to_path[id] = path + + return set(id_to_path.values()) + + +async def _do_sync( + wandb_files: set[pathlib.Path], + *, + cwd: pathlib.Path | None, + live: bool, + service: ServiceConnection, + entity: str, + project: str, + run_id: str, + settings: wandb.Settings, + printer: Printer, + parallelism: int, +) -> None: + """Sync the specified files. + + This is factored out to make the progress animation testable. + """ + init_handle = await service.init_sync( + wandb_files, + settings, + cwd=cwd, + live=live, + entity=entity, + project=project, + run_id=run_id, + ) + init_result = await init_handle.wait_async(timeout=5) + + sync_handle = await service.sync(init_result.id, parallelism=parallelism) + + await _SyncStatusLoop( + init_result.id, + service, + printer, + ).wait_with_progress(sync_handle) + + +class _SyncStatusLoop: + """Displays a sync operation's status until it completes.""" + + def __init__( + self, + id: str, + service: ServiceConnection, + printer: Printer, + ) -> None: + self._id = id + self._service = service + self._printer = printer + + self._rate_limit_last_time: float | None = None + self._done = asyncio.Event() + + async def wait_with_progress( + self, + handle: MailboxHandle[ServerSyncResponse], + ) -> None: + """Display status updates until the handle completes.""" + async with asyncio_compat.open_task_group() as group: + group.start_soon(self._wait_then_mark_done(handle)) + group.start_soon(self._show_progress_until_done()) + + async def _wait_then_mark_done( + self, + handle: MailboxHandle[ServerSyncResponse], + ) -> None: + response = await handle.wait_async(timeout=None) + for msg in response.messages: + self._printer.display(msg.content, level=msg.severity) + self._done.set() + + async def _show_progress_until_done(self) -> None: + """Show rate-limited status updates until _done is set.""" + with progress_printer(self._printer, "Syncing...") as progress: + while not await self._rate_limit_check_done(): + handle = await self._service.sync_status(self._id) + response = await handle.wait_async(timeout=None) + + for msg in response.new_messages: + self._printer.display(msg.content, level=msg.severity) + progress.update(list(response.stats)) + + async def _rate_limit_check_done(self) -> bool: + """Wait for rate limit and return whether _done is set.""" + now = time.monotonic() + last_time = self._rate_limit_last_time + self._rate_limit_last_time = now + + if last_time and (time_since_last := now - last_time) < _POLL_WAIT_SECONDS: + await asyncio_compat.race( + _SLEEP(_POLL_WAIT_SECONDS - time_since_last), + self._done.wait(), + ) + + return self._done.is_set() + + +def _find_wandb_files( + path: pathlib.Path, + *, + skip_synced: bool, +) -> Iterator[pathlib.Path]: + """Returns paths to the .wandb files to sync.""" + if skip_synced: + yield from filterfalse(_is_synced, _expand_wandb_files(path)) + else: + yield from _expand_wandb_files(path) + + +def _expand_wandb_files( + path: pathlib.Path, +) -> Iterator[pathlib.Path]: + """Iterate over .wandb files selected by the path.""" + if path.suffix == ".wandb": + yield path + return + + files_in_run_directory = path.glob("*.wandb") + try: + first_file = next(files_in_run_directory) + except StopIteration: + pass + else: + yield first_file + yield from files_in_run_directory + return + + yield from path.glob("*/*.wandb") + + +def _is_synced(path: pathlib.Path) -> bool: + """Returns whether the .wandb file is synced.""" + return path.with_suffix(".wandb.synced").exists() + + +def _print_sorted_paths( + paths: Iterable[pathlib.Path], + verbose: bool, + *, + root: pathlib.Path | None, +) -> None: + """Print file paths, sorting them and truncating the list if needed. + + Args: + paths: Paths to print. Must be absolute with symlinks resolved. + verbose: If true, doesn't truncate paths. + root: A root directory for making paths relative. + """ + # Prefer to print paths relative to the current working directory. + formatted_paths: list[str] = [] + for path in paths: + formatted_path = str(path) + + if root: + with contextlib.suppress(ValueError): + formatted_path = str(path.relative_to(root)) + + formatted_paths.append(formatted_path) + + sorted_paths = sorted(formatted_paths) + max_lines = len(sorted_paths) if verbose else _MAX_LIST_LINES + + for i in range(min(len(sorted_paths), max_lines)): + term.termlog(f" {sorted_paths[i]}") + + if len(sorted_paths) > max_lines: + remaining = len(sorted_paths) - max_lines + term.termlog(f" +{remaining:,d} more (pass --verbose to see all)") diff --git a/.venv/lib/python3.13/site-packages/wandb/cli/cli.py b/.venv/lib/python3.13/site-packages/wandb/cli/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..314720a93c89b942f2a58b749f1be46a7e178e7c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/cli/cli.py @@ -0,0 +1,2958 @@ +import asyncio +import datetime +import getpass +import json +import logging +import os +import shlex +import shutil +import subprocess +import sys +import tempfile +import textwrap +import time +import traceback +from functools import wraps +from typing import Any, Dict, Optional, Tuple + +import click +import yaml +from click.exceptions import ClickException + +import wandb +import wandb.errors +import wandb.sdk.verify.verify as wandb_verify +from wandb import Config, Error, env, util, wandb_agent +from wandb.analytics import get_sentry +from wandb.apis import InternalApi, PublicApi +from wandb.apis.public import RunQueue +from wandb.errors.links import url_registry +from wandb.old import core as old_core +from wandb.sdk import wandb_setup, wandb_sweep +from wandb.sdk.artifacts._validators import is_artifact_registry_project +from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache +from wandb.sdk.internal.internal_api import Api as SDKInternalApi +from wandb.sdk.launch import utils as launch_utils +from wandb.sdk.launch._launch_add import _launch_add +from wandb.sdk.launch.errors import ExecutionError, LaunchError +from wandb.sdk.launch.sweeps import utils as sweep_utils +from wandb.sdk.launch.sweeps.scheduler import Scheduler +from wandb.sdk.lib import filesystem, settings_file +from wandb.sync import SyncManager, get_run_from_path, get_runs + +from .beta import beta + +# Send cli logs to wandb/debug-cli..log by default and fallback to a temp dir. +_wandb_dir = old_core.wandb_dir(env.get_dir()) +if not os.path.exists(_wandb_dir) or not os.access(_wandb_dir, os.W_OK): + _wandb_dir = tempfile.gettempdir() + +try: + _username = getpass.getuser() +except KeyError: + # getuser() could raise KeyError in restricted environments like + # chroot jails or docker containers. Return user id in these cases. + _username = str(os.getuid()) + +_wandb_log_path = os.path.join(_wandb_dir, f"debug-cli.{_username}.log") +logger = logging.getLogger("wandb") + + +def _setup_logger() -> None: + """Set up logging to the wandb/debug-cli.user.log file.""" + logger_handler = logging.FileHandler(_wandb_log_path) + logger_handler.setLevel(logging.INFO) + logger_handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + + # The wandb logger does not forward messages to the root handler. + logger.addHandler(logger_handler) + logging.root.addHandler(logger_handler) + + +_HAS_DOCKER = bool(shutil.which("docker")) +_HAS_NVIDIA_DOCKER = bool(shutil.which("nvidia-docker")) + +# Click Contexts +CONTEXT = {"default_map": {}} +RUN_CONTEXT = { + "default_map": {}, + "allow_extra_args": True, + "ignore_unknown_options": True, +} + + +def cli_unsupported(argument): + wandb.termerror(f"Unsupported argument `{argument}`") + sys.exit(1) + + +class ClickWandbException(ClickException): + def format_message(self): + orig_type = f"{self.orig_type.__module__}.{self.orig_type.__name__}" + if issubclass(self.orig_type, Error): + return click.style(str(self.message), fg="red") + else: + return ( + f"An Exception was raised, see {_wandb_log_path} for full" + " traceback.\n" + f"{orig_type}: {self.message}" + ) + + +def parse_service_config( + ctx: Optional[click.Context], + param: Optional[click.Parameter], + value: Optional[Tuple[str, ...]], +) -> Dict[str, str]: + """Parse service configurations in format serviceName=policy.""" + if not value: + return {} + + result = {} + for config in value: + if "=" not in config: + raise click.BadParameter( + f"Service must be in format 'serviceName=policy', got '{config}'" + ) + + service_name, policy = config.split("=", 1) + service_name = service_name.strip() + policy = policy.strip() + if not service_name: + raise click.BadParameter("Service name cannot be empty") + + # Simple validation for two policies + if policy not in ["always", "never"]: + raise click.BadParameter( + f"Policy must be 'always' or 'never', got '{policy}'" + ) + + result[service_name] = policy + + return result + + +def display_error(func): + """Function decorator for catching common errors and re-raising as wandb.Error.""" + + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except wandb.Error as e: + exc_type, exc_value, exc_traceback = sys.exc_info() + lines = traceback.format_exception(exc_type, exc_value, exc_traceback) + logger.exception("".join(lines)) + wandb.termerror(f"Find detailed error logs at: {_wandb_log_path}") + click_exc = ClickWandbException(e) + click_exc.orig_type = exc_type + raise click_exc.with_traceback(sys.exc_info()[2]) + + return wrapper + + +_api = None # caching api instance allows patching from unit tests + + +def _get_cling_api(reset=None): + """Get a reference to the internal api with cling settings.""" + global _api + if reset: + _api = None + wandb.teardown() + if _api is None: + # TODO(jhr): make a settings object that is better for non runs. + # only override the necessary setting + wandb_setup.singleton().settings.x_cli_only_mode = True + _api = InternalApi() + return _api + + +def prompt_for_project(ctx, entity): + """Ask the user for a project, creating one if necessary.""" + result = ctx.invoke(projects, entity=entity, display=False) + api = _get_cling_api() + try: + if len(result) == 0: + project = click.prompt("Enter a name for your first project") + # description = editor() + project = api.upsert_project(project, entity=entity)["name"] + else: + project_names = [project["name"] for project in result] + ["Create New"] + wandb.termlog("Which project should we use?") + result = util.prompt_choices(project_names) + if result: + project = result + else: + project = "Create New" + # TODO: check with the server if the project exists + if project == "Create New": + project = click.prompt( + "Enter a name for your new project", value_proc=api.format_project + ) + # description = editor() + project = api.upsert_project(project, entity=entity)["name"] + + except wandb.errors.CommError as e: + raise ClickException(str(e)) + + return project + + +class RunGroup(click.Group): + @display_error + def get_command(self, ctx, cmd_name): + # TODO: check if cmd_name is a file in the current dir and not require `run`? + rv = click.Group.get_command(self, ctx, cmd_name) + if rv is not None: + return rv + return None + + +@click.command(cls=RunGroup, invoke_without_command=True) +@click.version_option(version=wandb.__version__) +@click.pass_context +def cli(ctx): + _setup_logger() + + if ctx.invoked_subcommand is None: + click.echo(ctx.get_help()) + + +@cli.command(context_settings=CONTEXT, help="List projects", hidden=True) +@click.option( + "--entity", + "-e", + default=None, + envvar=env.ENTITY, + help="The entity to scope the listing to.", +) +@display_error +def projects(entity, display=True): + api = _get_cling_api() + projects = api.list_projects(entity=entity) + if len(projects) == 0: + message = f"No projects found for {entity}" + else: + message = f'Latest projects for "{entity}"' + if display: + click.echo(click.style(message, bold=True)) + for project in projects: + click.echo( + "".join( + ( + click.style(project["name"], fg="blue", bold=True), + " - ", + str(project["description"] or "").split("\n")[0], + ) + ) + ) + return projects + + +@cli.command(context_settings=CONTEXT) +@click.argument("key", nargs=-1) +@click.option("--cloud", is_flag=True, help="Login to the cloud instead of local") +@click.option( + "--host", "--base-url", default=None, help="Login to a specific instance of W&B" +) +@click.option( + "--relogin", default=None, is_flag=True, help="Force relogin if already logged in." +) +@click.option( + "--anonymously", + default=False, + hidden=True, + is_flag=True, + help="Log in anonymously", +) +@click.option( + "--verify/--no-verify", + default=False, + is_flag=True, + help="Verify login credentials", +) +@display_error +def login(key, host, cloud, relogin, anonymously, verify, no_offline=False): + """Verify and store your API key for authentication with W&B services. + + By default, only store credentials locally without verifying them with W&B. + To verify credentials, set `--verify=True`. + + For server deployments (dedicated cloud or customer-managed instances), + specify the host URL using the `--host` flag. You can also set environment + variables `WANDB_BASE_URL` and `WANDB_API_KEY` instead of running + the `login` command with host parameters. + """ + # TODO: handle no_offline + if anonymously: + wandb.termwarn( + "The --anonymously parameter has no effect and will be removed" + + " in a future version.", + repeat=False, + ) + + if host and cloud: + wandb.termerror("Cannot use --host and --cloud together.") + sys.exit(1) + + if cloud: + host = "https://api.wandb.ai" + + # A change in click or the test harness means key can be none... + key = key[0] if key is not None and len(key) > 0 else None + relogin = True if key or relogin else False + + global_settings = wandb_setup.singleton().settings + global_settings.x_cli_only_mode = True + global_settings.x_disable_viewer = relogin and not verify + + wandb.login( + force=True, + host=host, + key=key, + relogin=relogin, + verify=verify, + referrer="models", + ) + + +@cli.command( + context_settings=CONTEXT, help="Configure a directory with Weights & Biases" +) +@click.option("--project", "-p", help="The project to use.") +@click.option("--entity", "-e", help="The entity to scope the project to.") +# TODO(jhr): Enable these with settings rework +# @click.option("--setting", "-s", help="enable an arbitrary setting.", multiple=True) +# @click.option('--show', is_flag=True, help="Show settings") +@click.option("--reset", is_flag=True, help="Reset settings") +@click.option( + "--mode", + "-m", + help=' Can be "online", "offline" or "disabled". Defaults to online.', +) +@click.pass_context +@display_error +def init(ctx, project, entity, reset, mode): + # Load settings from environment variables and other normal sources. + global_settings = wandb_setup.singleton().settings + + # non-interactive init + if reset or project or entity or mode: + system_settings = global_settings.read_system_settings() + + if reset: + system_settings.clear("entity") + system_settings.clear("project") + system_settings.clear("mode") + if entity: + system_settings.set("entity", entity) + if project: + system_settings.set("project", project) + if mode: + system_settings.set("mode", mode) + + system_settings.save() + return + + if os.path.exists(global_settings.settings_workspace): + click.confirm( + click.style( + "This directory has been configured previously, should we re-configure it?", + bold=True, + ), + abort=True, + ) + else: + click.echo( + click.style("Let's setup this directory for W&B!", fg="green", bold=True) + ) + api = _get_cling_api() + if api.api_key is None: + ctx.invoke(login) + api = _get_cling_api(reset=True) + + viewer = api.viewer() + + # Viewer can be `None` in case your API information became invalid, or + # in testing if you switch hosts. + if not viewer: + click.echo( + click.style( + "Your login information seems to be invalid: can you log in again please?", + fg="red", + bold=True, + ) + ) + ctx.invoke(login) + api = _get_cling_api(reset=True) + + # This shouldn't happen. + viewer = api.viewer() + if not viewer: + click.echo( + click.style( + "We're sorry, there was a problem logging you in. " + "Please send us a note at support@wandb.com and tell us how this happened.", + fg="red", + bold=True, + ) + ) + sys.exit(1) + + # At this point we should be logged in successfully. + if len(viewer["teams"]["edges"]) > 1: + team_names = [e["node"]["name"] for e in viewer["teams"]["edges"]] + [ + "Manual entry" + ] + wandb.termlog( + "Which team should we use?", + ) + result = util.prompt_choices(team_names) + # result can be empty on click + if result: + entity = result + else: + entity = "Manual Entry" + if entity == "Manual Entry": + entity = click.prompt("Enter the name of the team you want to use") + else: + entity = viewer.get("entity") or click.prompt( + "What username or team should we use?" + ) + + # TODO: this error handling sucks and the output isn't pretty + try: + project = prompt_for_project(ctx, entity) + except ClickWandbException: + raise ClickException(f"Could not find team: {entity}") + + system_settings = global_settings.read_system_settings() + system_settings.set("entity", entity) + system_settings.set("project", project) + system_settings.save() + + filesystem.mkdir_exists_ok(global_settings.wandb_dir) + with open(os.path.join(global_settings.wandb_dir, ".gitignore"), "w") as file: + file.write("*\n!settings") + + click.echo( + click.style("This directory is configured! Next, track a run:\n", fg="green") + + textwrap.dedent( + """\ + * In your training script: + {code1} + {code2} + * then `{run}`. + """ + ).format( + code1=click.style("import wandb", bold=True), + code2=click.style(f'wandb.init(project="{project}")', bold=True), + run=click.style("python ", bold=True), + ) + ) + + +@cli.command(context_settings=CONTEXT) +@click.pass_context +@click.argument("path", nargs=-1, type=click.Path(exists=True)) +@click.option("--view", is_flag=True, default=False, help="View runs", hidden=True) +@click.option("--verbose", is_flag=True, default=False, help="Verbose", hidden=True) +@click.option("--id", "run_id", help="The run you want to upload to.") +@click.option("--project", "-p", help="The project you want to upload to.") +@click.option("--entity", "-e", help="The entity to scope to.") +@click.option( + "--job_type", + "job_type", + help="Specifies the type of run for grouping related runs together.", +) +@click.option( + "--sync-tensorboard/--no-sync-tensorboard", + is_flag=True, + default=None, + help="Stream tfevent files to wandb.", +) +@click.option("--include-globs", help="Comma separated list of globs to include.") +@click.option("--exclude-globs", help="Comma separated list of globs to exclude.") +@click.option( + "--include-online/--no-include-online", + is_flag=True, + default=None, + help="Include online runs", +) +@click.option( + "--include-offline/--no-include-offline", + is_flag=True, + default=None, + help="Include offline runs", +) +@click.option( + "--include-synced/--no-include-synced", + is_flag=True, + default=None, + help="Include synced runs", +) +@click.option( + "--mark-synced/--no-mark-synced", + is_flag=True, + default=True, + help="Mark runs as synced", +) +@click.option("--sync-all", is_flag=True, default=False, help="Sync all runs") +@click.option("--clean", is_flag=True, default=False, help="Delete synced runs") +@click.option( + "--clean-old-hours", + default=24, + help="Delete runs created before this many hours. To be used alongside --clean flag.", + type=int, +) +@click.option( + "--clean-force", + is_flag=True, + default=False, + help="Clean without confirmation prompt.", +) +@click.option("--ignore", hidden=True) +@click.option("--show", default=5, help="Number of runs to show") +@click.option("--append", is_flag=True, default=False, help="Append run") +@click.option("--skip-console", is_flag=True, default=False, help="Skip console logs") +@click.option( + "--replace-tags", + help="Replace tags in the format 'old_tag1=new_tag1,old_tag2=new_tag2'", +) +@display_error +def sync( + ctx, + path=None, + view=None, + verbose=None, + run_id=None, + project=None, + entity=None, + job_type=None, # trace this back to SyncManager + sync_tensorboard=None, + include_globs=None, + exclude_globs=None, + include_online=None, + include_offline=None, + include_synced=None, + mark_synced=None, + sync_all=None, + ignore=None, + show=None, + clean=None, + clean_old_hours=24, + clean_force=None, + append=None, + skip_console=None, + replace_tags=None, +): + """Synchronize W&B run data to the cloud. + + If PATH is provided, sync runs found at the given path. If a path + is not specified, search for `./wandb` first, then search for a + `wandb/` subdirectory. + + To sync a specific run: + + wandb sync ./wandb/run-20250813_124246-n67z9ude + + Or equivalently: + + wandb sync ./wandb/run-20250813_124246-n67z9ude/run-n67z9ude.wandb + """ + api = _get_cling_api() + if not api.is_authenticated: + wandb.termlog("Login to W&B to sync runs") + ctx.invoke(login, no_offline=True) + api = _get_cling_api(reset=True) + + if ignore: + exclude_globs = ignore + if include_globs: + include_globs = include_globs.split(",") + if exclude_globs: + exclude_globs = exclude_globs.split(",") + + replace_tags_dict = _parse_sync_replace_tags(replace_tags) + if replace_tags and replace_tags_dict is None: + return # Error already printed by helper function + + def _summary(): + all_items = get_runs( + include_online=True, + include_offline=True, + include_synced=True, + include_unsynced=True, + ) + sync_items = get_runs( + include_online=include_online if include_online is not None else True, + include_offline=include_offline if include_offline is not None else True, + include_synced=include_synced if include_synced is not None else False, + include_unsynced=True, + exclude_globs=exclude_globs, + include_globs=include_globs, + ) + synced = [] + unsynced = [] + for item in all_items: + (synced if item.synced else unsynced).append(item) + if sync_items: + wandb.termlog(f"Number of runs to be synced: {len(sync_items)}") + if show and show < len(sync_items): + wandb.termlog(f"Showing {show} runs to be synced:") + for item in sync_items[: (show or len(sync_items))]: + wandb.termlog(f" {item}") + else: + wandb.termlog("No runs to be synced.") + if synced: + clean_cmd = click.style("wandb sync --clean", fg="yellow") + wandb.termlog( + f"NOTE: use {clean_cmd} to delete {len(synced)} synced runs from local directory." + ) + if unsynced: + sync_cmd = click.style("wandb sync --sync-all", fg="yellow") + wandb.termlog( + f"NOTE: use {sync_cmd} to sync {len(unsynced)} unsynced runs from local directory." + ) + + def _sync_path(_path, _sync_tensorboard): + if run_id and len(_path) > 1: + wandb.termerror("id can only be set for a single run.") + sys.exit(1) + sm = SyncManager( + project=project, + entity=entity, + run_id=run_id, + job_type=job_type, + mark_synced=mark_synced, + app_url=api.app_url, + view=view, + verbose=verbose, + sync_tensorboard=_sync_tensorboard, + log_path=_wandb_log_path, + append=append, + skip_console=skip_console, + replace_tags=replace_tags_dict, + ) + for p in _path: + sm.add(p) + sm.start() + while not sm.is_done(): + _ = sm.poll() + + def _sync_all(): + sync_items = get_runs( + include_online=include_online if include_online is not None else True, + include_offline=include_offline if include_offline is not None else True, + include_synced=include_synced if include_synced is not None else False, + include_unsynced=True, + exclude_globs=exclude_globs, + include_globs=include_globs, + ) + if not sync_items: + wandb.termerror("Nothing to sync.") + else: + # When syncing run directories, default to not syncing tensorboard + sync_tb = sync_tensorboard if sync_tensorboard is not None else False + _sync_path(sync_items, sync_tb) + + def _clean(): + if path: + runs = list(map(get_run_from_path, path)) + if not clean_force: + click.confirm( + click.style( + f"Are you sure you want to remove {len(runs)} runs?", + bold=True, + ), + abort=True, + ) + for run in runs: + shutil.rmtree(run.path) + click.echo(click.style("Success!", fg="green")) + return + runs = get_runs( + include_online=include_online if include_online is not None else True, + include_offline=include_offline if include_offline is not None else True, + include_synced=include_synced if include_synced is not None else True, + include_unsynced=False, + exclude_globs=exclude_globs, + include_globs=include_globs, + ) + since = datetime.datetime.now() - datetime.timedelta(hours=clean_old_hours) + old_runs = [run for run in runs if run.datetime < since] + old_runs.sort(key=lambda _run: _run.datetime) + if old_runs: + click.echo( + f"Found {len(runs)} runs, {len(old_runs)} are older than {clean_old_hours} hours" + ) + for run in old_runs: + click.echo(run.path) + if not clean_force: + click.confirm( + click.style( + f"Are you sure you want to remove {len(old_runs)} runs?", + bold=True, + ), + abort=True, + ) + for run in old_runs: + shutil.rmtree(run.path) + click.echo(click.style("Success!", fg="green")) + else: + click.echo( + click.style( + f"No runs older than {clean_old_hours} hours found", fg="red" + ) + ) + + if sync_all: + _sync_all() + elif clean: + _clean() + elif path: + # When syncing a specific path, default to syncing tensorboard + sync_tb = sync_tensorboard if sync_tensorboard is not None else True + _sync_path(path, sync_tb) + else: + _summary() + + +def _parse_sync_replace_tags(replace_tags: str) -> Optional[Dict[str, str]]: + """Parse replace_tags string into a dictionary. + + Args: + replace_tags: String in format 'old_tag1=new_tag1,old_tag2=new_tag2' + + Returns: + Mapping of old tags to new tags, or None if format is invalid + """ + if not replace_tags: + return {} + + replace_tags_dict = {} + for pair in replace_tags.split(","): + if "=" not in pair: + wandb.termerror( + f"Invalid replace-tags format: {pair}. Use 'old_tag=new_tag' format." + ) + return None + old_tag, new_tag = pair.split("=", 1) + replace_tags_dict[old_tag.strip()] = new_tag.strip() + + return replace_tags_dict + + +@cli.command( + context_settings=CONTEXT, + help="Initialize a hyperparameter sweep. Search for hyperparameters that optimizes a cost function of a machine learning model by testing various combinations.", +) +@click.option( + "--project", + "-p", + default=None, + help="""The name of the project where W&B runs created from the sweep are sent to. If the project is not specified, the run is sent to a project labeled Uncategorized.""", +) +@click.option( + "--entity", + "-e", + default=None, + help="""The username or team name where you want to send W&B runs created by the sweep to. Ensure that the entity you specify already exists. If you don't specify an entity, the run will be sent to your default entity, which is usually your username.""", +) +@click.option("--controller", is_flag=True, default=False, help="Run local controller") +@click.option("--verbose", is_flag=True, default=False, help="Display verbose output") +@click.option( + "--name", + default=None, + help="The name of the sweep. The sweep ID is used if no name is specified.", +) +@click.option("--program", default=None, help="Set sweep program") +@click.option("--settings", default=None, help="Set sweep settings", hidden=True) +@click.option("--update", default=None, help="Update pending sweep") +@click.option( + "--stop", + is_flag=True, + default=False, + help="Finish a sweep to stop running new runs and let currently running runs finish.", +) +@click.option( + "--cancel", + is_flag=True, + default=False, + help="Cancel a sweep to kill all running runs and stop running new runs.", +) +@click.option( + "--pause", + is_flag=True, + default=False, + help="Pause a sweep to temporarily stop running new runs.", +) +@click.option( + "--resume", + is_flag=True, + default=False, + help="Resume a sweep to continue running new runs.", +) +@click.option( + "--prior_run", + "-R", + "prior_runs", + multiple=True, + default=None, + help="ID of an existing run to add to this sweep", +) +@click.argument("config_yaml_or_sweep_id") +@click.pass_context +@display_error +def sweep( + ctx, + project, + entity, + controller, + verbose, + name, + program, + settings, + update, + stop, + cancel, + pause, + resume, + prior_runs, + config_yaml_or_sweep_id, +): + state_args = "stop", "cancel", "pause", "resume" + lcls = locals() + is_state_change_command = sum(lcls[k] for k in state_args) + if is_state_change_command > 1: + raise Exception("Only one state flag (stop/cancel/pause/resume) is allowed.") + elif is_state_change_command == 1: + sweep_id = config_yaml_or_sweep_id + api = _get_cling_api() + if not api.is_authenticated: + wandb.termlog("Login to W&B to use the sweep feature") + ctx.invoke(login, no_offline=True) + api = _get_cling_api(reset=True) + parts = dict(entity=entity, project=project, name=sweep_id) + err = sweep_utils.parse_sweep_id(parts) + if err: + wandb.termerror(err) + return + entity = parts.get("entity") or entity + project = parts.get("project") or project + sweep_id = parts.get("name") or sweep_id + state = [s for s in state_args if lcls[s]][0] + ings = { + "stop": "Stopping", + "cancel": "Cancelling", + "pause": "Pausing", + "resume": "Resuming", + } + wandb.termlog(f"{ings[state]} sweep {entity}/{project}/{sweep_id}") + getattr(api, f"{state}_sweep")(sweep_id, entity=entity, project=project) + wandb.termlog("Done.") + return + else: + config_yaml = config_yaml_or_sweep_id + + def _parse_settings(settings): + """Parse settings from json or comma separated assignments.""" + ret = {} + # TODO(jhr): merge with magic:_parse_magic + if settings.find("=") > 0: + for item in settings.split(","): + kv = item.split("=") + if len(kv) != 2: + wandb.termwarn( + "Unable to parse sweep settings key value pair", repeat=False + ) + ret.update(dict([kv])) + return ret + wandb.termwarn("Unable to parse settings parameter", repeat=False) + return ret + + api = _get_cling_api() + if not api.is_authenticated: + wandb.termlog("Login to W&B to use the sweep feature") + ctx.invoke(login, no_offline=True) + api = _get_cling_api(reset=True) + + sweep_obj_id = None + if update: + parts = dict(entity=entity, project=project, name=update) + err = sweep_utils.parse_sweep_id(parts) + if err: + wandb.termerror(err) + return + entity = parts.get("entity") or entity + project = parts.get("project") or project + sweep_id = parts.get("name") or update + + has_project = (project or api.settings("project")) is not None + has_entity = (entity or api.settings("entity")) is not None + + termerror_msg = ( + "Sweep lookup requires a valid %s, and none was specified. \n" + "Either set a default %s in wandb/settings, or, if invoking \n`wandb sweep` " + "from the command line, specify the full sweep path via: \n\n" + " wandb sweep {username}/{projectname}/{sweepid}\n\n" + ) + + if not has_entity: + wandb.termerror(termerror_msg % (("entity",) * 2)) + return + + if not has_project: + wandb.termerror(termerror_msg % (("project",) * 2)) + return + + found = api.sweep(sweep_id, "{}", entity=entity, project=project) + if not found: + wandb.termerror(f"Could not find sweep {entity}/{project}/{sweep_id}") + return + sweep_obj_id = found["id"] + + action = "Updating" if sweep_obj_id else "Creating" + wandb.termlog(f"{action} sweep from: {config_yaml}") + config = sweep_utils.load_sweep_config(config_yaml) + + # Set or override parameters + if name: + config["name"] = name + if program: + config["program"] = program + if settings: + settings = _parse_settings(settings) + if settings: + config.setdefault("settings", {}) + config["settings"].update(settings) + if controller: + config.setdefault("controller", {}) + config["controller"]["type"] = "local" + + is_local = config.get("controller", {}).get("type") == "local" + if is_local: + from wandb import controller as wandb_controller + + tuner = wandb_controller() + err = tuner._validate(config) + if err: + wandb.termerror(f"Error in sweep file: {err}") + return + + env = os.environ + entity = ( + entity + or env.get("WANDB_ENTITY") + or config.get("entity") + or api.settings("entity") + ) + project = ( + project + or env.get("WANDB_PROJECT") + or config.get("project") + or api.settings("project") + or util.auto_project_name(config.get("program")) + ) + + sweep_id, warnings = api.upsert_sweep( + config, + project=project, + entity=entity, + obj_id=sweep_obj_id, + prior_runs=prior_runs, + ) + sweep_utils.handle_sweep_config_violations(warnings) + + # Log nicely formatted sweep information + styled_id = click.style(sweep_id, fg="yellow") + wandb.termlog(f"{action} sweep with ID: {styled_id}") + + sweep_url = wandb_sweep._get_sweep_url(api, sweep_id) + if sweep_url: + styled_url = click.style(sweep_url, underline=True, fg="blue") + wandb.termlog(f"View sweep at: {styled_url}") + + # re-probe entity and project if it was auto-detected by upsert_sweep + entity = entity or env.get("WANDB_ENTITY") + project = project or env.get("WANDB_PROJECT") + + if entity and project: + sweep_path = f"{entity}/{project}/{sweep_id}" + elif project: + sweep_path = f"{project}/{sweep_id}" + else: + sweep_path = sweep_id + + if sweep_path.find(" ") >= 0: + sweep_path = f"{sweep_path!r}" + + styled_path = click.style(f"wandb agent {sweep_path}", fg="yellow") + wandb.termlog(f"Run sweep agent with: {styled_path}") + if controller: + wandb.termlog("Starting wandb controller...") + from wandb import controller as wandb_controller + + tuner = wandb_controller(sweep_id) + tuner.run(verbose=verbose) + + +@cli.command( + context_settings=CONTEXT, + no_args_is_help=True, + help="Run a W&B launch sweep (Experimental).", +) +@click.option( + "--queue", + "-q", + default=None, + help="The name of a queue to push the sweep to", +) +@click.option( + "--project", + "-p", + default=None, + help="Name of the project which the agent will watch. " + "If passed in, will override the project value passed in using a config file", +) +@click.option( + "--entity", + "-e", + default=None, + help="The entity to use. Defaults to current logged-in user", +) +@click.option( + "--resume_id", + "-r", + default=None, + help="Resume a launch sweep by passing an 8-char sweep id. Queue required", +) +@click.option( + "--prior_run", + "-R", + "prior_runs", + multiple=True, + default=None, + help="ID of an existing run to add to this sweep", +) +@click.argument("config", required=False, type=click.Path(exists=True)) +@click.pass_context +@display_error +def launch_sweep( + ctx, + project, + entity, + queue, + config, + resume_id, + prior_runs, +): + api = _get_cling_api() + env = os.environ + if not api.is_authenticated: + wandb.termlog("Login to W&B to use the sweep feature") + ctx.invoke(login, no_offline=True) + api = _get_cling_api(reset=True) + + entity = entity or env.get("WANDB_ENTITY") or api.settings("entity") + if entity is None: + wandb.termerror("Must specify entity when using launch") + return + + project = project or env.get("WANDB_PROJECT") or api.settings("project") + if project is None: + wandb.termerror("A project must be configured when using launch") + return + + # get personal username, not team name or service account, default to entity + author = api.viewer().get("username") or entity + + # if not sweep_config XOR resume_id + if not (config or resume_id): + wandb.termerror("'config' and/or 'resume_id' required") + return + + parsed_user_config = sweep_utils.load_launch_sweep_config(config) + # Rip special keys out of config, store in scheduler run_config + launch_args: Dict[str, Any] = parsed_user_config.pop("launch", {}) + scheduler_args: Dict[str, Any] = parsed_user_config.pop("scheduler", {}) + settings: Dict[str, Any] = scheduler_args.pop("settings", {}) + + scheduler_job: Optional[str] = scheduler_args.get("job") + if scheduler_job: + wandb.termwarn( + "Using a scheduler job for launch sweeps is *experimental* and may change without warning" + ) + queue: Optional[str] = queue or launch_args.get("queue") + + sweep_config, sweep_obj_id = None, None + if not resume_id: + sweep_config = parsed_user_config + + # check method + method = sweep_config.get("method") + if scheduler_job and not method: + sweep_config["method"] = "custom" + elif scheduler_job and method != "custom": + # TODO(gst): Check if using Anaconda2 + wandb.termwarn( + "Use 'method': 'custom' in the sweep config when using scheduler jobs, " + "or omit it entirely. For jobs using the wandb optimization engine (WandbScheduler), " + "set the method in the sweep config under scheduler.settings.method " + ) + settings["method"] = method + + if settings.get("method"): + # assume WandbScheduler, and user is using this right + sweep_config["method"] = settings["method"] + + else: # Resuming an existing sweep + found = api.sweep(resume_id, "{}", entity=entity, project=project) + if not found: + wandb.termerror(f"Could not find sweep {entity}/{project}/{resume_id}") + return + + if found.get("state") == "RUNNING": + wandb.termerror( + f"Cannot resume sweep {entity}/{project}/{resume_id}, it is already running" + ) + return + + sweep_obj_id = found["id"] + sweep_config = yaml.safe_load(found["config"]) + wandb.termlog(f"Resuming from existing sweep {entity}/{project}/{resume_id}") + if len(parsed_user_config.keys()) > 0: + wandb.termwarn( + "Sweep parameters loaded from resumed sweep, ignoring provided config" + ) + + prev_scheduler = json.loads(found.get("scheduler") or "{}") + run_spec = json.loads(prev_scheduler.get("run_spec", "{}")) + if ( + scheduler_job + and run_spec.get("job") + and run_spec.get("job") != scheduler_job + ): + wandb.termerror( + f"Resuming a launch sweep with a different scheduler job is not supported. Job loaded from sweep: {run_spec.get('job')}, job in config: {scheduler_job}" + ) + return + + prev_scheduler_args, prev_settings = sweep_utils.get_previous_args(run_spec) + # Passed in scheduler_args and settings override previous + scheduler_args.update(prev_scheduler_args) + settings.update(prev_settings) + if not queue: + wandb.termerror( + "Launch-sweeps require setting a 'queue', use --queue option or a 'queue' key in the 'launch' section in the config" + ) + return + + entrypoint = Scheduler.ENTRYPOINT if not scheduler_job else None + args = sweep_utils.construct_scheduler_args( + return_job=scheduler_job is not None, + sweep_config=sweep_config, + queue=queue, + project=project, + author=author, + ) + if not args: + return + + # validate training job existence + if not sweep_utils.check_job_exists(PublicApi(), sweep_config.get("job")): + return False + + # validate scheduler job existence, if present + if not sweep_utils.check_job_exists(PublicApi(), scheduler_job): + return False + + # Set run overrides for the Scheduler + overrides = {"run_config": {}} + if launch_args: + overrides["run_config"]["launch"] = launch_args + if scheduler_args: + overrides["run_config"]["scheduler"] = scheduler_args + if settings: + overrides["run_config"]["settings"] = settings + + if scheduler_job: + overrides["run_config"]["sweep_args"] = args + else: + overrides["args"] = args + + # configure scheduler job resource + resource = scheduler_args.get("resource") + if resource: + if resource == "local-process" and scheduler_job: + wandb.termerror( + "Scheduler jobs cannot be run with the 'local-process' resource" + ) + return + if resource == "local-process" and scheduler_args.get("docker_image"): + wandb.termerror( + "Scheduler jobs cannot be run with the 'local-process' resource and a docker image" + ) + return + else: # no resource set, default local-process if not scheduler job, else container + resource = "local-process" if not scheduler_job else "local-container" + + # Launch job spec for the Scheduler + launch_scheduler_spec = launch_utils.construct_launch_spec( + uri=Scheduler.PLACEHOLDER_URI, + api=api, + name="Scheduler.WANDB_SWEEP_ID", + project=project, + entity=entity, + docker_image=scheduler_args.get("docker_image"), + resource=resource, + entry_point=entrypoint, + resource_args=scheduler_args.get("resource_args", {}), + repository=launch_args.get("registry", {}).get("url", None), + job=scheduler_job, + version=None, + launch_config={"overrides": overrides}, + run_id="WANDB_SWEEP_ID", # scheduler inits run with sweep_id=run_id + author=None, # author gets passed into scheduler override args + ) + launch_scheduler_with_queue = json.dumps( + { + "queue": queue, + "run_queue_project": launch_utils.LAUNCH_DEFAULT_PROJECT, + "run_spec": json.dumps(launch_scheduler_spec), + } + ) + + sweep_id, warnings = api.upsert_sweep( + sweep_config, + project=project, + entity=entity, + obj_id=sweep_obj_id, # if resuming + launch_scheduler=launch_scheduler_with_queue, + state="PENDING", + prior_runs=prior_runs, + template_variable_values=scheduler_args.get("template_variables", None), + ) + sweep_utils.handle_sweep_config_violations(warnings) + # Log nicely formatted sweep information + styled_id = click.style(sweep_id, fg="yellow") + wandb.termlog(f"{'Resumed' if resume_id else 'Created'} sweep with ID: {styled_id}") + sweep_url = wandb_sweep._get_sweep_url(api, sweep_id) + if sweep_url: + styled_url = click.style(sweep_url, underline=True, fg="blue") + wandb.termlog(f"View sweep at: {styled_url}") + wandb.termlog(f"Scheduler added to launch queue ({queue})") + + +@cli.command(help=f"Launch or queue a W&B Job. See {url_registry.url('wandb-launch')}") +@click.option( + "--uri", + "-u", + metavar="(str)", + default=None, + help="Local path or git repo uri to launch. If provided this command will " + "create a job from the specified uri.", +) +@click.option( + "--job", + "-j", + metavar="(str)", + default=None, + help="Name of the job to launch. If passed in, launch does not require a uri.", +) +@click.option( + "--entry-point", + "-E", + metavar="NAME", + default=None, + help="""Entry point within project. [default: main]. If the entry point is not found, + attempts to run the project file with the specified name as a script, + using 'python' to run .py files and the default shell (specified by + environment variable $SHELL) to run .sh files. If passed in, will override the entrypoint value passed in using a config file.""", +) +@click.option( + "--git-version", + "-g", + metavar="GIT-VERSION", + hidden=True, + help="Version of the project to run, as a Git commit reference for Git projects.", +) +@click.option( + "--build-context", + metavar="(str)", + help="Path to the build context within the source code. Defaults to the " + "root of the source code. Compatible only with -u.", +) +@click.option( + "--job-name", + "-J", + metavar="(str)", + default=None, + hidden=True, + help="Name for the job created if the -u,--uri flag is passed in.", +) +@click.option( + "--name", + envvar="WANDB_NAME", + help="""Name of the run under which to launch the run. If not + specified, a random run name will be used to launch run. If passed in, will override the name passed in using a config file.""", +) +@click.option( + "--entity", + "-e", + metavar="(str)", + default=None, + help="""Name of the target entity which the new run will be sent to. Defaults to using the entity set by local wandb/settings folder. + If passed in, will override the entity value passed in using a config file.""", +) +@click.option( + "--project", + "-p", + metavar="(str)", + default=None, + help="""Name of the target project which the new run will be sent to. Defaults to using the project name given by the source uri + or for github runs, the git repo name. If passed in, will override the project value passed in using a config file.""", +) +@click.option( + "--resource", + "-r", + metavar="BACKEND", + default=None, + help="""Execution resource to use for run. Supported values: 'local-process', 'local-container', 'kubernetes', 'sagemaker', 'gcp-vertex'. + This is now a required parameter if pushing to a queue with no resource configuration. + If passed in, will override the resource value passed in using a config file.""", +) +@click.option( + "--docker-image", + "-d", + default=None, + metavar="DOCKER IMAGE", + help="""Specific docker image you'd like to use. In the form name:tag. + If passed in, will override the docker image value passed in using a config file.""", +) +@click.option( + "--base-image", + "-B", + default=None, + metavar="BASE IMAGE", + help="""Docker image to run job code in. Incompatible with --docker-image.""", +) +@click.option( + "--config", + "-c", + metavar="FILE", + help="""Path to JSON file (must end in '.json') or JSON string which will be passed + as a launch config. Dictation how the launched run will be configured.""", +) +@click.option( + "--set-var", + "-v", + "cli_template_vars", + default=None, + multiple=True, + help="""Set template variable values for queues with allow listing enabled, + as key-value pairs e.g. `--set-var key1=value1 --set-var key2=value2`""", +) +@click.option( + "--queue", + "-q", + is_flag=False, + flag_value="default", + default=None, + help="""Name of run queue to push to. If none, launches single run directly. If supplied without + an argument (`--queue`), defaults to queue 'default'. Else, if name supplied, specified run queue must exist under the + project and entity supplied.""", +) +@click.option( + "--async", + "run_async", + is_flag=True, + help="""Flag to run the job asynchronously. Defaults to false, i.e. unless --async is set, wandb launch will wait for + the job to finish. This option is incompatible with --queue; asynchronous options when running with an agent should be + set on wandb launch-agent.""", +) +@click.option( + "--resource-args", + "-R", + metavar="FILE", + help="""Path to JSON file (must end in '.json') or JSON string which will be passed + as resource args to the compute resource. The exact content which should be + provided is different for each execution backend. See documentation for layout of this file.""", +) +@click.option( + "--build", + "-b", + is_flag=True, + hidden=True, + help="Flag to build an associated job and push to queue as an image job.", +) +@click.option( + "--repository", + "-rg", + is_flag=False, + default=None, + hidden=True, + help="Name of a remote repository. Will be used to push a built image to.", +) +# TODO: this is only included for back compat. But we should remove this in the future +@click.option( + "--project-queue", + "-pq", + default=None, + hidden=True, + help="Name of the project containing the queue to push to. If none, defaults to entity level queues.", +) +@click.option( + "--dockerfile", + "-D", + default=None, + help="Path to the Dockerfile used to build the job, relative to the job's root", +) +@click.option( + "--priority", + "-P", + default=None, + type=click.Choice(["critical", "high", "medium", "low"]), + help="""When --queue is passed, set the priority of the job. Launch jobs with higher priority + are served first. The order, from highest to lowest priority, is: critical, high, medium, low""", +) +@display_error +def launch( + uri, + job, + entry_point, + git_version, + build_context, + name, + resource, + entity, + project, + docker_image, + base_image, + config, + cli_template_vars, + queue, + run_async, + resource_args, + build, + repository, + project_queue, + dockerfile, + priority, + job_name, +): + """Start a W&B run from the given URI. + + The URI can bea wandb URI, a GitHub repo uri, or a local path). In the case of a + wandb URI the arguments used in the original run will be used by default. These + arguments can be overridden using the args option, or specifying those arguments in + the config's 'overrides' key, 'args' field as a list of strings. + + Running `wandb launch [URI]` will launch the run directly. To add the run to a + queue, run `wandb launch [URI] --queue [optional queuename]`. + """ + logger.info( + f"=== Launch called with kwargs {locals()} CLI Version: {wandb.__version__}===" + ) + from wandb.sdk.launch._launch import _launch + from wandb.sdk.launch.create_job import _create_job + from wandb.sdk.launch.utils import _is_git_uri + + api = _get_cling_api() + get_sentry().configure_scope(process_context="launch_cli") + + if run_async and queue is not None: + raise LaunchError( + "Cannot use both --async and --queue with wandb launch, see help for details." + ) + + if queue and docker_image and not project: + raise LaunchError( + "Cannot use --queue and --docker together without a project. Please specify a project with --project or -p." + ) + + if priority is not None and queue is None: + raise LaunchError("--priority flag requires --queue to be set") + + if resource_args is not None: + resource_args = util.load_json_yaml_dict(resource_args) + if resource_args is None: + raise LaunchError("Invalid format for resource-args") + else: + resource_args = {} + + if entry_point is not None: + entry_point = shlex.split(entry_point) + + if config is not None: + config = util.load_json_yaml_dict(config) + if config is None: + raise LaunchError("Invalid format for config") + else: + config = {} + + resource = resource or config.get("resource") + + if build and queue is None: + raise LaunchError("Build flag requires a queue to be set") + + try: + launch_utils.check_logged_in(api) + except Exception: + wandb.termerror(f"Error running job: {traceback.format_exc()}") + + run_id = config.get("run_id") + + # If URI was provided, we need to create a job from it. + if uri: + if entry_point is None: + raise LaunchError( + "Cannot provide a uri without an entry point. Please provide an " + "entry point with --entry-point or -E." + ) + if job is not None: + raise LaunchError("Cannot provide both a uri and a job name.") + job_type = ( + "git" if _is_git_uri(uri) else "code" + ) # TODO: Add support for local URIs with git. + if entity is None: + entity = launch_utils.get_default_entity(api, config) + artifact, _, _ = _create_job( + api, + job_type, + uri, + entrypoint=" ".join(entry_point), + git_hash=git_version, + name=job_name, + project=project, + base_image=base_image, + build_context=build_context, + dockerfile=dockerfile, + entity=entity, + ) + if artifact is None: + raise LaunchError(f"Failed to create job from uri: {uri}") + job = f"{entity}/{project}/{artifact.name}" + + if dockerfile: + if "overrides" in config: + config["overrides"]["dockerfile"] = dockerfile + else: + config["overrides"] = {"dockerfile": dockerfile} + + if priority is not None: + priority_map = { + "critical": 0, + "high": 1, + "medium": 2, + "low": 3, + } + priority = priority_map[priority.lower()] + + template_variables = None + if cli_template_vars: + if queue is None: + raise LaunchError("'--set-var' flag requires queue to be set") + if entity is None: + entity = launch_utils.get_default_entity(api, config) + public_api = PublicApi() + runqueue = RunQueue(client=public_api.client, name=queue, entity=entity) + template_variables = launch_utils.fetch_and_validate_template_variables( + runqueue, cli_template_vars + ) + + if queue is None: + # direct launch + try: + run = asyncio.run( + _launch( + api, + job, + project=project, + entity=entity, + docker_image=docker_image, + name=name, + entry_point=entry_point, + version=git_version, + resource=resource, + resource_args=resource_args, + launch_config=config, + synchronous=(not run_async), + run_id=run_id, + repository=repository, + ) + ) + if asyncio.run(run.get_status()).state in [ + "failed", + "stopped", + "preempted", + ]: + wandb.termerror("Launched run exited with non-zero status") + sys.exit(1) + except LaunchError as e: + logger.exception("An error occurred.") + get_sentry().exception(e) + sys.exit(e) + except ExecutionError as e: + logger.exception("An error occurred.") + get_sentry().exception(e) + sys.exit(e) + except asyncio.CancelledError: + sys.exit(0) + else: + try: + _launch_add( + api, + job, + config, + template_variables, + project, + entity, + queue, + resource, + entry_point, + name, + git_version, + docker_image, + project_queue, + resource_args, + build=build, + run_id=run_id, + repository=repository, + priority=priority, + ) + + except Exception as e: + get_sentry().exception(e) + raise + + +@cli.command( + context_settings=CONTEXT, + help="Run a W&B launch agent.", +) +@click.pass_context +@click.option( + "--queue", + "-q", + "queues", + default=None, + multiple=True, + help="The name of a queue for the agent to watch. Multiple -q flags supported.", +) +@click.option( + "--entity", + "-e", + default=None, + help="The entity to use. Defaults to current logged-in user", +) +@click.option( + "--log-file", + "-l", + default=None, + help=( + "Destination for internal agent logs. Use - for stdout. " + "By default all agents logs will go to debug.log in your wandb/ " + "subdirectory or WANDB_DIR if set." + ), +) +@click.option( + "--max-jobs", + "-j", + default=None, + help="The maximum number of launch jobs this agent can run in parallel. Defaults to 1. Set to -1 for no upper limit", +) +@click.option( + "--config", "-c", default=None, help="path to the agent config yaml to use" +) +@click.option( + "--url", + "-u", + default=None, + hidden=True, + help="a wandb client registration URL, this is generated in the UI", +) +@click.option("--verbose", "-v", count=True, help="Display verbose output") +@display_error +def launch_agent( + ctx, + entity=None, + queues=None, + max_jobs=None, + config=None, + url=None, + log_file=None, + verbose=0, +): + logger.info( + f"=== Launch-agent called with kwargs {locals()} CLI Version: {wandb.__version__} ===" + ) + if url is not None: + raise LaunchError( + "--url is not supported in this version, upgrade with: pip install -u wandb" + ) + + import wandb.sdk.launch._launch as _launch + + if log_file is not None: + _launch.set_launch_logfile(log_file) + + api = _get_cling_api() + get_sentry().configure_scope(process_context="launch_agent") + agent_config, api = _launch.resolve_agent_config( + entity, max_jobs, queues, config, verbose + ) + + if len(agent_config.get("queues")) == 0: + raise LaunchError( + "To launch an agent please specify a queue or a list of queues in the configuration file or cli." + ) + + launch_utils.check_logged_in(api) + + wandb.termlog("Starting launch agent ✨") + try: + _launch.create_and_run_agent(api, agent_config) + except Exception as e: + get_sentry().exception(e) + raise + + +@cli.command(context_settings=CONTEXT, help="Run the W&B agent") +@click.pass_context +@click.option( + "--project", + "-p", + default=None, + help="""The name of the project where W&B runs created from the sweep are sent to. If the project is not specified, the run is sent to a project labeled 'Uncategorized'.""", +) +@click.option( + "--entity", + "-e", + default=None, + help="""The username or team name where you want to send W&B runs created by the sweep to. Ensure that the entity you specify already exists. If you don't specify an entity, the run will be sent to your default entity, which is usually your username.""", +) +@click.option( + "--count", default=None, type=int, help="The max number of runs for this agent." +) +@click.option( + "--forward-signals", + "-f", + is_flag=True, + default=False, + help="""Forward signals delivered to the agent (e.g. SIGINT/SIGTERM) to its child runs so they can shut down cleanly.""", +) +@click.argument("sweep_id") +@display_error +def agent(ctx, project, entity, count, forward_signals, sweep_id): + api = _get_cling_api() + if not api.is_authenticated: + wandb.termlog("Login to W&B to use the sweep agent feature") + ctx.invoke(login, no_offline=True) + api = _get_cling_api(reset=True) + + wandb.termlog("Starting wandb agent 🕵️") + wandb_agent.agent( + sweep_id, + entity=entity, + project=project, + count=count, + forward_signals=forward_signals, + ) + + # you can send local commands like so: + # agent_api.command({'type': 'run', 'program': 'train.py', + # 'args': ['--max_epochs=10']}) + + +@cli.command( + context_settings=RUN_CONTEXT, help="Run a W&B launch sweep scheduler (Experimental)" +) +@click.pass_context +@click.argument("sweep_id") +@display_error +def scheduler( + ctx, + sweep_id, +): + api = InternalApi() + if not api.is_authenticated: + wandb.termlog("Login to W&B to use the sweep scheduler feature") + ctx.invoke(login, no_offline=True) + api = InternalApi(reset=True) + + get_sentry().configure_scope(process_context="sweep_scheduler") + wandb.termlog("Starting a Launch Scheduler 🚀") + from wandb.sdk.launch.sweeps import load_scheduler + + # TODO(gst): remove this monstrosity + # Future-proofing hack to pull any kwargs that get passed in through the CLI + kwargs = {} + for i, _arg in enumerate(ctx.args): + if isinstance(_arg, str) and _arg.startswith("--"): + # convert input kwargs from hyphens to underscores + _key = _arg[2:].replace("-", "_") + _args = ctx.args[i + 1] + if str.isdigit(_args): + _args = int(_args) + kwargs[_key] = _args + try: + sweep_type = kwargs.get("sweep_type", "wandb") + _scheduler = load_scheduler(scheduler_type=sweep_type)( + api, + sweep_id=sweep_id, + **kwargs, + ) + _scheduler.start() + except Exception as e: + get_sentry().exception(e) + raise + + +@cli.group(help="Commands for managing and viewing W&B jobs") +def job() -> None: + pass + + +@job.command("list", help="List jobs in a project") +@click.option( + "--project", + "-p", + envvar=env.PROJECT, + help="The project you want to list jobs from.", +) +@click.option( + "--entity", + "-e", + default="models", + envvar=env.ENTITY, + help="The entity the jobs belong to", +) +def _list(project, entity): + wandb.termlog(f"Listing jobs in {entity}/{project}") + public_api = PublicApi() + try: + jobs = public_api.list_jobs(entity=entity, project=project) + except wandb.errors.CommError as e: + wandb.termerror(f"{e}") + return + + if len(jobs) == 0: + wandb.termlog("No jobs found") + return + + for job in jobs: + aliases = [] + if len(job["edges"]) == 0: + # deleted? + continue + + name = job["edges"][0]["node"]["artifactSequence"]["name"] + for version in job["edges"]: + aliases += [x["alias"] for x in version["node"]["aliases"]] + + # only list the most recent 10 job versions + aliases_str = ",".join(aliases[::-1]) + wandb.termlog(f"{name} -- versions ({len(aliases)}): {aliases_str}") + + +@job.command( + help="Describe a launch job. Provide the launch job in the form of: entity/project/job-name:alias-or-version" +) +@click.argument("job") +def describe(job): + public_api = PublicApi() + try: + job = public_api.job(name=job) + except wandb.errors.CommError as e: + wandb.termerror(f"{e}") + return + + for key in job._job_info: + if key.startswith("_"): + continue + wandb.termlog(f"{key}: {job._job_info[key]}") + + +@job.command( + no_args_is_help=True, +) +@click.option( + "--project", + "-p", + envvar=env.PROJECT, + help="The project you want to list jobs from.", +) +@click.option( + "--entity", + "-e", + envvar=env.ENTITY, + help="The entity the jobs belong to", +) +@click.option( + "--name", + "-n", + help="Name for the job", +) +@click.option( + "--description", + "-d", + help="Description for the job", +) +@click.option( + "--alias", + "-a", + "aliases", + help="Alias for the job", + multiple=True, + default=tuple(), +) +@click.option( + "--entry-point", + "-E", + "entrypoint", + help="Entrypoint to the script, including an executable and an entrypoint " + "file. Required for code or repo jobs. If --build-context is provided, " + "paths in the entrypoint command will be relative to the build context.", +) +@click.option( + "--git-hash", + "-g", + "git_hash", + type=str, + help="Commit reference to use as the source for git jobs", +) +@click.option( + "--runtime", + "-r", + type=str, + help="Python runtime to execute the job", +) +@click.option( + "--build-context", + "-b", + type=str, + help="Path to the build context from the root of the job source code. If " + "provided, this is used as the base path for the Dockerfile and entrypoint.", +) +@click.option( + "--base-image", + "-B", + type=str, + help="Base image to use for the job. Incompatible with image jobs.", +) +@click.option( + "--dockerfile", + "-D", + type=str, + help="Path to the Dockerfile for the job. If --build-context is provided, " + "the Dockerfile path will be relative to the build context.", +) +@click.argument( + "job_type", + type=click.Choice(("git", "code", "image")), +) +@click.option( + "--service", + "-s", + "services", + multiple=True, + callback=parse_service_config, + help="Service configurations in format serviceName=policy. Valid policies: always, never", + hidden=True, +) +@click.option( + "--schema", + type=str, + help="Path to the schema file for the job.", + hidden=True, +) +@click.argument("path") +def create( + path, + project, + entity, + name, + job_type, + description, + aliases, + entrypoint, + git_hash, + runtime, + build_context, + base_image, + dockerfile, + services, + schema, +): + """Create a job from a source, without a wandb run. + + Jobs can be of three types, git, code, or image. + + git: A git source, with an entrypoint either in the path or provided explicitly pointing to the main python executable. + code: A code path, containing a requirements.txt file. + image: A docker image. + """ + from wandb.sdk.launch.create_job import _create_job + + api = _get_cling_api() + get_sentry().configure_scope(process_context="job_create") + + entity = entity or os.getenv("WANDB_ENTITY") or api.default_entity + if not entity: + wandb.termerror("No entity provided, use --entity or set WANDB_ENTITY") + return + + project = project or os.getenv("WANDB_PROJECT") + if not project: + wandb.termerror("No project provided, use --project or set WANDB_PROJECT") + return + + if entrypoint is None and job_type in ["git", "code"]: + wandb.termwarn( + f"No entrypoint provided for {job_type} job, defaulting to main.py" + ) + entrypoint = "main.py" + + if job_type == "image" and base_image: + wandb.termerror("Cannot provide --base-image/-B for an `image` job") + return + + if schema: + schema_dict = util.load_json_yaml_dict(schema) + if schema_dict is None: + wandb.termerror(f"Invalid format for schema file: {schema}") + return + + artifact, action, aliases = _create_job( + api=api, + path=path, + entity=entity, + project=project, + name=name, + job_type=job_type, + description=description, + aliases=list(aliases), + entrypoint=entrypoint, + git_hash=git_hash, + runtime=runtime, + build_context=build_context, + base_image=base_image, + dockerfile=dockerfile, + services=services, + schema=schema_dict if schema else None, + ) + if not artifact: + wandb.termerror("Job creation failed") + return + + artifact_path = f"{entity}/{project}/{artifact.name}" + msg = f"{action} job: {click.style(artifact_path, fg='yellow')}" + if len(aliases) == 1: + alias_str = click.style(aliases[0], fg="yellow") + msg += f", with alias: {alias_str}" + elif len(aliases) > 1: + alias_str = click.style(", ".join(aliases), fg="yellow") + msg += f", with aliases: {alias_str}" + + wandb.termlog(msg) + web_url = util.app_url(api.settings().get("base_url")) + url = click.style(f"{web_url}/{entity}/{project}/jobs", underline=True) + wandb.termlog(f"View all jobs in project '{project}' here: {url}\n") + + +@cli.command(context_settings=CONTEXT, help="Run the W&B local sweep controller") +@click.option("--verbose", is_flag=True, default=False, help="Display verbose output") +@click.argument("sweep_id") +@display_error +def controller(verbose, sweep_id): + click.echo("Starting wandb controller...") + from wandb import controller as wandb_controller + + tuner = wandb_controller(sweep_id) + tuner.run(verbose=verbose) + + +@cli.command(context_settings=RUN_CONTEXT, name="docker-run") +@click.pass_context +@click.argument("docker_run_args", nargs=-1) +def docker_run(ctx, docker_run_args): + """Wrap `docker run` and adds WANDB_API_KEY and WANDB_DOCKER environment variables. + + This will also set the runtime to nvidia if the nvidia-docker executable is present + on the system and --runtime wasn't set. + + See `docker run --help` for more details. + """ + import wandb.docker + + api = InternalApi() + args = list(docker_run_args) + if len(args) > 0 and args[0] == "run": + args.pop(0) + if len([a for a in args if a.startswith("--runtime")]) == 0 and _HAS_NVIDIA_DOCKER: + args = ["--runtime", "nvidia"] + args + # TODO: image_from_docker_args uses heuristics to find the docker image arg, there are likely cases + # where this won't work + image = util.image_from_docker_args(args) + resolved_image = None + if image: + resolved_image = wandb.docker.image_id(image) + if resolved_image: + args = ["-e", f"WANDB_DOCKER={resolved_image}"] + args + else: + wandb.termlog( + "Couldn't detect image argument, running command without the WANDB_DOCKER env variable" + ) + if api.api_key: + args = ["-e", f"WANDB_API_KEY={api.api_key}"] + args + else: + wandb.termlog( + "Not logged in, run `wandb login` from the host machine to enable result logging" + ) + subprocess.call(["docker", "run"] + args) + + +@cli.command(context_settings=RUN_CONTEXT) +@click.pass_context +@click.argument("docker_run_args", nargs=-1) +@click.argument("docker_image", required=False) +@click.option( + "--nvidia/--no-nvidia", + default=_HAS_NVIDIA_DOCKER, + help="Use the nvidia runtime, defaults to nvidia if nvidia-docker is present", +) +@click.option( + "--digest", is_flag=True, default=False, help="Output the image digest and exit" +) +@click.option( + "--jupyter/--no-jupyter", default=False, help="Run jupyter lab in the container" +) +@click.option( + "--dir", default="/app", help="Which directory to mount the code in the container" +) +@click.option("--no-dir", is_flag=True, help="Don't mount the current directory") +@click.option( + "--shell", default="/bin/bash", help="The shell to start the container with" +) +@click.option("--port", default="8888", help="The host port to bind jupyter on") +@click.option("--cmd", help="The command to run in the container") +@click.option( + "--no-tty", is_flag=True, default=False, help="Run the command without a tty" +) +@display_error +def docker( + ctx, + docker_run_args, + docker_image, + nvidia, + digest, + jupyter, + dir, + no_dir, + shell, + port, + cmd, + no_tty, +): + """Run your code in a docker container. + + W&B docker lets you run your code in a docker image ensuring wandb is configured. It + adds the WANDB_DOCKER and WANDB_API_KEY environment variables to your container and + mounts the current directory in /app by default. You can pass additional args which + will be added to `docker run` before the image name is declared, we'll choose a + default image for you if one isn't passed: + + ```sh + wandb docker -v /mnt/dataset:/app/data + wandb docker gcr.io/kubeflow-images-public/tensorflow-1.12.0-notebook-cpu:v0.4.0 --jupyter + wandb docker wandb/deepo:keras-gpu --no-tty --cmd "python train.py --epochs=5" + ``` + + By default, we override the entrypoint to check for the existence of wandb and + install it if not present. If you pass the --jupyter flag we will ensure jupyter is + installed and start jupyter lab on port 8888. If we detect nvidia-docker on your + system we will use the nvidia runtime. If you just want wandb to set environment + variable to an existing docker run command, see the wandb docker-run command. + """ + api = InternalApi() + if not _HAS_DOCKER: + raise ClickException("Docker not installed, install it from https://docker.com") + + import wandb.docker + + args = list(docker_run_args) + image = docker_image or "" + # remove run for users used to nvidia-docker + if len(args) > 0 and args[0] == "run": + args.pop(0) + if image == "" and len(args) > 0: + image = args.pop(0) + # If the user adds docker args without specifying an image (should be rare) + if not util.docker_image_regex(image.split("@")[0]): + if image: + args = args + [image] + image = wandb.docker.default_image(gpu=nvidia) + subprocess.call(["docker", "pull", image]) + _, repo_name, tag = wandb.docker.parse(image) + + resolved_image = wandb.docker.image_id(image) + if resolved_image is None: + raise ClickException( + f"Couldn't find image locally or in a registry, try running `docker pull {image}`" + ) + if digest: + sys.stdout.write(resolved_image) + exit(0) + + existing = wandb.docker.shell(["ps", "-f", f"ancestor={resolved_image}", "-q"]) + if existing: + if click.confirm( + "Found running container with the same image, do you want to attach?" + ): + subprocess.call(["docker", "attach", existing.split("\n")[0]]) + exit(0) + cwd = os.getcwd() + command = [ + "docker", + "run", + "-e", + "LANG=C.UTF-8", + "-e", + f"WANDB_DOCKER={resolved_image}", + "--ipc=host", + "-v", + wandb.docker.entrypoint + ":/wandb-entrypoint.sh", + "--entrypoint", + "/wandb-entrypoint.sh", + ] + if nvidia: + command.extend(["--runtime", "nvidia"]) + if not no_dir: + # TODO: We should default to the working directory if defined + command.extend(["-v", cwd + ":" + dir, "-w", dir]) + if api.api_key: + command.extend(["-e", f"WANDB_API_KEY={api.api_key}"]) + else: + wandb.termlog( + "Couldn't find WANDB_API_KEY, run `wandb login` to enable streaming metrics" + ) + if jupyter: + command.extend(["-e", "WANDB_ENSURE_JUPYTER=1", "-p", port + ":8888"]) + no_tty = True + cmd = f"jupyter lab --no-browser --ip=0.0.0.0 --allow-root --NotebookApp.token= --notebook-dir {dir}" + command.extend(args) + if no_tty: + command.extend([image, shell, "-c", cmd]) + else: + if cmd: + command.extend(["-e", f"WANDB_COMMAND={cmd}"]) + command.extend(["-it", image, shell]) + wandb.termlog("Launching docker container \U0001f6a2") + subprocess.call(command) + + +@cli.command( + context_settings=RUN_CONTEXT, + help="Start a local W&B container (deprecated, see wandb server --help)", + hidden=True, +) +@click.pass_context +@click.option("--port", "-p", default="8080", help="The host port to bind W&B local on") +@click.option( + "--env", "-e", default=[], multiple=True, help="Env vars to pass to wandb/local" +) +@click.option( + "--daemon/--no-daemon", default=True, help="Run or don't run in daemon mode" +) +@click.option( + "--upgrade", is_flag=True, default=False, help="Upgrade to the most recent version" +) +@click.option( + "--edge", is_flag=True, default=False, help="Run the bleeding edge", hidden=True +) +@display_error +def local(ctx, *args, **kwargs): + wandb.termwarn("`wandb local` has been replaced with `wandb server start`.") + ctx.invoke(start, *args, **kwargs) + + +@cli.group(help="Commands for operating a local W&B server") +def server(): + pass + + +@server.command(context_settings=RUN_CONTEXT, help="Start a local W&B server") +@click.pass_context +@click.option( + "--port", "-p", default="8080", help="The host port to bind W&B server on" +) +@click.option( + "--env", "-e", default=[], multiple=True, help="Env vars to pass to wandb/local" +) +@click.option( + "--daemon/--no-daemon", default=True, help="Run or don't run in daemon mode" +) +@click.option( + "--upgrade", + is_flag=True, + default=False, + help="Upgrade to the most recent version", + hidden=True, +) +@click.option( + "--edge", is_flag=True, default=False, help="Run the bleeding edge", hidden=True +) +@display_error +def start(ctx, port, env, daemon, upgrade, edge): + api = InternalApi() + if not _HAS_DOCKER: + raise ClickException("Docker not installed, install it from https://docker.com") + + import wandb.docker + + local_image_sha = wandb.docker.image_id("wandb/local").split("wandb/local")[-1] + registry_image_sha = wandb.docker.image_id_from_registry("wandb/local").split( + "wandb/local" + )[-1] + if local_image_sha != registry_image_sha: + if upgrade: + subprocess.call(["docker", "pull", "wandb/local"]) + else: + wandb.termlog( + "A new version of the W&B server is available, upgrade by calling `wandb server start --upgrade`" + ) + running = subprocess.check_output( + ["docker", "ps", "--filter", "name=^wandb-local$", "--format", "{{.ID}}"] + ) + if running != b"": + if upgrade: + subprocess.call(["docker", "stop", "wandb-local"]) + else: + wandb.termerror( + "A container named wandb-local is already running, run `docker stop wandb-local` if you want to start a new instance" + ) + exit(1) + image = "docker.pkg.github.com/wandb/core/local" if edge else "wandb/local" + username = getpass.getuser() + env_vars = ["-e", f"LOCAL_USERNAME={username}"] + for e in env: + env_vars.append("-e") + env_vars.append(e) + command = [ + "docker", + "run", + "--rm", + "-v", + "wandb:/vol", + "-p", + port + ":8080", + "--name", + "wandb-local", + ] + env_vars + host = f"http://localhost:{port}" + + system_settings = wandb_setup.singleton().settings.read_system_settings() + system_settings.set("base_url", host, globally=True) + + try: + system_settings.save() + except settings_file.SaveSettingsError as e: + msg = "Failed to update base_url setting" + logger.exception(msg) + wandb.termerror(f"{msg}: {e}") + + if daemon: + command += ["-d"] + command += [image] + + # DEVNULL is only in py3 + try: + from subprocess import DEVNULL + except ImportError: + DEVNULL = open(os.devnull, "wb") # noqa: N806 + code = subprocess.call(command, stdout=DEVNULL) + if daemon: + if code != 0: + wandb.termerror( + "Failed to launch the W&B server container, see the above error." + ) + exit(1) + else: + wandb.termlog(f"W&B server started at http://localhost:{port} \U0001f680") + wandb.termlog("You can stop the server by running `wandb server stop`") + if not api.api_key: + # Let the server start before potentially launching a browser + time.sleep(2) + ctx.invoke(login, host=host) + + +@server.command(context_settings=RUN_CONTEXT, help="Stop a local W&B server") +def stop(): + if not _HAS_DOCKER: + raise ClickException("Docker not installed, install it from https://docker.com") + subprocess.call(["docker", "stop", "wandb-local"]) + + +@cli.group(help="Commands for interacting with artifacts") +def artifact(): + pass + + +@artifact.command(context_settings=CONTEXT, help="Upload an artifact to wandb") +@click.argument("path") +@click.option( + "--name", "-n", help="The name of the artifact to push: project/artifact_name" +) +@click.option("--description", "-d", help="A description of this artifact") +@click.option("--type", "-t", default="dataset", help="The type of the artifact") +@click.option( + "--alias", + "-a", + default=["latest"], + multiple=True, + help="An alias to apply to this artifact", +) +@click.option("--id", "run_id", help="The run you want to upload to.") +@click.option( + "--resume", + is_flag=True, + default=None, + help="Resume the last run from your current directory.", +) +@click.option( + "--skip_cache", + is_flag=True, + default=False, + help="Skip caching while uploading artifact files.", +) +@click.option( + "--policy", + default="mutable", + type=click.Choice(["mutable", "immutable"]), + help="Set the storage policy while uploading artifact files.", +) +@display_error +def put( + path, + name, + description, + type, + alias, + run_id, + resume, + skip_cache, + policy, +): + if name is None: + name = os.path.basename(path) + public_api = PublicApi() + entity, project, artifact_name = public_api._parse_artifact_path(name) + if project is None: + project = click.prompt("Enter the name of the project you want to use") + + artifact = wandb.Artifact(name=artifact_name, type=type, description=description) + artifact_path = f"{entity}/{project}/{artifact_name}:{alias[0]}" + if os.path.isdir(path): + wandb.termlog(f'Uploading directory {path} to: "{artifact_path}" ({type})') + artifact.add_dir(path, skip_cache=skip_cache, policy=policy) + elif os.path.isfile(path): + wandb.termlog(f'Uploading file {path} to: "{artifact_path}" ({type})') + artifact.add_file(path, skip_cache=skip_cache, policy=policy) + elif "://" in path: + wandb.termlog( + f'Logging reference artifact from {path} to: "{artifact_path}" ({type})' + ) + artifact.add_reference(path) + else: + raise ClickException("Path argument must be a file or directory") + + with wandb.init( + entity=entity, + project=project, + config={"path": path}, + job_type="cli_put", + id=run_id, + resume=resume, + ) as run: + run.log_artifact(artifact, aliases=alias) + artifact.wait() + + wandb.termlog( + "Artifact uploaded, use this artifact in a run by adding:\n", prefix=False + ) + wandb.termlog( + f' artifact = run.use_artifact("{artifact.source_qualified_name}")\n', + prefix=False, + ) + + +@artifact.command(context_settings=CONTEXT, help="Download an artifact from wandb") +@click.argument("path") +@click.option("--root", help="The directory you want to download the artifact to") +@click.option("--type", help="The type of artifact you are downloading") +@display_error +def get(path, root, type): + public_api = PublicApi() + entity, project, artifact_name = public_api._parse_artifact_path(path) + if project is None: + project = click.prompt("Enter the name of the project you want to use") + + try: + artifact_parts = artifact_name.split(":") + if len(artifact_parts) > 1: + version = artifact_parts[1] + artifact_name = artifact_parts[0] + else: + version = "latest" + if is_artifact_registry_project(project): + organization = path.split("/")[0] if path.count("/") == 2 else "" + # set entity to match the settings since in above code it was potentially set to an org + settings_entity = public_api.settings["entity"] or public_api.default_entity + # Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path, + # we need to fetch the org entity to for the user behind the scenes. + entity = SDKInternalApi()._resolve_org_entity_name( + entity=settings_entity, organization=organization + ) + full_path = f"{entity}/{project}/{artifact_name}:{version}" + wandb.termlog( + "Downloading {type} artifact {full_path}".format( + type=type or "dataset", full_path=full_path + ) + ) + artifact = public_api.artifact(full_path, type=type) + path = artifact.download(root=root) + wandb.termlog(f"Artifact downloaded to {path}") + except ValueError: + raise ClickException("Unable to download artifact") + + +@artifact.command( + context_settings=CONTEXT, help="List all artifacts in a wandb project" +) +@click.argument("path") +@click.option("--type", "-t", help="The type of artifacts to list") +@display_error +def ls(path, type): + public_api = PublicApi() + if type is not None: + types = [public_api.artifact_type(type, path)] + else: + types = public_api.artifact_types(path) + + for kind in types: + for collection in kind.collections(): + versions = public_api.artifact_versions( + kind.type, + "/".join([kind.entity, kind.project, collection.name]), + per_page=1, + ) + if (latest := next(versions, None)) is not None: + wandb.termlog( + f"{kind.type:<15s}{latest.updated_at:<15s}{util.to_human_size(latest.size):>15s} {latest.name:<20s}" + ) + else: + # Artifact collection exists but has no versions. This can happen when: + # 1. A collection was just created but no artifacts have been logged yet. + # 2. All versions within an artifact collection were deleted. + wandb.termlog( + f"{kind.type:<15s}{'N/A':<15s}{'0 B':>15s} {collection.name:<20s} (no versions)" + ) + + +@artifact.group(help="Commands for interacting with the artifact cache") +def cache(): + pass + + +@cache.command( + context_settings=CONTEXT, + help="Clean up less frequently used files from the artifacts cache", +) +@click.argument("target_size") +@click.option("--remove-temp/--no-remove-temp", default=False, help="Remove temp files") +@display_error +def cleanup(target_size, remove_temp): + target_size = util.from_human_size(target_size) + cache = get_artifact_file_cache() + reclaimed_bytes = cache.cleanup(target_size, remove_temp) + wandb.termlog(f"Reclaimed {util.to_human_size(reclaimed_bytes)} of space") + + +@cli.command(context_settings=CONTEXT, help="Pull files from Weights & Biases") +@click.argument("run", envvar=env.RUN_ID) +@click.option( + "--project", "-p", envvar=env.PROJECT, help="The project you want to download." +) +@click.option( + "--entity", + "-e", + default="models", + envvar=env.ENTITY, + help="The entity to scope the listing to.", +) +@display_error +def pull(run, project, entity): + api = InternalApi() + project, run = api.parse_slug(run, project=project) + urls = api.download_urls(project, run=run, entity=entity) + if len(urls) == 0: + raise ClickException("Run has no files") + click.echo(f"Downloading: {click.style(project, bold=True)}/{run}") + + for name in urls: + if api.file_current(name, urls[name]["md5"]): + click.echo(f"File {name} is up to date") + else: + length, response = api.download_file(urls[name]["url"]) + # TODO: I had to add this because some versions in CI broke click.progressbar + sys.stdout.write(f"File {name}\r") + dirname = os.path.dirname(name) + if dirname != "": + filesystem.mkdir_exists_ok(dirname) + with click.progressbar( + length=length, + label=f"File {name}", + fill_char=click.style("&", fg="green"), + ) as bar: + with open(name, "wb") as f: + for data in response.iter_content(chunk_size=4096): + f.write(data) + bar.update(len(data)) + + +@cli.command( + context_settings=CONTEXT, + help="Restore code, config and docker state for a run. Retrieves code from latest commit if code was not saved with `wandb.save()` or `wandb.init(save_code=True)`.", +) +@click.pass_context +@click.argument("run", envvar=env.RUN_ID) +@click.option("--no-git", is_flag=True, default=False, help="Don't restore git state") +@click.option( + "--branch/--no-branch", + default=True, + help="Whether to create a branch or checkout detached", +) +@click.option( + "--project", "-p", envvar=env.PROJECT, help="The project you wish to upload to." +) +@click.option( + "--entity", "-e", envvar=env.ENTITY, help="The entity to scope the listing to." +) +@display_error +def restore(ctx, run, no_git, branch, project, entity): + from wandb.old.core import wandb_dir + from wandb.sdk.lib.gitlib import GitRepo + + api = _get_cling_api() + if ":" in run: + if "/" in run: + entity, rest = run.split("/", 1) + else: + rest = run + project, run = rest.split(":", 1) + elif run.count("/") > 1: + entity, run = run.split("/", 1) + + project, run = api.parse_slug(run, project=project) + commit, json_config, patch_content, metadata = api.run_config( + project, run=run, entity=entity + ) + repo = metadata.get("git", {}).get("repo") + image = metadata.get("docker") + restore_message = f"""`wandb restore` needs to be run from the same git repository as the original run. +Run `git clone {repo}` and restore from there or pass the --no-git flag.""" + + git = GitRepo(remote=api.settings("git_remote")) + + if no_git: + commit = None + elif not git.enabled: + if repo: + raise ClickException(restore_message) + elif image: + wandb.termlog( + "Original run has no git history. Just restoring config and docker" + ) + + if commit and git.enabled: + wandb.termlog(f"Fetching origin and finding commit: {commit}") + subprocess.check_call(["git", "fetch", "--all"]) + try: + git.repo.commit(commit) + except ValueError: + wandb.termlog(f"Couldn't find original commit: {commit}") + commit = None + files = api.download_urls(project, run=run, entity=entity) + for filename in files: + if filename.startswith("upstream_diff_") and filename.endswith( + ".patch" + ): + commit = filename[len("upstream_diff_") : -len(".patch")] + try: + git.repo.commit(commit) + except ValueError: + commit = None + else: + break + + if commit: + wandb.termlog(f"Falling back to upstream commit: {commit}") + patch_path, _ = api.download_write_file(files[filename]) + else: + raise ClickException(restore_message) + else: + if patch_content: + patch_path = os.path.join(wandb_dir(), "diff.patch") + with open(patch_path, "w") as f: + f.write(patch_content) + else: + patch_path = None + + branch_name = f"wandb/{run}" + if branch and branch_name not in git.repo.branches: + git.repo.git.checkout(commit, b=branch_name) + wandb.termlog(f"Created branch {click.style(branch_name, bold=True)}") + elif branch: + wandb.termlog( + f"Using existing branch, run `git branch -D {branch_name}` from master for a clean checkout" + ) + git.repo.git.checkout(branch_name) + else: + wandb.termlog(f"Checking out {commit} in detached mode") + git.repo.git.checkout(commit) + + if patch_path: + # we apply the patch from the repository root so git doesn't exclude + # things outside the current directory + root = git.root + patch_rel_path = os.path.relpath(patch_path, start=root) + # --reject is necessary or else this fails any time a binary file + # occurs in the diff + exit_code = subprocess.call( + ["git", "apply", "--reject", patch_rel_path], cwd=root + ) + if exit_code == 0: + wandb.termlog("Applied patch") + else: + wandb.termerror( + "Failed to apply patch, try un-staging any un-committed changes" + ) + + filesystem.mkdir_exists_ok(wandb_dir()) + config_path = os.path.join(wandb_dir(), "config.yaml") + config = Config() + for k, v in json_config.items(): + if k not in ("_wandb", "wandb_version"): + config[k] = v + s = b"wandb_version: 1" + s += b"\n\n" + yaml.dump( + config._as_dict(), + Dumper=yaml.SafeDumper, + default_flow_style=False, + allow_unicode=True, + encoding="utf-8", + ) + s = s.decode("utf-8") + with open(config_path, "w") as f: + f.write(s) + + wandb.termlog(f"Restored config variables to {config_path}") + if image: + if not metadata["program"].startswith("<") and metadata.get("args") is not None: + # TODO: we may not want to default to python here. + runner = util.find_runner(metadata["program"]) or ["python"] + command = runner + [metadata["program"]] + metadata["args"] + cmd = " ".join(command) + else: + wandb.termlog("Couldn't find original command, just restoring environment") + cmd = None + wandb.termlog("Docker image found, attempting to start") + ctx.invoke(docker, docker_run_args=[image], cmd=cmd) + + return commit, json_config, patch_content, repo, metadata + + +@cli.command("online") +@display_error +def online(): + """Undo `wandb offline`.""" + system_settings = wandb_setup.singleton().settings.read_system_settings() + system_settings.clear("mode") + system_settings.save() + + click.echo( + "W&B online. Running your script from this directory will now sync to the cloud." + ) + + +@cli.command("offline") +@display_error +def offline(): + """Save data logged to W&B locally without uploading it to the cloud. + + Use `wandb online` or `wandb sync` to upload offline runs. + """ + system_settings = wandb_setup.singleton().settings.read_system_settings() + system_settings.set("mode", "offline") + system_settings.save() + + click.echo( + "W&B offline. Running your script from this directory will only write" + + " metadata locally. Use `wandb disabled` to completely turn off W&B." + ) + + +@cli.command("on", hidden=True) +@click.pass_context +@display_error +def on(ctx): + ctx.invoke(online) + + +@cli.command("off", hidden=True) +@click.pass_context +@display_error +def off(ctx): + ctx.invoke(offline) + + +@cli.command("status", help="Show configuration settings") +@click.option( + "--settings/--no-settings", help="Show the current settings", default=True +) +def status(settings): + api = _get_cling_api() + if settings: + click.echo(click.style("Current Settings", bold=True)) + settings = api.settings() + click.echo( + json.dumps(settings, sort_keys=True, indent=2, separators=(",", ": ")) + ) + + +@cli.command("disabled", help="Disable W&B.") +@click.option( + "--service", + is_flag=True, + show_default=True, + default=True, + help="Disable W&B service", +) +def disabled(service): + system_settings = wandb_setup.singleton().settings.read_system_settings() + system_settings.set("mode", "disabled") + system_settings.save() + + click.echo("W&B disabled.") + + +@cli.command("enabled", help="Enable W&B.") +@click.option( + "--service", + is_flag=True, + show_default=True, + default=True, + help="Enable W&B service", +) +def enabled(service): + system_settings = wandb_setup.singleton().settings.read_system_settings() + system_settings.set("mode", "online") + system_settings.save() + + click.echo("W&B enabled.") + + +@cli.command( + context_settings=CONTEXT, + help="""Checks and verifies local instance of W&B. W&B checks for: + + Checks that the host is not `api.wandb.ai` (host check). + + Verifies if the user is logged in correctly using the provided API key (login check). + + Checks that requests are made over HTTPS (secure requests). + + Validates the CORS (Cross-Origin Resource Sharing) configuration of the + object store (CORS configuration). + + Logs metrics, saves, and downloads files to check if runs are correctly + recorded and accessible (run check). + + Saves and downloads artifacts to verify that the artifact storage and + retrieval system is working as expected (artifact check). + + Tests the GraphQL endpoint by uploading a file to ensure it can handle + signed URL uploads (GraphQL PUT check). + + Checks the ability to send large payloads through the proxy (large payload check). + + Verifies that the installed version of the W&B package is up-to-date and + compatible with the server (W&B version check). + + Creates and executes a sweep to ensure that sweep functionality is + working correctly (sweeps check). +""", +) +@click.option("--host", default=None, help="Test a specific instance of W&B") +def verify(host): + # TODO: (kdg) Build this all into a WandbVerify object, and clean this up. + os.environ["WANDB_SILENT"] = "true" + os.environ["WANDB_PROJECT"] = "verify" + api = _get_cling_api() + reinit = False + if host is None: + host = api.settings("base_url") + wandb.termlog(f"Default host selected: {host}") + # if the given host does not match the default host, re-run init + elif host != api.settings("base_url"): + reinit = True + + tmp_dir = tempfile.mkdtemp() + wandb.termlog( + "Find detailed logs for this test at: {}".format(os.path.join(tmp_dir, "wandb")) + ) + os.chdir(tmp_dir) + os.environ["WANDB_BASE_URL"] = host + wandb.login(host=host) + if reinit: + api = _get_cling_api(reset=True) + if not wandb_verify.check_host(host): + sys.exit(1) + if not wandb_verify.check_logged_in(api, host): + sys.exit(1) + url_success, url = wandb_verify.check_graphql_put(api, host) + large_post_success = wandb_verify.check_large_post() + wandb_verify.check_secure_requests( + api.settings("base_url"), + "Checking requests to base url", + "Connections are not made over https. SSL required for secure communications.", + ) + if url: + wandb_verify.check_secure_requests( + url, + "Checking requests made over signed URLs", + "Signed URL requests not made over https. SSL is required for secure communications.", + ) + wandb_verify.check_cors_configuration(url, host) + wandb_verify.check_wandb_version(api) + check_run_success = wandb_verify.check_run(api) + check_artifacts_success = wandb_verify.check_artifacts() + check_sweeps_success = wandb_verify.check_sweeps(api) + if not ( + check_artifacts_success + and check_run_success + and large_post_success + and url_success + and check_sweeps_success + ): + sys.exit(1) + + +cli.add_command(beta) diff --git a/.venv/lib/python3.13/site-packages/wandb/docker/__init__.py b/.venv/lib/python3.13/site-packages/wandb/docker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..761748bb8bece16f4b95a25995f6d05ef710748d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/docker/__init__.py @@ -0,0 +1,290 @@ +import json +import logging +import os +import shutil +import subprocess +from typing import Any, Dict, List, Optional, Tuple, Union + +from wandb.docker import names +from wandb.errors import Error + + +class DockerError(Error): + """Raised when attempting to execute a docker command.""" + + def __init__( + self, + command_launched: List[str], + return_code: int, + stdout: Optional[bytes] = None, + stderr: Optional[bytes] = None, + ) -> None: + command_launched_str = " ".join(command_launched) + error_msg = ( + f"The docker command executed was `{command_launched_str}`.\n" + f"It returned with code {return_code}\n" + ) + if stdout is not None: + error_msg += f"The content of stdout is '{stdout.decode()}'\n" + else: + error_msg += ( + "The content of stdout can be found above the " + "stacktrace (it wasn't captured).\n" + ) + if stderr is not None: + error_msg += f"The content of stderr is '{stderr.decode()}'\n" + else: + error_msg += ( + "The content of stderr can be found above the " + "stacktrace (it wasn't captured)." + ) + super().__init__(error_msg) + + +entrypoint = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "wandb-entrypoint.sh" +) +log = logging.getLogger(__name__) + + +def shell(cmd: List[str]) -> Optional[str]: + """Simple wrapper for calling docker,. + + returning None on error and the output on success + """ + try: + return ( + subprocess.check_output(["docker"] + cmd, stderr=subprocess.STDOUT) + .decode("utf8") + .strip() + ) + except subprocess.CalledProcessError as e: + print(e) # noqa: T201 + return None + + +_buildx_installed = None + + +def is_buildx_installed() -> bool: + """Return `True` if docker buildx is installed and working.""" + global _buildx_installed + if _buildx_installed is not None: + return _buildx_installed # type: ignore + if not shutil.which("docker"): + _buildx_installed = False + else: + help_output = shell(["buildx", "--help"]) + _buildx_installed = help_output is not None and "buildx" in help_output + return _buildx_installed + + +def is_docker_installed() -> bool: + """Return `True` if docker is installed and working, else `False`.""" + try: + # Run the docker --version command + result = subprocess.run( + ["docker", "--version"], + capture_output=True, + ) + if result.returncode == 0: + return True + else: + return False + except FileNotFoundError: + # If docker command is not found + return False + + +def build( + tags: List[str], file: str, context_path: str, platform: Optional[str] = None +) -> str: + use_buildx = is_buildx_installed() + command = ["buildx", "build"] if use_buildx else ["build"] + command += ["--load"] if should_add_load_argument(platform) and use_buildx else [] + if platform: + command += ["--platform", platform] + build_tags = [] + for tag in tags: + build_tags += ["-t", tag] + args = ["docker"] + command + build_tags + ["-f", file, context_path] + stdout = run_command_live_output( + args, + ) + return stdout + + +def should_add_load_argument(platform: Optional[str]) -> bool: + # the load option does not work when multiple platforms are specified: + # https://github.com/docker/buildx/issues/59 + if platform is None or (platform and "," not in platform): + return True + return False + + +def run_command_live_output(args: List[Any]) -> str: + with subprocess.Popen( + args, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + ) as process: + stdout = "" + while True: + chunk = os.read(process.stdout.fileno(), 4096) # type: ignore + if not chunk: + break + index = chunk.find(b"\r") + if index != -1: + print(chunk.decode(), end="") # noqa: T201 + else: + stdout += chunk.decode() + print(chunk.decode(), end="\r") # noqa: T201 + + print(stdout) # noqa: T201 + + return_code = process.wait() + if return_code != 0: + raise DockerError(args, return_code, stdout.encode()) + + return stdout + + +def run( + args: List[Any], + capture_stdout: bool = True, + capture_stderr: bool = True, + input: Optional[bytes] = None, + return_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Union[str, Tuple[str, str]]: + args = [str(x) for x in args] + subprocess_env = dict(os.environ) + subprocess_env.update(env or {}) + if args[1] == "buildx": + subprocess_env["DOCKER_CLI_EXPERIMENTAL"] = "enabled" + stdout_dest: Optional[int] = subprocess.PIPE if capture_stdout else None + stderr_dest: Optional[int] = subprocess.PIPE if capture_stderr else None + + completed_process = subprocess.run( + args, input=input, stdout=stdout_dest, stderr=stderr_dest, env=subprocess_env + ) + if completed_process.returncode != 0: + raise DockerError( + args, + completed_process.returncode, + completed_process.stdout, + completed_process.stderr, + ) + + if return_stderr: + return ( + _post_process_stream(completed_process.stdout), + _post_process_stream(completed_process.stderr), + ) + else: + return _post_process_stream(completed_process.stdout) + + +def _post_process_stream(stream: Optional[bytes]) -> str: + if stream is None: + return "" + decoded_stream = stream.decode() + if len(decoded_stream) != 0 and decoded_stream[-1] == "\n": + decoded_stream = decoded_stream[:-1] + return decoded_stream + + +def default_image(gpu: bool = False) -> str: + tag = "all" + if not gpu: + tag += "-cpu" + return f"wandb/deepo:{tag}" + + +def parse_repository_tag(repo_name: str) -> Tuple[str, Optional[str]]: + parts = repo_name.rsplit("@", 1) + if len(parts) == 2: + return parts[0], parts[1] + parts = repo_name.rsplit(":", 1) + if len(parts) == 2 and "/" not in parts[1]: + return parts[0], parts[1] + return repo_name, None + + +def parse(image_name: str) -> Tuple[str, str, str]: + repository, tag = parse_repository_tag(image_name) + registry, repo_name = names.resolve_repository_name(repository) + if registry == "docker.io": + registry = "index.docker.io" + return registry, repo_name, (tag or "latest") + + +def image_id_from_registry(image_name: str) -> Optional[str]: + """Query the image manifest to get its full ID including the digest. + + Args: + image_name: The image name, such as "wandb/local". + + Returns: + The image name followed by its digest, like "wandb/local@sha256:...". + """ + # https://docs.docker.com/reference/cli/docker/buildx/imagetools/inspect + inspect_cmd = ["buildx", "imagetools", "inspect", image_name] + format_args = ["--format", r"{{.Name}}@{{.Manifest.Digest}}"] + return shell([*inspect_cmd, *format_args]) + + +def image_id(image_name: str) -> Optional[str]: + """Retrieve the image id from the local docker daemon or remote registry.""" + if "@sha256:" in image_name: + return image_name + else: + digests = shell(["inspect", image_name, "--format", "{{json .RepoDigests}}"]) + + if digests is None: + return image_id_from_registry(image_name) + + try: + return json.loads(digests)[0] + except (ValueError, IndexError): + return image_id_from_registry(image_name) + + +def get_image_uid(image_name: str) -> int: + """Retrieve the image default uid through brute force.""" + image_uid = shell(["run", image_name, "id", "-u"]) + return int(image_uid) if image_uid else -1 + + +def push(image: str, tag: str) -> Optional[str]: + """Push an image to a remote registry.""" + return shell(["push", f"{image}:{tag}"]) + + +def login(username: str, password: str, registry: str) -> Optional[str]: + """Login to a registry.""" + return shell(["login", "--username", username, "--password", password, registry]) + + +def tag(image_name: str, tag: str) -> Optional[str]: + """Tag an image.""" + return shell(["tag", image_name, tag]) + + +__all__ = [ + "shell", + "build", + "run", + "image_id", + "image_id_from_registry", + "is_docker_installed", + "parse", + "parse_repository_tag", + "default_image", + "get_image_uid", + "push", + "login", + "tag", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/docker/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/docker/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..163b582e9bd779152d8b25a5ef45a3a87d389ec7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/docker/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/docker/__pycache__/names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/docker/__pycache__/names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3ed1357fc4c18693c3440a4bc714018bef2c2f2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/docker/__pycache__/names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/docker/names.py b/.venv/lib/python3.13/site-packages/wandb/docker/names.py new file mode 100644 index 0000000000000000000000000000000000000000..384d39a9e21329bcb71abfbcacdcfa9b3c78bbd2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/docker/names.py @@ -0,0 +1,40 @@ +from __future__ import annotations + + +class InvalidRepositoryError(Exception): + """The given string is not a valid repository name.""" + + +def resolve_repository_name(repo_name: str) -> tuple[str, str]: + if "://" in repo_name: + raise InvalidRepositoryError( + f"Repository name cannot contain a scheme ({repo_name})" + ) + + index_name, remote_name = split_repo_name(repo_name) + if index_name[0] == "-" or index_name[-1] == "-": + raise InvalidRepositoryError( + f"Invalid index name ({index_name}). Cannot begin or end with a hyphen." + ) + return resolve_index_name(index_name), remote_name + + +def resolve_index_name(index_name: str) -> str: + index_name = convert_to_hostname(index_name) + if index_name == "index.docker.io": + index_name = "docker.io" + return index_name + + +def split_repo_name(repo_name: str) -> tuple[str, str]: + parts = repo_name.split("/", 1) + if len(parts) == 1 or ( + "." not in parts[0] and ":" not in parts[0] and parts[0] != "localhost" + ): + # This is a docker index repo (ex: username/foobar or ubuntu) + return "docker.io", repo_name + return parts[0], parts[1] + + +def convert_to_hostname(url: str) -> str: + return url.replace("http://", "").replace("https://", "").split("/", 1)[0] diff --git a/.venv/lib/python3.13/site-packages/wandb/docker/wandb-entrypoint.sh b/.venv/lib/python3.13/site-packages/wandb/docker/wandb-entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..f2d4af61de2eb53973f8ee5eb6e619521d66a63a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/docker/wandb-entrypoint.sh @@ -0,0 +1,33 @@ +#!/bin/sh +set -e + +wandb="\x1b[34m\x1b[1mwandb\x1b[0m" +/bin/echo -e "${wandb}: Checking image for required packages." + +if ! [ -x "$(command -v python)" ]; then + /bin/echo -e "${wandb}: python not installed, can't use wandb with this image." + exit 1 +fi + +if ! [ -x "$(command -v wandb)" ]; then + /bin/echo -e "${wandb}: wandb not installed, installing." + pip install wandb --upgrade +else + ver=$(wandb --version) + /bin/echo -e "${wandb}: Found $ver" +fi + +if [ "$WANDB_ENSURE_JUPYTER" = "1" ]; then + if ! [ -x "$(command -v jupyter-lab)" ]; then + /bin/echo -e "${wandb}: jupyter not installed, installing." + pip install jupyterlab + /bin/echo -e "${wandb}: starting jupyter, you can access it at: http://127.0.0.1:8888" + fi +fi + +if ! [ -z "$WANDB_COMMAND" ]; then + /bin/echo $WANDB_COMMAND >> ~/.bash_history + /bin/echo -e "${wandb}: Command added to history, press up arrow to access it." + /bin/echo -e "${wandb}: $WANDB_COMMAND" +fi +exec "$@" \ No newline at end of file diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__init__.py b/.venv/lib/python3.13/site-packages/wandb/filesync/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd8116ecc1af8b1eb2986b5aca139ee0f6d955e5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/dir_watcher.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/dir_watcher.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d27d7648d83a6b12032a809d3d8a0ef36cdb99ab Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/dir_watcher.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/stats.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/stats.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0968d6f5f2fb60681300e73ff8ffde128c095e70 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/stats.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_checksum.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_checksum.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b84abdb91e2f7a9a6144cbeaaa7a879c8a1f26e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_checksum.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_prepare.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_prepare.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..383f5232a0f7dd2a605d577f50eed0b199730e66 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_prepare.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_upload.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_upload.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7565eca3e5e50a95b0066c80c5b178b9f74960a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/step_upload.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/upload_job.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/upload_job.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a758fe5a434fb88cd1b56496c358a995f8573882 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/filesync/__pycache__/upload_job.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/dir_watcher.py b/.venv/lib/python3.13/site-packages/wandb/filesync/dir_watcher.py new file mode 100644 index 0000000000000000000000000000000000000000..ce78e51b987a21e2dd61345f7e4be39ed9f6a083 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/filesync/dir_watcher.py @@ -0,0 +1,404 @@ +import abc +import fnmatch +import glob +import logging +import os +import queue +import time +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, MutableSet, Optional + +from wandb import util +from wandb.sdk.lib.filesystem import GlobStr +from wandb.sdk.lib.paths import LogicalPath + +if TYPE_CHECKING: + import wandb.vendor.watchdog_0_9_0.observers.api as wd_api + import wandb.vendor.watchdog_0_9_0.observers.polling as wd_polling + import wandb.vendor.watchdog_0_9_0.watchdog.events as wd_events + from wandb.sdk.internal.file_pusher import FilePusher + from wandb.sdk.internal.settings_static import SettingsStatic + from wandb.sdk.lib.filesystem import PolicyName +else: + wd_polling = util.vendor_import("wandb_watchdog.observers.polling") + wd_events = util.vendor_import("wandb_watchdog.events") + +PathStr = str # TODO(spencerpearson): would be nice to use Path here + + +logger = logging.getLogger(__name__) + + +class FileEventHandler(abc.ABC): + def __init__( + self, + file_path: PathStr, + save_name: LogicalPath, + file_pusher: "FilePusher", + *args: Any, + **kwargs: Any, + ) -> None: + self.file_path = file_path + # Convert windows paths to unix paths + self.save_name = LogicalPath(save_name) + self._file_pusher = file_pusher + self._last_sync: Optional[float] = None + + @property + @abc.abstractmethod + def policy(self) -> "PolicyName": + raise NotImplementedError + + @abc.abstractmethod + def on_modified(self, force: bool = False) -> None: + raise NotImplementedError + + @abc.abstractmethod + def finish(self) -> None: + raise NotImplementedError + + def on_renamed(self, new_path: PathStr, new_name: LogicalPath) -> None: + self.file_path = new_path + self.save_name = new_name + self.on_modified() + + +class PolicyNow(FileEventHandler): + """This policy only uploads files now.""" + + def on_modified(self, force: bool = False) -> None: + # only upload if we've never uploaded or when .save is called + if self._last_sync is None or force: + self._file_pusher.file_changed(self.save_name, self.file_path) + self._last_sync = os.path.getmtime(self.file_path) + + def finish(self) -> None: + pass + + @property + def policy(self) -> "PolicyName": + return "now" + + +class PolicyEnd(FileEventHandler): + """This policy only updates at the end of the run.""" + + def on_modified(self, force: bool = False) -> None: + pass + + # TODO: make sure we call this + def finish(self) -> None: + # We use copy=False to avoid possibly expensive copies, and because + # user files shouldn't still be changing at the end of the run. + self._last_sync = os.path.getmtime(self.file_path) + self._file_pusher.file_changed(self.save_name, self.file_path, copy=False) + + @property + def policy(self) -> "PolicyName": + return "end" + + +class PolicyLive(FileEventHandler): + """Event handler that uploads respecting throttling. + + Uploads files every RATE_LIMIT_SECONDS, which changes as the size increases to deal + with throttling. + """ + + RATE_LIMIT_SECONDS = 15 + unit_dict = dict(util.POW_10_BYTES) + # Wait to upload until size has increased 20% from last upload + RATE_LIMIT_SIZE_INCREASE = 1.2 + + def __init__( + self, + file_path: PathStr, + save_name: LogicalPath, + file_pusher: "FilePusher", + settings: Optional["SettingsStatic"] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(file_path, save_name, file_pusher, *args, **kwargs) + self._last_uploaded_time: Optional[float] = None + self._last_uploaded_size: int = 0 + if settings is not None: + if settings.x_live_policy_rate_limit is not None: + self.RATE_LIMIT_SECONDS = settings.x_live_policy_rate_limit + self._min_wait_time: Optional[float] = settings.x_live_policy_wait_time + else: + self._min_wait_time = None + + @property + def current_size(self) -> int: + return os.path.getsize(self.file_path) + + @classmethod + def min_wait_for_size(cls, size: int) -> float: + if size < 10 * cls.unit_dict["MB"]: + return 60 + elif size < 100 * cls.unit_dict["MB"]: + return 5 * 60 + elif size < cls.unit_dict["GB"]: + return 10 * 60 + else: + return 20 * 60 + + def should_update(self) -> bool: + if self._last_uploaded_time is not None: + # Check rate limit by time elapsed + time_elapsed = time.time() - self._last_uploaded_time + # if more than 15 seconds has passed potentially upload it + if time_elapsed < self.RATE_LIMIT_SECONDS: + return False + + # Check rate limit by size increase + if float(self._last_uploaded_size) > 0: + size_increase = self.current_size / float(self._last_uploaded_size) + if size_increase < self.RATE_LIMIT_SIZE_INCREASE: + return False + return time_elapsed > ( + self._min_wait_time or self.min_wait_for_size(self.current_size) + ) + + # if the file has never been uploaded, we'll upload it + return True + + def on_modified(self, force: bool = False) -> None: + if self.current_size == 0: + return + if self._last_sync == os.path.getmtime(self.file_path): + return + if force or self.should_update(): + self.save_file() + + def save_file(self) -> None: + self._last_sync = os.path.getmtime(self.file_path) + self._last_uploaded_time = time.time() + self._last_uploaded_size = self.current_size + self._file_pusher.file_changed(self.save_name, self.file_path) + + def finish(self) -> None: + self.on_modified(force=True) + + @property + def policy(self) -> "PolicyName": + return "live" + + +class DirWatcher: + def __init__( + self, + settings: "SettingsStatic", + file_pusher: "FilePusher", + file_dir: Optional[PathStr] = None, + ) -> None: + self._file_count = 0 + self._dir = file_dir or settings.files_dir + self._settings = settings + self._savename_file_policies: MutableMapping[LogicalPath, PolicyName] = {} + self._user_file_policies: Mapping[PolicyName, MutableSet[GlobStr]] = { + "end": set(), + "live": set(), + "now": set(), + } + self._file_pusher = file_pusher + self._file_event_handlers: MutableMapping[LogicalPath, FileEventHandler] = {} + self._file_observer = wd_polling.PollingObserver() + self._file_observer.schedule( + self._per_file_event_handler(), self._dir, recursive=True + ) + self._file_observer.start() + logger.info("watching files in: %s", settings.files_dir) + + @property + def emitter(self) -> Optional["wd_api.EventEmitter"]: + try: + return next(iter(self._file_observer.emitters)) + except StopIteration: + return None + + def update_policy(self, path: GlobStr, policy: "PolicyName") -> None: + # When we're dealing with one of our own media files, there's no need + # to store the policy in memory. _get_file_event_handler will always + # return PolicyNow. Using the path makes syncing historic runs much + # faster if the name happens to include glob escapable characters. In + # the future we may add a flag to "files" records that indicates it's + # policy is not dynamic and doesn't need to be stored / checked. + save_name = LogicalPath( + os.path.relpath(os.path.join(self._dir, path), self._dir) + ) + if save_name.startswith("media/"): + pass + elif path == glob.escape(path): + self._savename_file_policies[save_name] = policy + else: + self._user_file_policies[policy].add(path) + + for src_path in glob.glob(os.path.join(self._dir, path)): + save_name = LogicalPath(os.path.relpath(src_path, self._dir)) + feh = self._get_file_event_handler(src_path, save_name) + # handle the case where the policy changed + if feh.policy != policy: + try: + del self._file_event_handlers[save_name] + except KeyError: + # TODO: probably should do locking, but this handles moved files for now + pass + feh = self._get_file_event_handler(src_path, save_name) + feh.on_modified(force=True) + + def _per_file_event_handler(self) -> "wd_events.FileSystemEventHandler": + """Create a Watchdog file event handler that does different things for every file.""" + file_event_handler = wd_events.PatternMatchingEventHandler() + file_event_handler.on_created = self._on_file_created + file_event_handler.on_modified = self._on_file_modified + file_event_handler.on_moved = self._on_file_moved + file_event_handler._patterns = [os.path.join(self._dir, os.path.normpath("*"))] + # Ignore hidden files/folders + # TODO: what other files should we skip? + file_event_handler._ignore_patterns = [ + "*.tmp", + "*.wandb", + "wandb-summary.json", + os.path.join(self._dir, ".*"), + os.path.join(self._dir, "*/.*"), + ] + for glb in self._settings.ignore_globs: + file_event_handler._ignore_patterns.append(os.path.join(self._dir, glb)) + + return file_event_handler + + def _on_file_created(self, event: "wd_events.FileCreatedEvent") -> None: + logger.info("file/dir created: %s", event.src_path) + if os.path.isdir(event.src_path): + return None + self._file_count += 1 + # We do the directory scan less often as it grows + if self._file_count % 100 == 0: + emitter = self.emitter + if emitter: + emitter._timeout = int(self._file_count / 100) + 1 + save_name = LogicalPath(os.path.relpath(event.src_path, self._dir)) + self._get_file_event_handler(event.src_path, save_name).on_modified() + + # TODO(spencerpearson): this pattern repeats so many times we should have a method/function for it + # def _save_name(self, path: PathStr) -> LogicalPath: + # return LogicalPath(os.path.relpath(path, self._dir)) + + def _on_file_modified(self, event: "wd_events.FileModifiedEvent") -> None: + logger.info(f"file/dir modified: {event.src_path}") + if os.path.isdir(event.src_path): + return None + save_name = LogicalPath(os.path.relpath(event.src_path, self._dir)) + self._get_file_event_handler(event.src_path, save_name).on_modified() + + def _on_file_moved(self, event: "wd_events.FileMovedEvent") -> None: + # TODO: test me... + logger.info(f"file/dir moved: {event.src_path} -> {event.dest_path}") + if os.path.isdir(event.dest_path): + return None + old_save_name = LogicalPath(os.path.relpath(event.src_path, self._dir)) + new_save_name = LogicalPath(os.path.relpath(event.dest_path, self._dir)) + + # We have to move the existing file handler to the new name + handler = self._get_file_event_handler(event.src_path, old_save_name) + self._file_event_handlers[new_save_name] = handler + del self._file_event_handlers[old_save_name] + + handler.on_renamed(event.dest_path, new_save_name) + + def _get_file_event_handler( + self, file_path: PathStr, save_name: LogicalPath + ) -> FileEventHandler: + """Get or create an event handler for a particular file. + + file_path: the file's actual path + save_name: its path relative to the run directory (aka the watch directory) + """ + # Always return PolicyNow for any of our media files. + if save_name.startswith("media/"): + return PolicyNow(file_path, save_name, self._file_pusher, self._settings) + if save_name not in self._file_event_handlers: + # TODO: we can use PolicyIgnore if there are files we never want to sync + if "tfevents" in save_name or "graph.pbtxt" in save_name: + self._file_event_handlers[save_name] = PolicyLive( + file_path, save_name, self._file_pusher, self._settings + ) + elif save_name in self._savename_file_policies: + policy_name = self._savename_file_policies[save_name] + make_handler = ( + PolicyLive + if policy_name == "live" + else PolicyNow + if policy_name == "now" + else PolicyEnd + ) + self._file_event_handlers[save_name] = make_handler( + file_path, save_name, self._file_pusher, self._settings + ) + else: + make_handler = PolicyEnd + for policy, globs in self._user_file_policies.items(): + if policy == "end": + continue + # Convert set to list to avoid RuntimeError's + # TODO: we may need to add locks + for g in list(globs): + paths = glob.glob(os.path.join(self._dir, g)) + if any(save_name in p for p in paths): + if policy == "live": + make_handler = PolicyLive + elif policy == "now": + make_handler = PolicyNow + self._file_event_handlers[save_name] = make_handler( + file_path, save_name, self._file_pusher, self._settings + ) + return self._file_event_handlers[save_name] + + def finish(self) -> None: + logger.info("shutting down directory watcher") + try: + # avoid hanging if we crashed before the observer was started + if self._file_observer.is_alive(): + # rather unfortunately we need to manually do a final scan of the dir + # with `queue_events`, then iterate through all events before stopping + # the observer to catch all files written. First we need to prevent the + # existing thread from consuming our final events, then we process them + self._file_observer._timeout = 0 + self._file_observer._stopped_event.set() + self._file_observer.join() + self.emitter.queue_events(0) # type: ignore[union-attr] + while True: + try: + self._file_observer.dispatch_events( + self._file_observer.event_queue, 0 + ) + except queue.Empty: + break + # Calling stop unschedules any inflight events so we handled them above + self._file_observer.stop() + # TODO: py2 TypeError: PyCObject_AsVoidPtr called with null pointer + except TypeError: + pass + # TODO: py3 SystemError: returned an error + except SystemError: + pass + + # Ensure we've at least noticed every file in the run directory. Sometimes + # we miss things because asynchronously watching filesystems isn't reliable. + logger.info("scan: %s", self._dir) + + for dirpath, _, filenames in os.walk(self._dir): + for fname in filenames: + file_path = os.path.join(dirpath, fname) + save_name = LogicalPath(os.path.relpath(file_path, self._dir)) + ignored = False + for glb in self._settings.ignore_globs: + if len(fnmatch.filter([save_name], glb)) > 0: + ignored = True + logger.info("ignored: %s matching glob %s", save_name, glb) + break + if ignored: + continue + logger.info("scan save: %s %s", file_path, save_name) + self._get_file_event_handler(file_path, save_name).finish() diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/stats.py b/.venv/lib/python3.13/site-packages/wandb/filesync/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..3351816f50045ec3703e543a4089fe79a41e2277 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/filesync/stats.py @@ -0,0 +1,100 @@ +import threading +from typing import MutableMapping, NamedTuple + +from wandb.sdk.lib import filenames + + +class FileStats(NamedTuple): + deduped: bool + total: int + uploaded: int + failed: bool + artifact_file: bool + + +class Summary(NamedTuple): + uploaded_bytes: int + total_bytes: int + deduped_bytes: int + + +class FileCountsByCategory(NamedTuple): + artifact: int + wandb: int + media: int + other: int + + +class Stats: + def __init__(self) -> None: + self._stats: MutableMapping[str, FileStats] = {} + self._lock = threading.Lock() + + def init_file( + self, save_name: str, size: int, is_artifact_file: bool = False + ) -> None: + with self._lock: + self._stats[save_name] = FileStats( + deduped=False, + total=size, + uploaded=0, + failed=False, + artifact_file=is_artifact_file, + ) + + def set_file_deduped(self, save_name: str) -> None: + with self._lock: + orig = self._stats[save_name] + self._stats[save_name] = orig._replace( + deduped=True, + uploaded=orig.total, + ) + + def update_uploaded_file(self, save_name: str, total_uploaded: int) -> None: + with self._lock: + self._stats[save_name] = self._stats[save_name]._replace( + uploaded=total_uploaded, + ) + + def update_failed_file(self, save_name: str) -> None: + with self._lock: + self._stats[save_name] = self._stats[save_name]._replace( + uploaded=0, + failed=True, + ) + + def summary(self) -> Summary: + # Need to use list to ensure we get a copy, since other threads may + # modify this while we iterate + with self._lock: + stats = list(self._stats.values()) + return Summary( + uploaded_bytes=sum(f.uploaded for f in stats), + total_bytes=sum(f.total for f in stats), + deduped_bytes=sum(f.total for f in stats if f.deduped), + ) + + def file_counts_by_category(self) -> FileCountsByCategory: + artifact_files = 0 + wandb_files = 0 + media_files = 0 + other_files = 0 + # Need to use list to ensure we get a copy, since other threads may + # modify this while we iterate + with self._lock: + file_stats = list(self._stats.items()) + for save_name, stats in file_stats: + if stats.artifact_file: + artifact_files += 1 + elif filenames.is_wandb_file(save_name): + wandb_files += 1 + elif save_name.startswith("media"): + media_files += 1 + else: + other_files += 1 + return FileCountsByCategory( + artifact=artifact_files, + wandb=wandb_files, + media=media_files, + other=other_files, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/step_checksum.py b/.venv/lib/python3.13/site-packages/wandb/filesync/step_checksum.py new file mode 100644 index 0000000000000000000000000000000000000000..c0acd96e80e8ad7b4fb8027da718a3b28780f86e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/filesync/step_checksum.py @@ -0,0 +1,142 @@ +"""Batching file prepare requests to our API.""" + +import concurrent.futures +import functools +import os +import queue +import shutil +import threading +from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast + +from wandb.filesync import step_upload +from wandb.sdk.lib import filesystem, runid +from wandb.sdk.lib.paths import LogicalPath + +if TYPE_CHECKING: + import tempfile + + from wandb.filesync import stats + from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest + from wandb.sdk.artifacts.artifact_saver import SaveFn + from wandb.sdk.internal import internal_api + + +class RequestUpload(NamedTuple): + path: str + save_name: LogicalPath + copy: bool + + +class RequestStoreManifestFiles(NamedTuple): + manifest: "ArtifactManifest" + artifact_id: str + save_fn: "SaveFn" + + +class RequestCommitArtifact(NamedTuple): + artifact_id: str + finalize: bool + before_commit: step_upload.PreCommitFn + result_future: "concurrent.futures.Future[None]" + + +class RequestFinish(NamedTuple): + callback: Optional[step_upload.OnRequestFinishFn] + + +Event = Union[ + RequestUpload, RequestStoreManifestFiles, RequestCommitArtifact, RequestFinish +] + + +class StepChecksum: + def __init__( + self, + api: "internal_api.Api", + tempdir: "tempfile.TemporaryDirectory", + request_queue: "queue.Queue[Event]", + output_queue: "queue.Queue[step_upload.Event]", + stats: "stats.Stats", + ) -> None: + self._api = api + self._tempdir = tempdir + self._request_queue = request_queue + self._output_queue = output_queue + self._stats = stats + + self._thread = threading.Thread(target=self._thread_body) + self._thread.daemon = True + + def _thread_body(self) -> None: + while True: + req = self._request_queue.get() + if isinstance(req, RequestUpload): + path = req.path + if req.copy: + path = os.path.join( + self._tempdir.name, + f"{runid.generate_id()}-{req.save_name}", + ) + filesystem.mkdir_exists_ok(os.path.dirname(path)) + try: + # certain linux distros throw an exception when copying + # large files: https://bugs.python.org/issue43743 + shutil.copy2(req.path, path) + except OSError: + shutil._USE_CP_SENDFILE = False # type: ignore[attr-defined] + shutil.copy2(req.path, path) + self._stats.init_file(req.save_name, os.path.getsize(path)) + self._output_queue.put( + step_upload.RequestUpload( + path, + req.save_name, + None, + None, + req.copy, + None, + None, + ) + ) + elif isinstance(req, RequestStoreManifestFiles): + for entry in req.manifest.entries.values(): + if entry.local_path: + self._stats.init_file( + entry.local_path, + cast(int, entry.size), + is_artifact_file=True, + ) + self._output_queue.put( + step_upload.RequestUpload( + entry.local_path, + entry.path, + req.artifact_id, + entry.digest, + False, + functools.partial(req.save_fn, entry), + entry.digest, + ) + ) + elif isinstance(req, RequestCommitArtifact): + self._output_queue.put( + step_upload.RequestCommitArtifact( + req.artifact_id, + req.finalize, + req.before_commit, + req.result_future, + ) + ) + elif isinstance(req, RequestFinish): + break + else: + raise TypeError + + self._output_queue.put(step_upload.RequestFinish(req.callback)) + + def start(self) -> None: + self._thread.start() + + def is_alive(self) -> bool: + return self._thread.is_alive() + + def finish(self) -> None: + self._request_queue.put(RequestFinish(None)) diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/step_prepare.py b/.venv/lib/python3.13/site-packages/wandb/filesync/step_prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..95a4a21b4593a62dd21d080e847c212eef7e2ca7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/filesync/step_prepare.py @@ -0,0 +1,179 @@ +"""Batching file prepare requests to our API.""" + +import queue +import threading +import time +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) + +if TYPE_CHECKING: + from wandb.sdk.internal.internal_api import ( + Api, + CreateArtifactFileSpecInput, + CreateArtifactFilesResponseFile, + ) + + +# Request for a file to be prepared. +class RequestPrepare(NamedTuple): + file_spec: "CreateArtifactFileSpecInput" + response_channel: "queue.Queue[ResponsePrepare]" + + +class RequestFinish(NamedTuple): + pass + + +class ResponsePrepare(NamedTuple): + birth_artifact_id: str + upload_url: Optional[str] + upload_headers: Sequence[str] + upload_id: Optional[str] + storage_path: Optional[str] + multipart_upload_urls: Optional[Dict[int, str]] + + +Request = Union[RequestPrepare, RequestFinish] + + +def _clamp(x: float, low: float, high: float) -> float: + return max(low, min(x, high)) + + +def gather_batch( + request_queue: "queue.Queue[Request]", + batch_time: float, + inter_event_time: float, + max_batch_size: int, + clock: Callable[[], float] = time.monotonic, +) -> Tuple[bool, Sequence[RequestPrepare]]: + batch_start_time = clock() + remaining_time = batch_time + + first_request = request_queue.get() + if isinstance(first_request, RequestFinish): + return True, [] + + batch: List[RequestPrepare] = [first_request] + + while remaining_time > 0 and len(batch) < max_batch_size: + try: + request = request_queue.get( + timeout=_clamp( + x=inter_event_time, + low=1e-12, # 0 = "block forever", so just use something tiny + high=remaining_time, + ), + ) + if isinstance(request, RequestFinish): + return True, batch + + batch.append(request) + remaining_time = batch_time - (clock() - batch_start_time) + + except queue.Empty: + break + + return False, batch + + +def prepare_response(response: "CreateArtifactFilesResponseFile") -> ResponsePrepare: + multipart_resp = response.get("uploadMultipartUrls") + part_list = multipart_resp["uploadUrlParts"] if multipart_resp else [] + multipart_parts = {u["partNumber"]: u["uploadUrl"] for u in part_list} or None + + return ResponsePrepare( + birth_artifact_id=response["artifact"]["id"], + upload_url=response["uploadUrl"], + upload_headers=response["uploadHeaders"], + upload_id=multipart_resp and multipart_resp.get("uploadID"), + storage_path=response.get("storagePath"), + multipart_upload_urls=multipart_parts, + ) + + +class StepPrepare: + """A thread that batches requests to our file prepare API. + + Any number of threads may call prepare() in parallel. The PrepareBatcher thread + will batch requests up and send them all to the backend at once. + """ + + def __init__( + self, + api: "Api", + batch_time: float, + inter_event_time: float, + max_batch_size: int, + request_queue: Optional["queue.Queue[Request]"] = None, + ) -> None: + self._api = api + self._inter_event_time = inter_event_time + self._batch_time = batch_time + self._max_batch_size = max_batch_size + self._request_queue: queue.Queue[Request] = request_queue or queue.Queue() + self._thread = threading.Thread(target=self._thread_body) + self._thread.daemon = True + + def _thread_body(self) -> None: + while True: + finish, batch = gather_batch( + request_queue=self._request_queue, + batch_time=self._batch_time, + inter_event_time=self._inter_event_time, + max_batch_size=self._max_batch_size, + ) + if batch: + batch_response = self._prepare_batch(batch) + # send responses + for prepare_request in batch: + name = prepare_request.file_spec["name"] + response_file = batch_response[name] + response = prepare_response(response_file) + prepare_request.response_channel.put(response) + if finish: + break + + def _prepare_batch( + self, batch: Sequence[RequestPrepare] + ) -> Mapping[str, "CreateArtifactFilesResponseFile"]: + """Execute the prepareFiles API call. + + Args: + batch: List of RequestPrepare objects + Returns: + dict of (save_name: ResponseFile) pairs where ResponseFile is a dict with + an uploadUrl key. The value of the uploadUrl key is None if the file + already exists, or a url string if the file should be uploaded. + """ + return self._api.create_artifact_files([req.file_spec for req in batch]) + + def prepare( + self, file_spec: "CreateArtifactFileSpecInput" + ) -> "queue.Queue[ResponsePrepare]": + response_queue: queue.Queue[ResponsePrepare] = queue.Queue() + self._request_queue.put(RequestPrepare(file_spec, response_queue)) + return response_queue + + def start(self) -> None: + self._thread.start() + + def finish(self) -> None: + self._request_queue.put(RequestFinish()) + + def is_alive(self) -> bool: + return self._thread.is_alive() + + def shutdown(self) -> None: + self.finish() + self._thread.join() diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/step_upload.py b/.venv/lib/python3.13/site-packages/wandb/filesync/step_upload.py new file mode 100644 index 0000000000000000000000000000000000000000..0840293d35aaf53aeaf8d211783f9be957beafe2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/filesync/step_upload.py @@ -0,0 +1,287 @@ +"""Batching file prepare requests to our API.""" + +import concurrent.futures +import logging +import queue +import sys +import threading +from typing import ( + TYPE_CHECKING, + Callable, + MutableMapping, + MutableSequence, + MutableSet, + NamedTuple, + Optional, + Union, +) + +from wandb.errors.term import termerror +from wandb.filesync import upload_job +from wandb.sdk.lib.paths import LogicalPath + +if TYPE_CHECKING: + from typing import TypedDict + + from wandb.filesync import stats + from wandb.sdk.internal import file_stream, internal_api, progress + from wandb.sdk.internal.settings_static import SettingsStatic + + class ArtifactStatus(TypedDict): + finalize: bool + pending_count: int + commit_requested: bool + pre_commit_callbacks: MutableSet["PreCommitFn"] + result_futures: MutableSet["concurrent.futures.Future[None]"] + + +PreCommitFn = Callable[[], None] +OnRequestFinishFn = Callable[[], None] +SaveFn = Callable[["progress.ProgressFn"], bool] + +logger = logging.getLogger(__name__) + + +class RequestUpload(NamedTuple): + path: str + save_name: LogicalPath + artifact_id: Optional[str] + md5: Optional[str] + copied: bool + save_fn: Optional[SaveFn] + digest: Optional[str] + + +class RequestCommitArtifact(NamedTuple): + artifact_id: str + finalize: bool + before_commit: PreCommitFn + result_future: "concurrent.futures.Future[None]" + + +class RequestFinish(NamedTuple): + callback: Optional[OnRequestFinishFn] + + +class EventJobDone(NamedTuple): + job: RequestUpload + exc: Optional[BaseException] + + +Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone] + + +class StepUpload: + def __init__( + self, + api: "internal_api.Api", + stats: "stats.Stats", + event_queue: "queue.Queue[Event]", + max_threads: int, + file_stream: "file_stream.FileStreamApi", + settings: Optional["SettingsStatic"] = None, + ) -> None: + self._api = api + self._stats = stats + self._event_queue = event_queue + self._file_stream = file_stream + + self._thread = threading.Thread(target=self._thread_body) + self._thread.daemon = True + + self._pool = concurrent.futures.ThreadPoolExecutor( + thread_name_prefix="wandb-upload", + max_workers=max_threads, + ) + + # Indexed by files' `save_name`'s, which are their ID's in the Run. + self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {} + self._pending_jobs: MutableSequence[RequestUpload] = [] + + self._artifacts: MutableMapping[str, ArtifactStatus] = {} + + self.silent = bool(settings.silent) if settings else False + + def _thread_body(self) -> None: + event: Optional[Event] + # Wait for event in the queue, and process one by one until a + # finish event is received + finish_callback = None + while True: + event = self._event_queue.get() + if isinstance(event, RequestFinish): + finish_callback = event.callback + break + self._handle_event(event) + + # We've received a finish event. At this point, further Upload requests + # are invalid. + + # After a finish event is received, iterate through the event queue + # one by one and process all remaining events. + while True: + try: + event = self._event_queue.get(True, 0.2) + except queue.Empty: + event = None + if event: + self._handle_event(event) + elif not self._running_jobs: + # Queue was empty and no jobs left. + self._pool.shutdown(wait=False) + if finish_callback: + finish_callback() + break + + def _handle_event(self, event: Event) -> None: + if isinstance(event, EventJobDone): + job = event.job + + if event.exc is not None: + logger.exception( + "Failed to upload file: %s", job.path, exc_info=event.exc + ) + + if job.artifact_id: + if event.exc is None: + self._artifacts[job.artifact_id]["pending_count"] -= 1 + self._maybe_commit_artifact(job.artifact_id) + else: + if not self.silent: + termerror( + "Uploading artifact file failed. Artifact won't be committed." + ) + self._fail_artifact_futures(job.artifact_id, event.exc) + self._running_jobs.pop(job.save_name) + # If we have any pending jobs, start one now + if self._pending_jobs: + event = self._pending_jobs.pop(0) + self._start_upload_job(event) + elif isinstance(event, RequestCommitArtifact): + if event.artifact_id not in self._artifacts: + self._init_artifact(event.artifact_id) + self._artifacts[event.artifact_id]["commit_requested"] = True + self._artifacts[event.artifact_id]["finalize"] = event.finalize + self._artifacts[event.artifact_id]["pre_commit_callbacks"].add( + event.before_commit + ) + self._artifacts[event.artifact_id]["result_futures"].add( + event.result_future + ) + self._maybe_commit_artifact(event.artifact_id) + elif isinstance(event, RequestUpload): + if event.artifact_id is not None: + if event.artifact_id not in self._artifacts: + self._init_artifact(event.artifact_id) + self._artifacts[event.artifact_id]["pending_count"] += 1 + self._start_upload_job(event) + else: + raise TypeError(f"Event has unexpected type: {event!s}") + + def _start_upload_job(self, event: RequestUpload) -> None: + # Operations on a single backend file must be serialized. if + # we're already uploading this file, put the event on the + # end of the queue + if event.save_name in self._running_jobs: + self._pending_jobs.append(event) + return + + self._spawn_upload(event) + + def _spawn_upload(self, event: RequestUpload) -> None: + """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`. + + Context: it's important that, whenever we add an entry to `self._running_jobs`, + we ensure that a corresponding `EventJobDone` message will eventually get handled; + otherwise, the `_running_jobs` entry will never get removed, and the StepUpload + will never shut down. + + The sole purpose of this function is to make sure that the code that adds an entry + to `self._running_jobs` is textually right next to the code that eventually enqueues + the `EventJobDone` message. This should help keep them in sync. + """ + # Adding the entry to `self._running_jobs` MUST happen in the main thread, + # NOT in the job that gets submitted to the thread-pool, to guard against + # this sequence of events: + # - StepUpload receives a RequestUpload + # ...and therefore spawns a thread to do the upload + # - StepUpload receives a RequestFinish + # ...and checks `self._running_jobs` to see if there are any tasks to wait for... + # ...and there are none, because the addition to `self._running_jobs` happens in + # the background thread, which the scheduler hasn't yet run... + # ...so the StepUpload shuts down. Even though we haven't uploaded the file! + # + # This would be very bad! + # So, this line has to happen _outside_ the `pool.submit()`. + self._running_jobs[event.save_name] = event + + def run_and_notify() -> None: + try: + self._do_upload(event) + finally: + self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1])) + + self._pool.submit(run_and_notify) + + def _do_upload(self, event: RequestUpload) -> None: + job = upload_job.UploadJob( + self._stats, + self._api, + self._file_stream, + self.silent, + event.save_name, + event.path, + event.artifact_id, + event.md5, + event.copied, + event.save_fn, + event.digest, + ) + job.run() + + def _init_artifact(self, artifact_id: str) -> None: + self._artifacts[artifact_id] = { + "finalize": False, + "pending_count": 0, + "commit_requested": False, + "pre_commit_callbacks": set(), + "result_futures": set(), + } + + def _maybe_commit_artifact(self, artifact_id: str) -> None: + artifact_status = self._artifacts[artifact_id] + if ( + artifact_status["pending_count"] == 0 + and artifact_status["commit_requested"] + ): + try: + for pre_callback in artifact_status["pre_commit_callbacks"]: + pre_callback() + if artifact_status["finalize"]: + self._api.commit_artifact(artifact_id) + except Exception as exc: + termerror( + f"Committing artifact failed. Artifact {artifact_id} won't be finalized." + ) + termerror(str(exc)) + self._fail_artifact_futures(artifact_id, exc) + else: + self._resolve_artifact_futures(artifact_id) + + def _fail_artifact_futures(self, artifact_id: str, exc: BaseException) -> None: + futures = self._artifacts[artifact_id]["result_futures"] + for result_future in futures: + result_future.set_exception(exc) + futures.clear() + + def _resolve_artifact_futures(self, artifact_id: str) -> None: + futures = self._artifacts[artifact_id]["result_futures"] + for result_future in futures: + result_future.set_result(None) + futures.clear() + + def start(self) -> None: + self._thread.start() + + def is_alive(self) -> bool: + return self._thread.is_alive() diff --git a/.venv/lib/python3.13/site-packages/wandb/filesync/upload_job.py b/.venv/lib/python3.13/site-packages/wandb/filesync/upload_job.py new file mode 100644 index 0000000000000000000000000000000000000000..120c8b44466a3c077596a98998ef82c69ae65f0e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/filesync/upload_job.py @@ -0,0 +1,143 @@ +import logging +import os +from typing import TYPE_CHECKING, Optional + +import wandb +from wandb.analytics import get_sentry +from wandb.sdk.lib.paths import LogicalPath + +if TYPE_CHECKING: + from wandb.filesync import dir_watcher, stats, step_upload + from wandb.sdk.internal import file_stream, internal_api + + +logger = logging.getLogger(__name__) + + +class UploadJob: + def __init__( + self, + stats: "stats.Stats", + api: "internal_api.Api", + file_stream: "file_stream.FileStreamApi", + silent: bool, + save_name: LogicalPath, + path: "dir_watcher.PathStr", + artifact_id: Optional[str], + md5: Optional[str], + copied: bool, + save_fn: Optional["step_upload.SaveFn"], + digest: Optional[str], + ) -> None: + """A file uploader. + + Args: + push_function: function(save_name, actual_path) which actually uploads + the file. + save_name: string logical location of the file relative to the run + directory. + path: actual string path of the file to upload on the filesystem. + """ + self._stats = stats + self._api = api + self._file_stream = file_stream + self.silent = silent + self.save_name = save_name + self.save_path = path + self.artifact_id = artifact_id + self.md5 = md5 + self.copied = copied + self.save_fn = save_fn + self.digest = digest + super().__init__() + + def run(self) -> None: + success = False + try: + self.push() + success = True + finally: + if self.copied and os.path.isfile(self.save_path): + os.remove(self.save_path) + if success: + self._file_stream.push_success(self.artifact_id, self.save_name) # type: ignore + + def push(self) -> None: + if self.save_fn: + # Retry logic must happen in save_fn currently + try: + deduped = self.save_fn( + lambda _, t: self._stats.update_uploaded_file(self.save_path, t) + ) + except Exception as e: + self._stats.update_failed_file(self.save_path) + logger.exception("Failed to upload file: %s", self.save_path) + get_sentry().exception(e) + message = str(e) + # TODO: this is usually XML, but could be JSON + if hasattr(e, "response"): + message = e.response.content + wandb.termerror( + f'Error uploading "{self.save_path}": {type(e).__name__}, {message}' + ) + raise + + if deduped: + logger.info("Skipped uploading %s", self.save_path) + self._stats.set_file_deduped(self.save_path) + else: + logger.info("Uploaded file %s", self.save_path) + return + + if self.md5: + # This is the new artifact manifest upload flow, in which we create the + # database entry for the manifest file before creating it. This is used for + # artifact L0 files. Which now is only artifact_manifest.json + _, response = self._api.create_artifact_manifest( + self.save_name, self.md5, self.artifact_id + ) + upload_url = response["uploadUrl"] + upload_headers = response["uploadHeaders"] + else: + # The classic file upload flow. We get a signed url and upload the file + # then the backend handles the cloud storage metadata callback to create the + # file entry. This flow has aged like a fine wine. + project = self._api.get_project() + _, upload_headers, result = self._api.upload_urls(project, [self.save_name]) + file_info = result[self.save_name] + upload_url = file_info["uploadUrl"] + + if upload_url is None: + logger.info("Skipped uploading %s", self.save_path) + self._stats.set_file_deduped(self.save_name) + else: + extra_headers = self._api._extra_http_headers + for upload_header in upload_headers: + key, val = upload_header.split(":", 1) + extra_headers[key] = val + # Copied from push TODO(artifacts): clean up + # If the upload URL is relative, fill it in with the base URL, + # since its a proxied file store like the on-prem VM. + if upload_url.startswith("/"): + upload_url = f"{self._api.api_url}{upload_url}" + try: + with open(self.save_path, "rb") as f: + self._api.upload_file_retry( + upload_url, + f, + lambda _, t: self.progress(t), + extra_headers=extra_headers, + ) + logger.info("Uploaded file %s", self.save_path) + except Exception as e: + self._stats.update_failed_file(self.save_name) + logger.exception("Failed to upload file: %s", self.save_path) + get_sentry().exception(e) + if not self.silent: + wandb.termerror( + f'Error uploading "{self.save_name}": {type(e).__name__}, {e}' + ) + raise + + def progress(self, total_bytes: int) -> None: + self._stats.update_uploaded_file(self.save_name, total_bytes) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df14a8edf41550c8bdd0872015dabb39a920d043 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/catboost/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/catboost/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd732b7e26cead03e12f0d6ab00f6d3165c2ba4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/catboost/__init__.py @@ -0,0 +1,5 @@ +"""W&B callback for CatBoost.""" + +from .catboost import WandbCallback, log_summary + +__all__ = ["log_summary", "WandbCallback"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/catboost/catboost.py b/.venv/lib/python3.13/site-packages/wandb/integration/catboost/catboost.py new file mode 100644 index 0000000000000000000000000000000000000000..09dc31b84ffaec7826d70536eb59ef56452a9062 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/catboost/catboost.py @@ -0,0 +1,182 @@ +"""catboost init.""" + +from pathlib import Path +from types import SimpleNamespace +from typing import List, Union + +from catboost import CatBoostClassifier, CatBoostRegressor # type: ignore + +import wandb +from wandb.sdk.lib import telemetry as wb_telemetry + + +class WandbCallback: + """`WandbCallback` automatically integrates CatBoost with wandb. + + Args: + - metric_period: (int) if you are passing `metric_period` to your CatBoost model please pass the same value here (default=1). + + Passing `WandbCallback` to CatBoost will: + - log training and validation metrics at every `metric_period` + - log iteration at every `metric_period` + + Example: + ``` + train_pool = Pool( + train[features], label=train["label"], cat_features=cat_features + ) + test_pool = Pool(test[features], label=test["label"], cat_features=cat_features) + + model = CatBoostRegressor( + iterations=100, + loss_function="Cox", + eval_metric="Cox", + ) + + model.fit( + train_pool, + eval_set=test_pool, + callbacks=[WandbCallback()], + ) + ``` + """ + + def __init__(self, metric_period: int = 1): + if wandb.run is None: + raise wandb.Error("You must call `wandb.init()` before `WandbCallback()`") + + with wb_telemetry.context() as tel: + tel.feature.catboost_wandb_callback = True + + self.metric_period: int = metric_period + + def after_iteration(self, info: SimpleNamespace) -> bool: + if info.iteration % self.metric_period == 0: + for data, metric in info.metrics.items(): + for metric_name, log in metric.items(): + # todo: replace with wandb.run._log once available + wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False) + # todo: replace with wandb.run._log once available + wandb.log({f"iteration@metric-period-{self.metric_period}": info.iteration}) + + return True + + +def _checkpoint_artifact( + model: Union[CatBoostClassifier, CatBoostRegressor], aliases: List[str] +) -> None: + """Upload model checkpoint as W&B artifact.""" + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` before `_checkpoint_artifact()`" + ) + + model_name = f"model_{wandb.run.id}" + # save the model in the default `cbm` format + model_path = Path(wandb.run.dir) / "model" + + model.save_model(model_path) + + model_artifact = wandb.Artifact(name=model_name, type="model") + model_artifact.add_file(str(model_path)) + wandb.log_artifact(model_artifact, aliases=aliases) + + +def _log_feature_importance( + model: Union[CatBoostClassifier, CatBoostRegressor], +) -> None: + """Log feature importance with default settings.""" + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` before `_checkpoint_artifact()`" + ) + + feat_df = model.get_feature_importance(prettified=True) + + fi_data = [ + [feat, feat_imp] + for feat, feat_imp in zip(feat_df["Feature Id"], feat_df["Importances"]) + ] + table = wandb.Table(data=fi_data, columns=["Feature", "Importance"]) + # todo: replace with wandb.run._log once available + wandb.log( + { + "Feature Importance": wandb.plot.bar( + table, "Feature", "Importance", title="Feature Importance" + ) + }, + commit=False, + ) + + +def log_summary( + model: Union[CatBoostClassifier, CatBoostRegressor], + log_all_params: bool = True, + save_model_checkpoint: bool = False, + log_feature_importance: bool = True, +) -> None: + """`log_summary` logs useful metrics about catboost model after training is done. + + Args: + model: it can be CatBoostClassifier or CatBoostRegressor. + log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config. + save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts. + log_feature_importance: (boolean) if True (default) logs feature importance as W&B bar chart using the default setting of `get_feature_importance`. + + Using this along with `wandb_callback` will: + + - save the hyperparameters as W&B config, + - log `best_iteration` and `best_score` as `wandb.summary`, + - save and upload your trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`) + - log feature importance plot. + + Example: + ```python + train_pool = Pool( + train[features], label=train["label"], cat_features=cat_features + ) + test_pool = Pool(test[features], label=test["label"], cat_features=cat_features) + + model = CatBoostRegressor( + iterations=100, + loss_function="Cox", + eval_metric="Cox", + ) + + model.fit( + train_pool, + eval_set=test_pool, + callbacks=[WandbCallback()], + ) + + log_summary(model) + ``` + """ + if wandb.run is None: + raise wandb.Error("You must call `wandb.init()` before `log_summary()`") + + if not (isinstance(model, (CatBoostClassifier, CatBoostRegressor))): + raise wandb.Error( + "Model should be an instance of CatBoostClassifier or CatBoostRegressor" + ) + + with wb_telemetry.context() as tel: + tel.feature.catboost_log_summary = True + + # log configs + params = model.get_all_params() + if log_all_params: + wandb.config.update(params) + + # log best score and iteration + wandb.run.summary["best_iteration"] = model.get_best_iteration() + wandb.run.summary["best_score"] = model.get_best_score() + + # log model + if save_model_checkpoint: + aliases = ["best"] if params["use_best_model"] else ["last"] + _checkpoint_artifact(model, aliases=aliases) + + # Feature importance + if log_feature_importance: + _log_feature_importance(model) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/cohere/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/cohere/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d367dc6988bda4251745dc6b3610a3e92b4c85e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/cohere/__init__.py @@ -0,0 +1,3 @@ +__all__ = ("autolog",) + +from .cohere import autolog diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/cohere/cohere.py b/.venv/lib/python3.13/site-packages/wandb/integration/cohere/cohere.py new file mode 100644 index 0000000000000000000000000000000000000000..91f9a43e23150a6882dbb87512e6fbe657a7b8d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/cohere/cohere.py @@ -0,0 +1,21 @@ +import logging + +from wandb.sdk.integration_utils.auto_logging import AutologAPI + +from .resolver import CohereRequestResponseResolver + +logger = logging.getLogger(__name__) + + +autolog = AutologAPI( + name="Cohere", + symbols=( + "Client.generate", + "Client.chat", + "Client.classify", + "Client.summarize", + "Client.rerank", + ), + resolver=CohereRequestResponseResolver(), + telemetry_feature="cohere_autolog", +) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/cohere/resolver.py b/.venv/lib/python3.13/site-packages/wandb/integration/cohere/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfcdf020b87c8a5a3c15f164226ac50f8670d4d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/cohere/resolver.py @@ -0,0 +1,347 @@ +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import wandb +from wandb.sdk.integration_utils.auto_logging import Response +from wandb.sdk.lib.runid import generate_id + +logger = logging.getLogger(__name__) + + +def subset_dict( + original_dict: Dict[str, Any], keys_subset: Sequence[str] +) -> Dict[str, Any]: + """Create a subset of a dictionary using a subset of keys. + + :param original_dict: The original dictionary. + :param keys_subset: The subset of keys to extract. + :return: A dictionary containing only the specified keys. + """ + return {key: original_dict[key] for key in keys_subset if key in original_dict} + + +def reorder_and_convert_dict_list_to_table( + data: List[Dict[str, Any]], order: List[str] +) -> Tuple[List[str], List[List[Any]]]: + """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries. + + :param data: A list of dictionaries. + :param order: A list of keys specifying the desired order for specific dictionaries. The remaining dictionaries will be ordered based on their original order. + :return: A pair of column names and corresponding values. + """ + final_columns = [] + keys_present = set() + + # First, add all ordered keys to the final columns + for key in order: + if key not in keys_present: + final_columns.append(key) + keys_present.add(key) + + # Then, add any keys present in the dictionaries but not in the order + for d in data: + for key in d: + if key not in keys_present: + final_columns.append(key) + keys_present.add(key) + + # Then, construct the table of values + values = [] + for d in data: + row = [] + for key in final_columns: + row.append(d.get(key, None)) + values.append(row) + + return final_columns, values + + +def flatten_dict( + dictionary: Dict[str, Any], parent_key: str = "", sep: str = "-" +) -> Dict[str, Any]: + """Flatten a nested dictionary, joining keys using a specified separator. + + :param dictionary: The dictionary to flatten. + :param parent_key: The base key to prepend to each key. + :param sep: The separator to use when joining keys. + :return: A flattened dictionary. + """ + flattened_dict = {} + for key, value in dictionary.items(): + new_key = f"{parent_key}{sep}{key}" if parent_key else key + if isinstance(value, dict): + flattened_dict.update(flatten_dict(value, new_key, sep=sep)) + else: + flattened_dict[new_key] = value + return flattened_dict + + +def collect_common_keys(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List[Any]]: + """Collect the common keys of a list of dictionaries. For each common key, put its values into a list in the order they appear in the original dictionaries. + + :param list_of_dicts: The list of dictionaries to inspect. + :return: A dictionary with each common key and its corresponding list of values. + """ + common_keys = set.intersection(*map(set, list_of_dicts)) + common_dict = {key: [] for key in common_keys} + for d in list_of_dicts: + for key in common_keys: + common_dict[key].append(d[key]) + return common_dict + + +class CohereRequestResponseResolver: + """Class to resolve the request/response from the Cohere API and convert it to a dictionary that can be logged.""" + + def __call__( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + response: Response, + start_time: float, + time_elapsed: float, + ) -> Optional[Dict[str, Any]]: + """Process the response from the Cohere API and convert it to a dictionary that can be logged. + + :param args: The arguments of the original function. + :param kwargs: The keyword arguments of the original function. + :param response: The response from the Cohere API. + :param start_time: The start time of the request. + :param time_elapsed: The time elapsed for the request. + :return: A dictionary containing the parsed response and timing information. + """ + try: + # Each of the different endpoints map to one specific response type + # We want to 'type check' the response without directly importing the packages type + # It may make more sense to pass the invoked symbol from the AutologAPI instead + response_type = str(type(response)).split("'")[1].split(".")[-1] + + # Initialize parsed_response to None to handle the case where the response type is unsupported + parsed_response = None + if response_type == "Generations": + parsed_response = self._resolve_generate_response(response) + # TODO: Remove hard-coded default model name + table_column_order = [ + "start_time", + "query_id", + "model", + "prompt", + "text", + "token_likelihoods", + "likelihood", + "time_elapsed_(seconds)", + "end_time", + ] + default_model = "command" + elif response_type == "Chat": + parsed_response = self._resolve_chat_response(response) + table_column_order = [ + "start_time", + "query_id", + "model", + "conversation_id", + "response_id", + "query", + "text", + "prompt", + "preamble", + "chat_history", + "chatlog", + "time_elapsed_(seconds)", + "end_time", + ] + default_model = "command" + elif response_type == "Classifications": + parsed_response = self._resolve_classify_response(response) + kwargs = self._resolve_classify_kwargs(kwargs) + table_column_order = [ + "start_time", + "query_id", + "model", + "id", + "input", + "prediction", + "confidence", + "time_elapsed_(seconds)", + "end_time", + ] + default_model = "embed-english-v2.0" + elif response_type == "SummarizeResponse": + parsed_response = self._resolve_summarize_response(response) + table_column_order = [ + "start_time", + "query_id", + "model", + "response_id", + "text", + "additional_command", + "summary", + "time_elapsed_(seconds)", + "end_time", + "length", + "format", + ] + default_model = "summarize-xlarge" + elif response_type == "Reranking": + parsed_response = self._resolve_rerank_response(response) + table_column_order = [ + "start_time", + "query_id", + "model", + "id", + "query", + "top_n", + # This is a nested dict key that got flattened + "document-text", + "relevance_score", + "index", + "time_elapsed_(seconds)", + "end_time", + ] + default_model = "rerank-english-v2.0" + else: + logger.info(f"Unsupported Cohere response object: {response}") + + return self._resolve( + args, + kwargs, + parsed_response, + start_time, + time_elapsed, + response_type, + table_column_order, + default_model, + ) + except Exception as e: + logger.warning(f"Failed to resolve request/response: {e}") + return None + + # These helper functions process the response from different endpoints of the Cohere API. + # Since the response objects for different endpoints have different structures, + # we need different logic to process them. + + def _resolve_generate_response(self, response: Response) -> List[Dict[str, Any]]: + return_list = [] + for _response in response: + # Built in Cohere.*.Generations function to color token_likelihoods and return a dict of response data + _response_dict = _response._visualize_helper() + try: + _response_dict["token_likelihoods"] = wandb.Html( + _response_dict["token_likelihoods"] + ) + except (KeyError, ValueError): + pass + return_list.append(_response_dict) + + return return_list + + def _resolve_chat_response(self, response: Response) -> List[Dict[str, Any]]: + return [ + subset_dict( + response.__dict__, + [ + "response_id", + "generation_id", + "query", + "text", + "conversation_id", + "prompt", + "chatlog", + "preamble", + ], + ) + ] + + def _resolve_classify_response(self, response: Response) -> List[Dict[str, Any]]: + # The labels key is a dict returning the scores for the classification probability for each label provided + # We flatten this nested dict for ease of consumption in the wandb UI + return [flatten_dict(_response.__dict__) for _response in response] + + def _resolve_classify_kwargs(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + # Example texts look strange when rendered in Wandb UI as it is a list of text and label + # We extract each value into its own column + example_texts = [] + example_labels = [] + for example in kwargs["examples"]: + example_texts.append(example.text) + example_labels.append(example.label) + kwargs.pop("examples") + kwargs["example_texts"] = example_texts + kwargs["example_labels"] = example_labels + return kwargs + + def _resolve_summarize_response(self, response: Response) -> List[Dict[str, Any]]: + return [{"response_id": response.id, "summary": response.summary}] + + def _resolve_rerank_response(self, response: Response) -> List[Dict[str, Any]]: + # The documents key contains a dict containing the content of the document which is at least "text" + # We flatten this nested dict for ease of consumption in the wandb UI + flattened_response_dicts = [ + flatten_dict(_response.__dict__) for _response in response + ] + # ReRank returns each document provided a top_n value so we aggregate into one view so users can paginate a row + # As opposed to each row being one of the top_n responses + return_dict = collect_common_keys(flattened_response_dicts) + return_dict["id"] = response.id + return [return_dict] + + def _resolve( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + parsed_response: List[Dict[str, Any]], + start_time: float, + time_elapsed: float, + response_type: str, + table_column_order: List[str], + default_model: str, + ) -> Dict[str, Any]: + """Convert a list of dictionaries to a pair of column names and corresponding values, with the option to order specific dictionaries. + + :param args: The arguments passed to the API client. + :param kwargs: The keyword arguments passed to the API client. + :param parsed_response: The parsed response from the API. + :param start_time: The start time of the API request. + :param time_elapsed: The time elapsed during the API request. + :param response_type: The type of the API response. + :param table_column_order: The desired order of columns in the resulting table. + :param default_model: The default model to use if not specified in the response. + :return: A dictionary containing the formatted response. + """ + # Args[0] is the client object where we can grab specific metadata about the underlying API status + query_id = generate_id(length=16) + parsed_args = subset_dict( + args[0].__dict__, + ["api_version", "batch_size", "max_retries", "num_workers", "timeout"], + ) + + start_time_dt = datetime.fromtimestamp(start_time) + end_time_dt = datetime.fromtimestamp(start_time + time_elapsed) + + timings = { + "start_time": start_time_dt, + "end_time": end_time_dt, + "time_elapsed_(seconds)": time_elapsed, + } + + packed_data = [] + for _parsed_response in parsed_response: + _packed_dict = { + "query_id": query_id, + **kwargs, + **_parsed_response, + **timings, + **parsed_args, + } + if "model" not in _packed_dict: + _packed_dict["model"] = default_model + packed_data.append(_packed_dict) + + columns, data = reorder_and_convert_dict_list_to_table( + packed_data, table_column_order + ) + + request_response_table = wandb.Table(data=data, columns=columns) + + return {f"{response_type}": request_response_table} diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5bcf7980133c41393fdc22db3bbff89be29be94e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/__init__.py @@ -0,0 +1,3 @@ +from .autologger import autolog + +__all__ = ["autolog"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/autologger.py b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/autologger.py new file mode 100644 index 0000000000000000000000000000000000000000..ad21a77edc6a1125f8d2f5621c046c6bbe1f0b9b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/autologger.py @@ -0,0 +1,76 @@ +import logging + +from wandb.sdk.integration_utils.auto_logging import AutologAPI + +from .pipeline_resolver import DiffusersPipelineResolver + +logger = logging.getLogger(__name__) + +autolog = AutologAPI( + name="diffusers", + symbols=( + "DiffusionPipeline.__call__", + "AutoPipelineForText2Image.__call__", + "AutoPipelineForImage2Image.__call__", + "AutoPipelineForInpainting.__call__", + "StableDiffusionPipeline.__call__", + "KandinskyCombinedPipeline.__call__", + "KandinskyV22CombinedPipeline.__call__", + "LatentConsistencyModelPipeline.__call__", + "LDMTextToImagePipeline.__call__", + "StableDiffusionPanoramaPipeline.__call__", + "StableDiffusionParadigmsPipeline.__call__", + "PixArtAlphaPipeline.__call__", + "StableDiffusionSAGPipeline.__call__", + "SemanticStableDiffusionPipeline.__call__", + "WuerstchenCombinedPipeline.__call__", + "AltDiffusionPipeline.__call__", + "StableDiffusionAttendAndExcitePipeline.__call__", + "StableDiffusionXLPipeline.__call__", + "StableDiffusionXLImg2ImgPipeline.__call__", + "IFPipeline.__call__", + "BlipDiffusionPipeline.__call__", + "BlipDiffusionControlNetPipeline.__call__", + "StableDiffusionControlNetPipeline.__call__", + "StableDiffusionControlNetImg2ImgPipeline.__call__", + "StableDiffusionControlNetInpaintPipeline.__call__", + "CycleDiffusionPipeline.__call__", + "StableDiffusionInstructPix2PixPipeline.__call__", + "PaintByExamplePipeline.__call__", + "RePaintPipeline.__call__", + "KandinskyImg2ImgCombinedPipeline.__call__", + "KandinskyInpaintCombinedPipeline.__call__", + "KandinskyV22Img2ImgCombinedPipeline.__call__", + "KandinskyV22InpaintCombinedPipeline.__call__", + "Kandinsky3Pipeline.__call__", + "Kandinsky3Img2ImgPipeline.__call__", + "AnimateDiffPipeline.__call__", + "AudioLDMPipeline.__call__", + "AudioLDM2Pipeline.__call__", + "MusicLDMPipeline.__call__", + "StableDiffusionPix2PixZeroPipeline.__call__", + "PNDMPipeline.__call__", + "ShapEPipeline.__call__", + "StableDiffusionImg2ImgPipeline.__call__", + "StableDiffusionInpaintPipeline.__call__", + "StableDiffusionDepth2ImgPipeline.__call__", + "StableDiffusionImageVariationPipeline.__call__", + "StableDiffusionPipelineSafe.__call__", + "StableDiffusionUpscalePipeline.__call__", + "StableDiffusionAdapterPipeline.__call__", + "StableDiffusionGLIGENPipeline.__call__", + "StableDiffusionModelEditingPipeline.__call__", + "VersatileDiffusionTextToImagePipeline.__call__", + "VersatileDiffusionImageVariationPipeline.__call__", + "VersatileDiffusionDualGuidedPipeline.__call__", + "LDMPipeline.__call__", + "TextToVideoSDPipeline.__call__", + "TextToVideoZeroPipeline.__call__", + "StableVideoDiffusionPipeline.__call__", + "AmusedPipeline.__call__", + "StableDiffusionXLControlNetPipeline.__call__", + "StableDiffusionXLControlNetImg2ImgPipeline.__call__", + ), + resolver=DiffusersPipelineResolver(), + telemetry_feature="diffusers_autolog", +) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/pipeline_resolver.py b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/pipeline_resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c4d73511c20f53e75f4b1cbf5a6cfc444c8875 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/pipeline_resolver.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, Sequence + +from wandb.sdk.integration_utils.auto_logging import Response + +from .resolvers import ( + SUPPORTED_MULTIMODAL_PIPELINES, + DiffusersMultiModalPipelineResolver, +) + + +class DiffusersPipelineResolver: + """Resolver for `DiffusionPipeline` request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index), providing necessary data transformations, formatting, and logging. + + This is based off `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`. + """ + + def __init__(self) -> None: + self.wandb_table = None + self.pipeline_call_count = 1 + + def __call__( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + response: Response, + start_time: float, + time_elapsed: float, + ) -> Any: + """Main call method for the `DiffusersPipelineResolver` class. + + Args: + args: (Sequence[Any]) List of arguments. + kwargs: (Dict[str, Any]) Dictionary of keyword arguments. + response: (wandb.sdk.integration_utils.auto_logging.Response) The response from + the request. + start_time: (float) Time when request started. + time_elapsed: (float) Time elapsed for the request. + + Returns: + Packed data as a dictionary for logging to wandb, None if an exception occurred. + """ + pipeline_name = args[0].__class__.__name__ + resolver = None + if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES: + resolver = DiffusersMultiModalPipelineResolver( + pipeline_name, self.pipeline_call_count + ) + self.pipeline_call_count += 1 + loggable_dict = resolver(args, kwargs, response, start_time, time_elapsed) + return loggable_dict diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d6211bd4bc2b9901593e477c1b5713bc074f8804 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/__init__.py @@ -0,0 +1,9 @@ +from .multimodal import ( + SUPPORTED_MULTIMODAL_PIPELINES, + DiffusersMultiModalPipelineResolver, +) + +__all__ = [ + "SUPPORTED_MULTIMODAL_PIPELINES", + "DiffusersMultiModalPipelineResolver", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/multimodal.py b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/multimodal.py new file mode 100644 index 0000000000000000000000000000000000000000..34ef33e8639193b8f87a65434e6261ff112fbd7a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/multimodal.py @@ -0,0 +1,881 @@ +import logging +from typing import Any, Dict, List, Sequence + +import wandb +from wandb.sdk.integration_utils.auto_logging import Response + +from .utils import ( + chunkify, + decode_sdxl_t2i_latents, + get_updated_kwargs, + postprocess_np_arrays_for_video, + postprocess_pils_to_np, +) + +logger = logging.getLogger(__name__) + + +SUPPORTED_MULTIMODAL_PIPELINES = { + "BlipDiffusionPipeline": { + "table-schema": [ + "Reference-Image", + "Prompt", + "Negative-Prompt", + "Source-Subject-Category", + "Target-Subject-Category", + "Generated-Image", + ], + "kwarg-logging": [ + "reference_image", + "prompt", + "neg_prompt", + "source_subject_category", + "target_subject_category", + ], + "kwarg-actions": [wandb.Image, None, None, None, None], + }, + "BlipDiffusionControlNetPipeline": { + "table-schema": [ + "Reference-Image", + "Control-Image", + "Prompt", + "Negative-Prompt", + "Source-Subject-Category", + "Target-Subject-Category", + "Generated-Image", + ], + "kwarg-logging": [ + "reference_image", + "condtioning_image", + "prompt", + "neg_prompt", + "source_subject_category", + "target_subject_category", + ], + "kwarg-actions": [wandb.Image, wandb.Image, None, None, None, None], + }, + "StableDiffusionControlNetPipeline": { + "table-schema": [ + "Control-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "StableDiffusionControlNetImg2ImgPipeline": { + "table-schema": [ + "Source-Image", + "Control-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "control_image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, wandb.Image, None, None], + }, + "StableDiffusionControlNetInpaintPipeline": { + "table-schema": [ + "Source-Image", + "Mask-Image", + "Control-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "image", + "mask_image", + "control_image", + "prompt", + "negative_prompt", + ], + "kwarg-actions": [wandb.Image, wandb.Image, wandb.Image, None, None], + }, + "CycleDiffusionPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Source-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "image", + "prompt", + "source_prompt", + ], + "kwarg-actions": [wandb.Image, None, None], + }, + "StableDiffusionInstructPix2PixPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "image", + "prompt", + "negative_prompt", + ], + "kwarg-actions": [wandb.Image, None, None], + }, + "PaintByExamplePipeline": { + "table-schema": [ + "Source-Image", + "Example-Image", + "Mask-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "image", + "example_image", + "mask_image", + ], + "kwarg-actions": [wandb.Image, wandb.Image, wandb.Image], + }, + "RePaintPipeline": { + "table-schema": [ + "Source-Image", + "Mask-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "image", + "mask_image", + ], + "kwarg-actions": [wandb.Image, wandb.Image], + }, + "StableDiffusionPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "KandinskyCombinedPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "KandinskyV22CombinedPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "LatentConsistencyModelPipeline": { + "table-schema": ["Prompt", "Generated-Image"], + "kwarg-logging": ["prompt"], + "kwarg-actions": [None], + }, + "LDMTextToImagePipeline": { + "table-schema": ["Prompt", "Generated-Image"], + "kwarg-logging": ["prompt"], + "kwarg-actions": [None], + }, + "StableDiffusionPanoramaPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "PixArtAlphaPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "StableDiffusionSAGPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "SemanticStableDiffusionPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "WuerstchenCombinedPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "IFPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "AltDiffusionPipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "StableDiffusionAttendAndExcitePipeline": { + "table-schema": ["Prompt", "Negative-Prompt", "Generated-Image"], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "KandinskyImg2ImgCombinedPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "KandinskyInpaintCombinedPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "KandinskyV22Img2ImgCombinedPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "KandinskyV22InpaintCombinedPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "AnimateDiffPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Number-of-Frames", + "Generated-Video", + ], + "kwarg-logging": ["prompt", "negative_prompt", "num_frames"], + "kwarg-actions": [None, None, None], + "output-type": "video", + }, + "StableVideoDiffusionPipeline": { + "table-schema": [ + "Input-Image", + "Frames-Per-Second", + "Generated-Video", + ], + "kwarg-logging": ["image", "fps"], + "kwarg-actions": [wandb.Image, None], + "output-type": "video", + }, + "AudioLDMPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Audio-Length-in-Seconds", + "Generated-Audio", + ], + "kwarg-logging": ["prompt", "negative_prompt", "audio_length_in_s"], + "kwarg-actions": [None, None, None], + "output-type": "audio", + }, + "AudioLDM2Pipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Audio-Length-in-Seconds", + "Generated-Audio", + ], + "kwarg-logging": ["prompt", "negative_prompt", "audio_length_in_s"], + "kwarg-actions": [None, None, None], + "output-type": "audio", + }, + "MusicLDMPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Audio-Length-in-Seconds", + "Generated-Audio", + ], + "kwarg-logging": ["prompt", "negative_prompt", "audio_length_in_s"], + "kwarg-actions": [None, None, None], + "output-type": "audio", + }, + "StableDiffusionPix2PixZeroPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "PNDMPipeline": { + "table-schema": [ + "Batch-Size", + "Number-of-Inference-Steps", + "Generated-Image", + ], + "kwarg-logging": ["batch_size", "num_inference_steps"], + "kwarg-actions": [None, None], + }, + "ShapEPipeline": { + "table-schema": [ + "Prompt", + "Generated-Video", + ], + "kwarg-logging": ["prompt"], + "kwarg-actions": [None], + "output-type": "video", + }, + "StableDiffusionImg2ImgPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "StableDiffusionInpaintPipeline": { + "table-schema": [ + "Source-Image", + "Mask-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "mask_image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, wandb.Image, None, None], + }, + "StableDiffusionDepth2ImgPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "StableDiffusionImageVariationPipeline": { + "table-schema": [ + "Source-Image", + "Generated-Image", + ], + "kwarg-logging": [ + "image", + ], + "kwarg-actions": [wandb.Image], + }, + "StableDiffusionPipelineSafe": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "StableDiffusionUpscalePipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Upscaled-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "StableDiffusionAdapterPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "StableDiffusionGLIGENPipeline": { + "table-schema": [ + "Prompt", + "GLIGEN-Phrases", + "GLIGEN-Boxes", + "GLIGEN-Inpaint-Image", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "gligen_phrases", + "gligen_boxes", + "gligen_inpaint_image", + "negative_prompt", + ], + "kwarg-actions": [None, None, None, wandb.Image, None], + }, + "VersatileDiffusionTextToImagePipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["prompt", "negative_prompt"], + "kwarg-actions": [None, None], + }, + "VersatileDiffusionImageVariationPipeline": { + "table-schema": [ + "Source-Image", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "negative_prompt"], + "kwarg-actions": [wandb.Image, None], + }, + "VersatileDiffusionDualGuidedPipeline": { + "table-schema": [ + "Source-Image", + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": ["image", "prompt", "negative_prompt"], + "kwarg-actions": [wandb.Image, None, None], + }, + "LDMPipeline": { + "table-schema": [ + "Batch-Size", + "Number-of-Inference-Steps", + "Generated-Image", + ], + "kwarg-logging": ["batch_size", "num_inference_steps"], + "kwarg-actions": [None, None], + }, + "TextToVideoSDPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Number-of-Frames", + "Generated-Video", + ], + "kwarg-logging": ["prompt", "negative_prompt", "num_frames"], + "output-type": "video", + }, + "TextToVideoZeroPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Number-of-Frames", + "Generated-Video", + ], + "kwarg-logging": ["prompt", "negative_prompt", "video_length"], + }, + "AmusedPipeline": { + "table-schema": [ + "Prompt", + "Guidance Scale", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "guidance_scale", + ], + "kwarg-actions": [None, None], + }, + "StableDiffusionXLControlNetPipeline": { + "table-schema": [ + "Prompt-1", + "Prompt-2", + "Control-Image", + "Negative-Prompt-1", + "Negative-Prompt-2", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "prompt_2", + "image", + "negative_prompt", + "negative_prompt_2", + ], + "kwarg-actions": [None, None, wandb.Image, None, None], + }, + "StableDiffusionXLControlNetImg2ImgPipeline": { + "table-schema": [ + "Prompt-1", + "Prompt-2", + "Input-Image", + "Control-Image", + "Negative-Prompt-1", + "Negative-Prompt-2", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "prompt_2", + "image", + "control_image", + "negative_prompt", + "negative_prompt_2", + ], + "kwarg-actions": [None, None, wandb.Image, wandb.Image, None, None], + }, + "Kandinsky3Pipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "negative_prompt", + ], + "kwarg-actions": [None, None], + }, + "Kandinsky3Img2ImgPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Input-Image", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "negative_prompt", + "image", + ], + "kwarg-actions": [None, None, wandb.Image], + }, + "StableDiffusionXLPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Prompt-2", + "Negative-Prompt-2", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "negative_prompt", + "prompt_2", + "negative_prompt_2", + ], + "kwarg-actions": [None, None, None, None], + }, + "StableDiffusionXLImg2ImgPipeline": { + "table-schema": [ + "Prompt", + "Negative-Prompt", + "Prompt-2", + "Negative-Prompt-2", + "Input-Image", + "Generated-Image", + ], + "kwarg-logging": [ + "prompt", + "negative_prompt", + "prompt_2", + "negative_prompt_2", + "image", + ], + "kwarg-actions": [None, None, None, None, wandb.Image], + }, +} + + +class DiffusersMultiModalPipelineResolver: + """Resolver for request and responses from [HuggingFace Diffusers](https://huggingface.co/docs/diffusers/index) multi-modal Diffusion Pipelines, providing necessary data transformations, formatting, and logging. + + This resolver is internally involved in the + `__call__` for `wandb.integration.diffusers.pipeline_resolver.DiffusersPipelineResolver`. + This is based on `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`. + + Args: + pipeline_name: (str) The name of the Diffusion Pipeline. + """ + + def __init__(self, pipeline_name: str, pipeline_call_count: int) -> None: + self.pipeline_name = pipeline_name + self.pipeline_call_count = pipeline_call_count + columns = [] + if pipeline_name in SUPPORTED_MULTIMODAL_PIPELINES: + columns += SUPPORTED_MULTIMODAL_PIPELINES[pipeline_name]["table-schema"] + else: + wandb.Error("Pipeline not supported for logging") + self.wandb_table = wandb.Table(columns=columns) + + def __call__( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + response: Response, + start_time: float, + time_elapsed: float, + ) -> Any: + """Main call method for the `DiffusersPipelineResolver` class. + + Args: + args: (Sequence[Any]) List of arguments. + kwargs: (Dict[str, Any]) Dictionary of keyword arguments. + response: (wandb.sdk.integration_utils.auto_logging.Response) The response from + the request. + start_time: (float) Time when request started. + time_elapsed: (float) Time elapsed for the request. + + Returns: + Packed data as a dictionary for logging to wandb, None if an exception occurred. + """ + try: + # Get the pipeline and the args + pipeline, args = args[0], args[1:] + + # Update the Kwargs so that they can be logged easily + kwargs = get_updated_kwargs(pipeline, args, kwargs) + + # Get the pipeline configs + pipeline_configs = dict(pipeline.config) + pipeline_configs["pipeline-name"] = self.pipeline_name + + if "workflow" not in wandb.config: + wandb.config.update( + { + "workflow": [ + { + "pipeline": pipeline_configs, + "params": kwargs, + "stage": f"Pipeline-Call-{self.pipeline_call_count}", + } + ] + } + ) + else: + existing_workflow = wandb.config.workflow + updated_workflow = existing_workflow + [ + { + "pipeline": pipeline_configs, + "params": kwargs, + "stage": f"Pipeline-Call-{self.pipeline_call_count}", + } + ] + wandb.config.update( + {"workflow": updated_workflow}, allow_val_change=True + ) + + # Return the WandB loggable dict + return self.prepare_loggable_dict(pipeline, response, kwargs) + except Exception as e: + logger.warning(e) + return None + + def get_output_images(self, response: Response) -> List: + """Unpack the generated images, audio, video, etc. from the Diffusion Pipeline's response. + + Args: + response: (wandb.sdk.integration_utils.auto_logging.Response) The response from + the request. + + Returns: + List of generated images, audio, video, etc. + """ + if "output-type" not in SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]: + return response.images + else: + if ( + SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"] + == "video" + ): + if self.pipeline_name in ["ShapEPipeline"]: + return response.images + return response.frames + elif ( + SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"] + == "audio" + ): + return response.audios + + def log_media(self, image: Any, loggable_kwarg_chunks: List, idx: int) -> None: + """Log the generated images, audio, video, etc. from the Diffusion Pipeline's response along with an optional caption to a media panel in the run. + + Args: + image: (Any) The generated images, audio, video, etc. from the Diffusion + Pipeline's response. + loggable_kwarg_chunks: (List) Loggable chunks of kwargs. + """ + if "output-type" not in SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]: + try: + caption = "" + if self.pipeline_name in [ + "StableDiffusionXLPipeline", + "StableDiffusionXLImg2ImgPipeline", + ]: + prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ].index("prompt") + prompt2_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ].index("prompt_2") + caption = f"Prompt-1: {loggable_kwarg_chunks[prompt_index][idx]}\nPrompt-2: {loggable_kwarg_chunks[prompt2_index][idx]}" + else: + prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ].index("prompt") + caption = loggable_kwarg_chunks[prompt_index][idx] + except ValueError: + caption = None + wandb.log( + { + f"Generated-Image/Pipeline-Call-{self.pipeline_call_count}": wandb.Image( + image, caption=caption + ) + } + ) + else: + if ( + SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"] + == "video" + ): + try: + prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ].index("prompt") + caption = loggable_kwarg_chunks[prompt_index][idx] + except ValueError: + caption = None + wandb.log( + { + f"Generated-Video/Pipeline-Call-{self.pipeline_call_count}": wandb.Video( + postprocess_pils_to_np(image), fps=4, caption=caption + ) + } + ) + elif ( + SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"] + == "audio" + ): + try: + prompt_index = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ].index("prompt") + caption = loggable_kwarg_chunks[prompt_index][idx] + except ValueError: + caption = None + wandb.log( + { + f"Generated-Audio/Pipeline-Call-{self.pipeline_call_count}": wandb.Audio( + image, sample_rate=16000, caption=caption + ) + } + ) + + def add_data_to_table( + self, image: Any, loggable_kwarg_chunks: List, idx: int + ) -> None: + """Populate the row of the `wandb.Table`. + + Args: + image: (Any) The generated images, audio, video, etc. from the Diffusion + Pipeline's response. + loggable_kwarg_chunks: (List) Loggable chunks of kwargs. + idx: (int) Chunk index. + """ + table_row = [] + kwarg_actions = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-actions" + ] + for column_idx, loggable_kwarg_chunk in enumerate(loggable_kwarg_chunks): + if kwarg_actions[column_idx] is None: + table_row.append( + loggable_kwarg_chunk[idx] + if loggable_kwarg_chunk[idx] is not None + else "" + ) + else: + table_row.append(kwarg_actions[column_idx](loggable_kwarg_chunk[idx])) + if "output-type" not in SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]: + table_row.append(wandb.Image(image)) + else: + if ( + SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"] + == "video" + ): + table_row.append(wandb.Video(postprocess_pils_to_np(image), fps=4)) + elif ( + SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name]["output-type"] + == "audio" + ): + table_row.append(wandb.Audio(image, sample_rate=16000)) + self.wandb_table.add_data(*table_row) + + def prepare_loggable_dict( + self, pipeline: Any, response: Response, kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + """Prepare the loggable dictionary, which is the packed data as a dictionary for logging to wandb, None if an exception occurred. + + Args: + pipeline: (Any) The Diffusion Pipeline. + response: (wandb.sdk.integration_utils.auto_logging.Response) The response from + the request. + kwargs: (Dict[str, Any]) Dictionary of keyword arguments. + + Returns: + Packed data as a dictionary for logging to wandb, None if an exception occurred. + """ + # Unpack the generated images, audio, video, etc. from the Diffusion Pipeline's response. + images = self.get_output_images(response) + if ( + self.pipeline_name == "StableDiffusionXLPipeline" + and kwargs["output_type"] == "latent" + ): + images = decode_sdxl_t2i_latents(pipeline, response.images) + + # Account for exception pipelines for text-to-video + if self.pipeline_name in ["TextToVideoSDPipeline", "TextToVideoZeroPipeline"]: + video = postprocess_np_arrays_for_video( + images, normalize=self.pipeline_name == "TextToVideoZeroPipeline" + ) + wandb.log( + { + f"Generated-Video/Pipeline-Call-{self.pipeline_call_count}": wandb.Video( + video, fps=4, caption=kwargs["prompt"] + ) + } + ) + loggable_kwarg_ids = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ] + table_row = [ + kwargs[loggable_kwarg_ids[idx]] + for idx in range(len(loggable_kwarg_ids)) + ] + table_row.append(wandb.Video(video, fps=4)) + self.wandb_table.add_data(*table_row) + else: + loggable_kwarg_ids = SUPPORTED_MULTIMODAL_PIPELINES[self.pipeline_name][ + "kwarg-logging" + ] + # chunkify loggable kwargs + loggable_kwarg_chunks = [] + for loggable_kwarg_id in loggable_kwarg_ids: + loggable_kwarg_chunks.append( + kwargs[loggable_kwarg_id] + if isinstance(kwargs[loggable_kwarg_id], list) + else [kwargs[loggable_kwarg_id]] + ) + # chunkify the generated media + images = chunkify(images, len(loggable_kwarg_chunks[0])) + for idx in range(len(loggable_kwarg_chunks[0])): + for image in images[idx]: + # Log media to media panel + self.log_media(image, loggable_kwarg_chunks, idx) + # Populate the row of the wandb_table + self.add_data_to_table(image, loggable_kwarg_chunks, idx) + return { + f"Result-Table/Pipeline-Call-{self.pipeline_call_count}": self.wandb_table + } diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/utils.py b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82a6aac045ba9739af9d0ab342a63e33252b63e0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/diffusers/resolvers/utils.py @@ -0,0 +1,102 @@ +import inspect +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence + +import wandb +from wandb.util import get_module + +if TYPE_CHECKING: + np_array = get_module("numpy.array") + torch_float_tensor = get_module("torch.FloatTensor") + + +def chunkify(input_list, chunk_size) -> List: + chunk_size = max(1, chunk_size) + return [ + input_list[i : i + chunk_size] for i in range(0, len(input_list), chunk_size) + ] + + +def get_updated_kwargs( + pipeline: Any, args: Sequence[Any], kwargs: Dict[str, Any] +) -> Dict[str, Any]: + pipeline_call_parameters = list( + inspect.signature(pipeline.__call__).parameters.items() + ) + for idx, arg in enumerate(args): + kwargs[pipeline_call_parameters[idx][0]] = arg + for pipeline_parameter in pipeline_call_parameters: + if pipeline_parameter[0] not in kwargs: + kwargs[pipeline_parameter[0]] = pipeline_parameter[1].default + if "generator" in kwargs: + generator = kwargs["generator"] + kwargs["generator"] = ( + { + "seed": generator.initial_seed(), + "device": generator.device, + "random_state": generator.get_state().cpu().numpy().tolist(), + } + if generator is not None + else None + ) + if "ip_adapter_image" in kwargs: + if kwargs["ip_adapter_image"] is not None: + wandb.log({"IP-Adapter-Image": wandb.Image(kwargs["ip_adapter_image"])}) + return kwargs + + +def postprocess_pils_to_np(image: List) -> "np_array": + np = get_module( + "numpy", + required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.", + ) + return np.stack( + [np.transpose(np.array(img).astype("uint8"), axes=(2, 0, 1)) for img in image], + axis=0, + ) + + +def postprocess_np_arrays_for_video( + images: List["np_array"], normalize: Optional[bool] = False +) -> "np_array": + np = get_module( + "numpy", + required="Please ensure NumPy is installed. You can run `pip install numpy` to install it.", + ) + images = [(img * 255).astype("uint8") for img in images] if normalize else images + return np.transpose(np.stack((images), axis=0), axes=(0, 3, 1, 2)) + + +def decode_sdxl_t2i_latents(pipeline: Any, latents: "torch_float_tensor") -> List: + """Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl). + + Args: + pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from + [`diffusers`](https://huggingface.co/docs/diffusers). + latents (torch.FloatTensor): The generated latents. + + Returns: + List of `PIL` images corresponding to the generated latents. + """ + torch = get_module( + "torch", + required="Please ensure PyTorch is installed. You can check out https://pytorch.org/get-started/locally/#start-locally for installation instructions.", + ) + with torch.no_grad(): + needs_upcasting = ( + pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast + ) + if needs_upcasting: + pipeline.upcast_vae() + latents = latents.to( + next(iter(pipeline.vae.post_quant_conv.parameters())).dtype + ) + images = pipeline.vae.decode( + latents / pipeline.vae.config.scaling_factor, return_dict=False + )[0] + if needs_upcasting: + pipeline.vae.to(dtype=torch.float16) + if pipeline.watermark is not None: + images = pipeline.watermark.apply_watermark(images) + images = pipeline.image_processor.postprocess(images, output_type="pil") + pipeline.maybe_free_model_hooks() + return images diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/dspy/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/dspy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d5490703ce96e5cd79294c05051313939fca49 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/dspy/__init__.py @@ -0,0 +1,5 @@ +"""W&B DSPy integration package.""" + +from .dspy import WandbDSPyCallback + +__all__ = ["WandbDSPyCallback"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/dspy/dspy.py b/.venv/lib/python3.13/site-packages/wandb/integration/dspy/dspy.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb06b23f4198bd80f6d3775d6bcef5f153dbf85 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/dspy/dspy.py @@ -0,0 +1,423 @@ +"""DSPy ↔ Weights & Biases integration.""" + +from __future__ import annotations + +import logging +import os +from collections.abc import Mapping, Sequence +from typing import Any, Literal + +import wandb +import wandb.util +from wandb.sdk.lib import telemetry +from wandb.sdk.wandb_run import Run + +dspy = wandb.util.get_module( + name="dspy", + required=( + "To use the W&B DSPy integration you need to have the `dspy` " + "python package installed. Install it with `uv pip install dspy`." + ), + lazy=False, +) +if dspy is not None: + assert dspy.__version__ >= "3.0.0", ( + "DSPy 3.0.0 or higher is required. You have " + dspy.__version__ + ) + + +logger = logging.getLogger(__name__) + + +def _flatten_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Flatten a list of nested row dicts into flat key/value dicts. + + Args: + rows (list[dict[str, Any]]): List of nested dictionaries to flatten. + + Returns: + list[dict[str, Any]]: List of flattened dictionaries. + + """ + + def _flatten( + d: dict[str, Any], parent_key: str = "", sep: str = "." + ) -> dict[str, Any]: + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(_flatten(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + return [_flatten(row) for row in rows] + + +class WandbDSPyCallback(dspy.utils.BaseCallback): + """W&B callback for tracking DSPy evaluation and optimization. + + This callback logs evaluation scores, per-step predictions (optional), and + a table capturing the DSPy program signature over time. It can also save + the best program as a W&B Artifact for reproducibility. + + Examples: + Basic usage within DSPy settings: + + ```python + import dspy + import wandb + from wandb.integration.dspy import WandbDSPyCallback + + with wandb.init(project="dspy-optimization") as run: + dspy.settings.callbacks.append(WandbDSPyCallback(run=run)) + # Run your DSPy optimization/evaluation + ``` + """ + + def __init__(self, log_results: bool = True, run: Run | None = None) -> None: + """Initialize the callback. + + Args: + log_results (bool): Whether to log per-evaluation prediction tables. + run (Run | None): Optional W&B run to use. Defaults to the + current global run if available. + + Raises: + wandb.Error: If no active run is provided or found. + """ + # If no run is provided, use the current global run if available. + if run is None: + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` before instantiating WandbDSPyCallback()." + ) + run = wandb.run + + self.log_results = log_results + + with telemetry.context(run=run) as tel: + tel.feature.dspy_callback = True + + self._run = run + self._did_log_config: bool = False + self._program_info: dict[str, Any] = {} + self._program_table: wandb.Table | None = None + self._row_idx: int = 0 + + def _flatten_dict( + self, nested: Any, parent_key: str = "", sep: str = "." + ) -> dict[str, Any]: + """Recursively flatten arbitrarily nested mappings and sequences. + + Args: + nested (Any): Nested structure of mappings/lists to flatten. + parent_key (str): Prefix to prepend to keys in the flattened output. + sep (str): Key separator for nested fields. + + Returns: + dict[str, Any]: Flattened dictionary representation. + """ + flat: dict[str, Any] = {} + + def _walk(obj: Any, base: str) -> None: + if isinstance(obj, Mapping): + for k, v in obj.items(): + new_key = f"{base}{sep}{k}" if base else str(k) + _walk(v, new_key) + elif isinstance(obj, Sequence) and not isinstance( + obj, (str, bytes, bytearray) + ): + for idx, v in enumerate(obj): + new_key = f"{base}{sep}{idx}" if base else str(idx) + _walk(v, new_key) + else: + # Base can be empty only if the top-level is a scalar; guard against that. + key = base if base else "" + if key: + flat[key] = obj + + _walk(nested, parent_key) + return flat + + def _extract_fields(self, fields: list[dict[str, Any]]) -> dict[str, str]: + """Convert signature fields to a flat mapping of strings. + + Note: + The input is expected to be a dict-like mapping from field names to + field metadata. Values are stringified for logging. + + Args: + fields (list[dict[str, Any]]): Mapping of field name to metadata. + + Returns: + dict[str, str]: Mapping of field name to string value. + """ + return {k: str(v) for k, v in fields.items()} + + def _extract_program_info(self, program_obj: Any) -> dict[str, Any]: + """Extract signature-related info from a DSPy program. + + Attempts to read the program signature, instructions, input and output + fields from a DSPy `Predict` parameter if available. + + Args: + program_obj (Any): DSPy program/module instance. + + Returns: + dict[str, Any]: Flattened dictionary of signature metadata. + """ + info_dict = {} + + if program_obj is None: + return info_dict + + try: + sig = next( + param.signature + for _, param in program_obj.named_parameters() + if isinstance(param, dspy.Predict) + ) + + if getattr(sig, "signature", None): + info_dict["signature"] = sig.signature + if getattr(sig, "instructions", None): + info_dict["instructions"] = sig.instructions + if getattr(sig, "input_fields", None): + input_fields = sig.input_fields + info_dict["input_fields"] = self._extract_fields(input_fields) + if getattr(sig, "output_fields", None): + output_fields = sig.output_fields + info_dict["output_fields"] = self._extract_fields(output_fields) + + return self._flatten_dict(info_dict) + except Exception as e: + logger.warning( + "Failed to extract program info from Evaluate instance: %s", e + ) + return info_dict + + def on_evaluate_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ) -> None: + """Handle start of a DSPy evaluation call. + + Logs non-private fields from the evaluator instance to W&B config and + captures program signature info for later logging. + + Args: + call_id (str): Unique identifier for the evaluation call. + instance (Any): The evaluation instance (e.g., `dspy.Evaluate`). + inputs (dict[str, Any]): Inputs passed to the evaluation (may + include a `program` key with the DSPy program). + """ + if not self._did_log_config: + instance_vars = vars(instance) if hasattr(instance, "__dict__") else {} + serializable = { + k: v for k, v in instance_vars.items() if not k.startswith("_") + } + if "devset" in serializable: + # we don't want to log the devset in the config + del serializable["devset"] + + self._run.config.update(serializable) + self._did_log_config = True + + # 2) Build/append program signature tables from the 'program' inputs + if program_obj := inputs.get("program"): + self._program_info = self._extract_program_info(program_obj) + + def on_evaluate_end( + self, + call_id: str, + outputs: Any | None, + exception: Exception | None = None, + ) -> None: + """Handle end of a DSPy evaluation call. + + If available, logs a numeric `score` metric and (optionally) per-step + prediction tables. Always appends a row to the program-signature table. + + Args: + call_id (str): Unique identifier for the evaluation call. + outputs (Any | None): Evaluation outputs; supports + `dspy.evaluate.evaluate.EvaluationResult`. + exception (Exception | None): Exception raised during evaluation, if any. + """ + # The `BaseCallback` does not define the interface for the `outputs` parameter, + # Currently, we know of `EvaluationResult` which is a subclass of `dspy.Prediction`. + # We currently support this type and will warn the user if a different type is passed. + score: float | None = None + if exception is None: + if isinstance(outputs, dspy.evaluate.evaluate.EvaluationResult): + # log the float score as a wandb metric + score = outputs.score + wandb.log({"score": float(score)}, step=self._row_idx) + + # Log the predictions as a separate table for each eval end. + # We know that results if of type `list[tuple["dspy.Example", "dspy.Example", Any]]` + results = outputs.results + if self.log_results: + rows = self._parse_results(results) + if rows: + self._log_predictions_table(rows) + else: + wandb.termwarn( + f"on_evaluate_end received unexpected outputs type: {type(outputs)}. " + "Expected dspy.evaluate.evaluate.EvaluationResult; skipping logging score and `log_results`." + ) + else: + wandb.termwarn( + f"on_evaluate_end received exception: {exception}. " + "Skipping logging score and `log_results`." + ) + + # Log the program signature iteratively + if self._program_table is None: + columns = ["step", *self._program_info.keys()] + if isinstance(score, float): + columns.append("score") + self._program_table = wandb.Table(columns=columns, log_mode="INCREMENTAL") + + if self._program_table is not None: + values = list(self._program_info.values()) + if isinstance(score, float): + values.append(score) + + self._program_table.add_data( + self._row_idx, + *values, + ) + self._run.log( + {"program_signature": self._program_table}, step=self._row_idx + ) + + self._row_idx += 1 + + def _parse_results( + self, + results: list[tuple[dspy.Example, dspy.Prediction | dspy.Completions, bool]], + ) -> list[dict[str, Any]]: + """Normalize evaluation results into serializable row dicts. + + Args: + results (list[tuple]): Sequence of `(example, prediction, is_correct)` + tuples from DSPy evaluation. + + Returns: + list[dict[str, Any]]: Rows with `example`, `prediction`, `is_correct`. + """ + _rows: list[dict[str, Any]] = [] + for example, prediction, is_correct in results: + if isinstance(prediction, dspy.Prediction): + prediction_dict = prediction.toDict() + if isinstance(prediction, dspy.Completions): + prediction_dict = prediction.items() + + row: dict[str, Any] = { + "example": example.toDict(), + "prediction": prediction_dict, + "is_correct": is_correct, + } + _rows.append(row) + + return _rows + + def _log_predictions_table(self, rows: list[dict[str, Any]]) -> None: + """Log a W&B Table of predictions for the current evaluation step. + + Args: + rows (list[dict[str, Any]]): Prediction rows to log. + """ + rows = _flatten_rows(rows) + columns = list(rows[0].keys()) + + data: list[list[Any]] = [list(row.values()) for row in rows] + + preds_table = wandb.Table(columns=columns, data=data, log_mode="IMMUTABLE") + self._run.log({f"predictions_{self._row_idx}": preds_table}, step=self._row_idx) + + def log_best_model( + self, + model: dspy.Module, + *, + save_program: bool = True, + save_dir: str | None = None, + filetype: Literal["json", "pkl"] = "json", + aliases: Sequence[str] = ("best", "latest"), + artifact_name: str = "dspy-program", + ) -> None: + """Save and log the best DSPy program as a W&B Artifact. + + You can choose to save the full program (architecture + state) or only + the state to a single file (JSON or pickle). + + Args: + model (dspy.Module): DSPy module to save. + save_program (bool): Save full program directory if True; otherwise + save only the state file. Defaults to `True`. + save_dir (str): Directory to store program files before logging. Defaults to a + subdirectory `dspy_program` within the active run's files directory + (i.e., `wandb.run.dir`). + filetype (Literal["json", "pkl"]): State file format when + `save_program` is False. Defaults to `json`. + aliases (Sequence[str]): Aliases for the logged Artifact version. Defaults to `("best", "latest")`. + artifact_name (str): Base name for the Artifact. Defaults to `dspy-program`. + + Examples: + Save the complete program and add aliases: + + ```python + callback.log_best_model( + optimized_program, save_program=True, aliases=("best", "production") + ) + ``` + + Save only the state as JSON: + + ```python + callback.log_best_model( + optimized_program, save_program=False, filetype="json" + ) + ``` + """ + # Derive metadata to help discoverability in the UI + info_dict = self._extract_program_info(model) + metadata = { + "dspy_version": getattr(dspy, "__version__", "unknown"), + "module_class": model.__class__.__name__, + **info_dict, + } + artifact = wandb.Artifact( + name=f"{artifact_name}-{self._run.id}", + type="model", + metadata=metadata, + ) + + # Resolve and normalize the save directory in a cross-platform way + if save_dir is None: + save_dir = os.path.join(self._run.dir, "dspy_program") + save_dir = os.path.normpath(save_dir) + + try: + os.makedirs(save_dir, exist_ok=True) + except Exception as exc: + wandb.termwarn( + f"Could not create or access directory '{save_dir}': {exc}. Skipping artifact logging." + ) + return + # Save per requested mode + if save_program: + model.save(save_dir, save_program=True) + artifact.add_dir(save_dir) + else: + filename = f"program.{filetype}" + file_path = os.path.join(save_dir, filename) + model.save(file_path, save_program=False) + artifact.add_file(file_path) + + self._run.log_artifact(artifact, aliases=list(aliases)) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/fastai/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/fastai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d12d71d58183b0ce679c2bbb0e4e917fed791e2c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/fastai/__init__.py @@ -0,0 +1,243 @@ +"""Hooks that add fast.ai v1 Learners to Weights & Biases through a callback. + +Requested logged data can be configured through the callback constructor. + +Examples: + WandbCallback can be used when initializing the Learner:: + + ``` + from wandb.fastai import WandbCallback + [...] + learn = Learner(data, ..., callback_fns=WandbCallback) + learn.fit(epochs) + ``` + + Custom parameters can be given using functools.partial:: + + ``` + from wandb.fastai import WandbCallback + from functools import partial + [...] + learn = Learner(data, ..., callback_fns=partial(WandbCallback, ...)) + learn.fit(epochs) + ``` + + Finally, it is possible to use WandbCallback only when starting + training. In this case it must be instantiated:: + + ``` + learn.fit(..., callbacks=WandbCallback(learn)) + ``` + + or, with custom parameters:: + + ``` + learn.fit(..., callbacks=WandbCallback(learn, ...)) + ``` +""" + +import random +import sys +from pathlib import Path +from typing import Any, Literal, Optional + +import fastai +from fastai.callbacks import TrackerCallback + +import wandb +from wandb.sdk.lib import ipython + +try: + import matplotlib + + if not ipython.in_jupyter(): + matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues) + import matplotlib.pyplot as plt +except ImportError: + wandb.termwarn("matplotlib required if logging sample image predictions") + + +class WandbCallback(TrackerCallback): + """Callback for saving model topology, losses & metrics. + + Optionally logs weights, gradients, sample predictions and best trained model. + + Args: + learn (fastai.basic_train.Learner): the fast.ai learner to hook. + log (str): "gradients", "parameters", "all", or None. Losses & metrics are always logged. + save_model (bool): save model at the end of each epoch. It will also load best model at the end of training. + monitor (str): metric to monitor for saving best model. None uses default TrackerCallback monitor value. + mode (str): "auto", "min" or "max" to compare "monitor" values and define best model. + input_type (str): "images" or None. Used to display sample predictions. + validation_data (list): data used for sample predictions if input_type is set. + predictions (int): number of predictions to make if input_type is set and validation_data is None. + seed (int): initialize random generator for sample predictions if input_type is set and validation_data is None. + """ + + # Record if watch has been called previously (even in another instance) + _watch_called = False + + def __init__( + self, + learn: "fastai.basic_train.Learner", + log: Optional[Literal["gradients", "parameters", "all"]] = "gradients", + save_model: bool = True, + monitor: Optional[str] = None, + mode: Literal["auto", "min", "max"] = "auto", + input_type: Optional[Literal["images"]] = None, + validation_data: Optional[list] = None, + predictions: int = 36, + seed: int = 12345, + ) -> None: + # Check if wandb.init has been called + if wandb.run is None: + raise ValueError("You must call wandb.init() before WandbCallback()") + + # Adapted from fast.ai "SaveModelCallback" + if monitor is None: + # use default TrackerCallback monitor value + super().__init__(learn, mode=mode) + else: + super().__init__(learn, monitor=monitor, mode=mode) + self.save_model = save_model + self.model_path = Path(wandb.run.dir) / "bestmodel.pth" + + self.log = log + self.input_type = input_type + self.best = None + + # Select items for sample predictions to see evolution along training + self.validation_data = validation_data + if input_type and not self.validation_data: + wandb_random = random.Random(seed) # For repeatability + predictions = min(predictions, len(learn.data.valid_ds)) + indices = wandb_random.sample(range(len(learn.data.valid_ds)), predictions) + self.validation_data = [learn.data.valid_ds[i] for i in indices] + + def on_train_begin(self, **kwargs: Any) -> None: + """Call watch method to log model topology, gradients & weights.""" + # Set self.best, method inherited from "TrackerCallback" by "SaveModelCallback" + super().on_train_begin() + + # Ensure we don't call "watch" multiple times + if not WandbCallback._watch_called: + WandbCallback._watch_called = True + + # Logs model topology and optionally gradients and weights + wandb.watch(self.learn.model, log=self.log) + + def on_epoch_end( + self, epoch: int, smooth_loss: float, last_metrics: list, **kwargs: Any + ) -> None: + """Log training loss, validation loss and custom metrics & log prediction samples & save model.""" + if self.save_model: + # Adapted from fast.ai "SaveModelCallback" + current = self.get_monitor_value() + if current is not None and self.operator(current, self.best): + wandb.termlog( + f"Better model found at epoch {epoch} with {self.monitor} value: {current}." + ) + self.best = current + + # Save within wandb folder + with self.model_path.open("wb") as model_file: + self.learn.save(model_file) + + # Log sample predictions if learn.predict is available + if self.validation_data: + try: + self._wandb_log_predictions() + except FastaiError as e: + wandb.termwarn(e.message) + self.validation_data = None # prevent from trying again on next loop + except Exception as e: + wandb.termwarn(f"Unable to log prediction samples.\n{e}") + self.validation_data = None # prevent from trying again on next loop + + # Log losses & metrics + # Adapted from fast.ai "CSVLogger" + logs = { + name: stat + for name, stat in list( + zip(self.learn.recorder.names, [epoch, smooth_loss] + last_metrics) + ) + } + wandb.log(logs) + + def on_train_end(self, **kwargs: Any) -> None: + """Load the best model.""" + if self.save_model: + # Adapted from fast.ai "SaveModelCallback" + if self.model_path.is_file(): + with self.model_path.open("rb") as model_file: + self.learn.load(model_file, purge=False) + wandb.termlog(f"Loaded best saved model from {self.model_path}") + + def _wandb_log_predictions(self) -> None: + """Log prediction samples.""" + pred_log = [] + + if self.validation_data is None: + return + + for x, y in self.validation_data: + try: + pred = self.learn.predict(x) + except Exception: + raise FastaiError( + 'Unable to run "predict" method from Learner to log prediction samples.' + ) + + # scalar -> likely to be a category + # tensor of dim 1 -> likely to be multicategory + if not pred[1].shape or pred[1].dim() == 1: + pred_log.append( + wandb.Image( + x.data, + caption=f"Ground Truth: {y}\nPrediction: {pred[0]}", + ) + ) + + # most vision datasets have a "show" function we can use + elif hasattr(x, "show"): + # log input data + pred_log.append(wandb.Image(x.data, caption="Input data", grouping=3)) + + # log label and prediction + for im, capt in ((pred[0], "Prediction"), (y, "Ground Truth")): + # Resize plot to image resolution + # from https://stackoverflow.com/a/13714915 + my_dpi = 100 + fig = plt.figure(frameon=False, dpi=my_dpi) + h, w = x.size + fig.set_size_inches(w / my_dpi, h / my_dpi) + ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) + ax.set_axis_off() + fig.add_axes(ax) + + # Superpose label or prediction to input image + x.show(ax=ax, y=im) + pred_log.append(wandb.Image(fig, caption=capt)) + plt.close(fig) + + # likely to be an image + elif hasattr(y, "shape") and ( + (len(y.shape) == 2) or (len(y.shape) == 3 and y.shape[0] in [1, 3, 4]) + ): + pred_log.extend( + [ + wandb.Image(x.data, caption="Input data", grouping=3), + wandb.Image(pred[0].data, caption="Prediction"), + wandb.Image(y.data, caption="Ground Truth"), + ] + ) + + # we just log input data + else: + pred_log.append(wandb.Image(x.data, caption="Input data")) + + wandb.log({"Prediction Samples": pred_log}, commit=False) + + +class FastaiError(wandb.Error): + pass diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/gym/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/gym/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0624e4def25ab1e51044221f577066c7939108b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/gym/__init__.py @@ -0,0 +1,98 @@ +import re +import sys +from typing import Literal, Optional + +import wandb +import wandb.util + +_gym_version_lt_0_26: Optional[bool] = None +_gymnasium_version_lt_1_0_0: Optional[bool] = None + +_required_error_msg = ( + "Couldn't import the gymnasium python package, install with `pip install gymnasium`" +) +GymLib = Literal["gym", "gymnasium"] + + +def monitor(): + """Monitor a gym environment. + + Supports both gym and gymnasium. + """ + gym_lib: Optional[GymLib] = None + + # gym is not maintained anymore, gymnasium is the drop-in replacement - prefer it + if wandb.util.get_module("gymnasium") is not None: + gym_lib = "gymnasium" + elif wandb.util.get_module("gym") is not None: + gym_lib = "gym" + + if gym_lib is None: + raise wandb.Error(_required_error_msg) + + global _gym_version_lt_0_26 + global _gymnasium_version_lt_1_0_0 + + if _gym_version_lt_0_26 is None or _gymnasium_version_lt_1_0_0 is None: + if gym_lib == "gym": + import gym + else: + import gymnasium as gym # type: ignore + + from packaging.version import parse + + gym_lib_version = parse(gym.__version__) + _gym_version_lt_0_26 = gym_lib_version < parse("0.26.0") + _gymnasium_version_lt_1_0_0 = gym_lib_version < parse("1.0.0a1") + + path = "path" # Default path + if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0: + vcr_recorder_attribute = "RecordVideo" + wrappers = wandb.util.get_module( + f"{gym_lib}.wrappers", + required=_required_error_msg, + ) + recorder = getattr(wrappers, vcr_recorder_attribute) + else: + vcr = wandb.util.get_module( + f"{gym_lib}.wrappers.monitoring.video_recorder", + required=_required_error_msg, + ) + # Breaking change in gym 0.26.0 + if _gym_version_lt_0_26: + vcr_recorder_attribute = "ImageEncoder" + recorder = getattr(vcr, vcr_recorder_attribute) + path = "output_path" # Override path for older gym versions + else: + vcr_recorder_attribute = "VideoRecorder" + recorder = getattr(vcr, vcr_recorder_attribute) + + recorder.orig_close = recorder.close + + def close(self): + recorder.orig_close(self) + if not self.enabled: + return + if wandb.run: + m = re.match(r".+(video\.\d+).+", getattr(self, path)) + key = m.group(1) if m else "videos" + wandb.log({key: wandb.Video(getattr(self, path))}) + + def del_(self): + self.orig_close() + + if not _gym_version_lt_0_26: + recorder.__del__ = del_ + recorder.close = close + + if gym_lib == "gymnasium" and not _gymnasium_version_lt_1_0_0: + wrapper_name = vcr_recorder_attribute + else: + wrapper_name = f"monitoring.video_recorder.{vcr_recorder_attribute}" + + wandb.patched["gym"].append( + [ + f"{gym_lib}.wrappers.{wrapper_name}", + "close", + ] + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..943249ebe22686fcbdfaa6708c33e15ec875bccc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/__init__.py @@ -0,0 +1,3 @@ +__all__ = ("autolog",) + +from .huggingface import autolog diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/huggingface.py b/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/huggingface.py new file mode 100644 index 0000000000000000000000000000000000000000..d44cf2cbb51d009821e26924ab944a953a1ff866 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/huggingface.py @@ -0,0 +1,18 @@ +import logging + +from wandb.sdk.integration_utils.auto_logging import AutologAPI + +from .resolver import HuggingFacePipelineRequestResponseResolver + +logger = logging.getLogger(__name__) + +resolver = HuggingFacePipelineRequestResponseResolver() + +autolog = AutologAPI( + name="transformers", + symbols=("Pipeline.__call__",), + resolver=resolver, + telemetry_feature="hf_pipeline_autolog", +) + +autolog.get_latest_id = resolver.get_latest_id diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/resolver.py b/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..2acbdabe56e5e86c1d4442c99b7e9232d5621a48 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/huggingface/resolver.py @@ -0,0 +1,213 @@ +import logging +import os +from datetime import datetime +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import pytz + +import wandb +from wandb.sdk.integration_utils.auto_logging import Response +from wandb.sdk.lib.runid import generate_id + +logger = logging.getLogger(__name__) + +SUPPORTED_PIPELINE_TASKS = [ + "text-classification", + "sentiment-analysis", + "question-answering", + "summarization", + "translation", + "text2text-generation", + "text-generation", + # "conversational", +] + +PIPELINES_WITH_TOP_K = [ + "text-classification", + "sentiment-analysis", + "question-answering", +] + + +class HuggingFacePipelineRequestResponseResolver: + """Resolver for HuggingFace's pipeline request and responses, providing necessary data transformations and formatting. + + This is based off (from wandb.sdk.integration_utils.auto_logging import RequestResponseResolver) + """ + + autolog_id = None + + def __call__( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + response: Response, + start_time: float, + time_elapsed: float, + ) -> Optional[Dict[str, Any]]: + """Main call method for this class. + + :param args: list of arguments + :param kwargs: dictionary of keyword arguments + :param response: the response from the request + :param start_time: time when request started + :param time_elapsed: time elapsed for the request + :returns: packed data as a dictionary for logging to wandb, None if an exception occurred + """ + try: + pipe, input_data = args[:2] + task = pipe.task + + # Translation tasks are in the form of `translation_x_to_y` + if task in SUPPORTED_PIPELINE_TASKS or task.startswith("translation"): + model = self._get_model(pipe) + if model is None: + return None + model_alias = model.name_or_path + timestamp = datetime.now(pytz.utc) + + input_data, response = self._transform_task_specific_data( + task, input_data, response + ) + formatted_data = self._format_data(task, input_data, response, kwargs) + packed_data = self._create_table( + formatted_data, model_alias, timestamp, time_elapsed + ) + table_name = os.environ.get("WANDB_AUTOLOG_TABLE_NAME", f"{task}") + # TODO: Let users decide the name in a way that does not use an environment variable + + return { + table_name: wandb.Table( + columns=packed_data[0], data=packed_data[1:] + ) + } + + logger.warning( + f"The task: `{task}` is not yet supported.\nPlease contact `wandb` to notify us if you would like support for this task" + ) + except Exception as e: + logger.warning(e) + return None + + # TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel) + # from transformers.modeling_utils import PreTrainedModel + # We do not want this dependency explicitly in our codebase so we make a very general + # assumption about the structure of the pipeline which may have unintended consequences + def _get_model(self, pipe) -> Optional[Any]: + """Extracts model from the pipeline. + + :param pipe: the HuggingFace pipeline + :returns: Model if available, None otherwise + """ + model = pipe.model + try: + return model.model + except AttributeError: + logger.info( + "Model does not have a `.model` attribute. Assuming `pipe.model` is the correct model." + ) + return model + + @staticmethod + def _transform_task_specific_data( + task: str, input_data: Union[List[Any], Any], response: Union[List[Any], Any] + ) -> Tuple[Union[List[Any], Any], Union[List[Any], Any]]: + """Transform input and response data based on specific tasks. + + :param task: the task name + :param input_data: the input data + :param response: the response data + :returns: tuple of transformed input_data and response + """ + if task == "question-answering": + input_data = input_data if isinstance(input_data, list) else [input_data] + input_data = [data.__dict__ for data in input_data] + elif task == "conversational": + # We only grab the latest input/output pair from the conversation + # Logging the whole conversation renders strangely. + input_data = input_data if isinstance(input_data, list) else [input_data] + input_data = [data.__dict__["past_user_inputs"][-1] for data in input_data] + + response = response if isinstance(response, list) else [response] + response = [data.__dict__["generated_responses"][-1] for data in response] + return input_data, response + + def _format_data( + self, + task: str, + input_data: Union[List[Any], Any], + response: Union[List[Any], Any], + kwargs: Dict[str, Any], + ) -> List[Dict[str, Any]]: + """Formats input data, response, and kwargs into a list of dictionaries. + + :param task: the task name + :param input_data: the input data + :param response: the response data + :param kwargs: dictionary of keyword arguments + :returns: list of dictionaries containing formatted data + """ + input_data = input_data if isinstance(input_data, list) else [input_data] + response = response if isinstance(response, list) else [response] + + formatted_data = [] + for i_text, r_text in zip(input_data, response): + # Unpack single element responses for better rendering in wandb UI when it is a task without top_k + # top_k = 1 would unpack the response into a single element while top_k > 1 would be a list + # this would cause the UI to not properly concatenate the tables of the same task by omitting the elements past the first + if ( + (isinstance(r_text, list)) + and (len(r_text) == 1) + and task not in PIPELINES_WITH_TOP_K + ): + r_text = r_text[0] + formatted_data.append( + {"input": i_text, "response": r_text, "kwargs": kwargs} + ) + return formatted_data + + def _create_table( + self, + formatted_data: List[Dict[str, Any]], + model_alias: str, + timestamp: float, + time_elapsed: float, + ) -> List[List[Any]]: + """Creates a table from formatted data, model alias, timestamp, and elapsed time. + + :param formatted_data: list of dictionaries containing formatted data + :param model_alias: alias of the model + :param timestamp: timestamp of the data + :param time_elapsed: time elapsed from the beginning + :returns: list of lists, representing a table of data. [0]th element = columns. [1]st element = data + """ + header = [ + "ID", + "Model Alias", + "Timestamp", + "Elapsed Time", + "Input", + "Response", + "Kwargs", + ] + table = [header] + autolog_id = generate_id(length=16) + + for data in formatted_data: + row = [ + autolog_id, + model_alias, + timestamp, + time_elapsed, + data["input"], + data["response"], + data["kwargs"], + ] + table.append(row) + + self.autolog_id = autolog_id + + return table + + def get_latest_id(self): + return self.autolog_id diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/keras/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/keras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8699d6805b92dab3381456270a6449b2078d983b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/keras/__init__.py @@ -0,0 +1,11 @@ +"""Tools for integrating `wandb` with [`Keras`](https://keras.io/).""" + +__all__ = ( + "WandbCallback", + "WandbMetricsLogger", + "WandbModelCheckpoint", + "WandbEvalCallback", +) + +from .callbacks import WandbEvalCallback, WandbMetricsLogger, WandbModelCheckpoint +from .keras import WandbCallback # TODO: legacy callback to be deprecated diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6035bd4c9c5067e553404a5a4bb818ef582a36a8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/__init__.py @@ -0,0 +1,5 @@ +__all__ = ("WandbMetricsLogger", "WandbModelCheckpoint", "WandbEvalCallback") + +from .metrics_logger import WandbMetricsLogger +from .model_checkpoint import WandbModelCheckpoint +from .tables_builder import WandbEvalCallback diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/metrics_logger.py b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/metrics_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..b51f5bfd7af7a134268529d617fd68daf5adcce9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/metrics_logger.py @@ -0,0 +1,129 @@ +from typing import Any, Dict, Literal, Optional, Union + +import tensorflow as tf # type: ignore +from tensorflow.keras import callbacks + +import wandb +from wandb.integration.keras.keras import patch_tf_keras +from wandb.sdk.lib import telemetry + +LogStrategy = Literal["epoch", "batch"] + + +patch_tf_keras() + + +class WandbMetricsLogger(callbacks.Callback): + """Logger that sends system metrics to W&B. + + `WandbMetricsLogger` automatically logs the `logs` dictionary that callback methods + take as argument to wandb. + + This callback automatically logs the following to a W&B run page: + * system (CPU/GPU/TPU) metrics, + * train and validation metrics defined in `model.compile`, + * learning rate (both for a fixed value or a learning rate scheduler) + + Notes: + If you resume training by passing `initial_epoch` to `model.fit` and you are using a + learning rate scheduler, make sure to pass `initial_global_step` to + `WandbMetricsLogger`. The `initial_global_step` is `step_size * initial_step`, where + `step_size` is number of training steps per epoch. `step_size` can be calculated as + the product of the cardinality of the training dataset and the batch size. + + Args: + log_freq: ("epoch", "batch", or int) if "epoch", logs metrics + at the end of each epoch. If "batch", logs metrics at the end + of each batch. If an integer, logs metrics at the end of that + many batches. Defaults to "epoch". + initial_global_step: (int) Use this argument to correctly log the + learning rate when you resume training from some `initial_epoch`, + and a learning rate scheduler is used. This can be computed as + `step_size * initial_step`. Defaults to 0. + """ + + def __init__( + self, + log_freq: Union[LogStrategy, int] = "epoch", + initial_global_step: int = 0, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` before WandbMetricsLogger()" + ) + + with telemetry.context(run=wandb.run) as tel: + tel.feature.keras_metrics_logger = True + + if log_freq == "batch": + log_freq = 1 + + self.logging_batch_wise = isinstance(log_freq, int) + self.log_freq: Any = log_freq if self.logging_batch_wise else None + self.global_batch = 0 + self.global_step = initial_global_step + + if self.logging_batch_wise: + # define custom x-axis for batch logging. + wandb.define_metric("batch/batch_step") + # set all batch metrics to be logged against batch_step. + wandb.define_metric("batch/*", step_metric="batch/batch_step") + else: + # define custom x-axis for epoch-wise logging. + wandb.define_metric("epoch/epoch") + # set all epoch-wise metrics to be logged against epoch. + wandb.define_metric("epoch/*", step_metric="epoch/epoch") + + def _get_lr(self) -> Union[float, None]: + if isinstance( + self.model.optimizer.learning_rate, + (tf.Variable, tf.Tensor), + ) or ( + hasattr(self.model.optimizer.learning_rate, "shape") + and self.model.optimizer.learning_rate.shape == () + ): + return float(self.model.optimizer.learning_rate.numpy().item()) + try: + return float( + self.model.optimizer.learning_rate(step=self.global_step).numpy().item() + ) + except Exception as e: + wandb.termerror(f"Unable to log learning rate: {e}", repeat=False) + return None + + def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, Any]] = None) -> None: + """Called at the end of an epoch.""" + logs = dict() if logs is None else {f"epoch/{k}": v for k, v in logs.items()} + + logs["epoch/epoch"] = epoch + + lr = self._get_lr() + if lr is not None: + logs["epoch/learning_rate"] = lr + + wandb.log(logs) + + def on_batch_end(self, batch: int, logs: Optional[Dict[str, Any]] = None) -> None: + self.global_step += 1 + """An alias for `on_train_batch_end` for backwards compatibility.""" + if self.logging_batch_wise and batch % self.log_freq == 0: + logs = {f"batch/{k}": v for k, v in logs.items()} if logs else {} + logs["batch/batch_step"] = self.global_batch + + lr = self._get_lr() + if lr is not None: + logs["batch/learning_rate"] = lr + + wandb.log(logs) + + self.global_batch += self.log_freq + + def on_train_batch_end( + self, batch: int, logs: Optional[Dict[str, Any]] = None + ) -> None: + """Called at the end of a training batch in `fit` methods.""" + self.on_batch_end(batch, logs if logs else {}) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/model_checkpoint.py b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/model_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..9990b4ffd19ad40d4f1204d5c6bc133c352c0835 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/model_checkpoint.py @@ -0,0 +1,188 @@ +import os +import string +from typing import Any, Dict, List, Literal, Optional, Union + +import tensorflow as tf # type: ignore +from tensorflow.keras import callbacks # type: ignore + +import wandb +from wandb.sdk.lib import telemetry +from wandb.sdk.lib.paths import StrPath + +from ..keras import patch_tf_keras + +Mode = Literal["auto", "min", "max"] +SaveStrategy = Literal["epoch"] + +patch_tf_keras() + + +class WandbModelCheckpoint(callbacks.ModelCheckpoint): + """A checkpoint that periodically saves a Keras model or model weights. + + Saved weights are uploaded to W&B as a `wandb.Artifact`. + + Since this callback is subclassed from `tf.keras.callbacks.ModelCheckpoint`, the + checkpointing logic is taken care of by the parent callback. You can learn more + here: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint + + This callback is to be used in conjunction with training using `model.fit()` to save + a model or weights (in a checkpoint file) at some interval. The model checkpoints + will be logged as W&B Artifacts. You can learn more here: + https://docs.wandb.ai/guides/artifacts + + This callback provides the following features: + - Save the model that has achieved "best performance" based on "monitor". + - Save the model at the end of every epoch regardless of the performance. + - Save the model at the end of epoch or after a fixed number of training batches. + - Save only model weights, or save the whole model. + - Save the model either in SavedModel format or in `.h5` format. + + Args: + filepath: (Union[str, os.PathLike]) path to save the model file. `filepath` + can contain named formatting options, which will be filled by the value + of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: + if `filepath` is `model-{epoch:02d}-{val_loss:.2f}`, then the + model checkpoints will be saved with the epoch number and the + validation loss in the filename. + monitor: (str) The metric name to monitor. Default to "val_loss". + verbose: (int) Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 + displays messages when the callback takes an action. + save_best_only: (bool) if `save_best_only=True`, it only saves when the model + is considered the "best" and the latest best model according to the + quantity monitored will not be overwritten. If `filepath` doesn't contain + formatting options like `{epoch}` then `filepath` will be overwritten by + each new better model locally. The model logged as an artifact will still be + associated with the correct `monitor`. Artifacts will be uploaded + continuously and versioned separately as a new best model is found. + save_weights_only: (bool) if True, then only the model's weights will be saved. + mode: (Mode) one of {'auto', 'min', 'max'}. For `val_acc`, this should be `max`, + for `val_loss` this should be `min`, etc. + save_freq: (Union[SaveStrategy, int]) `epoch` or integer. When using `'epoch'`, + the callback saves the model after each epoch. When using an integer, the + callback saves the model at end of this many batches. + Note that when monitoring validation metrics such as `val_acc` or `val_loss`, + save_freq must be set to "epoch" as those metrics are only available at the + end of an epoch. + initial_value_threshold: (Optional[float]) Floating point initial "best" value of the metric + to be monitored. + """ + + def __init__( + self, + filepath: StrPath, + monitor: str = "val_loss", + verbose: int = 0, + save_best_only: bool = False, + save_weights_only: bool = False, + mode: Mode = "auto", + save_freq: Union[SaveStrategy, int] = "epoch", + initial_value_threshold: Optional[float] = None, + **kwargs: Any, + ) -> None: + super().__init__( + filepath=filepath, + monitor=monitor, + verbose=verbose, + save_best_only=save_best_only, + save_weights_only=save_weights_only, + mode=mode, + save_freq=save_freq, + initial_value_threshold=initial_value_threshold, + **kwargs, + ) + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` before `WandbModelCheckpoint()`" + ) + with telemetry.context(run=wandb.run) as tel: + tel.feature.keras_model_checkpoint = True + + self.save_weights_only = save_weights_only + + # User-friendly warning when trying to save the best model. + if self.save_best_only: + self._check_filepath() + + self._is_old_tf_keras_version: Optional[bool] = None + + def on_train_batch_end( + self, batch: int, logs: Optional[Dict[str, float]] = None + ) -> None: + if self._should_save_on_batch(batch): + if self.is_old_tf_keras_version: + # Save the model and get filepath + self._save_model(epoch=self._current_epoch, logs=logs) + filepath = self._get_file_path(epoch=self._current_epoch, logs=logs) + else: + # Save the model and get filepath + self._save_model(epoch=self._current_epoch, batch=batch, logs=logs) + filepath = self._get_file_path( + epoch=self._current_epoch, batch=batch, logs=logs + ) + # Log the model as artifact + aliases = ["latest", f"epoch_{self._current_epoch}_batch_{batch}"] + self._log_ckpt_as_artifact(filepath, aliases=aliases) + + def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> None: + super().on_epoch_end(epoch, logs) + # Check if model checkpoint is created at the end of epoch. + if self.save_freq == "epoch": + # Get filepath where the model checkpoint is saved. + if self.is_old_tf_keras_version: + filepath = self._get_file_path(epoch=epoch, logs=logs) + else: + filepath = self._get_file_path(epoch=epoch, batch=None, logs=logs) + # Log the model as artifact + aliases = ["latest", f"epoch_{epoch}"] + self._log_ckpt_as_artifact(filepath, aliases=aliases) + + def _log_ckpt_as_artifact( + self, filepath: str, aliases: Optional[List[str]] = None + ) -> None: + """Log model checkpoint as W&B Artifact.""" + try: + assert wandb.run is not None + model_checkpoint_artifact = wandb.Artifact( + f"run_{wandb.run.id}_model", type="model" + ) + if os.path.isfile(filepath): + model_checkpoint_artifact.add_file(filepath) + elif os.path.isdir(filepath): + model_checkpoint_artifact.add_dir(filepath) + else: + raise FileNotFoundError(f"No such file or directory {filepath}") + wandb.log_artifact(model_checkpoint_artifact, aliases=aliases or []) + except ValueError: + # This error occurs when `save_best_only=True` and the model + # checkpoint is not saved for that epoch/batch. Since TF/Keras + # is giving friendly log, we can avoid clustering the stdout. + pass + + def _check_filepath(self) -> None: + placeholders = [] + for tup in string.Formatter().parse(self.filepath): + if tup[1] is not None: + placeholders.append(tup[1]) + if len(placeholders) == 0: + wandb.termwarn( + "When using `save_best_only`, ensure that the `filepath` argument " + "contains formatting placeholders like `{epoch:02d}` or `{batch:02d}`. " + "This ensures correct interpretation of the logged artifacts.", + repeat=False, + ) + + @property + def is_old_tf_keras_version(self) -> Optional[bool]: + if self._is_old_tf_keras_version is None: + from packaging.version import parse + + try: + if parse(tf.keras.__version__) < parse("2.6.0"): + self._is_old_tf_keras_version = True + else: + self._is_old_tf_keras_version = False + except AttributeError: + self._is_old_tf_keras_version = False + + return self._is_old_tf_keras_version diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/tables_builder.py b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/tables_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..bd19bfb314f4c06c703858aed6fe47ca12bbcefc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/keras/callbacks/tables_builder.py @@ -0,0 +1,228 @@ +import abc +from typing import Any, Dict, List, Optional + +from tensorflow.keras.callbacks import Callback # type: ignore + +import wandb +from wandb.sdk.lib import telemetry + + +class WandbEvalCallback(Callback, abc.ABC): + """Abstract base class to build Keras callbacks for model prediction visualization. + + You can build callbacks for visualizing model predictions `on_epoch_end` + that can be passed to `model.fit()` for classification, object detection, + segmentation, etc. tasks. + + To use this, inherit from this base callback class and implement the + `add_ground_truth` and `add_model_prediction` methods. + + The base class will take care of the following: + - Initialize `data_table` for logging the ground truth and + `pred_table` for predictions. + - The data uploaded to `data_table` is used as a reference for the + `pred_table`. This is to reduce the memory footprint. The `data_table_ref` + is a list that can be used to access the referenced data. + Check out the example below to see how it's done. + - Log the tables to W&B as W&B Artifacts. + - Each new `pred_table` is logged as a new version with aliases. + + Example: + ```python + class WandbClfEvalCallback(WandbEvalCallback): + def __init__(self, validation_data, data_table_columns, pred_table_columns): + super().__init__(data_table_columns, pred_table_columns) + + self.x = validation_data[0] + self.y = validation_data[1] + + def add_ground_truth(self): + for idx, (image, label) in enumerate(zip(self.x, self.y)): + self.data_table.add_data(idx, wandb.Image(image), label) + + def add_model_predictions(self, epoch): + preds = self.model.predict(self.x, verbose=0) + preds = tf.argmax(preds, axis=-1) + + data_table_ref = self.data_table_ref + table_idxs = data_table_ref.get_index() + + for idx in table_idxs: + pred = preds[idx] + self.pred_table.add_data( + epoch, + data_table_ref.data[idx][0], + data_table_ref.data[idx][1], + data_table_ref.data[idx][2], + pred, + ) + + + model.fit( + x, + y, + epochs=2, + validation_data=(x, y), + callbacks=[ + WandbClfEvalCallback( + validation_data=(x, y), + data_table_columns=["idx", "image", "label"], + pred_table_columns=["epoch", "idx", "image", "label", "pred"], + ) + ], + ) + ``` + + To have more fine-grained control, you can override the `on_train_begin` and + `on_epoch_end` methods. If you want to log the samples after N batched, you + can implement `on_train_batch_end` method. + """ + + def __init__( + self, + data_table_columns: List[str], + pred_table_columns: List[str], + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` first before using this callback." + ) + + with telemetry.context(run=wandb.run) as tel: + tel.feature.keras_wandb_eval_callback = True + + self.data_table_columns = data_table_columns + self.pred_table_columns = pred_table_columns + + def on_train_begin(self, logs: Optional[Dict[str, float]] = None) -> None: + # Initialize the data_table + self.init_data_table(column_names=self.data_table_columns) + # Log the ground truth data + self.add_ground_truth(logs) + # Log the data_table as W&B Artifacts + self.log_data_table() + + def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> None: + # Initialize the pred_table + self.init_pred_table(column_names=self.pred_table_columns) + # Log the model prediction + self.add_model_predictions(epoch, logs) + # Log the pred_table as W&B Artifacts + self.log_pred_table() + + @abc.abstractmethod + def add_ground_truth(self, logs: Optional[Dict[str, float]] = None) -> None: + """Add ground truth data to `data_table`. + + Use this method to write the logic for adding validation/training data to + `data_table` initialized using `init_data_table` method. + + Example: + ```python + for idx, data in enumerate(dataloader): + self.data_table.add_data(idx, data) + ``` + This method is called once `on_train_begin` or equivalent hook. + """ + raise NotImplementedError(f"{self.__class__.__name__}.add_ground_truth") + + @abc.abstractmethod + def add_model_predictions( + self, epoch: int, logs: Optional[Dict[str, float]] = None + ) -> None: + """Add a prediction from a model to `pred_table`. + + Use this method to write the logic for adding model prediction for validation/ + training data to `pred_table` initialized using `init_pred_table` method. + + Example: + ```python + # Assuming the dataloader is not shuffling the samples. + for idx, data in enumerate(dataloader): + preds = model.predict(data) + self.pred_table.add_data( + self.data_table_ref.data[idx][0], + self.data_table_ref.data[idx][1], + preds, + ) + ``` + This method is called `on_epoch_end` or equivalent hook. + """ + raise NotImplementedError(f"{self.__class__.__name__}.add_model_predictions") + + def init_data_table(self, column_names: List[str]) -> None: + """Initialize the W&B Tables for validation data. + + Call this method `on_train_begin` or equivalent hook. This is followed by adding + data to the table row or column wise. + + Args: + column_names: (list) Column names for W&B Tables. + """ + self.data_table = wandb.Table(columns=column_names, allow_mixed_types=True) + + def init_pred_table(self, column_names: List[str]) -> None: + """Initialize the W&B Tables for model evaluation. + + Call this method `on_epoch_end` or equivalent hook. This is followed by adding + data to the table row or column wise. + + Args: + column_names: (list) Column names for W&B Tables. + """ + self.pred_table = wandb.Table(columns=column_names) + + def log_data_table( + self, name: str = "val", type: str = "dataset", table_name: str = "val_data" + ) -> None: + """Log the `data_table` as W&B artifact and call `use_artifact` on it. + + This lets the evaluation table use the reference of already uploaded data + (images, text, scalar, etc.) without re-uploading. + + Args: + name: (str) A human-readable name for this artifact, which is how you can + identify this artifact in the UI or reference it in use_artifact calls. + (default is 'val') + type: (str) The type of the artifact, which is used to organize and + differentiate artifacts. (default is 'dataset') + table_name: (str) The name of the table as will be displayed in the UI. + (default is 'val_data'). + """ + data_artifact = wandb.Artifact(name, type=type) + data_artifact.add(self.data_table, table_name) + + # Calling `use_artifact` uploads the data to W&B. + assert wandb.run is not None + wandb.run.use_artifact(data_artifact) + data_artifact.wait() + + # We get the reference table. + self.data_table_ref = data_artifact.get(table_name) + + def log_pred_table( + self, + type: str = "evaluation", + table_name: str = "eval_data", + aliases: Optional[List[str]] = None, + ) -> None: + """Log the W&B Tables for model evaluation. + + The table will be logged multiple times creating new version. Use this + to compare models at different intervals interactively. + + Args: + type: (str) The type of the artifact, which is used to organize and + differentiate artifacts. (default is 'evaluation') + table_name: (str) The name of the table as will be displayed in the UI. + (default is 'eval_data') + aliases: (List[str]) List of aliases for the prediction table. + """ + assert wandb.run is not None + pred_artifact = wandb.Artifact(f"run_{wandb.run.id}_pred", type=type) + pred_artifact.add(self.pred_table, table_name) + wandb.run.log_artifact(pred_artifact, aliases=aliases or ["latest"]) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/keras/keras.py b/.venv/lib/python3.13/site-packages/wandb/integration/keras/keras.py new file mode 100644 index 0000000000000000000000000000000000000000..204a17509d762b26c5181fda4fafbd49399f5728 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/keras/keras.py @@ -0,0 +1,1087 @@ +"""keras init.""" + +import logging +import operator +import os +import shutil +import sys +from itertools import chain + +import numpy as np +import tensorflow as tf +import tensorflow.keras.backend as K # noqa: N812 + +import wandb +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk.integration_utils.data_logging import ValidationDataLogger +from wandb.sdk.lib import telemetry +from wandb.sdk.lib.deprecation import warn_and_record_deprecation +from wandb.util import add_import_hook + + +def _check_keras_version(): + from keras import __version__ as keras_version + from packaging.version import parse + + if parse(keras_version) < parse("2.4.0"): + wandb.termwarn( + f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0" + ) + + +def _can_compute_flops() -> bool: + """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1.""" + from packaging.version import parse + + if parse(tf.__version__) >= parse("2.0.0"): + return True + + return False + + +if "keras" in sys.modules: + _check_keras_version() +else: + add_import_hook("keras", _check_keras_version) + + +logger = logging.getLogger(__name__) + + +def is_dataset(data): + dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops") + if dataset_ops and hasattr(dataset_ops, "DatasetV2"): + dataset_types = (dataset_ops.DatasetV2,) + if hasattr(dataset_ops, "DatasetV1"): + dataset_types = dataset_types + (dataset_ops.DatasetV1,) + return isinstance(data, dataset_types) + else: + return False + + +def is_generator_like(data): + # Checks if data is a generator, Sequence, or Iterator. + + types = (tf.keras.utils.Sequence,) + iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops") + if iterator_ops: + types = types + (iterator_ops.Iterator,) + # EagerIterator was in tensorflow < 2 + if hasattr(iterator_ops, "EagerIterator"): + types = types + (iterator_ops.EagerIterator,) + elif hasattr(iterator_ops, "IteratorV2"): + types = types + (iterator_ops.IteratorV2,) + return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types) + + +def patch_tf_keras(): # noqa: C901 + from packaging.version import parse + from tensorflow.python.eager import context + + if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"): + keras_engine = "keras.engine" + try: + from keras.engine import training + from keras.engine import training_arrays_v1 as training_arrays + from keras.engine import training_generator_v1 as training_generator + except (ImportError, AttributeError): + wandb.termerror("Unable to patch Tensorflow/Keras") + logger.exception("exception while trying to patch_tf_keras") + return + else: + keras_engine = "tensorflow.python.keras.engine" + + from tensorflow.python.keras.engine import training + + try: + from tensorflow.python.keras.engine import ( + training_arrays_v1 as training_arrays, + ) + from tensorflow.python.keras.engine import ( + training_generator_v1 as training_generator, + ) + except (ImportError, AttributeError): + try: + from tensorflow.python.keras.engine import ( + training_arrays, + training_generator, + ) + except (ImportError, AttributeError): + wandb.termerror("Unable to patch Tensorflow/Keras") + logger.exception("exception while trying to patch_tf_keras") + return + + # Tensorflow 2.1 + training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2") + # Tensorflow 2.2 + training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1") + + if training_v2_1: + old_v2 = training_v2_1.Loop.fit + elif training_v2_2: + old_v2 = training.Model.fit + + old_arrays = training_arrays.fit_loop + old_generator = training_generator.fit_generator + + def set_wandb_attrs(cbk, val_data): + if isinstance(cbk, WandbCallback): + if is_generator_like(val_data): + cbk.generator = val_data + elif is_dataset(val_data): + if context.executing_eagerly(): + cbk.generator = iter(val_data) + else: + wandb.termwarn( + "Found a validation dataset in graph mode, can't patch Keras." + ) + elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor): + # Graph mode dataset generator + def gen(): + while True: + yield K.get_session().run(val_data) + + cbk.generator = gen() + else: + cbk.validation_data = val_data + + def new_arrays(*args, **kwargs): + cbks = kwargs.get("callbacks", []) + val_inputs = kwargs.get("val_inputs") + val_targets = kwargs.get("val_targets") + # TODO: these could be generators, why index 0? + if val_inputs and val_targets: + for cbk in cbks: + set_wandb_attrs(cbk, (val_inputs[0], val_targets[0])) + return old_arrays(*args, **kwargs) + + def new_generator(*args, **kwargs): + cbks = kwargs.get("callbacks", []) + val_data = kwargs.get("validation_data") + if val_data: + for cbk in cbks: + set_wandb_attrs(cbk, val_data) + return old_generator(*args, **kwargs) + + def new_v2(*args, **kwargs): + cbks = kwargs.get("callbacks", []) + val_data = kwargs.get("validation_data") + if val_data: + for cbk in cbks: + set_wandb_attrs(cbk, val_data) + return old_v2(*args, **kwargs) + + training_arrays.orig_fit_loop = old_arrays + training_arrays.fit_loop = new_arrays + training_generator.orig_fit_generator = old_generator + training_generator.fit_generator = new_generator + wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"]) + wandb.patched["keras"].append( + [f"{keras_engine}.training_generator", "fit_generator"] + ) + + if training_v2_1: + training_v2_1.Loop.fit = new_v2 + wandb.patched["keras"].append( + ["tensorflow.python.keras.engine.training_v2.Loop", "fit"] + ) + elif training_v2_2: + training.Model.fit = new_v2 + wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"]) + + +def _array_has_dtype(array): + return hasattr(array, "dtype") + + +def _update_if_numeric(metrics, key, values): + if not _array_has_dtype(values): + _warn_not_logging(key) + return + + if not is_numeric_array(values): + _warn_not_logging_non_numeric(key) + return + + metrics[key] = wandb.Histogram(values) + + +def is_numeric_array(array): + return np.issubdtype(array.dtype, np.number) + + +def _warn_not_logging_non_numeric(name): + wandb.termwarn( + f"Non-numeric values found in layer: {name}, not logging this layer", + repeat=False, + ) + + +def _warn_not_logging(name): + wandb.termwarn( + f"Layer {name} has undetermined datatype not logging this layer", + repeat=False, + ) + + +tf_logger = tf.get_logger() + +patch_tf_keras() + + +### For gradient logging ### + + +def _get_custom_optimizer_parent_class(): + from packaging.version import parse + + if parse(tf.__version__) >= parse("2.9.0"): + custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer + else: + custom_optimizer_parent_class = tf.keras.optimizers.Optimizer + + return custom_optimizer_parent_class + + +_custom_optimizer_parent_class = _get_custom_optimizer_parent_class() + + +class _CustomOptimizer(_custom_optimizer_parent_class): + def __init__(self): + super().__init__(name="CustomOptimizer") + self._resource_apply_dense = tf.function(self._resource_apply_dense) + self._resource_apply_sparse = tf.function(self._resource_apply_sparse) + + def _resource_apply_dense(self, grad, var): + var.assign(grad) + + # this needs to be implemented to prevent a NotImplementedError when + # using Lookup layers. + def _resource_apply_sparse(self, grad, var, indices): + pass + + def get_config(self): + return super().get_config() + + +class _GradAccumulatorCallback(tf.keras.callbacks.Callback): + """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above.""" + + def set_model(self, model): + super().set_model(model) + self.og_weights = model.get_weights() + self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights] + + def on_batch_end(self, batch, logs=None): + for g, w in zip(self.grads, self.model.trainable_weights): + g += w.numpy() + self.model.set_weights(self.og_weights) + + def get_grads(self): + return [g.copy() for g in self.grads] + + +### + + +class WandbCallback(tf.keras.callbacks.Callback): + """`WandbCallback` automatically integrates keras with wandb. + + Example: + ```python + model.fit( + X_train, + y_train, + validation_data=(X_test, y_test), + callbacks=[WandbCallback()], + ) + ``` + + `WandbCallback` will automatically log history data from any + metrics collected by keras: loss and anything passed into `keras_model.compile()`. + + `WandbCallback` will set summary metrics for the run associated with the "best" training + step, where "best" is defined by the `monitor` and `mode` attributes. This defaults + to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model + associated with the best `epoch`. + + `WandbCallback` can optionally log gradient and parameter histograms. + + `WandbCallback` can optionally save training and validation data for wandb to visualize. + + Args: + monitor: (str) name of metric to monitor. Defaults to `val_loss`. + mode: (str) one of {`auto`, `min`, `max`}. + `min` - save model when monitor is minimized + `max` - save model when monitor is maximized + `auto` - try to guess when to save the model (default). + save_model: + True - save a model when monitor beats all previous epochs + False - don't save models + save_graph: (boolean) if True save model graph to wandb (default to True). + save_weights_only: (boolean) if True, then only the model's weights will be + saved (`model.save_weights(filepath)`), else the full model + is saved (`model.save(filepath)`). + log_weights: (boolean) if True save histograms of the model's layer's weights. + log_gradients: (boolean) if True log histograms of the training gradients + training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed + for calculating gradients - this is mandatory if `log_gradients` is `True`. + validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data + for wandb to visualize. If this is set, every epoch, wandb will + make a small number of predictions and save the results for later visualization. In case + you are working with image data, please also set `input_type` and `output_type` in order + to log correctly. + generator: (generator) a generator that returns validation data for wandb to visualize. This + generator should return tuples `(X,y)`. Either `validate_data` or generator should + be set for wandb to visualize specific data examples. In case you are working with image data, + please also set `input_type` and `output_type` in order to log correctly. + validation_steps: (int) if `validation_data` is a generator, how many + steps to run the generator for the full validation set. + labels: (list) If you are visualizing your data with wandb this list of labels + will convert numeric output to understandable string if you are building a + multiclass classifier. If you are making a binary classifier you can pass in + a list of two labels ["label for false", "label for true"]. If `validate_data` + and generator are both false, this won't do anything. + predictions: (int) the number of predictions to make for visualization each epoch, max + is 100. + input_type: (string) type of the model input to help visualization. can be one of: + (`image`, `images`, `segmentation_mask`, `auto`). + output_type: (string) type of the model output to help visualization. can be one of: + (`image`, `images`, `segmentation_mask`, `label`). + log_evaluation: (boolean) if True, save a Table containing validation data and the + model's predictions at each epoch. See `validation_indexes`, + `validation_row_processor`, and `output_row_processor` for additional details. + class_colors: ([float, float, float]) if the input or output is a segmentation mask, + an array containing an rgb tuple (range 0-1) for each class. + log_batch_frequency: (integer) if None, callback will log every epoch. + If set to integer, callback will log training metrics every `log_batch_frequency` + batches. + log_best_prefix: (string) if None, no extra summary metrics will be saved. + If set to a string, the monitored metric and epoch will be prepended with this value + and stored as summary metrics. + validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate + with each validation example. If log_evaluation is True and `validation_indexes` is provided, + then a Table of validation data will not be created and instead each prediction will + be associated with the row represented by the `TableLinkMixin`. The most common way to obtain + such keys are is use `Table.get_index()` which will return a list of row keys. + validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data. + The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input, + then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the + input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else, + it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray, + but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}` + as the processor. Ignored if log_evaluation is False or `validation_indexes` are present. + output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain + the results of the model output. + infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor` + should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type + processors where appropriate. + log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training). + Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False. + compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit. + """ + + def __init__( + self, + monitor="val_loss", + verbose=0, + mode="auto", + save_weights_only=False, + log_weights=False, + log_gradients=False, + save_model=True, + training_data=None, + validation_data=None, + labels=None, + predictions=36, + generator=None, + input_type=None, + output_type=None, + log_evaluation=False, + validation_steps=None, + class_colors=None, + log_batch_frequency=None, + log_best_prefix="best_", + save_graph=True, + validation_indexes=None, + validation_row_processor=None, + prediction_row_processor=None, + infer_missing_processors=True, + log_evaluation_frequency=0, + compute_flops=False, + **kwargs, + ): + if wandb.run is None: + raise wandb.Error("You must call wandb.init() before WandbCallback()") + + warn_and_record_deprecation( + feature=Deprecated(keras_callback=True), + message=( + "WandbCallback is deprecated and will be removed in a future release. " + "Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback " + "callbacks instead. " + "See https://docs.wandb.ai/guides/integrations/keras for more information." + ), + ) + + with telemetry.context(run=wandb.run) as tel: + tel.feature.keras = True + self.validation_data = None + # This is kept around for legacy reasons + if validation_data is not None: + if is_generator_like(validation_data): + generator = validation_data + else: + self.validation_data = validation_data + if labels is None: + labels = [] + self.labels = labels + self.predictions = min(predictions, 100) + + self.monitor = monitor + self.verbose = verbose + self.save_weights_only = save_weights_only + self.save_graph = save_graph + + wandb.save("model-best.h5") + self.filepath = os.path.join(wandb.run.dir, "model-best.h5") + self.save_model = save_model + if save_model: + warn_and_record_deprecation( + feature=Deprecated(keras_callback__save_model=True), + message=( + "The save_model argument by default saves the model in the HDF5 format that cannot save " + "custom objects like subclassed models and custom layers. This behavior will be deprecated " + "in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved " + "as W&B files and the SavedModel as W&B Artifacts." + ), + ) + + self.save_model_as_artifact = True + self.log_weights = log_weights + self.log_gradients = log_gradients + self.training_data = training_data + self.generator = generator + self._graph_rendered = False + + data_type = kwargs.get("data_type", None) + if data_type is not None: + warn_and_record_deprecation( + feature=Deprecated(keras_callback__data_type=True), + message=( + "The data_type argument of wandb.keras.WandbCallback is deprecated " + "and will be removed in a future release. Please use input_type instead.\n" + "Setting input_type = data_type." + ), + ) + input_type = data_type + self.input_type = input_type + self.output_type = output_type + self.log_evaluation = log_evaluation + self.validation_steps = validation_steps + self.class_colors = np.array(class_colors) if class_colors is not None else None + self.log_batch_frequency = log_batch_frequency + self.log_best_prefix = log_best_prefix + self.compute_flops = compute_flops + + self._prediction_batch_size = None + + if self.log_gradients: + if int(tf.__version__.split(".")[0]) < 2: + raise Exception("Gradient logging requires tensorflow 2.0 or higher.") + if self.training_data is None: + raise ValueError( + "training_data argument is required for gradient logging." + ) + if isinstance(self.training_data, (list, tuple)): + if len(self.training_data) != 2: + raise ValueError("training data must be a tuple of length two") + self._training_data_x, self._training_data_y = self.training_data + else: + self._training_data_x = ( + self.training_data + ) # generator, tf.data.Dataset etc + self._training_data_y = None + + # From Keras + if mode not in ["auto", "min", "max"]: + wandb.termwarn( + f"WandbCallback mode {mode} is unknown, fallback to auto mode." + ) + mode = "auto" + + if mode == "min": + self.monitor_op = operator.lt + self.best = float("inf") + elif mode == "max": + self.monitor_op = operator.gt + self.best = float("-inf") + else: + if "acc" in self.monitor or self.monitor.startswith("fmeasure"): + self.monitor_op = operator.gt + self.best = float("-inf") + else: + self.monitor_op = operator.lt + self.best = float("inf") + # Get the previous best metric for resumed runs + previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}") + if previous_best is not None: + self.best = previous_best + + self._validation_data_logger = None + self._validation_indexes = validation_indexes + self._validation_row_processor = validation_row_processor + self._prediction_row_processor = prediction_row_processor + self._infer_missing_processors = infer_missing_processors + self._log_evaluation_frequency = log_evaluation_frequency + self._model_trained_since_last_eval = False + + def _build_grad_accumulator_model(self): + inputs = self.model.inputs + outputs = self.model(inputs) + grad_acc_model = tf.keras.models.Model(inputs, outputs) + grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer()) + + # make sure magic doesn't think this is a user model + grad_acc_model._wandb_internal_model = True + + self._grad_accumulator_model = grad_acc_model + self._grad_accumulator_callback = _GradAccumulatorCallback() + + def _implements_train_batch_hooks(self): + return self.log_batch_frequency is not None + + def _implements_test_batch_hooks(self): + return self.log_batch_frequency is not None + + def _implements_predict_batch_hooks(self): + return self.log_batch_frequency is not None + + def set_params(self, params): + self.params = params + + def set_model(self, model): + super().set_model(model) + if self.input_type == "auto" and len(model.inputs) == 1: + self.input_type = wandb.util.guess_data_type( + model.inputs[0].shape, risky=True + ) + if self.input_type and self.output_type is None and len(model.outputs) == 1: + self.output_type = wandb.util.guess_data_type(model.outputs[0].shape) + if self.log_gradients: + self._build_grad_accumulator_model() + + def _attempt_evaluation_log(self, commit=True): + if self.log_evaluation and self._validation_data_logger: + try: + if not self.model: + wandb.termwarn("WandbCallback unable to read model from trainer") + else: + self._validation_data_logger.log_predictions( + predictions=self._validation_data_logger.make_predictions( + self.model.predict + ), + commit=commit, + ) + self._model_trained_since_last_eval = False + except Exception as e: + wandb.termwarn("Error during prediction logging for epoch: " + str(e)) + + def on_epoch_end(self, epoch, logs=None): + if logs is None: + logs = {} + if self.log_weights: + wandb.log(self._log_weights(), commit=False) + + if self.log_gradients: + wandb.log(self._log_gradients(), commit=False) + + if self.input_type in ( + "image", + "images", + "segmentation_mask", + ) or self.output_type in ("image", "images", "segmentation_mask"): + if self.generator: + self.validation_data = next(self.generator) + if self.validation_data is None: + wandb.termwarn( + "No validation_data set, pass a generator to the callback." + ) + elif self.validation_data and len(self.validation_data) > 0: + wandb.log( + {"examples": self._log_images(num_images=self.predictions)}, + commit=False, + ) + + if ( + self._log_evaluation_frequency > 0 + and epoch % self._log_evaluation_frequency == 0 + ): + self._attempt_evaluation_log(commit=False) + + wandb.log({"epoch": epoch}, commit=False) + wandb.log(logs, commit=True) + + self.current = logs.get(self.monitor) + if self.current and self.monitor_op(self.current, self.best): + if self.log_best_prefix: + wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = ( + self.current + ) + wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch + if self.verbose and not self.save_model: + wandb.termlog( + f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}" + ) + if self.save_model: + self._save_model(epoch) + + if self.save_model and self.save_model_as_artifact: + self._save_model_as_artifact(epoch) + + self.best = self.current + + # This is what keras used pre tensorflow.keras + def on_batch_begin(self, batch, logs=None): + pass + + # This is what keras used pre tensorflow.keras + def on_batch_end(self, batch, logs=None): + if self.save_graph and not self._graph_rendered: + # Couldn't do this in train_begin because keras may still not be built + wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model) + self._graph_rendered = True + + if self.log_batch_frequency and batch % self.log_batch_frequency == 0: + wandb.log(logs, commit=True) + + def on_train_batch_begin(self, batch, logs=None): + self._model_trained_since_last_eval = True + + def on_train_batch_end(self, batch, logs=None): + if self.save_graph and not self._graph_rendered: + # Couldn't do this in train_begin because keras may still not be built + wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model) + self._graph_rendered = True + + if self.log_batch_frequency and batch % self.log_batch_frequency == 0: + wandb.log(logs, commit=True) + + def on_test_begin(self, logs=None): + pass + + def on_test_end(self, logs=None): + pass + + def on_test_batch_begin(self, batch, logs=None): + pass + + def on_test_batch_end(self, batch, logs=None): + pass + + def on_train_begin(self, logs=None): + if self.log_evaluation: + try: + validation_data = None + if self.validation_data: + validation_data = self.validation_data + elif self.generator: + if not self.validation_steps: + wandb.termwarn( + "WandbCallback is unable to log validation data. " + "When using a generator for validation_data, you must pass validation_steps" + ) + else: + x = None + y_true = None + for _ in range(self.validation_steps): + bx, by_true = next(self.generator) + if x is None: + x, y_true = bx, by_true + else: + x, y_true = ( + np.append(x, bx, axis=0), + np.append(y_true, by_true, axis=0), + ) + validation_data = (x, y_true) + else: + wandb.termwarn( + "WandbCallback is unable to read validation_data from trainer " + "and therefore cannot log validation data. Ensure Keras is properly " + "patched by calling `from wandb.keras import WandbCallback` at the top of your script." + ) + if validation_data: + self._validation_data_logger = ValidationDataLogger( + inputs=validation_data[0], + targets=validation_data[1], + indexes=self._validation_indexes, + validation_row_processor=self._validation_row_processor, + prediction_row_processor=self._prediction_row_processor, + class_labels=self.labels, + infer_missing_processors=self._infer_missing_processors, + ) + except Exception as e: + wandb.termwarn( + "Error initializing ValidationDataLogger in WandbCallback. " + f"Skipping logging validation data. Error: {str(e)}" + ) + + if self.compute_flops and _can_compute_flops(): + try: + wandb.summary["GFLOPs"] = self.get_flops() + except Exception: + logger.exception("Error computing FLOPs") + wandb.termwarn("Unable to compute FLOPs for this model.") + + def on_train_end(self, logs=None): + if self._model_trained_since_last_eval: + self._attempt_evaluation_log() + + def on_predict_begin(self, logs=None): + pass + + def on_predict_end(self, logs=None): + pass + + def on_predict_batch_begin(self, batch, logs=None): + pass + + def on_predict_batch_end(self, batch, logs=None): + pass + + def _logits_to_captions(self, logits): + if logits[0].shape[-1] == 1: + # Scalar output from the model + # TODO: handle validation_y + if len(self.labels) == 2: + # User has named true and false + captions = [ + self.labels[1] if logits[0] > 0.5 else self.labels[0] + for logit in logits + ] + else: + if len(self.labels) != 0: + wandb.termwarn( + "keras model is producing a single output, " + 'so labels should be a length two array: ["False label", "True label"].' + ) + captions = [logit[0] for logit in logits] + else: + # Vector output from the model + # TODO: handle validation_y + labels = np.argmax(np.stack(logits), axis=1) + + if len(self.labels) > 0: + # User has named the categories in self.labels + captions = [] + for label in labels: + try: + captions.append(self.labels[label]) + except IndexError: + captions.append(label) + else: + captions = labels + return captions + + def _masks_to_pixels(self, masks): + # if its a binary mask, just return it as grayscale instead of picking the argmax + if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1: + return masks + class_colors = ( + self.class_colors + if self.class_colors is not None + else np.array(wandb.util.class_colors(masks[0].shape[2])) + ) + imgs = class_colors[np.argmax(masks, axis=-1)] + return imgs + + def _log_images(self, num_images=36): + validation_X = self.validation_data[0] # noqa: N806 + validation_y = self.validation_data[1] + + validation_length = len(validation_X) + + if validation_length > num_images: + # pick some data at random + indices = np.random.choice(validation_length, num_images, replace=False) + else: + indices = range(validation_length) + + test_data = [] + test_output = [] + for i in indices: + test_example = validation_X[i] + test_data.append(test_example) + test_output.append(validation_y[i]) + + if self.model.stateful: + predictions = self.model.predict(np.stack(test_data), batch_size=1) + self.model.reset_states() + else: + predictions = self.model.predict( + np.stack(test_data), batch_size=self._prediction_batch_size + ) + if len(predictions) != len(test_data): + self._prediction_batch_size = 1 + predictions = self.model.predict( + np.stack(test_data), batch_size=self._prediction_batch_size + ) + + if self.input_type == "label": + if self.output_type in ("image", "images", "segmentation_mask"): + captions = self._logits_to_captions(test_data) + output_image_data = ( + self._masks_to_pixels(predictions) + if self.output_type == "segmentation_mask" + else predictions + ) + reference_image_data = ( + self._masks_to_pixels(test_output) + if self.output_type == "segmentation_mask" + else test_output + ) + output_images = [ + wandb.Image(data, caption=captions[i], grouping=2) + for i, data in enumerate(output_image_data) + ] + reference_images = [ + wandb.Image(data, caption=captions[i]) + for i, data in enumerate(reference_image_data) + ] + return list(chain.from_iterable(zip(output_images, reference_images))) + elif self.input_type in ("image", "images", "segmentation_mask"): + input_image_data = ( + self._masks_to_pixels(test_data) + if self.input_type == "segmentation_mask" + else test_data + ) + if self.output_type == "label": + # we just use the predicted label as the caption for now + captions = self._logits_to_captions(predictions) + return [ + wandb.Image(data, caption=captions[i]) + for i, data in enumerate(test_data) + ] + elif self.output_type in ("image", "images", "segmentation_mask"): + output_image_data = ( + self._masks_to_pixels(predictions) + if self.output_type == "segmentation_mask" + else predictions + ) + reference_image_data = ( + self._masks_to_pixels(test_output) + if self.output_type == "segmentation_mask" + else test_output + ) + input_images = [ + wandb.Image(data, grouping=3) + for i, data in enumerate(input_image_data) + ] + output_images = [ + wandb.Image(data) for i, data in enumerate(output_image_data) + ] + reference_images = [ + wandb.Image(data) for i, data in enumerate(reference_image_data) + ] + return list( + chain.from_iterable( + zip(input_images, output_images, reference_images) + ) + ) + else: + # unknown output, just log the input images + return [wandb.Image(img) for img in test_data] + elif self.output_type in ("image", "images", "segmentation_mask"): + # unknown input, just log the predicted and reference outputs without captions + output_image_data = ( + self._masks_to_pixels(predictions) + if self.output_type == "segmentation_mask" + else predictions + ) + reference_image_data = ( + self._masks_to_pixels(test_output) + if self.output_type == "segmentation_mask" + else test_output + ) + output_images = [ + wandb.Image(data, grouping=2) + for i, data in enumerate(output_image_data) + ] + reference_images = [ + wandb.Image(data) for i, data in enumerate(reference_image_data) + ] + return list(chain.from_iterable(zip(output_images, reference_images))) + + def _log_weights(self): + metrics = {} + for layer in self.model.layers: + weights = layer.get_weights() + if len(weights) == 1: + _update_if_numeric( + metrics, "parameters/" + layer.name + ".weights", weights[0] + ) + elif len(weights) == 2: + _update_if_numeric( + metrics, "parameters/" + layer.name + ".weights", weights[0] + ) + _update_if_numeric( + metrics, "parameters/" + layer.name + ".bias", weights[1] + ) + return metrics + + def _log_gradients(self): + # Suppress callback warnings grad accumulator + og_level = tf_logger.level + tf_logger.setLevel("ERROR") + + self._grad_accumulator_model.fit( + self._training_data_x, + self._training_data_y, + verbose=0, + callbacks=[self._grad_accumulator_callback], + ) + tf_logger.setLevel(og_level) + weights = self.model.trainable_weights + grads = self._grad_accumulator_callback.grads + metrics = {} + for weight, grad in zip(weights, grads): + metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = ( + wandb.Histogram(grad) + ) + return metrics + + def _log_dataframe(self): + x, y_true, y_pred = None, None, None + + if self.validation_data: + x, y_true = self.validation_data[0], self.validation_data[1] + y_pred = self.model.predict(x) + elif self.generator: + if not self.validation_steps: + wandb.termwarn( + "when using a generator for validation data with dataframes, " + "you must pass validation_steps. skipping" + ) + return None + + for _ in range(self.validation_steps): + bx, by_true = next(self.generator) + by_pred = self.model.predict(bx) + if x is None: + x, y_true, y_pred = bx, by_true, by_pred + else: + x, y_true, y_pred = ( + np.append(x, bx, axis=0), + np.append(y_true, by_true, axis=0), + np.append(y_pred, by_pred, axis=0), + ) + + if self.input_type in ("image", "images") and self.output_type == "label": + return wandb.image_categorizer_dataframe( + x=x, y_true=y_true, y_pred=y_pred, labels=self.labels + ) + elif ( + self.input_type in ("image", "images") + and self.output_type == "segmentation_mask" + ): + return wandb.image_segmentation_dataframe( + x=x, + y_true=y_true, + y_pred=y_pred, + labels=self.labels, + class_colors=self.class_colors, + ) + else: + wandb.termwarn( + f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}" + ) + return None + + def _save_model(self, epoch): + if wandb.run.disabled: + return + if self.verbose > 0: + wandb.termlog( + f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, " + f"saving model to {self.filepath}" + ) + + try: + if self.save_weights_only: + self.model.save_weights(self.filepath, overwrite=True) + else: + self.model.save(self.filepath, overwrite=True) + # Was getting `RuntimeError: Unable to create link` in TF 1.13.1 + # also saw `TypeError: can't pickle _thread.RLock objects` + except (ImportError, RuntimeError, TypeError, AttributeError): + logger.exception("Error saving model in the h5py format") + wandb.termerror( + "Can't save model in the h5py format. The model will be saved as " + "as an W&B Artifact in the 'tf' format." + ) + + def _save_model_as_artifact(self, epoch): + if wandb.run.disabled: + return + + # Save the model in the SavedModel format. + # TODO: Replace this manual artifact creation with the `log_model` method + # after `log_model` is released from beta. + self.model.save(self.filepath[:-3], overwrite=True, save_format="tf") + + # Log the model as artifact. + name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}") + model_artifact = wandb.Artifact(name, type="model") + model_artifact.add_dir(self.filepath[:-3]) + wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"]) + + # Remove the SavedModel from wandb dir as we don't want to log it to save memory. + shutil.rmtree(self.filepath[:-3]) + + def get_flops(self) -> float: + """Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode. + + It uses tf.compat.v1.profiler under the hood. + """ + if not hasattr(self, "model"): + raise wandb.Error("self.model must be set before using this method.") + + if not isinstance( + self.model, (tf.keras.models.Sequential, tf.keras.models.Model) + ): + raise TypeError( + "Calculating FLOPS is only supported for " + "`tf.keras.Model` and `tf.keras.Sequential` instances." + ) + + from tensorflow.python.framework.convert_to_constants import ( + convert_variables_to_constants_v2_as_graph, + ) + + # Compute FLOPs for one sample + batch_size = 1 + inputs = [ + tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype) + for inp in self.model.inputs + ] + + # convert tf.keras model into frozen graph to count FLOPs about operations used at inference + real_model = tf.function(self.model).get_concrete_function(inputs) + frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model) + + # Calculate FLOPs with tf.profiler + run_meta = tf.compat.v1.RunMetadata() + opts = ( + tf.compat.v1.profiler.ProfileOptionBuilder( + tf.compat.v1.profiler.ProfileOptionBuilder().float_operation() + ) + .with_empty_output() + .build() + ) + + flops = tf.compat.v1.profiler.profile( + graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts + ) + + # convert to GFLOPs + return (flops.total_float_ops / 1e9) / 2 diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/kfp/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f3a362ca0fc895bce03dccdfeaa3274a9b5963d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/__init__.py @@ -0,0 +1,6 @@ +__all__ = ["wandb_log", "unpatch_kfp"] + +from .kfp_patch import patch_kfp, unpatch_kfp +from .wandb_logging import wandb_log + +patch_kfp() diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/kfp/helpers.py b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..feebcc62ea3688a2a93818adb445f6454a6d9dc2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/helpers.py @@ -0,0 +1,28 @@ +import json + + +def add_wandb_visualization(run, mlpipeline_ui_metadata_path): + """NOTE: To use this, you must modify your component to have an output called `mlpipeline_ui_metadata_path` AND call `wandb.init` yourself inside that component. + + Example usage: + + def my_component(..., mlpipeline_ui_metadata_path: OutputPath()): + import wandb + from wandb.integration.kfp.helpers import add_wandb_visualization + + with wandb.init() as run: + add_wandb_visualization(run, mlpipeline_ui_metadata_path) + + ... # the rest of your code here + """ + + def get_iframe_html(run): + return f'' + + iframe_html = get_iframe_html(run) + metadata = { + "outputs": [{"type": "markdown", "storage": "inline", "source": iframe_html}] + } + + with open(mlpipeline_ui_metadata_path, "w") as metadata_file: + json.dump(metadata, metadata_file) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/kfp/kfp_patch.py b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/kfp_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..367b03afd0fd673250328fe3882fb8c2fef0550c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/kfp_patch.py @@ -0,0 +1,335 @@ +import inspect +import itertools +import textwrap +from typing import Callable, List, Mapping, Optional + +import wandb + +try: + from kfp import __version__ as kfp_version + from kfp.components import structures + from kfp.components._components import _create_task_factory_from_component_spec + from kfp.components._python_op import _func_to_component_spec + from packaging.version import parse + + MIN_KFP_VERSION = "1.6.1" + + if parse(kfp_version) < parse(MIN_KFP_VERSION): + wandb.termwarn( + f"Your version of kfp {kfp_version} may not work. This integration requires kfp>={MIN_KFP_VERSION}" + ) + +except ImportError: + wandb.termerror("kfp not found! Please `pip install kfp`") + +from .wandb_logging import wandb_log + +decorator_code = inspect.getsource(wandb_log) +wandb_logging_extras = f""" +import typing +from typing import NamedTuple + +import collections +from collections import namedtuple + +import kfp +from kfp import components +from kfp.components import InputPath, OutputPath + +import wandb + +{decorator_code} +""" + + +def full_path_exists(full_func): + def get_parent_child_pairs(full_func): + components = full_func.split(".") + parents, children = [], [] + for i, _ in enumerate(components[:-1], 1): + parent = ".".join(components[:i]) + child = components[i] + parents.append(parent) + children.append(child) + return zip(parents, children) + + for parent, child in get_parent_child_pairs(full_func): + module = wandb.util.get_module(parent) + if not module or not hasattr(module, child) or getattr(module, child) is None: + return False + return True + + +def patch(module_name, func): + module = wandb.util.get_module(module_name) + success = False + + full_func = f"{module_name}.{func.__name__}" + if not full_path_exists(full_func): + wandb.termerror( + f"Failed to patch {module_name}.{func.__name__}! Please check if this package/module is installed!" + ) + else: + wandb.patched.setdefault(module.__name__, []) + # if already patched, do not patch again + if [module, func.__name__] not in wandb.patched[module.__name__]: + setattr(module, f"orig_{func.__name__}", getattr(module, func.__name__)) + setattr(module, func.__name__, func) + wandb.patched[module.__name__].append([module, func.__name__]) + success = True + + return success + + +def unpatch(module_name): + if module_name in wandb.patched: + for module, func in wandb.patched[module_name]: + setattr(module, func, getattr(module, f"orig_{func}")) + wandb.patched[module_name] = [] + + +def unpatch_kfp(): + unpatch("kfp.components") + unpatch("kfp.components._python_op") + unpatch("wandb.integration.kfp") + + +def patch_kfp(): + to_patch = [ + ( + "kfp.components", + create_component_from_func, + ), + ( + "kfp.components._python_op", + create_component_from_func, + ), + ( + "kfp.components._python_op", + _get_function_source_definition, + ), + ("kfp.components._python_op", strip_type_hints), + ] + + successes = [] + for module_name, func in to_patch: + success = patch(module_name, func) + successes.append(success) + if not all(successes): + wandb.termerror( + "Failed to patch one or more kfp functions. Patching @wandb_log decorator to no-op." + ) + patch("wandb.integration.kfp", wandb_log) + + +def wandb_log( + func=None, + # /, # py38 only + log_component_file=True, +): + """Wrap a standard python function and log to W&B. + + NOTE: Because patching failed, this decorator is a no-op. + """ + from functools import wraps + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + if func is None: + return decorator + else: + return decorator(func) + + +def _get_function_source_definition(func: Callable) -> str: + """Get the source code of a function. + + This function is modified from KFP. The original source is below: + https://github.com/kubeflow/pipelines/blob/b6406b02f45cdb195c7b99e2f6d22bf85b12268b/sdk/python/kfp/components/_python_op.py#L300-L319. + """ + func_code = inspect.getsource(func) + + # Function might be defined in some indented scope (e.g. in another + # function). We need to handle this and properly dedent the function source + # code + func_code = textwrap.dedent(func_code) + func_code_lines = func_code.split("\n") + + # For wandb, allow decorators (so we can use the @wandb_log decorator) + func_code_lines = itertools.dropwhile( + lambda x: not (x.startswith(("def", "@wandb_log"))), + func_code_lines, + ) + + if not func_code_lines: + raise ValueError( + f'Failed to dedent and clean up the source of function "{func.__name__}". ' + "It is probably not properly indented." + ) + + return "\n".join(func_code_lines) + + +def create_component_from_func( + func: Callable, + output_component_file: Optional[str] = None, + base_image: Optional[str] = None, + packages_to_install: Optional[List[str]] = None, + annotations: Optional[Mapping[str, str]] = None, +): + '''Convert a Python function to a component and returns a task factory. + + The returned task factory accepts arguments and returns a task object. + + This function is modified from KFP. The original source is below: + https://github.com/kubeflow/pipelines/blob/b6406b02f45cdb195c7b99e2f6d22bf85b12268b/sdk/python/kfp/components/_python_op.py#L998-L1110. + + Args: + func: The python function to convert + base_image: Optional. Specify a custom Docker container image to use in the component. For lightweight components, the image needs to have python 3.5+. Default is the python image corresponding to the current python environment. + output_component_file: Optional. Write a component definition to a local file. The produced component file can be loaded back by calling :code:`load_component_from_file` or :code:`load_component_from_uri`. + packages_to_install: Optional. List of [versioned] python packages to pip install before executing the user function. + annotations: Optional. Allows adding arbitrary key-value data to the component specification. + + Returns: + A factory function with a strongly-typed signature taken from the python function. + Once called with the required arguments, the factory constructs a task instance that can run the original function in a container. + + Examples: + The function name and docstring are used as component name and description. Argument and return annotations are used as component input/output types:: + + def add(a: float, b: float) -> float: + """Return sum of two arguments""" + return a + b + + + # add_op is a task factory function that creates a task object when given arguments + add_op = create_component_from_func( + func=add, + base_image="python:3.7", # Optional + output_component_file="add.component.yaml", # Optional + packages_to_install=["pandas==0.24"], # Optional + ) + + # The component spec can be accessed through the .component_spec attribute: + add_op.component_spec.save("add.component.yaml") + + # The component function can be called with arguments to create a task: + add_task = add_op(1, 3) + + # The resulting task has output references, corresponding to the component outputs. + # When the function only has a single anonymous return value, the output name is "Output": + sum_output_ref = add_task.outputs["Output"] + + # These task output references can be passed to other component functions, constructing a computation graph: + task2 = add_op(sum_output_ref, 5) + + + :code:`create_component_from_func` function can also be used as decorator:: + + @create_component_from_func + def add_op(a: float, b: float) -> float: + """Return sum of two arguments""" + return a + b + + To declare a function with multiple return values, use the :code:`NamedTuple` return annotation syntax:: + + from typing import NamedTuple + + + def add_multiply_two_numbers(a: float, b: float) -> NamedTuple( + "Outputs", [("sum", float), ("product", float)] + ): + """Return sum and product of two arguments""" + return (a + b, a * b) + + + add_multiply_op = create_component_from_func(add_multiply_two_numbers) + + # The component function can be called with arguments to create a task: + add_multiply_task = add_multiply_op(1, 3) + + # The resulting task has output references, corresponding to the component outputs: + sum_output_ref = add_multiply_task.outputs["sum"] + + # These task output references can be passed to other component functions, constructing a computation graph: + task2 = add_multiply_op(sum_output_ref, 5) + + Bigger data should be read from files and written to files. + Use the :py:class:`kfp.components.InputPath` parameter annotation to tell the system that the function wants to consume the corresponding input data as a file. The system will download the data, write it to a local file and then pass the **path** of that file to the function. + Use the :py:class:`kfp.components.OutputPath` parameter annotation to tell the system that the function wants to produce the corresponding output data as a file. The system will prepare and pass the **path** of a file where the function should write the output data. After the function exits, the system will upload the data to the storage system so that it can be passed to downstream components. + + You can specify the type of the consumed/produced data by specifying the type argument to :py:class:`kfp.components.InputPath` and :py:class:`kfp.components.OutputPath`. The type can be a python type or an arbitrary type name string. :code:`OutputPath('CatBoostModel')` means that the function states that the data it has written to a file has type :code:`CatBoostModel`. :code:`InputPath('CatBoostModel')` means that the function states that it expect the data it reads from a file to have type 'CatBoostModel'. When the pipeline author connects inputs to outputs the system checks whether the types match. + Every kind of data can be consumed as a file input. Conversely, bigger data should not be consumed by value as all value inputs pass through the command line. + + Example of a component function declaring file input and output:: + + def catboost_train_classifier( + training_data_path: InputPath( + "CSV" + ), # Path to input data file of type "CSV" + trained_model_path: OutputPath( + "CatBoostModel" + ), # Path to output data file of type "CatBoostModel" + number_of_trees: int = 100, # Small output of type "Integer" + ) -> NamedTuple( + "Outputs", + [ + ("Accuracy", float), # Small output of type "Float" + ("Precision", float), # Small output of type "Float" + ("JobUri", "URI"), # Small output of type "URI" + ], + ): + """Train CatBoost classification model""" + ... + + return (accuracy, precision, recall) + ''' + core_packages = ["wandb", "kfp"] + + if not packages_to_install: + packages_to_install = core_packages + else: + packages_to_install += core_packages + + component_spec = _func_to_component_spec( + func=func, + extra_code=wandb_logging_extras, + base_image=base_image, + packages_to_install=packages_to_install, + ) + if annotations: + component_spec.metadata = structures.MetadataSpec( + annotations=annotations, + ) + + if output_component_file: + component_spec.save(output_component_file) + + return _create_task_factory_from_component_spec(component_spec) + + +def strip_type_hints(source_code: str) -> str: + """Strip type hints from source code. + + This function is modified from KFP. The original source is below: + https://github.com/kubeflow/pipelines/blob/b6406b02f45cdb195c7b99e2f6d22bf85b12268b/sdk/python/kfp/components/_python_op.py#L237-L248. + """ + # For wandb, do not strip type hints + + # try: + # return _strip_type_hints_using_lib2to3(source_code) + # except Exception as ex: + # print('Error when stripping type annotations: ' + str(ex)) + + # try: + # return _strip_type_hints_using_strip_hints(source_code) + # except Exception as ex: + # print('Error when stripping type annotations: ' + str(ex)) + + return source_code diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/kfp/wandb_logging.py b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/wandb_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0edf3eac211b59908bc3aa4849a459c2da2148 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/kfp/wandb_logging.py @@ -0,0 +1,182 @@ +def wandb_log( # noqa: C901 + func=None, + # /, # py38 only + log_component_file=True, +): + """Wrap a standard python function and log to W&B.""" + import json + import os + from functools import wraps + from inspect import Parameter, signature + + from kfp import components + from kfp.components import ( + InputArtifact, + InputBinaryFile, + InputPath, + InputTextFile, + OutputArtifact, + OutputBinaryFile, + OutputPath, + OutputTextFile, + ) + + import wandb + from wandb.sdk.lib import telemetry as wb_telemetry + + output_types = (OutputArtifact, OutputBinaryFile, OutputPath, OutputTextFile) + input_types = (InputArtifact, InputBinaryFile, InputPath, InputTextFile) + + def isinstance_namedtuple(x): + t = type(x) + b = t.__bases__ + if len(b) != 1 or b[0] is not tuple: + return False + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return False + return all(isinstance(n, str) for n in f) + + def get_iframe_html(run): + return f'' + + def get_link_back_to_kubeflow(): + wandb_kubeflow_url = os.getenv("WANDB_KUBEFLOW_URL") + return f"{wandb_kubeflow_url}/#/runs/details/{{workflow.uid}}" + + def log_input_scalar(name, data, run=None): + run.config[name] = data + wandb.termlog(f"Setting config: {name} to {data}") + + def log_input_artifact(name, data, type, run=None): + artifact = wandb.Artifact(name, type=type) + artifact.add_file(data) + run.use_artifact(artifact) + wandb.termlog(f"Using artifact: {name}") + + def log_output_scalar(name, data, run=None): + if isinstance_namedtuple(data): + for k, v in zip(data._fields, data): + run.log({f"{func.__name__}.{k}": v}) + else: + run.log({name: data}) + + def log_output_artifact(name, data, type, run=None): + artifact = wandb.Artifact(name, type=type) + artifact.add_file(data) + run.log_artifact(artifact) + wandb.termlog(f"Logging artifact: {name}") + + def _log_component_file(func, run=None): + name = func.__name__ + output_component_file = f"{name}.yml" + components._python_op.func_to_component_file(func, output_component_file) + artifact = wandb.Artifact(name, type="kubeflow_component_file") + artifact.add_file(output_component_file) + run.log_artifact(artifact) + wandb.termlog(f"Logging component file: {output_component_file}") + + # Add `mlpipeline_ui_metadata_path` to signature to show W&B run in "ML Visualizations tab" + sig = signature(func) + no_default = [] + has_default = [] + + for param in sig.parameters.values(): + if param.default is param.empty: + no_default.append(param) + else: + has_default.append(param) + + new_params = tuple( + ( + *no_default, + Parameter( + "mlpipeline_ui_metadata_path", + annotation=OutputPath(), + kind=Parameter.POSITIONAL_OR_KEYWORD, + ), + *has_default, + ) + ) + new_sig = sig.replace(parameters=new_params) + new_anns = {param.name: param.annotation for param in new_params} + if "return" in func.__annotations__: + new_anns["return"] = func.__annotations__["return"] + + def decorator(func): + input_scalars = {} + input_artifacts = {} + output_scalars = {} + output_artifacts = {} + + for name, ann in func.__annotations__.items(): + if name == "return": + output_scalars[name] = ann + elif isinstance(ann, output_types): + output_artifacts[name] = ann + elif isinstance(ann, input_types): + input_artifacts[name] = ann + else: + input_scalars[name] = ann + + @wraps(func) + def wrapper(*args, **kwargs): + bound = new_sig.bind(*args, **kwargs) + bound.apply_defaults() + + mlpipeline_ui_metadata_path = bound.arguments["mlpipeline_ui_metadata_path"] + del bound.arguments["mlpipeline_ui_metadata_path"] + + with wandb.init( + job_type=func.__name__, + group="{{workflow.annotations.pipelines.kubeflow.org/run_name}}", + ) as run: + # Link back to the kfp UI + kubeflow_url = get_link_back_to_kubeflow() + run.notes = kubeflow_url + run.config["LINK_TO_KUBEFLOW_RUN"] = kubeflow_url + + iframe_html = get_iframe_html(run) + metadata = { + "outputs": [ + { + "type": "markdown", + "storage": "inline", + "source": iframe_html, + } + ] + } + + with open(mlpipeline_ui_metadata_path, "w") as metadata_file: + json.dump(metadata, metadata_file) + + if log_component_file: + _log_component_file(func, run=run) + + for name, _ in input_scalars.items(): + log_input_scalar(name, kwargs[name], run) + + for name, ann in input_artifacts.items(): + log_input_artifact(name, kwargs[name], ann.type, run) + + with wb_telemetry.context(run=run) as tel: + tel.feature.kfp_wandb_log = True + + result = func(*bound.args, **bound.kwargs) + + for name, _ in output_scalars.items(): + log_output_scalar(name, result, run) + + for name, ann in output_artifacts.items(): + log_output_artifact(name, kwargs[name], ann.type, run) + + return result + + wrapper.__signature__ = new_sig + wrapper.__annotations__ = new_anns + return wrapper + + if func is None: + return decorator + else: + return decorator(func) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/langchain/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/langchain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aaec971a312e9369e1740f0b31e31782bb9b9e3f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/langchain/__init__.py @@ -0,0 +1,3 @@ +__all__ = ("WandbTracer",) + +from .wandb_tracer import WandbTracer diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/langchain/wandb_tracer.py b/.venv/lib/python3.13/site-packages/wandb/integration/langchain/wandb_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0603c9e0366201e850a922d90c7df3d70867d9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/langchain/wandb_tracer.py @@ -0,0 +1,49 @@ +"""This module contains an integration with the LangChain library. + +Specifically, it exposes a `WandbTracer` class that can be used to stream +LangChain activity to W&B. The intended usage pattern is to call +`tracer = WandbTracer()` at the top of the script/notebook, and call +`tracer.finish()` at the end of the script/notebook. + This will stream all LangChain activity to W&B. + +Technical Note: +LangChain is in very rapid development - meaning their APIs and schemas are actively changing. +As a matter of precaution, any call to LangChain apis, or use of their returned data is wrapped +in a try/except block. This is to ensure that if a breaking change is introduced, the W&B +integration will not break user code. The one exception to the rule is at import time. If +LangChain is not installed, or the symbols are not in the same place, the appropriate error +will be raised when importing this module. +""" + +from packaging import version + +import wandb.util +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk.lib.deprecation import warn_and_record_deprecation + +langchain = wandb.util.get_module( + name="langchain", + required="To use the LangChain WandbTracer you need to have the `langchain` python " + "package installed. Please install it with `pip install langchain`.", +) + +if version.parse(langchain.__version__) < version.parse("0.0.188"): + raise ValueError( + "The Weights & Biases Langchain integration does not support versions 0.0.187 and lower. " + "To ensure proper functionality, please use version 0.0.188 or higher." + ) + +# isort: off +from langchain.callbacks.tracers import WandbTracer # noqa: E402 + + +class WandbTracer(WandbTracer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + warn_and_record_deprecation( + feature=Deprecated(langchain_tracer=True), + message="This feature is deprecated and has been moved to `langchain`. Enable tracing by setting " + "LANGCHAIN_WANDB_TRACING=true in your environment. See the documentation at " + "https://python.langchain.com/docs/ecosystem/integrations/agent_with_wandb_tracing for guidance. " + "Replace your current import with `from langchain.callbacks.tracers import WandbTracer`.", + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/lightgbm/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/lightgbm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc1029563baa721fbdbc2c2809974b59203008ee --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/lightgbm/__init__.py @@ -0,0 +1,239 @@ +"""W&B callback for lightgbm. + +Really simple callback to get logging for each tree + +Example usage: + +param_list = [("eta", 0.08), ("max_depth", 6), ("subsample", 0.8), ("colsample_bytree", 0.8), ("alpha", 8), ("num_class", 10)] +config.update(dict(param_list)) +lgb = lgb.train(param_list, d_train, callbacks=[wandb_callback()]) +""" + +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +import lightgbm # type: ignore +from lightgbm import Booster + +import wandb +from wandb.sdk.lib import telemetry as wb_telemetry + +MINIMIZE_METRICS = [ + "l1", + "l2", + "rmse", + "mape", + "huber", + "fair", + "poisson", + "gamma", + "binary_logloss", +] + +MAXIMIZE_METRICS = ["map", "auc", "average_precision"] + + +if TYPE_CHECKING: + from typing import Any, Dict, List, NamedTuple, Tuple, Union + + # Note: upstream lightgbm has this defined incorrectly + _EvalResultTuple = Union[ + Tuple[str, str, float, bool], Tuple[str, str, float, bool, float] + ] + + class CallbackEnv(NamedTuple): + model: Any + params: Dict + iteration: int + begin_interation: int + end_iteration: int + evaluation_result_list: List[_EvalResultTuple] + + +def _define_metric(data: str, metric_name: str) -> None: + """Capture model performance at the best step. + + instead of the last step, of training in your `wandb.summary` + """ + if "loss" in str.lower(metric_name): + wandb.define_metric(f"{data}_{metric_name}", summary="min") + elif str.lower(metric_name) in MINIMIZE_METRICS: + wandb.define_metric(f"{data}_{metric_name}", summary="min") + elif str.lower(metric_name) in MAXIMIZE_METRICS: + wandb.define_metric(f"{data}_{metric_name}", summary="max") + + +def _checkpoint_artifact( + model: "Booster", iteration: int, aliases: "List[str]" +) -> None: + """Upload model checkpoint as W&B artifact.""" + # NOTE: type ignore required because wandb.run is improperly inferred as None type + model_name = f"model_{wandb.run.id}" # type: ignore + model_path = Path(wandb.run.dir) / f"model_ckpt_{iteration}.txt" # type: ignore + + model.save_model(model_path, num_iteration=iteration) + + model_artifact = wandb.Artifact(name=model_name, type="model") + model_artifact.add_file(str(model_path)) + wandb.log_artifact(model_artifact, aliases=aliases) + + +def _log_feature_importance(model: "Booster") -> None: + """Log feature importance.""" + feat_imps = model.feature_importance() + feats = model.feature_name() + fi_data = [[feat, feat_imp] for feat, feat_imp in zip(feats, feat_imps)] + table = wandb.Table(data=fi_data, columns=["Feature", "Importance"]) + wandb.log( + { + "Feature Importance": wandb.plot.bar( + table, "Feature", "Importance", title="Feature Importance" + ) + }, + commit=False, + ) + + +class _WandbCallback: + """Internal class to handle `wandb_callback` logic. + + This callback is adapted form the LightGBM's `_RecordEvaluationCallback`. + """ + + def __init__(self, log_params: bool = True, define_metric: bool = True) -> None: + self.order = 20 + self.before_iteration = False + self.log_params = log_params + self.define_metric_bool = define_metric + + def _init(self, env: "CallbackEnv") -> None: + with wb_telemetry.context() as tel: + tel.feature.lightgbm_wandb_callback = True + + # log the params as W&B config. + if self.log_params: + wandb.config.update(env.params) + + # use `define_metric` to set the wandb summary to the best metric value. + for item in env.evaluation_result_list: + if self.define_metric_bool: + if len(item) == 4: + data_name, eval_name = item[:2] + _define_metric(data_name, eval_name) + else: + data_name, eval_name = item[1].split() + _define_metric(data_name, f"{eval_name}-mean") + _define_metric(data_name, f"{eval_name}-stdv") + + def __call__(self, env: "CallbackEnv") -> None: + if env.iteration == env.begin_iteration: # type: ignore + self._init(env) + + for item in env.evaluation_result_list: + if len(item) == 4: + data_name, eval_name, result = item[:3] + wandb.log( + {data_name + "_" + eval_name: result}, + commit=False, + ) + else: + data_name, eval_name = item[1].split() + res_mean = item[2] + res_stdv = item[4] + wandb.log( + { + data_name + "_" + eval_name + "-mean": res_mean, + data_name + "_" + eval_name + "-stdv": res_stdv, + }, + commit=False, + ) + + # call `commit=True` to log the data as a single W&B step. + wandb.log({"iteration": env.iteration}, commit=True) + + +def wandb_callback(log_params: bool = True, define_metric: bool = True) -> Callable: + """Automatically integrates LightGBM with wandb. + + Args: + log_params: (boolean) if True (default) logs params passed to lightgbm.train as W&B config + define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary` + + Passing `wandb_callback` to LightGBM will: + - log params passed to lightgbm.train as W&B config (default). + - log evaluation metrics collected by LightGBM, such as rmse, accuracy etc to Weights & Biases + - Capture the best metric in `wandb.summary` when `define_metric=True` (default). + + Use `log_summary` as an extension of this callback. + + Example: + ```python + params = { + "boosting_type": "gbdt", + "objective": "regression", + } + gbm = lgb.train( + params, + lgb_train, + num_boost_round=10, + valid_sets=lgb_eval, + valid_names=("validation"), + callbacks=[wandb_callback()], + ) + ``` + """ + return _WandbCallback(log_params, define_metric) + + +def log_summary( + model: Booster, feature_importance: bool = True, save_model_checkpoint: bool = False +) -> None: + """Log useful metrics about lightgbm model after training is done. + + Args: + model: (Booster) is an instance of lightgbm.basic.Booster. + feature_importance: (boolean) if True (default), logs the feature importance plot. + save_model_checkpoint: (boolean) if True saves the best model and upload as W&B artifacts. + + Using this along with `wandb_callback` will: + + - log `best_iteration` and `best_score` as `wandb.summary`. + - log feature importance plot. + - save and upload your best trained model to Weights & Biases Artifacts (when `save_model_checkpoint = True`) + + Example: + ```python + params = { + "boosting_type": "gbdt", + "objective": "regression", + } + gbm = lgb.train( + params, + lgb_train, + num_boost_round=10, + valid_sets=lgb_eval, + valid_names=("validation"), + callbacks=[wandb_callback()], + ) + + log_summary(gbm) + ``` + """ + if wandb.run is None: + raise wandb.Error("You must call wandb.init() before WandbCallback()") + + if not isinstance(model, Booster): + raise wandb.Error("Model should be an instance of lightgbm.basic.Booster") + + wandb.run.summary["best_iteration"] = model.best_iteration + wandb.run.summary["best_score"] = model.best_score + + # Log feature importance + if feature_importance: + _log_feature_importance(model) + + if save_model_checkpoint: + _checkpoint_artifact(model, model.best_iteration, aliases=["best"]) + + with wb_telemetry.context() as tel: + tel.feature.lightgbm_log_summary = True diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/lightning/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/lightning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/lightning/fabric/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/lightning/fabric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eff7554e1131034c637c8cfef591756211b8a898 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/lightning/fabric/__init__.py @@ -0,0 +1,3 @@ +from wandb.integration.lightning.fabric.logger import WandbLogger + +__all__ = ("WandbLogger",) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/lightning/fabric/logger.py b/.venv/lib/python3.13/site-packages/wandb/integration/lightning/fabric/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..21e326241c75ecfedb2ceb5cd8e39b0bc1c0b0f2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/lightning/fabric/logger.py @@ -0,0 +1,763 @@ +import os +from argparse import Namespace +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Union + +from packaging import version +from typing_extensions import override + +import wandb +from wandb import Artifact +from wandb.sdk.lib import telemetry + +try: + import lightning + import torch.nn as nn + from lightning.fabric.loggers.logger import Logger, rank_zero_experiment + from lightning.fabric.utilities.exceptions import MisconfigurationException + from lightning.fabric.utilities.logger import ( + _add_prefix, + _convert_params, + _sanitize_callable_params, + ) + from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn + from lightning.fabric.utilities.types import _PATH + from torch import Tensor + from torch.nn import Module + + if version.parse(lightning.__version__) > version.parse("2.1.3"): + wandb.termwarn( + """This integration is tested and supported for lightning Fabric 2.1.3. + Please report any issues to https://github.com/wandb/wandb/issues with the tag `lightning-fabric`.""", + repeat=False, + ) + + if TYPE_CHECKING: + from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint + +except ImportError as e: + wandb.Error(e) + + +class WandbLogger(Logger): + r"""Log using `Weights and Biases `_. + + **Installation and set-up** + + Install with pip: + + .. code-block:: bash + + pip install wandb + + Create a `WandbLogger` instance: + + .. code-block:: python + + from lightning.fabric.loggers import WandbLogger + + wandb_logger = WandbLogger(project="MNIST") + + Pass the logger instance to the `Trainer`: + + .. code-block:: python + + trainer = Trainer(logger=wandb_logger) + + A new W&B run will be created when training starts if you have not created one manually before with `wandb.init()`. + + **Log metrics** + + Log from :class:`~lightning.pytorch.core.LightningModule`: + + .. code-block:: python + + class LitModule(LightningModule): + def training_step(self, batch, batch_idx): + self.log("train/loss", loss) + + Use directly wandb module: + + .. code-block:: python + + wandb.log({"train/loss": loss}) + + **Log hyper-parameters** + + Save :class:`~lightning.pytorch.core.LightningModule` parameters: + + .. code-block:: python + + class LitModule(LightningModule): + def __init__(self, *args, **kwarg): + self.save_hyperparameters() + + Add other config parameters: + + .. code-block:: python + + # add one parameter + wandb_logger.experiment.config["key"] = value + + # add multiple parameters + wandb_logger.experiment.config.update({key1: val1, key2: val2}) + + # use directly wandb module + wandb.config["key"] = value + wandb.config.update() + + **Log gradients, parameters and model topology** + + Call the `watch` method for automatically tracking gradients: + + .. code-block:: python + + # log gradients and model topology + wandb_logger.watch(model) + + # log gradients, parameter histogram and model topology + wandb_logger.watch(model, log="all") + + # change log frequency of gradients and parameters (100 steps by default) + wandb_logger.watch(model, log_freq=500) + + # do not log graph (in case of errors) + wandb_logger.watch(model, log_graph=False) + + The `watch` method adds hooks to the model which can be removed at the end of training: + + .. code-block:: python + + wandb_logger.experiment.unwatch(model) + + **Log model checkpoints** + + Log model checkpoints at the end of training: + + .. code-block:: python + + wandb_logger = WandbLogger(log_model=True) + + Log model checkpoints as they get created during training: + + .. code-block:: python + + wandb_logger = WandbLogger(log_model="all") + + Custom checkpointing can be set up through :class:`~lightning.pytorch.callbacks.ModelCheckpoint`: + + .. code-block:: python + + # log model only if `val_accuracy` increases + wandb_logger = WandbLogger(log_model="all") + checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max") + trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback]) + + `latest` and `best` aliases are automatically set to easily retrieve a model checkpoint: + + .. code-block:: python + + # reference can be retrieved in artifacts panel + # "VERSION" can be a version (ex: "v2") or an alias ("latest or "best") + checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION" + + # download checkpoint locally (if not already cached) + run = wandb.init(project="MNIST") + artifact = run.use_artifact(checkpoint_reference, type="model") + artifact_dir = artifact.download() + + # load checkpoint + model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt") + + **Log media** + + Log text with: + + .. code-block:: python + + # using columns and data + columns = ["input", "label", "prediction"] + data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]] + wandb_logger.log_text(key="samples", columns=columns, data=data) + + # using a pandas DataFrame + wandb_logger.log_text(key="samples", dataframe=my_dataframe) + + Log images with: + + .. code-block:: python + + # using tensors, numpy arrays or PIL images + wandb_logger.log_image(key="samples", images=[img1, img2]) + + # adding captions + wandb_logger.log_image( + key="samples", images=[img1, img2], caption=["tree", "person"] + ) + + # using file path + wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"]) + + More arguments can be passed for logging segmentation masks and bounding boxes. Refer to + `Image Overlays documentation `_. + + **Log Tables** + + `W&B Tables `_ can be used to log, + query and analyze tabular data. + + They support any type of media (text, image, video, audio, molecule, html, etc) and are great for storing, + understanding and sharing any form of data, from datasets to model predictions. + + .. code-block:: python + + columns = ["caption", "image", "sound"] + data = [ + ["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], + ["wine", wandb.Image(img_2), wandb.Audio(snd_2)], + ] + wandb_logger.log_table(key="samples", columns=columns, data=data) + + + **Downloading and Using Artifacts** + + To download an artifact without starting a run, call the ``download_artifact`` + function on the class: + + .. code-block:: python + + artifact_dir = wandb_logger.download_artifact(artifact="path/to/artifact") + + To download an artifact and link it to an ongoing run call the ``download_artifact`` + function on the logger instance: + + .. code-block:: python + + class MyModule(LightningModule): + def any_lightning_module_function_or_hook(self): + self.logger.download_artifact(artifact="path/to/artifact") + + To link an artifact from a previous run you can use ``use_artifact`` function: + + .. code-block:: python + + wandb_logger.use_artifact(artifact="path/to/artifact") + + See Also: + - `Demo in Google Colab `__ with hyperparameter search and model logging + - `W&B Documentation `__ + + Args: + name: Display name for the run. + save_dir: Path where data is saved. + version: Sets the version, mainly used to resume a previous run. + offline: Run offline (data can be streamed later to wandb servers). + dir: Same as save_dir. + id: Same as version. + anonymous: Enables or explicitly disables anonymous logging. + project: The name of the project to which this run will belong. If not set, the environment variable + `WANDB_PROJECT` will be used as a fallback. If both are not set, it defaults to ``'lightning_logs'``. + log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.ModelCheckpoint` + as W&B artifacts. `latest` and `best` aliases are automatically set. + + * if ``log_model == 'all'``, checkpoints are logged during training. + * if ``log_model == True``, checkpoints are logged at the end of training, except when + `~lightning.pytorch.callbacks.ModelCheckpoint.save_top_k` ``== -1`` + which also logs every checkpoint during training. + * if ``log_model == False`` (default), no checkpoint is logged. + + prefix: A string to put at the beginning of metric keys. + experiment: WandB experiment object. Automatically set when creating a run. + checkpoint_name: Name of the model checkpoint artifact being logged. + log_checkpoint_on: When to log model checkpoints as W&B artifacts. Only used if ``log_model`` is ``True``. + Options: ``"success"``, ``"all"``. Default: ``"success"``. + \**kwargs: Arguments passed to :func:`wandb.init` like `entity`, `group`, `tags`, etc. + + Raises: + ModuleNotFoundError: + If required WandB package is not installed on the device. + MisconfigurationException: + If both ``log_model`` and ``offline`` is set to ``True``. + + """ + + LOGGER_JOIN_CHAR = "-" + + def __init__( + self, + name: Optional[str] = None, + save_dir: _PATH = ".", + version: Optional[str] = None, + offline: bool = False, + dir: Optional[_PATH] = None, + id: Optional[str] = None, + anonymous: Optional[bool] = None, + project: Optional[str] = None, + log_model: Union[Literal["all"], bool] = False, + experiment: Optional["wandb.Run"] = None, + prefix: str = "", + checkpoint_name: Optional[str] = None, + log_checkpoint_on: Union[Literal["success"], Literal["all"]] = "success", + **kwargs: Any, + ) -> None: + if offline and log_model: + raise MisconfigurationException( + f"Providing log_model={log_model} and offline={offline} is an invalid configuration" + " since model checkpoints cannot be uploaded in offline mode.\n" + "Hint: Set `offline=False` to log your model." + ) + + super().__init__() + self._offline = offline + self._log_model = log_model + self._prefix = prefix + self._experiment = experiment + self._logged_model_time: Dict[str, float] = {} + self._checkpoint_callback: Optional[ModelCheckpoint] = None + + # paths are processed as strings + if save_dir is not None: + save_dir = os.fspath(save_dir) + elif dir is not None: + dir = os.fspath(dir) + + project = project or os.environ.get("WANDB_PROJECT", "lightning_fabric_logs") + + # set wandb init arguments + self._wandb_init: Dict[str, Any] = { + "name": name, + "project": project, + "dir": save_dir or dir, + "id": version or id, + "resume": "allow", + "anonymous": ("allow" if anonymous else None), + } + self._wandb_init.update(**kwargs) + # extract parameters + self._project = self._wandb_init.get("project") + self._save_dir = self._wandb_init.get("dir") + self._name = self._wandb_init.get("name") + self._id = self._wandb_init.get("id") + self._checkpoint_name = checkpoint_name + self._log_checkpoint_on = log_checkpoint_on + + def __getstate__(self) -> Dict[str, Any]: + # Hack: If the 'spawn' launch method is used, the logger will get pickled and this `__getstate__` gets called. + # We create an experiment here in the main process, and attach to it in the worker process. + # Using wandb-service, we persist the same experiment even if multiple `Trainer.fit/test/validate` calls + # are made. + _ = self.experiment + + state = self.__dict__.copy() + # args needed to reload correct experiment + if self._experiment is not None: + state["_id"] = getattr(self._experiment, "id", None) + state["_attach_id"] = getattr(self._experiment, "_attach_id", None) + state["_name"] = self._experiment.name + + # cannot be pickled + state["_experiment"] = None + return state + + @property + @rank_zero_experiment + def experiment(self) -> "wandb.Run": + r"""Actual wandb object. + + To use wandb features in your :class:`~lightning.pytorch.core.LightningModule`, do the + following. + + Example:: + + .. code-block:: python + + self.logger.experiment.some_wandb_function() + + """ + if self._experiment is None: + if self._offline: + os.environ["WANDB_MODE"] = "dryrun" + + attach_id = getattr(self, "_attach_id", None) + if wandb.run is not None: + # wandb process already created in this instance + rank_zero_warn( + "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" + " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." + ) + self._experiment = wandb.run + elif attach_id is not None and hasattr(wandb, "_attach"): + # attach to wandb process referenced + self._experiment = wandb._attach(attach_id) + else: + # create new wandb process + self._experiment = wandb.init(**self._wandb_init) + + # define default x-axis + if isinstance(self._experiment, wandb.Run) and getattr( + self._experiment, "define_metric", None + ): + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric( + "*", step_metric="trainer/global_step", step_sync=True + ) + + self._experiment._label(repo="lightning_fabric_logger") # pylint: disable=protected-access + with telemetry.context(run=self._experiment) as tel: + tel.feature.lightning_fabric_logger = True + return self._experiment + + def watch( + self, + model: nn.Module, + log: str = "gradients", + log_freq: int = 100, + log_graph: bool = True, + ) -> None: + self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph) + + @override + @rank_zero_only + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: # type: ignore[override] + params = _convert_params(params) + params = _sanitize_callable_params(params) + self.experiment.config.update(params, allow_val_change=True) + + @override + @rank_zero_only + def log_metrics( + self, metrics: Mapping[str, float], step: Optional[int] = None + ) -> None: + assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" + + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + if step is not None: + self.experiment.log(dict(metrics, **{"trainer/global_step": step})) + else: + self.experiment.log(metrics) + + @rank_zero_only + def log_table( + self, + key: str, + columns: Optional[List[str]] = None, + data: Optional[List[List[Any]]] = None, + dataframe: Any = None, + step: Optional[int] = None, + ) -> None: + """Log a Table containing any object type (text, image, audio, video, molecule, html, etc). + + Can be defined either with `columns` and `data` or with `dataframe`. + + """ + metrics = {key: wandb.Table(columns=columns, data=data, dataframe=dataframe)} + self.log_metrics(metrics, step) + + @rank_zero_only + def log_text( + self, + key: str, + columns: Optional[List[str]] = None, + data: Optional[List[List[str]]] = None, + dataframe: Any = None, + step: Optional[int] = None, + ) -> None: + """Log text as a Table. + + Can be defined either with `columns` and `data` or with `dataframe`. + + """ + self.log_table(key, columns, data, dataframe, step) + + @rank_zero_only + def log_html( + self, key: str, htmls: List[Any], step: Optional[int] = None, **kwargs: Any + ) -> None: + """Log html files. + + Optional kwargs are lists passed to each html (ex: inject). + + """ + if not isinstance(htmls, list): + raise TypeError(f'Expected a list as "htmls", found {type(htmls)}') + n = len(htmls) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + metrics = { + key: [wandb.Html(html, **kwarg) for html, kwarg in zip(htmls, kwarg_list)] + } + self.log_metrics(metrics, step) # type: ignore[arg-type] + + @rank_zero_only + def log_image( + self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any + ) -> None: + """Log images (tensors, numpy arrays, PIL Images or file paths). + + Optional kwargs are lists passed to each image (ex: caption, masks, boxes). + + """ + if not isinstance(images, list): + raise TypeError(f'Expected a list as "images", found {type(images)}') + n = len(images) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + metrics = { + key: [wandb.Image(img, **kwarg) for img, kwarg in zip(images, kwarg_list)] + } + self.log_metrics(metrics, step) # type: ignore[arg-type] + + @rank_zero_only + def log_audio( + self, key: str, audios: List[Any], step: Optional[int] = None, **kwargs: Any + ) -> None: + r"""Log audios (numpy arrays, or file paths). + + Args: + key: The key to be used for logging the audio files + audios: The list of audio file paths, or numpy arrays to be logged + step: The step number to be used for logging the audio files + \**kwargs: Optional kwargs are lists passed to each ``Wandb.Audio`` instance (ex: caption, sample_rate). + + Optional kwargs are lists passed to each audio (ex: caption, sample_rate). + + """ + if not isinstance(audios, list): + raise TypeError(f'Expected a list as "audios", found {type(audios)}') + n = len(audios) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + metrics = { + key: [ + wandb.Audio(audio, **kwarg) for audio, kwarg in zip(audios, kwarg_list) + ] + } + self.log_metrics(metrics, step) # type: ignore[arg-type] + + @rank_zero_only + def log_video( + self, key: str, videos: List[Any], step: Optional[int] = None, **kwargs: Any + ) -> None: + """Log videos (numpy arrays, or file paths). + + Args: + key: The key to be used for logging the video files + videos: The list of video file paths, or numpy arrays to be logged + step: The step number to be used for logging the video files + **kwargs: Optional kwargs are lists passed to each Wandb.Video instance (ex: caption, fps, format). + + Optional kwargs are lists passed to each video (ex: caption, fps, format). + + """ + if not isinstance(videos, list): + raise TypeError(f'Expected a list as "videos", found {type(videos)}') + n = len(videos) + for k, v in kwargs.items(): + if len(v) != n: + raise ValueError(f"Expected {n} items but only found {len(v)} for {k}") + kwarg_list = [{k: kwargs[k][i] for k in kwargs} for i in range(n)] + + metrics = { + key: [ + wandb.Video(video, **kwarg) for video, kwarg in zip(videos, kwarg_list) + ] + } + self.log_metrics(metrics, step) # type: ignore[arg-type] + + @property + @override + def save_dir(self) -> Optional[str]: + """Gets the save directory. + + Returns: + The path to the save directory. + + """ + return self._save_dir + + @property + @override + def name(self) -> Optional[str]: + """The project name of this experiment. + + Returns: + The name of the project the current experiment belongs to. This name is not the same as `wandb.Run`'s + name. To access wandb's internal experiment name, use ``logger.experiment.name`` instead. + + """ + return self._project + + @property + @override + def version(self) -> Optional[str]: + """Gets the id of the experiment. + + Returns: + The id of the experiment if the experiment exists else the id given to the constructor. + + """ + # don't create an experiment if we don't have one + return self._experiment.id if self._experiment else self._id + + @property + def log_dir(self) -> Optional[str]: + """Gets the save directory. + + Returns: + The path to the save directory. + + """ + return self.save_dir + + @property + def group_separator(self) -> str: + """Return the default separator used by the logger to group the data into subfolders.""" + return self.LOGGER_JOIN_CHAR + + @property + def root_dir(self) -> Optional[str]: + """Return the root directory. + + Return the root directory where all versions of an experiment get saved, or `None` if the logger does not + save data locally. + """ + return self.save_dir.parent if self.save_dir else None + + def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + """Record model graph. + + Args: + model: the model with an implementation of ``forward``. + input_array: input passes to `model.forward` + + This is a noop function and does not perform any operation. + """ + return + + @override + def after_save_checkpoint(self, checkpoint_callback: "ModelCheckpoint") -> None: + # log checkpoints as artifacts + if ( + self._log_model == "all" + or self._log_model is True + and checkpoint_callback.save_top_k == -1 + ): + # TODO: Replace with new Fabric Checkpoints system + self._scan_and_log_pytorch_checkpoints(checkpoint_callback) + elif self._log_model is True: + self._checkpoint_callback = checkpoint_callback + + @staticmethod + @rank_zero_only + def download_artifact( + artifact: str, + save_dir: Optional[_PATH] = None, + artifact_type: Optional[str] = None, + use_artifact: Optional[bool] = True, + ) -> str: + """Downloads an artifact from the wandb server. + + Args: + artifact: The path of the artifact to download. + save_dir: The directory to save the artifact to. + artifact_type: The type of artifact to download. + use_artifact: Whether to add an edge between the artifact graph. + + Returns: + The path to the downloaded artifact. + + """ + if wandb.run is not None and use_artifact: + artifact = wandb.run.use_artifact(artifact) + else: + api = wandb.Api() + artifact = api.artifact(artifact, type=artifact_type) + + save_dir = None if save_dir is None else os.fspath(save_dir) + return artifact.download(root=save_dir) + + def use_artifact( + self, artifact: str, artifact_type: Optional[str] = None + ) -> "Artifact": + """Logs to the wandb dashboard that the mentioned artifact is used by the run. + + Args: + artifact: The path of the artifact. + artifact_type: The type of artifact being used. + + Returns: + wandb Artifact object for the artifact. + + """ + return self.experiment.use_artifact(artifact, type=artifact_type) + + @override + @rank_zero_only + def save(self) -> None: + """Save log data.""" + self.experiment.log({}, commit=True) + + @override + @rank_zero_only + def finalize(self, status: str) -> None: + if self._log_checkpoint_on == "success" and status != "success": + # Currently, checkpoints only get logged on success + return + # log checkpoints as artifacts + if ( + self._checkpoint_callback + and self._experiment is not None + and self._log_checkpoint_on in ["success", "all"] + ): + self._scan_and_log_pytorch_checkpoints(self._checkpoint_callback) + + def _scan_and_log_pytorch_checkpoints( + self, checkpoint_callback: "ModelCheckpoint" + ) -> None: + from lightning.pytorch.loggers.utilities import _scan_checkpoints + + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # log iteratively all new checkpoints + for t, p, s, _ in checkpoints: + metadata = { + "score": s.item() if isinstance(s, Tensor) else s, + "original_filename": Path(p).name, + checkpoint_callback.__class__.__name__: { + k: getattr(checkpoint_callback, k) + for k in [ + "monitor", + "mode", + "save_last", + "save_top_k", + "save_weights_only", + "_every_n_train_steps", + ] + # ensure it does not break if `ModelCheckpoint` args change + if hasattr(checkpoint_callback, k) + }, + } + if not self._checkpoint_name: + self._checkpoint_name = f"model-{self.experiment.id}" + artifact = wandb.Artifact( + name=self._checkpoint_name, type="model", metadata=metadata + ) + artifact.add_file(p, name="model.ckpt") + aliases = ( + ["latest", "best"] + if p == checkpoint_callback.best_model_path + else ["latest"] + ) + self.experiment.log_model(artifact, aliases=aliases) + # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) + self._logged_model_time[p] = t diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7dc33daec78edbe8e6c91dc5f1f0b468acabbb7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/__init__.py @@ -0,0 +1,9 @@ +"""W&B Integration for Metaflow. + +Defines a custom step and flow decorator `wandb_log` that automatically logs +flow parameters and artifacts to W&B. +""" + +from .metaflow import wandb_log, wandb_track, wandb_use + +__all__ = ["wandb_log", "wandb_track", "wandb_use"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_pandas.py b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_pandas.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad5528152ebed557341df0dc8a48d4eea2acde0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_pandas.py @@ -0,0 +1,74 @@ +"""Support for Pandas datatypes. + +May raise MissingDependencyError on import. +""" + +from __future__ import annotations + +from typing_extensions import Any, TypeIs + +import wandb + +from . import errors + +try: + import pandas as pd +except ImportError as e: + warning = ( + "`pandas` not installed >>" + " @wandb_log(datasets=True) may not auto log your dataset!" + ) + raise errors.MissingDependencyError(warning=warning) from e + + +def is_dataframe(data: Any) -> TypeIs[pd.DataFrame]: + """Returns whether the data is a Pandas DataFrame.""" + return isinstance(data, pd.DataFrame) + + +def use_dataframe( + name: str, + run: wandb.Run | None, + testing: bool = False, +) -> str | None: + """Log a dependency on a DataFrame input. + + Args: + name: Name of the input. + run: The run to update. + testing: True in unit tests. + """ + if testing: + return "datasets" + assert run + + wandb.termlog(f"Using artifact: {name} (Pandas DataFrame)") + run.use_artifact(f"{name}:latest") + return None + + +def track_dataframe( + name: str, + data: pd.DataFrame, + run: wandb.Run | None, + testing: bool = False, +) -> str | None: + """Log a DataFrame output as an artifact. + + Args: + name: The output's name. + data: The output's value. + run: The run to update. + testing: True in unit tests. + """ + if testing: + return "pd.DataFrame" + assert run + + artifact = wandb.Artifact(name, type="dataset") + with artifact.new_file(f"{name}.parquet", "wb") as f: + data.to_parquet(f, engine="pyarrow") + + wandb.termlog(f"Logging artifact: {name} (Pandas DataFrame)") + run.log_artifact(artifact) + return None diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_pytorch.py b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2977ea7e2343aa7f6dcb23a622845ff6692729 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_pytorch.py @@ -0,0 +1,75 @@ +"""Support for PyTorch datatypes. + +May raise MissingDependencyError on import. +""" + +from __future__ import annotations + +from typing_extensions import Any, TypeIs + +import wandb + +from . import errors + +try: + import torch + import torch.nn as nn +except ImportError as e: + warning = ( + "`torch` (PyTorch) not installed >>" + " @wandb_log(models=True) may not auto log your model!" + ) + raise errors.MissingDependencyError(warning=warning) from e + + +def is_nn_module(data: Any) -> TypeIs[nn.Module]: + """Returns whether the data is a PyTorch nn.Module.""" + return isinstance(data, nn.Module) + + +def use_nn_module( + name: str, + run: wandb.Run | None, + testing: bool = False, +) -> str | None: + """Log a dependency on a PyTorch model input. + + Args: + name: Name of the input. + run: The run to update. + testing: True in unit tests. + """ + if testing: + return "models" + assert run + + wandb.termlog(f"Using artifact: {name} (PyTorch nn.Module)") + run.use_artifact(f"{name}:latest") + return None + + +def track_nn_module( + name: str, + data: nn.Module, + run: wandb.Run | None, + testing: bool = False, +) -> str | None: + """Log a PyTorch model output as an artifact. + + Args: + name: The output's name. + data: The output's value. + run: The run to update. + testing: True in unit tests. + """ + if testing: + return "nn.Module" + assert run + + artifact = wandb.Artifact(name, type="model") + with artifact.new_file(f"{name}.pkl", "wb") as f: + torch.save(data, f) + + wandb.termlog(f"Logging artifact: {name} (PyTorch nn.Module)") + run.log_artifact(artifact) + return None diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_sklearn.py b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_sklearn.py new file mode 100644 index 0000000000000000000000000000000000000000..98fe4576f1903ea821372f2758ef499fb7e5e9a0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/data_sklearn.py @@ -0,0 +1,76 @@ +"""Support for sklearn datatypes. + +May raise MissingDependencyError on import. +""" + +from __future__ import annotations + +import pickle + +from typing_extensions import Any, TypeIs + +import wandb + +from . import errors + +try: + from sklearn.base import BaseEstimator +except ImportError as e: + warning = ( + "`sklearn` not installed >>" + " @wandb_log(models=True) may not auto log your model!" + ) + raise errors.MissingDependencyError(warning=warning) from e + + +def is_estimator(data: Any) -> TypeIs[BaseEstimator]: + """Returns whether the data is an sklearn BaseEstimator.""" + return isinstance(data, BaseEstimator) + + +def use_estimator( + name: str, + run: wandb.Run | None, + testing: bool = False, +) -> str | None: + """Log a dependency on an sklearn estimator. + + Args: + name: Name of the input. + run: The run to update. + testing: True in unit tests. + """ + if testing: + return "models" + assert run + + wandb.termlog(f"Using artifact: {name} (sklearn BaseEstimator)") + run.use_artifact(f"{name}:latest") + return None + + +def track_estimator( + name: str, + data: BaseEstimator, + run: wandb.Run | None, + testing: bool = False, +) -> str | None: + """Log an sklearn estimator output as an artifact. + + Args: + name: The output's name. + data: The output's value. + run: The run to update. + testing: True in unit tests. + """ + if testing: + return "BaseEstimator" + assert run + + artifact = wandb.Artifact(name, type="model") + with artifact.new_file(f"{name}.pkl", "wb") as f: + pickle.dump(data, f) + + wandb.termlog(f"Logging artifact: {name} (sklearn BaseEstimator)") + run.log_artifact(artifact) + return None diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/errors.py b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0a3ae962cdadde195fa31c6efb4f3d0cc07d5b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/errors.py @@ -0,0 +1,13 @@ +import wandb + + +class MissingDependencyError(Exception): + """An optional dependency is missing.""" + + def __init__(self, *args: object, warning: str) -> None: + super().__init__(*args) + self._wb_warning = warning + + def warn(self) -> None: + """Print a warning for the problem.""" + wandb.termwarn(self._wb_warning) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/metaflow.py b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/metaflow.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce970d17636ce5417e81fbf9bb5b3b0a6553bd1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/metaflow/metaflow.py @@ -0,0 +1,327 @@ +import inspect +import pickle +from functools import wraps +from pathlib import Path +from typing import Optional, Union + +import wandb +from wandb.sdk.lib import telemetry as wb_telemetry + +from . import errors + +try: + from metaflow import current +except ImportError as e: + raise Exception( + "Error: `metaflow` not installed >> This integration requires metaflow!" + " To fix, please `pip install -Uqq metaflow`" + ) from e + + +try: + from . import data_pandas +except errors.MissingDependencyError as e: + e.warn() + data_pandas = None + +try: + from . import data_pytorch +except errors.MissingDependencyError as e: + e.warn() + data_pytorch = None + +try: + from . import data_sklearn +except errors.MissingDependencyError as e: + e.warn() + data_sklearn = None + + +class ArtifactProxy: + def __init__(self, flow): + # do this to avoid recursion problem with __setattr__ + self.__dict__.update( + { + "flow": flow, + "inputs": {}, + "outputs": {}, + "base": set(dir(flow)), + "params": {p: getattr(flow, p) for p in current.parameter_names}, + } + ) + + def __setattr__(self, key, val): + self.outputs[key] = val + return setattr(self.flow, key, val) + + def __getattr__(self, key): + if key not in self.base and key not in self.outputs: + self.inputs[key] = getattr(self.flow, key) + return getattr(self.flow, key) + + +def _track_scalar( + name: str, + data: Union[dict, list, set, str, int, float, bool], + run, + testing: bool = False, +) -> Optional[str]: + if testing: + return "scalar" + + run.log({name: data}) + return None + + +def _track_path( + name: str, + data: Path, + run, + testing: bool = False, +) -> Optional[str]: + if testing: + return "Path" + + artifact = wandb.Artifact(name, type="dataset") + if data.is_dir(): + artifact.add_dir(data) + elif data.is_file(): + artifact.add_file(data) + run.log_artifact(artifact) + wandb.termlog(f"Logging artifact: {name} ({type(data)})") + return None + + +def _track_generic( + name: str, + data, + run, + testing: bool = False, +) -> Optional[str]: + if testing: + return "generic" + + artifact = wandb.Artifact(name, type="other") + with artifact.new_file(f"{name}.pkl", "wb") as f: + pickle.dump(data, f) + run.log_artifact(artifact) + wandb.termlog(f"Logging artifact: {name} ({type(data)})") + return None + + +def wandb_track( + name: str, + data, + datasets: bool = False, + models: bool = False, + others: bool = False, + run: Optional[wandb.Run] = None, + testing: bool = False, +) -> Optional[str]: + """Track data as wandb artifacts based on type and flags.""" + # Check for pandas DataFrame + if data_pandas and data_pandas.is_dataframe(data) and datasets: + return data_pandas.track_dataframe(name, data, run, testing) + + # Check for PyTorch Module + if data_pytorch and data_pytorch.is_nn_module(data) and models: + return data_pytorch.track_nn_module(name, data, run, testing) + + # Check for scikit-learn BaseEstimator + if data_sklearn and data_sklearn.is_estimator(data) and models: + return data_sklearn.track_estimator(name, data, run, testing) + + # Check for Path objects + if isinstance(data, Path) and datasets: + return _track_path(name, data, run, testing) + + # Check for scalar types + if isinstance(data, (dict, list, set, str, int, float, bool)): + return _track_scalar(name, data, run, testing) + + # Generic fallback + if others: + return _track_generic(name, data, run, testing) + + # No action taken + return None + + +def wandb_use( + name: str, + data, + datasets: bool = False, + models: bool = False, + others: bool = False, + run=None, + testing: bool = False, +) -> Optional[str]: + """Use wandb artifacts based on data type and flags.""" + # Skip scalar types - nothing to use + if isinstance(data, (dict, list, set, str, int, float, bool)): + return None + + try: + # Check for pandas DataFrame + if data_pandas and data_pandas.is_dataframe(data) and datasets: + return data_pandas.use_dataframe(name, run, testing) + + # Check for PyTorch Module + elif data_pytorch and data_pytorch.is_nn_module(data) and models: + return data_pytorch.use_nn_module(name, run, testing) + + # Check for scikit-learn BaseEstimator + elif data_sklearn and data_sklearn.is_estimator(data) and models: + return data_sklearn.use_estimator(name, run, testing) + + # Check for Path objects + elif isinstance(data, Path) and datasets: + return _use_path(name, data, run, testing) + + # Generic fallback + elif others: + return _use_generic(name, data, run, testing) + + else: + return None + + except wandb.CommError: + wandb.termwarn( + f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!" + " If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier)," + " then you can safely ignore this. Otherwise you may want to check your internet connection!" + ) + return None + + +def _use_path( + name: str, + data: Path, + run, + testing: bool = False, +) -> Optional[str]: + if testing: + return "datasets" + + run.use_artifact(f"{name}:latest") + wandb.termlog(f"Using artifact: {name} ({type(data)})") + return None + + +def _use_generic( + name: str, + data, + run, + testing: bool = False, +) -> Optional[str]: + if testing: + return "others" + + run.use_artifact(f"{name}:latest") + wandb.termlog(f"Using artifact: {name} ({type(data)})") + return None + + +def coalesce(*arg): + return next((a for a in arg if a is not None), None) + + +def wandb_log( + func=None, + /, + datasets: bool = False, + models: bool = False, + others: bool = False, + settings: Optional[wandb.Settings] = None, +): + """Automatically log parameters and artifacts to W&B. + + This decorator can be applied to a flow, step, or both: + + - Decorating a step enables or disables logging within that step + - Decorating a flow is equivalent to decorating all steps + - Decorating a step after decorating its flow overwrites the flow decoration + + Args: + func: The step method or flow class to decorate. + datasets: Whether to log `pd.DataFrame` and `pathlib.Path` + types. Defaults to False. + models: Whether to log `nn.Module` and `sklearn.base.BaseEstimator` + types. Defaults to False. + others: If `True`, log anything pickle-able. Defaults to False. + settings: Custom settings to pass to `wandb.init`. + If `run_group` is `None`, it is set to `{flow_name}/{run_id}`. + If `run_job_type` is `None`, it is set to `{run_job_type}/{step_name}`. + """ + + @wraps(func) + def decorator(func): + # If you decorate a class, apply the decoration to all methods in that class + if inspect.isclass(func): + cls = func + for attr in cls.__dict__: + if callable(getattr(cls, attr)): + if not hasattr(attr, "_base_func"): + setattr(cls, attr, decorator(getattr(cls, attr))) + return cls + + # prefer the earliest decoration (i.e. method decoration overrides class decoration) + if hasattr(func, "_base_func"): + return func + + @wraps(func) + def wrapper(self, *args, settings=settings, **kwargs): + if not isinstance(settings, wandb.sdk.wandb_settings.Settings): + settings = wandb.Settings() + + settings.update_from_dict( + { + "run_group": coalesce( + settings.run_group, f"{current.flow_name}/{current.run_id}" + ), + "run_job_type": coalesce(settings.run_job_type, current.step_name), + } + ) + + with wandb.init(settings=settings) as run: + with wb_telemetry.context(run=run) as tel: + tel.feature.metaflow = True + proxy = ArtifactProxy(self) + run.config.update(proxy.params) + func(proxy, *args, **kwargs) + + for name, data in proxy.inputs.items(): + wandb_use( + name, + data, + datasets=datasets, + models=models, + others=others, + run=run, + ) + + for name, data in proxy.outputs.items(): + wandb_track( + name, + data, + datasets=datasets, + models=models, + others=others, + run=run, + ) + + wrapper._base_func = func + + # Add for testing visibility + wrapper._kwargs = { + "datasets": datasets, + "models": models, + "others": others, + "settings": settings, + } + return wrapper + + if func is None: + return decorator + else: + return decorator(func) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/openai/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/openai/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c216c5f72a7d0310b5cfab98b4c6ff5e0e75e6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/openai/__init__.py @@ -0,0 +1,3 @@ +__all__ = ("autolog", "WandbLogger") + +from .openai import autolog diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/openai/fine_tuning.py b/.venv/lib/python3.13/site-packages/wandb/integration/openai/fine_tuning.py new file mode 100644 index 0000000000000000000000000000000000000000..22000c8c342d31451a3959dae33553a8029fa915 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/openai/fine_tuning.py @@ -0,0 +1,480 @@ +import base64 +import datetime +import io +import json +import os +import re +import tempfile +import time +from typing import Any, Dict, List, Optional, Tuple, Union + +from packaging.version import parse + +import wandb +from wandb import util +from wandb.data_types import Table +from wandb.sdk.lib import telemetry + +openai = util.get_module( + name="openai", + required="This integration requires `openai`. To install, please run `pip install openai`", + lazy=False, +) + +if parse(openai.__version__) < parse("1.12.0"): + raise wandb.Error( + f"This integration requires openai version 1.12.0 and above. Your current version is {openai.__version__} " + "To fix, please `pip install -U openai`" + ) + +from openai import OpenAI # noqa: E402 +from openai.types.fine_tuning import FineTuningJob # noqa: E402 +from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402 + Error, + Hyperparameters, +) + +np = util.get_module( + name="numpy", + required="`numpy` not installed >> This integration requires numpy! To fix, please `pip install numpy`", + lazy=False, +) + +pd = util.get_module( + name="pandas", + required="`pandas` not installed >> This integration requires pandas! To fix, please `pip install pandas`", + lazy=False, +) + + +class WandbLogger: + """Log OpenAI fine-tunes to [Weights & Biases](https://wandb.me/openai-docs).""" + + _wandb_api: Optional[wandb.Api] = None + _logged_in: bool = False + openai_client: Optional[OpenAI] = None + _run: Optional[wandb.Run] = None + + @classmethod + def sync( + cls, + fine_tune_job_id: Optional[str] = None, + openai_client: Optional[OpenAI] = None, + num_fine_tunes: Optional[int] = None, + project: str = "OpenAI-Fine-Tune", + entity: Optional[str] = None, + overwrite: bool = False, + wait_for_job_success: bool = True, + log_datasets: bool = True, + model_artifact_name: str = "model-metadata", + model_artifact_type: str = "model", + **kwargs_wandb_init: Dict[str, Any], + ) -> str: + """Sync fine-tunes to Weights & Biases. + + :param fine_tune_job_id: The id of the fine-tune (optional) + :param openai_client: Pass the `OpenAI()` client (optional) + :param num_fine_tunes: Number of most recent fine-tunes to log when an fine_tune_job_id is not provided. By default, every fine-tune is synced. + :param project: Name of the project where you're sending runs. By default, it is "GPT-3". + :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username. + :param overwrite: Forces logging and overwrite existing wandb run of the same fine-tune. + :param wait_for_job_success: Waits for the fine-tune to be complete and then log metrics to W&B. By default, it is True. + :param model_artifact_name: Name of the model artifact that is logged + :param model_artifact_type: Type of the model artifact that is logged + """ + if openai_client is None: + openai_client = OpenAI() + cls.openai_client = openai_client + + if fine_tune_job_id: + wandb.termlog("Retrieving fine-tune job...") + fine_tune = openai_client.fine_tuning.jobs.retrieve( + fine_tuning_job_id=fine_tune_job_id + ) + fine_tunes = [fine_tune] + else: + # get list of fine_tune to log + fine_tunes = openai_client.fine_tuning.jobs.list() + if not fine_tunes or fine_tunes.data is None: + wandb.termwarn("No fine-tune has been retrieved") + return + # Select the `num_fine_tunes` from the `fine_tunes.data` list. + # If `num_fine_tunes` is None, it selects all items in the list (from start to end). + # If for example, `num_fine_tunes` is 5, it selects the last 5 items in the list. + # Note that the last items in the list are the latest fine-tune jobs. + fine_tunes = fine_tunes.data[ + -num_fine_tunes if num_fine_tunes is not None else None : + ] + + # log starting from oldest fine_tune + show_individual_warnings = ( + fine_tune_job_id is not None or num_fine_tunes is not None + ) + fine_tune_logged = [] + for fine_tune in fine_tunes: + fine_tune_id = fine_tune.id + # check run with the given `fine_tune_id` has not been logged already + run_path = f"{project}/{fine_tune_id}" + if entity is not None: + run_path = f"{entity}/{run_path}" + wandb_run = cls._get_wandb_run(run_path) + if wandb_run: + wandb_status = wandb_run.summary.get("status") + if show_individual_warnings: + if wandb_status == "succeeded" and not overwrite: + wandb.termwarn( + f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}. " + "Use `overwrite=True` if you want to overwrite previous run" + ) + elif wandb_status != "succeeded" or overwrite: + if wandb_status != "succeeded": + wandb.termwarn( + f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully" + ) + wandb.termlog( + f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten" + ) + overwrite = True + if wandb_status == "succeeded" and not overwrite: + return + + # check if the user has not created a wandb run externally + if wandb.run is None: + cls._run = wandb.init( + job_type="fine-tune", + project=project, + entity=entity, + name=fine_tune_id, + id=fine_tune_id, + **kwargs_wandb_init, + ) + else: + # if a run exits - created externally + cls._run = wandb.run + + if wait_for_job_success: + fine_tune = cls._wait_for_job_success(fine_tune) + + cls._log_fine_tune( + fine_tune, + project, + entity, + overwrite, + show_individual_warnings, + log_datasets, + model_artifact_name, + model_artifact_type, + **kwargs_wandb_init, + ) + + if not show_individual_warnings and not any(fine_tune_logged): + wandb.termwarn("No new successful fine-tunes were found") + + return "🎉 wandb sync completed successfully" + + @classmethod + def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob: + wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...") + wandb.termlog( + "To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes." + ) + while True: + if fine_tune.status == "succeeded": + wandb.termlog( + "Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases" + ) + return fine_tune + if fine_tune.status == "failed": + wandb.termwarn( + f"Fine-tune {fine_tune.id} has failed and will not be logged" + ) + return fine_tune + if fine_tune.status == "cancelled": + wandb.termwarn( + f"Fine-tune {fine_tune.id} was cancelled and will not be logged" + ) + return fine_tune + time.sleep(10) + fine_tune = cls.openai_client.fine_tuning.jobs.retrieve( + fine_tuning_job_id=fine_tune.id + ) + + @classmethod + def _log_fine_tune( + cls, + fine_tune: FineTuningJob, + project: str, + entity: Optional[str], + overwrite: bool, + show_individual_warnings: bool, + log_datasets: bool, + model_artifact_name: str, + model_artifact_type: str, + **kwargs_wandb_init: Dict[str, Any], + ): + fine_tune_id = fine_tune.id + status = fine_tune.status + + with telemetry.context(run=cls._run) as tel: + tel.feature.openai_finetuning = True + + # check run completed successfully + if status != "succeeded": + if show_individual_warnings: + wandb.termwarn( + f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged' + ) + return + + # check results are present + try: + results_id = fine_tune.result_files[0] + try: + encoded_results = cls.openai_client.files.content( + file_id=results_id + ).read() + results = base64.b64decode(encoded_results).decode("utf-8") + except Exception: + # attempt to read as text, works for older jobs + results = cls.openai_client.files.content(file_id=results_id).text + except openai.NotFoundError: + if show_individual_warnings: + wandb.termwarn( + f"Fine-tune {fine_tune_id} has no results and will not be logged" + ) + return + + # update the config + cls._run.config.update(cls._get_config(fine_tune)) + + # log results + df_results = pd.read_csv(io.StringIO(results)) + for _, row in df_results.iterrows(): + metrics = {k: v for k, v in row.items() if not np.isnan(v)} + step = metrics.pop("step") + if step is not None: + step = int(step) + cls._run.log(metrics, step=step) + fine_tuned_model = fine_tune.fine_tuned_model + if fine_tuned_model is not None: + cls._run.summary["fine_tuned_model"] = fine_tuned_model + + # training/validation files and fine-tune details + cls._log_artifacts( + fine_tune, + project, + entity, + log_datasets, + overwrite, + model_artifact_name, + model_artifact_type, + ) + + # mark run as complete + cls._run.summary["status"] = "succeeded" + + cls._run.finish() + return True + + @classmethod + def _ensure_logged_in(cls): + if not cls._logged_in: + if wandb.login(): + cls._logged_in = True + else: + raise Exception( + "It appears you are not currently logged in to Weights & Biases. " + "Please run `wandb login` in your terminal or `wandb.login()` in a notebook. " + "Create a new API key at https://wandb.ai/settings and store it securely." + ) + + @classmethod + def _get_wandb_run(cls, run_path: str): + cls._ensure_logged_in() + try: + if cls._wandb_api is None: + cls._wandb_api = wandb.Api() + return cls._wandb_api.run(run_path) + except Exception: + return None + + @classmethod + def _get_wandb_artifact(cls, artifact_path: str): + cls._ensure_logged_in() + try: + if cls._wandb_api is None: + cls._wandb_api = wandb.Api() + return cls._wandb_api.artifact(artifact_path) + except Exception: + return None + + @classmethod + def _get_config(cls, fine_tune: FineTuningJob) -> Dict[str, Any]: + config = dict(fine_tune) + config["result_files"] = config["result_files"][0] + if config.get("created_at"): + config["created_at"] = datetime.datetime.fromtimestamp( + config["created_at"] + ).strftime("%Y-%m-%d %H:%M:%S") + if config.get("finished_at"): + config["finished_at"] = datetime.datetime.fromtimestamp( + config["finished_at"] + ).strftime("%Y-%m-%d %H:%M:%S") + if config.get("hyperparameters"): + config["hyperparameters"] = cls.sanitize(config["hyperparameters"]) + if config.get("error"): + config["error"] = cls.sanitize(config["error"]) + return config + + @classmethod + def _unpack_hyperparameters(cls, hyperparameters: Hyperparameters): + # `Hyperparameters` object is not unpacking properly using `vars` or `__dict__`, + # vars(hyperparameters) return {n_epochs: n} only. + hyperparams = {} + try: + hyperparams["n_epochs"] = hyperparameters.n_epochs + hyperparams["batch_size"] = hyperparameters.batch_size + hyperparams["learning_rate_multiplier"] = ( + hyperparameters.learning_rate_multiplier + ) + except Exception: + # If unpacking fails, return the object to be logged as config + return None + + return hyperparams + + @staticmethod + def sanitize(input: Any) -> Union[Dict, List, str]: + valid_types = [bool, int, float, str] + if isinstance(input, (Hyperparameters, Error)): + return dict(input) + if isinstance(input, dict): + return { + k: v if type(v) in valid_types else str(v) for k, v in input.items() + } + elif isinstance(input, list): + return [v if type(v) in valid_types else str(v) for v in input] + else: + return str(input) + + @classmethod + def _log_artifacts( + cls, + fine_tune: FineTuningJob, + project: str, + entity: Optional[str], + log_datasets: bool, + overwrite: bool, + model_artifact_name: str, + model_artifact_type: str, + ) -> None: + if log_datasets: + wandb.termlog("Logging training/validation files...") + # training/validation files + training_file = fine_tune.training_file if fine_tune.training_file else None + validation_file = ( + fine_tune.validation_file if fine_tune.validation_file else None + ) + for file, prefix, artifact_type in ( + (training_file, "train", "training_files"), + (validation_file, "valid", "validation_files"), + ): + if file is not None: + cls._log_artifact_inputs( + file, prefix, artifact_type, project, entity, overwrite + ) + + # fine-tune details + fine_tune_id = fine_tune.id + artifact = wandb.Artifact( + model_artifact_name, + type=model_artifact_type, + metadata=dict(fine_tune), + ) + + with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f: + dict_fine_tune = dict(fine_tune) + dict_fine_tune["hyperparameters"] = cls.sanitize( + dict_fine_tune["hyperparameters"] + ) + dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"]) + dict_fine_tune = cls.sanitize(dict_fine_tune) + json.dump(dict_fine_tune, f, indent=2) + cls._run.log_artifact( + artifact, + aliases=["latest", fine_tune_id], + ) + + @classmethod + def _log_artifact_inputs( + cls, + file_id: Optional[str], + prefix: str, + artifact_type: str, + project: str, + entity: Optional[str], + overwrite: bool, + ) -> None: + # get input artifact + artifact_name = f"{prefix}-{file_id}" + # sanitize name to valid wandb artifact name + artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name) + artifact_alias = file_id + artifact_path = f"{project}/{artifact_name}:{artifact_alias}" + if entity is not None: + artifact_path = f"{entity}/{artifact_path}" + artifact = cls._get_wandb_artifact(artifact_path) + + # create artifact if file not already logged previously + if artifact is None or overwrite: + # get file content + try: + file_content = cls.openai_client.files.content(file_id=file_id) + except openai.NotFoundError: + wandb.termerror( + f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files" + ) + return + + artifact = wandb.Artifact(artifact_name, type=artifact_type) + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file.write(file_content.content) + tmp_file_path = tmp_file.name + artifact.add_file(tmp_file_path, file_id) + os.unlink(tmp_file_path) + + # create a Table + try: + table, n_items = cls._make_table(file_content.text) + # Add table to the artifact. + artifact.add(table, file_id) + # Add the same table to the workspace. + cls._run.log({f"{prefix}_data": table}) + # Update the run config and artifact metadata + cls._run.config.update({f"n_{prefix}": n_items}) + artifact.metadata["items"] = n_items + except Exception as e: + wandb.termerror( + f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'" + ) + else: + # log number of items + cls._run.config.update({f"n_{prefix}": artifact.metadata.get("items")}) + + cls._run.use_artifact(artifact, aliases=["latest", artifact_alias]) + + @classmethod + def _make_table(cls, file_content: str) -> Tuple[Table, int]: + table = wandb.Table(columns=["role: system", "role: user", "role: assistant"]) + + df = pd.read_json(io.StringIO(file_content), orient="records", lines=True) + for _idx, message in df.iterrows(): + messages = message.messages + assert len(messages) == 3 + table.add_data( + messages[0]["content"], + messages[1]["content"], + messages[2]["content"], + ) + + return table, len(df) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/openai/openai.py b/.venv/lib/python3.13/site-packages/wandb/integration/openai/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..250d437d7239bbcd31cc950c80563519f87e5fde --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/openai/openai.py @@ -0,0 +1,22 @@ +import logging + +from wandb.sdk.integration_utils.auto_logging import AutologAPI + +from .resolver import OpenAIRequestResponseResolver + +logger = logging.getLogger(__name__) + + +autolog = AutologAPI( + name="OpenAI", + symbols=( + "Edit.create", + "Completion.create", + "ChatCompletion.create", + "Edit.acreate", + "Completion.acreate", + "ChatCompletion.acreate", + ), + resolver=OpenAIRequestResponseResolver(), + telemetry_feature="openai_autolog", +) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/openai/resolver.py b/.venv/lib/python3.13/site-packages/wandb/integration/openai/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..500c58ce2f4b33a6a987243169bc998cf387fef4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/openai/resolver.py @@ -0,0 +1,240 @@ +import datetime +import io +import logging +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Sequence + +import wandb +from wandb.sdk.data_types import trace_tree +from wandb.sdk.integration_utils.auto_logging import Response + +logger = logging.getLogger(__name__) + + +@dataclass +class UsageMetrics: + elapsed_time: float = None + prompt_tokens: int = None + completion_tokens: int = None + total_tokens: int = None + + +@dataclass +class Metrics: + usage: UsageMetrics = None + stats: wandb.Table = None + trace: trace_tree.WBTraceTree = None + + +usage_metric_keys = {f"usage/{k}" for k in asdict(UsageMetrics())} + + +class OpenAIRequestResponseResolver: + def __init__(self): + self.define_metrics_called = False + + def __call__( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + response: Response, + start_time: float, # pass to comply with the protocol, but use response["created"] instead + time_elapsed: float, + ) -> Optional[Dict[str, Any]]: + request = kwargs + + if not self.define_metrics_called: + # define metrics on first call + for key in usage_metric_keys: + wandb.define_metric(key, step_metric="_timestamp") + self.define_metrics_called = True + + try: + if response.get("object") == "edit": + return self._resolve_edit(request, response, time_elapsed) + elif response.get("object") == "text_completion": + return self._resolve_completion(request, response, time_elapsed) + elif response.get("object") == "chat.completion": + return self._resolve_chat_completion(request, response, time_elapsed) + else: + # todo: properly treat failed requests + logger.info( + f"Unsupported OpenAI response object: {response.get('object')}" + ) + except Exception as e: + logger.warning(f"Failed to resolve request/response: {e}") + return None + + @staticmethod + def results_to_trace_tree( + request: Dict[str, Any], + response: Response, + results: List[trace_tree.Result], + time_elapsed: float, + ) -> trace_tree.WBTraceTree: + """Converts the request, response, and results into a trace tree. + + params: + request: The request dictionary + response: The response object + results: A list of results object + time_elapsed: The time elapsed in seconds + returns: + A wandb trace tree object. + """ + start_time_ms = int(round(response["created"] * 1000)) + end_time_ms = start_time_ms + int(round(time_elapsed * 1000)) + span = trace_tree.Span( + name=f"{response.get('model', 'openai')}_{response['object']}_{response.get('created')}", + attributes=dict(response), # type: ignore + start_time_ms=start_time_ms, + end_time_ms=end_time_ms, + span_kind=trace_tree.SpanKind.LLM, + results=results, + ) + model_obj = {"request": request, "response": response, "_kind": "openai"} + return trace_tree.WBTraceTree(root_span=span, model_dict=model_obj) + + def _resolve_edit( + self, + request: Dict[str, Any], + response: Response, + time_elapsed: float, + ) -> Dict[str, Any]: + """Resolves the request and response objects for `openai.Edit`.""" + request_str = ( + f"\n\n**Instruction**: {request['instruction']}\n\n" + f"**Input**: {request['input']}\n" + ) + choices = [ + f"\n\n**Edited**: {choice['text']}\n" for choice in response["choices"] + ] + + return self._resolve_metrics( + request=request, + response=response, + request_str=request_str, + choices=choices, + time_elapsed=time_elapsed, + ) + + def _resolve_completion( + self, + request: Dict[str, Any], + response: Response, + time_elapsed: float, + ) -> Dict[str, Any]: + """Resolves the request and response objects for `openai.Completion`.""" + request_str = f"\n\n**Prompt**: {request['prompt']}\n" + choices = [ + f"\n\n**Completion**: {choice['text']}\n" for choice in response["choices"] + ] + + return self._resolve_metrics( + request=request, + response=response, + request_str=request_str, + choices=choices, + time_elapsed=time_elapsed, + ) + + def _resolve_chat_completion( + self, + request: Dict[str, Any], + response: Response, + time_elapsed: float, + ) -> Dict[str, Any]: + """Resolves the request and response objects for `openai.Completion`.""" + prompt = io.StringIO() + for message in request["messages"]: + prompt.write(f"\n\n**{message['role']}**: {message['content']}\n") + request_str = prompt.getvalue() + + choices = [ + f"\n\n**{choice['message']['role']}**: {choice['message']['content']}\n" + for choice in response["choices"] + ] + + return self._resolve_metrics( + request=request, + response=response, + request_str=request_str, + choices=choices, + time_elapsed=time_elapsed, + ) + + def _resolve_metrics( + self, + request: Dict[str, Any], + response: Response, + request_str: str, + choices: List[str], + time_elapsed: float, + ) -> Dict[str, Any]: + """Resolves the request and response objects for `openai.Completion`.""" + results = [ + trace_tree.Result( + inputs={"request": request_str}, + outputs={"response": choice}, + ) + for choice in choices + ] + metrics = self._get_metrics_to_log(request, response, results, time_elapsed) + return self._convert_metrics_to_dict(metrics) + + @staticmethod + def _get_usage_metrics(response: Response, time_elapsed: float) -> UsageMetrics: + """Gets the usage stats from the response object.""" + if response.get("usage"): + usage_stats = UsageMetrics(**response["usage"]) + else: + usage_stats = UsageMetrics() + usage_stats.elapsed_time = time_elapsed + return usage_stats + + def _get_metrics_to_log( + self, + request: Dict[str, Any], + response: Response, + results: List[Any], + time_elapsed: float, + ) -> Metrics: + model = response.get("model") or request.get("model") + usage_metrics = self._get_usage_metrics(response, time_elapsed) + + usage = [] + for result in results: + row = { + "request": result.inputs["request"], + "response": result.outputs["response"], + "model": model, + "start_time": datetime.datetime.fromtimestamp(response["created"]), + "end_time": datetime.datetime.fromtimestamp( + response["created"] + time_elapsed + ), + "request_id": response.get("id", None), + "api_type": response.get("api_type", "openai"), + "session_id": wandb.run.id, + } + row.update(asdict(usage_metrics)) + usage.append(row) + usage_table = wandb.Table( + columns=list(usage[0].keys()), + data=[(item.values()) for item in usage], + ) + + trace = self.results_to_trace_tree(request, response, results, time_elapsed) + + metrics = Metrics(stats=usage_table, trace=trace, usage=usage_metrics) + return metrics + + @staticmethod + def _convert_metrics_to_dict(metrics: Metrics) -> Dict[str, Any]: + """Converts metrics to a dict.""" + metrics_dict = { + "stats": metrics.stats, + "trace": metrics.trace, + } + usage_stats = {f"usage/{k}": v for k, v in asdict(metrics.usage).items()} + metrics_dict.update(usage_stats) + return metrics_dict diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/prodigy/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/prodigy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94ed0ea26bc5f886cf90db26da37dc8f08f9b18e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/prodigy/__init__.py @@ -0,0 +1,3 @@ +from .prodigy import upload_dataset + +__all__ = ["upload_dataset"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/prodigy/prodigy.py b/.venv/lib/python3.13/site-packages/wandb/integration/prodigy/prodigy.py new file mode 100644 index 0000000000000000000000000000000000000000..d25fc3cbab7a2c49aa585631e7162485487945e6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/prodigy/prodigy.py @@ -0,0 +1,291 @@ +"""Prodigy integration for W&B. + +User can upload Prodigy annotated datasets directly +from the local database to W&B in Tables format. + +Example usage: + +```python +import wandb +from wandb.integration.prodigy import upload_dataset + +run = wandb.init(project="prodigy") +upload_dataset("name_of_dataset") +wandb.finish() +``` +""" + +import base64 +import collections.abc +import io +import urllib +from copy import deepcopy + +import pandas as pd +from PIL import Image + +import wandb +from wandb import util +from wandb.plot.utils import test_missing +from wandb.sdk.lib import telemetry as wb_telemetry + + +def named_entity(docs): + """Create a named entity visualization. + + Taken from https://github.com/wandb/wandb/blob/main/wandb/plots/named_entity.py. + """ + spacy = util.get_module( + "spacy", + required="part_of_speech requires the spacy library, install with `pip install spacy`", + ) + + util.get_module( + "en_core_web_md", + required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`", + ) + + # Test for required packages and missing & non-integer values in docs data + if test_missing(docs=docs): + html = spacy.displacy.render( + docs, style="ent", page=True, minify=True, jupyter=False + ) + wandb_html = wandb.Html(html) + return wandb_html + + +def merge(dict1, dict2): + """Return a new dictionary by merging two dictionaries recursively.""" + result = deepcopy(dict1) + + for key, value in dict2.items(): + if isinstance(value, collections.abc.Mapping): + result[key] = merge(result.get(key, {}), value) + else: + result[key] = deepcopy(dict2[key]) + + return result + + +def get_schema(list_data_dict, struct, array_dict_types): + """Get a schema of the dataset's structure and data types.""" + # Get the structure of the JSON objects in the database + # This is similar to getting a JSON schema but with slightly different format + for _i, item in enumerate(list_data_dict): + # If the list contains dict objects + for k, v in item.items(): + # Check if key already exists in template + if k not in struct.keys(): + if isinstance(v, list): + if len(v) > 0 and isinstance(v[0], list): + # nested list structure + struct[k] = type(v) # type list + elif len(v) > 0 and not ( + isinstance(v[0], list) or isinstance(v[0], dict) + ): + # list of singular values + struct[k] = type(v) # type list + else: + # list of dicts + array_dict_types.append( + k + ) # keep track of keys that are type list[dict] + struct[k] = {} + struct[k] = get_schema(v, struct[k], array_dict_types) + elif isinstance(v, dict): + struct[k] = {} + struct[k] = get_schema([v], struct[k], array_dict_types) + else: + struct[k] = type(v) + else: + # Get the value of struct[k] which is the current template + # Find new keys and then merge the two templates together + cur_struct = struct[k] + if isinstance(v, list): + if len(v) > 0 and isinstance(v[0], list): + # nested list coordinate structure + # if the value in the item is currently None, then update + if v is not None: + struct[k] = type(v) # type list + elif len(v) > 0 and not ( + isinstance(v[0], list) or isinstance(v[0], dict) + ): + # single list with values + # if the value in the item is currently None, then update + if v is not None: + struct[k] = type(v) # type list + else: + array_dict_types.append( + k + ) # keep track of keys that are type list[dict] + struct[k] = {} + struct[k] = get_schema(v, struct[k], array_dict_types) + # merge cur_struct and struct[k], remove duplicates + struct[k] = merge(struct[k], cur_struct) + elif isinstance(v, dict): + struct[k] = {} + struct[k] = get_schema([v], struct[k], array_dict_types) + # merge cur_struct and struct[k], remove duplicates + struct[k] = merge(struct[k], cur_struct) + else: + # if the value in the item is currently None, then update + if v is not None: + struct[k] = type(v) + + return struct + + +def standardize(item, structure, array_dict_types): + """Standardize all rows/entries in dataset to fit the schema. + + Will look for missing values and fill it in so all rows have + the same items and structure. + """ + for k, v in structure.items(): + if k not in item: + # If the structure/field does not exist + if isinstance(v, dict) and (k not in array_dict_types): + # If key k is of type dict, and not not a type list[dict] + item[k] = {} + standardize(item[k], v, array_dict_types) + elif isinstance(v, dict) and (k in array_dict_types): + # If key k is of type dict, and is actually of type list[dict], + # just treat as a list and set to None by default + item[k] = None + else: + # Assign a default type + item[k] = v() + else: + # If the structure/field already exists and is a list or dict + if isinstance(item[k], list): + # ignore if item is a nested list structure or list of non-dicts + condition = ( + not (len(item[k]) > 0 and isinstance(item[k][0], list)) + ) and ( + not ( + len(item[k]) > 0 + and not ( + isinstance(item[k][0], list) or isinstance(item[k][0], dict) + ) + ) + ) + if condition: + for sub_item in item[k]: + standardize(sub_item, v, array_dict_types) + elif isinstance(item[k], dict): + standardize(item[k], v, array_dict_types) + + +def create_table(data): + """Create a W&B Table. + + - Create/decode images from URL/Base64 + - Uses spacy to translate NER span data to visualizations. + """ + # create table object from columns + table_df = pd.DataFrame(data) + columns = list(table_df.columns) + if ("spans" in table_df.columns) and ("text" in table_df.columns): + columns.append("spans_visual") + if "image" in columns: + columns.append("image_visual") + main_table = wandb.Table(columns=columns) + + # Convert to dictionary format to maintain order during processing + matrix = table_df.to_dict(orient="records") + + # Import en_core_web_md if exists + en_core_web_md = util.get_module( + "en_core_web_md", + required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`", + ) + nlp = en_core_web_md.load(disable=["ner"]) + + # Go through each individual row + for _i, document in enumerate(matrix): + # Text NER span visualizations + if ("spans_visual" in columns) and ("text" in columns): + # Add visuals for spans + document["spans_visual"] = None + doc = nlp(document["text"]) + ents = [] + if ("spans" in document) and (document["spans"] is not None): + for span in document["spans"]: + if ("start" in span) and ("end" in span) and ("label" in span): + charspan = doc.char_span( + span["start"], span["end"], span["label"] + ) + ents.append(charspan) + doc.ents = ents + document["spans_visual"] = named_entity(docs=doc) + + # Convert image link to wandb Image + if "image" in columns: + # Turn into wandb image + document["image_visual"] = None + if ("image" in document) and (document["image"] is not None): + isurl = urllib.parse.urlparse(document["image"]).scheme in ( + "http", + "https", + ) + isbase64 = ("data:" in document["image"]) and ( + ";base64" in document["image"] + ) + if isurl: + # is url + try: + im = Image.open(urllib.request.urlopen(document["image"])) + document["image_visual"] = wandb.Image(im) + except urllib.error.URLError: + wandb.termwarn(f"Image URL {document['image']} is invalid.") + document["image_visual"] = None + elif isbase64: + # is base64 uri + imgb64 = document["image"].split("base64,")[1] + try: + msg = base64.b64decode(imgb64) + buf = io.BytesIO(msg) + im = Image.open(buf) + document["image_visual"] = wandb.Image(im) + except base64.binascii.Error: + wandb.termwarn(f"Base64 string {document['image']} is invalid.") + document["image_visual"] = None + else: + # is data path + document["image_visual"] = wandb.Image(document["image"]) + + # Create row and append to table + values_list = list(document.values()) + main_table.add_data(*values_list) + return main_table + + +def upload_dataset(dataset_name): + """Upload dataset from local database to Weights & Biases. + + Args: + dataset_name: The name of the dataset in the Prodigy database. + """ + # Check if wandb.init has been called + if wandb.run is None: + raise ValueError("You must call wandb.init() before upload_dataset()") + + with wb_telemetry.context(run=wandb.run) as tel: + tel.feature.prodigy = True + + prodigy_db = util.get_module( + "prodigy.components.db", + required="`prodigy` library is required but not installed. Please see https://prodi.gy/docs/install", + ) + # Retrieve and upload prodigy dataset + database = prodigy_db.connect() + data = database.get_dataset(dataset_name) + + array_dict_types = [] + schema = get_schema(data, {}, array_dict_types) + + for i, _d in enumerate(data): + standardize(data[i], schema, array_dict_types) + table = create_table(data) + wandb.log({dataset_name: table}) + wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.") diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sacred/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/sacred/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dabcc47afd726c57cf269e698096e399a1574b3a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sacred/__init__.py @@ -0,0 +1,117 @@ +import warnings + +import numpy +from sacred.dependencies import get_digest +from sacred.observers import RunObserver + +import wandb + + +class WandbObserver(RunObserver): + """Log sacred experiment data to W&B. + + Args: + Accepts all the arguments accepted by wandb.init(). + + name — A display name for this run, which shows up in the UI and is editable, doesn't have to be unique + notes — A multiline string description associated with the run + config — a dictionary-like object to set as initial config + project — the name of the project to which this run will belong + tags — a list of strings to associate with this run as tags + dir — the path to a directory where artifacts will be written (default: ./wandb) + entity — the team posting this run (default: your username or your default team) + job_type — the type of job you are logging, e.g. eval, worker, ps (default: training) + save_code — save the main python or notebook file to wandb to enable diffing (default: editable from your settings page) + group — a string by which to group other runs; see Grouping + reinit — Shorthand for the reinit setting that defines what to do when `wandb.init()` is called while a run is active. See the setting's documentation. + id — A unique ID for this run primarily used for Resuming. It must be globally unique, and if you delete a run you can't reuse the ID. Use the name field for a descriptive, useful name for the run. The ID cannot contain special characters. + resume — if set to True, the run auto resumes; can also be a unique string for manual resuming; see Resuming (default: False) + anonymous — can be "allow", "never", or "must". This enables or explicitly disables anonymous logging. (default: never) + force — whether to force a user to be logged into wandb when running a script (default: False) + magic — (bool, dict, or str, optional): magic configuration as bool, dict, json string, yaml filename. If set to True will attempt to auto-instrument your script. (default: None) + sync_tensorboard — A boolean indicating whether or not copy all TensorBoard logs wandb; see Tensorboard (default: False) + monitor_gym — A boolean indicating whether or not to log videos generated by OpenAI Gym; see Ray Tune (default: False) + allow_val_change — whether to allow wandb.config values to change, by default we throw an exception if config values are overwritten. (default: False) + + Examples: + Create sacred experiment:: + from wandb.sacred import WandbObserver + ex.observers.append(WandbObserver(project='sacred_test', + name='test1')) + @ex.config + def cfg(): + C = 1.0 + gamma = 0.7 + @ex.automain + def run(C, gamma, _run): + iris = datasets.load_iris() + per = permutation(iris.target.size) + iris.data = iris.data[per] + iris.target = iris.target[per] + clf = svm.SVC(C, 'rbf', gamma=gamma) + clf.fit(iris.data[:90], + iris.target[:90]) + return clf.score(iris.data[90:], + iris.target[90:]) + """ + + def __init__(self, **kwargs): + self.run = wandb.init(**kwargs) + self.resources = {} + + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): + # TODO: add the source code file + # TODO: add dependencies and metadata. + self.__update_config(config) + + def completed_event(self, stop_time, result): + if result: + if not isinstance(result, tuple): + result = ( + result, + ) # transform single result to tuple so that both single & multiple results use same code + + for i, r in enumerate(result): + if isinstance(r, float) or isinstance(r, int): + wandb.log({f"result_{i}": float(r)}) + elif isinstance(r, dict): + wandb.log(r) + elif isinstance(r, object): + artifact = wandb.Artifact(f"result_{i}.pkl", type="result") + artifact.add_file(r) + self.run.log_artifact(artifact) + elif isinstance(r, numpy.ndarray): + wandb.log({f"result_{i}": wandb.Image(r)}) + else: + warnings.warn( + f"logging results does not support type '{type(r)}' results. Ignoring this result", + stacklevel=2, + ) + + def artifact_event(self, name, filename, metadata=None, content_type=None): + if content_type is None: + content_type = "file" + artifact = wandb.Artifact(name, type=content_type) + artifact.add_file(filename) + self.run.log_artifact(artifact) + + def resource_event(self, filename): + """TODO: Maintain resources list.""" + if filename not in self.resources: + md5 = get_digest(filename) + self.resources[filename] = md5 + + def log_metrics(self, metrics_by_name, info): + for metric_name, metric_ptr in metrics_by_name.items(): + for _step, value in zip(metric_ptr["steps"], metric_ptr["values"]): + if isinstance(value, numpy.ndarray): + wandb.log({metric_name: wandb.Image(value)}) + else: + wandb.log({metric_name: value}) + + def __update_config(self, config): + for k, v in config.items(): + self.run.config[k] = v + self.run.config["resources"] = [] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c8509c8fae0a795d4f17bedd708c87ee587e931 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__init__.py @@ -0,0 +1,14 @@ +"""wandb integration sagemaker module.""" + +from .auth import sagemaker_auth +from .config import is_using_sagemaker, parse_sm_config +from .resources import parse_sm_secrets, set_global_settings, set_run_id + +__all__ = [ + "sagemaker_auth", + "is_using_sagemaker", + "parse_sm_config", + "parse_sm_secrets", + "set_global_settings", + "set_run_id", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae6af0ffe87057b76c9cf5f078f062be99021517 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/auth.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/auth.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3df975c75d32f59a32792e06e46fff20aad5fe1 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/auth.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/config.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1086e58a94536151801d45d16a22fabe26836606 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/config.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/files.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/files.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ea996e0a9fbfebdd2def02cd94bbb4df2e11f63 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/files.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/resources.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/resources.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb33bb7ff6542b88b74f7c46d8c1932a73abf685 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/__pycache__/resources.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/auth.py b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/auth.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ccbae433cdd2cd85382db3e56dc05d447f16b6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/auth.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import os +from typing import Any + +from wandb import env +from wandb.sdk import wandb_setup +from wandb.sdk.lib import wbauth + + +def sagemaker_auth( + overrides: dict[str, Any] | None = None, + path: str = ".", + api_key: str | None = None, +) -> None: + """Write a secrets.env file with the W&B ApiKey and any additional secrets passed. + + Args: + overrides: Additional environment variables to write to secrets.env + path: The path to write the secrets file. + """ + overrides = overrides or dict() + + api_key = ( + overrides.get(env.API_KEY, None) + or api_key + or wandb_setup.singleton().settings.api_key + or wbauth.read_netrc_auth(host=wandb_setup.singleton().settings.base_url) + ) + + if api_key is None: + raise ValueError( + "Can't find W&B API key, set the WANDB_API_KEY env variable" + + " or run `wandb login`" + ) + + overrides[env.API_KEY] = api_key + with open(os.path.join(path, "secrets.env"), "w") as file: + for k, v in overrides.items(): + file.write(f"{k}={v}\n") diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/config.py b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/config.py new file mode 100644 index 0000000000000000000000000000000000000000..be71b92c19d204b249c277a1c0c82c70bab3a044 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/config.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json +import os +import re +import warnings +from typing import Any + +from . import files as sm_files + + +def is_using_sagemaker() -> bool: + """Returns whether we're in a SageMaker environment.""" + return ( + os.path.exists(sm_files.SM_PARAM_CONFIG) # + or "SM_TRAINING_ENV" in os.environ + ) + + +def parse_sm_config() -> dict[str, Any]: + """Parses SageMaker configuration. + + Returns: + A dictionary of SageMaker config keys/values + or an empty dict if not found. + SM_TRAINING_ENV is a json string of the + training environment variables set by SageMaker + and is only available when running in SageMaker, + but not in local mode. + SM_TRAINING_ENV is set by the SageMaker container and + contains arguments such as hyperparameters + and arguments passed to the training job. + """ + conf = {} + + if os.path.exists(sm_files.SM_PARAM_CONFIG): + conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME") + + # Hyperparameter searches quote configs... + with open(sm_files.SM_PARAM_CONFIG) as fid: + for key, val in json.load(fid).items(): + cast = val.strip('"') + if re.match(r"^-?[\d]+$", cast): + cast = int(cast) + elif re.match(r"^-?[.\d]+$", cast): + cast = float(cast) + conf[key] = cast + + if env := os.environ.get("SM_TRAINING_ENV"): + try: + conf.update(json.loads(env)) + except json.JSONDecodeError: + warnings.warn( + "Failed to parse SM_TRAINING_ENV not valid JSON string", + stacklevel=2, + ) + + return conf diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/files.py b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/files.py new file mode 100644 index 0000000000000000000000000000000000000000..1f91e72fb07b5febb7ee6cb89d06f18444fce826 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/files.py @@ -0,0 +1,2 @@ +SM_PARAM_CONFIG = "/opt/ml/input/config/hyperparameters.json" +SM_SECRETS = "secrets.env" diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/resources.py b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/resources.py new file mode 100644 index 0000000000000000000000000000000000000000..4410c755772eead22a846cf25a5beed6fca6cc45 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sagemaker/resources.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import os +import secrets +import socket +import string + +import wandb + +from . import config +from . import files as sm_files + + +def set_run_id(run_settings: wandb.Settings) -> bool: + """Set a run ID and group when using SageMaker. + + Returns whether the ID and group were updated. + """ + # Added in https://github.com/wandb/wandb/pull/3290. + # + # Prevents SageMaker from overriding the run ID configured + # in environment variables. Note, however, that it will still + # override a run ID passed explicitly to `wandb.init()`. + if os.getenv("WANDB_RUN_ID"): + return False + + run_group = os.getenv("TRAINING_JOB_NAME") + if not run_group: + return False + + alphanumeric = string.ascii_lowercase + string.digits + random = "".join(secrets.choice(alphanumeric) for _ in range(6)) + + host = os.getenv("CURRENT_HOST", socket.gethostname()) + + run_settings.run_id = f"{run_group}-{random}-{host}" + run_settings.run_group = run_group + return True + + +def set_global_settings(settings: wandb.Settings) -> None: + """Set global W&B settings based on the SageMaker environment.""" + if env := parse_sm_secrets(): + settings.update_from_env_vars(env) + + # The SageMaker config may contain an API key, in which case it + # takes precedence over the value in the secrets. It's unclear + # whether this is by design, or by accident; we keep it for + # backward compatibility for now. + sm_config = config.parse_sm_config() + if api_key := sm_config.get("wandb_api_key"): + settings.api_key = api_key + + +def parse_sm_secrets() -> dict[str, str]: + """We read our api_key from secrets.env in SageMaker.""" + env_dict = dict() + # Set secret variables + if os.path.exists(sm_files.SM_SECRETS): + for line in open(sm_files.SM_SECRETS): + key, val = line.strip().split("=", 1) + env_dict[key] = val + return env_dict diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sb3/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/sb3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..29dd81a941cb1745cf82fecbcf12a857ce786f45 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sb3/__init__.py @@ -0,0 +1,3 @@ +from .sb3 import WandbCallback + +__all__ = ["WandbCallback"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sb3/sb3.py b/.venv/lib/python3.13/site-packages/wandb/integration/sb3/sb3.py new file mode 100644 index 0000000000000000000000000000000000000000..2eec145d0fc052575ca6ba65d68283dcaa0ea69b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sb3/sb3.py @@ -0,0 +1,147 @@ +"""W&B callback for sb3. + +Really simple callback to get logging for each tree + +Example usage: + +```python +import gym +from stable_baselines3 import PPO +from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder +import wandb +from wandb.integration.sb3 import WandbCallback + + +config = { + "policy_type": "MlpPolicy", + "total_timesteps": 25000, + "env_name": "CartPole-v1", +} +run = wandb.init( + project="sb3", + config=config, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True, # auto-upload the videos of agents playing the game + save_code=True, # optional +) + + +def make_env(): + env = gym.make(config["env_name"]) + env = Monitor(env) # record stats such as returns + return env + + +env = DummyVecEnv([make_env]) +env = VecVideoRecorder( + env, "videos", record_video_trigger=lambda x: x % 2000 == 0, video_length=200 +) +model = PPO(config["policy_type"], env, verbose=1, tensorboard_log=f"runs") +model.learn( + total_timesteps=config["total_timesteps"], + callback=WandbCallback( + model_save_path=f"models/{run.id}", + gradient_save_freq=100, + log="all", + ), +) +``` +""" + +import logging +import os +from typing import Literal, Optional + +from stable_baselines3.common.callbacks import BaseCallback # type: ignore + +import wandb +from wandb.sdk.lib import telemetry as wb_telemetry + +logger = logging.getLogger(__name__) + + +class WandbCallback(BaseCallback): + """Callback for logging experiments to Weights and Biases. + + Log SB3 experiments to Weights and Biases + - Added model tracking and uploading + - Added complete hyperparameters recording + - Added gradient logging + - Note that `wandb.init(...)` must be called before the WandbCallback can be used. + + Args: + verbose: The verbosity of sb3 output + model_save_path: Path to the folder where the model will be saved, The default value is `None` so the model is not logged + model_save_freq: Frequency to save the model + gradient_save_freq: Frequency to log gradient. The default value is 0 so the gradients are not logged + log: What to log. One of "gradients", "parameters", or "all". + """ + + def __init__( + self, + verbose: int = 0, + model_save_path: Optional[str] = None, + model_save_freq: int = 0, + gradient_save_freq: int = 0, + log: Optional[Literal["gradients", "parameters", "all"]] = "all", + ) -> None: + super().__init__(verbose) + if wandb.run is None: + raise wandb.Error("You must call wandb.init() before WandbCallback()") + with wb_telemetry.context() as tel: + tel.feature.sb3 = True + self.model_save_freq = model_save_freq + self.model_save_path = model_save_path + self.gradient_save_freq = gradient_save_freq + if log not in ["gradients", "parameters", "all", None]: + wandb.termwarn( + "`log` must be one of `None`, 'gradients', 'parameters', or 'all', " + "falling back to 'all'" + ) + log = "all" + self.log = log + # Create folder if needed + if self.model_save_path is not None: + os.makedirs(self.model_save_path, exist_ok=True) + self.path = os.path.join(self.model_save_path, "model.zip") + else: + assert self.model_save_freq == 0, ( + "to use the `model_save_freq` you have to set the `model_save_path` parameter" + ) + + def _init_callback(self) -> None: + d = {} + if "algo" not in d: + d["algo"] = type(self.model).__name__ + for key in self.model.__dict__: + if key in wandb.config: + continue + if type(self.model.__dict__[key]) in [float, int, str]: + d[key] = self.model.__dict__[key] + else: + d[key] = str(self.model.__dict__[key]) + if self.gradient_save_freq > 0: + wandb.watch( + self.model.policy, + log_freq=self.gradient_save_freq, + log=self.log, + ) + wandb.config.setdefaults(d) + + def _on_step(self) -> bool: + if self.model_save_freq > 0: + if self.model_save_path is not None: + if self.n_calls % self.model_save_freq == 0: + self.save_model() + return True + + def _on_training_end(self) -> None: + if self.model_save_path is not None: + self.save_model() + + def save_model(self) -> None: + self.model.save(self.path) + wandb.save(self.path, base_path=self.model_save_path) + if self.verbose > 1: + logger.info(f"Saving model checkpoint to {self.path}") diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb3cb14c923c289c8075c296066f840c3c4e0858 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/__init__.py @@ -0,0 +1,37 @@ +"""Create informative charts for scikit-learn models and log them to W&B.""" + +from .plot import ( + plot_calibration_curve, + plot_class_proportions, + plot_classifier, + plot_clusterer, + plot_confusion_matrix, + plot_elbow_curve, + plot_feature_importances, + plot_learning_curve, + plot_outlier_candidates, + plot_precision_recall, + plot_regressor, + plot_residuals, + plot_roc, + plot_silhouette, + plot_summary_metrics, +) + +__all__ = [ + "plot_classifier", + "plot_clusterer", + "plot_regressor", + "plot_summary_metrics", + "plot_learning_curve", + "plot_feature_importances", + "plot_class_proportions", + "plot_calibration_curve", + "plot_roc", + "plot_precision_recall", + "plot_confusion_matrix", + "plot_elbow_curve", + "plot_silhouette", + "plot_residuals", + "plot_outlier_candidates", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d22d629bbc7dced6da390f885a4d327049491ff --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/__init__.py @@ -0,0 +1,32 @@ +"""Calculates and formats metrics and charts for introspecting sklearn models. + +The functions in these modules are designed to be called by functions from the +plot submodule that have been exported into the namespace of the wandb.sklearn +submodule, rather than being called directly. +""" + +from .calibration_curves import calibration_curves +from .class_proportions import class_proportions +from .confusion_matrix import confusion_matrix +from .decision_boundaries import decision_boundaries +from .elbow_curve import elbow_curve +from .feature_importances import feature_importances +from .learning_curve import learning_curve +from .outlier_candidates import outlier_candidates +from .residuals import residuals +from .silhouette import silhouette +from .summary_metrics import summary_metrics + +__all__ = [ + "calibration_curves", + "class_proportions", + "confusion_matrix", + "decision_boundaries", + "elbow_curve", + "feature_importances", + "learning_curve", + "outlier_candidates", + "residuals", + "silhouette", + "summary_metrics", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/calibration_curves.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/calibration_curves.py new file mode 100644 index 0000000000000000000000000000000000000000..b59aa25d7d8c984ca30529d497eb0e263570c365 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/calibration_curves.py @@ -0,0 +1,125 @@ +from warnings import simplefilter + +import numpy as np +import sklearn +from sklearn import model_selection, naive_bayes +from sklearn.calibration import CalibratedClassifierCV +from sklearn.linear_model import LogisticRegression + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def calibration_curves(clf, X, y, clf_name): # noqa: N803 + # ComplementNB (introduced in 0.20.0) requires non-negative features + if int(sklearn.__version__.split(".")[1]) >= 20 and isinstance( + clf, naive_bayes.ComplementNB + ): + X = X - X.min() # noqa:N806 + + # Calibrated with isotonic calibration + isotonic = CalibratedClassifierCV(clf, cv=2, method="isotonic") + + # Calibrated with sigmoid calibration + sigmoid = CalibratedClassifierCV(clf, cv=2, method="sigmoid") + + # Logistic regression with no calibration as baseline + lr = LogisticRegression(C=1.0) + + model_column = [] # color + frac_positives_column = [] # y axis + mean_pred_value_column = [] # x axis + hist_column = [] # barchart y + edge_column = [] # barchart x + + # Add curve for perfectly calibrated model + # format: model, fraction_of_positives, mean_predicted_value + model_column.append("Perfectly calibrated") + frac_positives_column.append(0) + mean_pred_value_column.append(0) + hist_column.append(0) + edge_column.append(0) + model_column.append("Perfectly calibrated") + hist_column.append(0) + edge_column.append(0) + frac_positives_column.append(1) + mean_pred_value_column.append(1) + + x_train, x_test, y_train, y_test = model_selection.train_test_split( + X, y, test_size=0.9, random_state=42 + ) + + # Add curve for LogisticRegression baseline and other models + + models = [lr, isotonic, sigmoid] + names = ["Logistic", f"{clf_name} Isotonic", f"{clf_name} Sigmoid"] + + for model, name in zip(models, names): + model.fit(x_train, y_train) + if hasattr(model, "predict_proba"): + prob_pos = model.predict_proba(x_test)[:, 1] + else: # use decision function + prob_pos = model.decision_function(x_test) + prob_pos = (prob_pos - prob_pos.min()) / (prob_pos.max() - prob_pos.min()) + + hist, edges = np.histogram(prob_pos, bins=10, density=False) + frac_positives, mean_pred_value = sklearn.calibration.calibration_curve( + y_test, prob_pos, n_bins=10 + ) + + # format: model, fraction_of_positives, mean_predicted_value + num_entries = len(frac_positives) + for i in range(num_entries): + hist_column.append(hist[i]) + edge_column.append(edges[i]) + model_column.append(name) + frac_positives_column.append(utils.round_3(frac_positives[i])) + mean_pred_value_column.append(utils.round_3(mean_pred_value[i])) + if utils.check_against_limit( + i, + "calibration_curve", + utils.chart_limit - 2, + ): + break + + table = make_table( + model_column, + frac_positives_column, + mean_pred_value_column, + hist_column, + edge_column, + ) + chart = wandb.visualize("wandb/calibration/v1", table) + + return chart + + +def make_table( + model_column, + frac_positives_column, + mean_pred_value_column, + hist_column, + edge_column, +): + columns = [ + "model", + "fraction_of_positives", + "mean_predicted_value", + "hist_dict", + "edge_dict", + ] + + data = list( + zip( + model_column, + frac_positives_column, + mean_pred_value_column, + hist_column, + edge_column, + ) + ) + + return wandb.Table(columns=columns, data=data) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/class_proportions.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/class_proportions.py new file mode 100644 index 0000000000000000000000000000000000000000..183bf785a2fb62e2efe20fbdfde68f7a0fc45024 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/class_proportions.py @@ -0,0 +1,68 @@ +from warnings import simplefilter + +import numpy as np +from sklearn.utils.multiclass import unique_labels + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def class_proportions(y_train, y_test, labels): + # Get the unique values from the dataset + targets = (y_train,) if y_test is None else (y_train, y_test) + class_ids = np.array(unique_labels(*targets)) + + # Compute the class counts + counts_train = np.array([(y_train == c).sum() for c in class_ids]) + counts_test = np.array([(y_test == c).sum() for c in class_ids]) + + class_column, dataset_column, count_column = make_columns( + class_ids, counts_train, counts_test + ) + + if labels is not None and ( + isinstance(class_column[0], int) or isinstance(class_column[0], np.integer) + ): + class_column = get_named_labels(labels, class_column) + + table = make_table(class_column, dataset_column, count_column) + chart = wandb.visualize("wandb/class_proportions/v1", table) + + return chart + + +def make_table(class_column, dataset_column, count_column): + columns = ["class", "dataset", "count"] + data = list(zip(class_column, dataset_column, count_column)) + + return wandb.Table(data=data, columns=columns) + + +def make_columns(class_ids, counts_train, counts_test): + class_column, dataset_column, count_column = [], [], [] + + for i in range(len(class_ids)): + # add class counts from training set + class_column.append(class_ids[i]) + dataset_column.append("train") + count_column.append(counts_train[i]) + # add class counts from test set + class_column.append(class_ids[i]) + dataset_column.append("test") + count_column.append(counts_test[i]) + + if utils.check_against_limit( + i, + "class_proportions", + utils.chart_limit, + ): + break + + return class_column, dataset_column, count_column + + +def get_named_labels(labels, numeric_labels): + return np.array([labels[num_label] for num_label in numeric_labels]) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/confusion_matrix.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/confusion_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..777c84ada92237222e0a058c5facd7b761313b12 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/confusion_matrix.py @@ -0,0 +1,93 @@ +import itertools +from warnings import simplefilter + +import numpy as np +from sklearn import metrics +from sklearn.utils.multiclass import unique_labels + +import wandb + +from .. import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def validate_labels(*args, **kwargs): # FIXME + raise AssertionError() + + +def confusion_matrix( + y_true=None, + y_pred=None, + labels=None, + true_labels=None, + pred_labels=None, + normalize=False, +): + """Compute the confusion matrix to evaluate the performance of a classification. + + Called by plot_confusion_matrix to visualize roc curves. Please use the function + plot_confusion_matrix() if you wish to visualize your confusion matrix. + """ + cm = metrics.confusion_matrix(y_true, y_pred) + + if labels is None: + classes = unique_labels(y_true, y_pred) + else: + classes = np.asarray(labels) + + if normalize: + cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] + cm = np.around(cm, decimals=2) + cm[np.isnan(cm)] = 0.0 + + if true_labels is None: + true_classes = classes + else: + validate_labels(classes, true_labels, "true_labels") + + true_label_indexes = np.in1d(classes, true_labels) + + true_classes = classes[true_label_indexes] + cm = cm[true_label_indexes] + + if pred_labels is None: + pred_classes = classes + else: + validate_labels(classes, pred_labels, "pred_labels") + + pred_label_indexes = np.in1d(classes, pred_labels) + + pred_classes = classes[pred_label_indexes] + cm = cm[:, pred_label_indexes] + + table = make_table(cm, pred_classes, true_classes, labels) + chart = wandb.visualize("wandb/confusion_matrix/v1", table) + + return chart + + +def make_table(cm, pred_classes, true_classes, labels): + data, count = [], 0 + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + if labels is not None and ( + isinstance(pred_classes[i], int) or isinstance(pred_classes[0], np.integer) + ): + pred = labels[pred_classes[i]] + true = labels[true_classes[j]] + else: + pred = pred_classes[i] + true = true_classes[j] + data.append([pred, true, cm[i, j]]) + count += 1 + if utils.check_against_limit( + count, + "confusion_matrix", + utils.chart_limit, + ): + break + + table = wandb.Table(columns=["Predicted", "Actual", "Count"], data=data) + + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/decision_boundaries.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/decision_boundaries.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a849b86b151270adf2437e4f2e72a03ce6c1a9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/decision_boundaries.py @@ -0,0 +1,40 @@ +from warnings import simplefilter + +import wandb + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def decision_boundaries( + decision_boundary_x, + decision_boundary_y, + decision_boundary_color, + train_x, + train_y, + train_color, + test_x, + test_y, + test_color, +): + x_dict, y_dict, color_dict = [], [], [] + for i in range(min(len(decision_boundary_x), 100)): + x_dict.append(decision_boundary_x[i]) + y_dict.append(decision_boundary_y[i]) + color_dict.append(decision_boundary_color) + for i in range(300): + x_dict.append(test_x[i]) + y_dict.append(test_y[i]) + color_dict.append(test_color[i]) + for i in range(min(len(train_x), 600)): + x_dict.append(train_x[i]) + y_dict.append(train_y[i]) + color_dict.append(train_color[i]) + + return wandb.visualize( + "wandb/decision_boundaries/v1", + wandb.Table( + columns=["x", "y", "color"], + data=[[x_dict[i], y_dict[i], color_dict[i]] for i in range(len(x_dict))], + ), + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/elbow_curve.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/elbow_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..731102d75adabf238d44de00b8d79bca59f0a87a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/elbow_curve.py @@ -0,0 +1,55 @@ +import time +from warnings import simplefilter + +import numpy as np +from joblib import Parallel, delayed +from sklearn.base import clone + +import wandb + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def elbow_curve(clusterer, X, cluster_ranges, n_jobs, show_cluster_time): # noqa: N803 + if cluster_ranges is None: + cluster_ranges = range(1, 10, 2) + else: + cluster_ranges = sorted(cluster_ranges) + + clfs, times = _compute_results_parallel(n_jobs, clusterer, X, cluster_ranges) + + clfs = np.absolute(clfs) + + table = make_table(cluster_ranges, clfs, times) + chart = wandb.visualize("wandb/elbow/v1", table) + + return chart + + +def make_table(cluster_ranges, clfs, times): + columns = ["cluster_ranges", "errors", "clustering_time"] + + data = list(zip(cluster_ranges, clfs, times)) + + table = wandb.Table(columns=columns, data=data) + + return table + + +def _compute_results_parallel(n_jobs, clusterer, x, cluster_ranges): + parallel_runner = Parallel(n_jobs=n_jobs) + _cluster_scorer = delayed(_clone_and_score_clusterer) + results = parallel_runner(_cluster_scorer(clusterer, x, i) for i in cluster_ranges) + + clfs, times = zip(*results) + + return clfs, times + + +def _clone_and_score_clusterer(clusterer, x, n_clusters): + start = time.time() + clusterer = clone(clusterer) + clusterer.n_clusters = n_clusters + + return clusterer.fit(x).score(x), time.time() - start diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/feature_importances.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/feature_importances.py new file mode 100644 index 0000000000000000000000000000000000000000..fac0452b944fd31d3bf611df054f4f6c0dbc0895 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/feature_importances.py @@ -0,0 +1,67 @@ +from warnings import simplefilter + +import numpy as np + +import wandb + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def feature_importances(model, feature_names): + attributes_to_check = ["feature_importances_", "feature_log_prob_", "coef_"] + found_attribute = check_for_attribute_on(model, attributes_to_check) + if found_attribute is None: + wandb.termwarn( + f"could not find any of attributes {', '.join(attributes_to_check)} on classifier. Cannot plot feature importances." + ) + return + elif found_attribute == "feature_importances_": + importances = model.feature_importances_ + elif found_attribute == "coef_": # ElasticNet-like models + importances = model.coef_ + elif found_attribute == "feature_log_prob_": + # coef_ was deprecated in sklearn 0.24, replaced with + # feature_log_prob_ + importances = model.feature_log_prob_ + + if len(importances.shape) > 1: + n_significant_dims = sum(i > 1 for i in importances.shape) + if n_significant_dims > 1: + nd = len(importances.shape) + wandb.termwarn( + f"{nd}-dimensional feature importances array passed to plot_feature_importances. " + f"{nd}-dimensional and higher feature importances arrays are not currently supported. " + f"These importances will not be plotted." + ) + return + else: + importances = np.squeeze(importances) + + indices = np.argsort(importances)[::-1] + importances = importances[indices] + + if feature_names is None: + feature_names = indices + else: + feature_names = np.array(feature_names)[indices] + + table = make_table(feature_names, importances) + chart = wandb.visualize("wandb/feature_importances/v1", table) + + return chart + + +def make_table(feature_names, importances): + table = wandb.Table( + columns=["feature_names", "importances"], + data=[[feature_names[i], importances[i]] for i in range(len(feature_names))], + ) + return table + + +def check_for_attribute_on(model, attributes_to_check): + for attr in attributes_to_check: + if hasattr(model, attr): + return attr + return None diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/learning_curve.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/learning_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..296c7d7780fec21e5f9e3a7a346301c3c38b7de8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/learning_curve.py @@ -0,0 +1,64 @@ +from warnings import simplefilter + +import numpy as np +from sklearn import model_selection + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def learning_curve( + model, + X, # noqa: N803 + y, + cv=None, + shuffle=False, + random_state=None, + train_sizes=None, + n_jobs=1, + scoring=None, +): + """Train model on datasets of varying size and generates plot of score vs size. + + Called by plot_learning_curve to visualize learning curve. Please use the function + plot_learning_curve() if you wish to visualize your learning curves. + """ + train_sizes, train_scores, test_scores = model_selection.learning_curve( + model, + X, + y, + cv=cv, + n_jobs=n_jobs, + train_sizes=train_sizes, + scoring=scoring, + shuffle=shuffle, + random_state=random_state, + ) + train_scores_mean = np.mean(train_scores, axis=1) + test_scores_mean = np.mean(test_scores, axis=1) + + table = make_table(train_scores_mean, test_scores_mean, train_sizes) + chart = wandb.visualize("wandb/learning_curve/v1", table) + + return chart + + +def make_table(train, test, train_sizes): + data = [] + for i in range(len(train)): + if utils.check_against_limit( + i, + "learning_curve", + utils.chart_limit / 2, + ): + break + train_set = ["train", utils.round_2(train[i]), train_sizes[i]] + test_set = ["test", utils.round_2(test[i]), train_sizes[i]] + data.append(train_set) + data.append(test_set) + + table = wandb.Table(columns=["dataset", "score", "train_size"], data=data) + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/outlier_candidates.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/outlier_candidates.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb0a29f4e7c49939e2d34420c02f9d82d08570b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/outlier_candidates.py @@ -0,0 +1,69 @@ +from warnings import simplefilter + +import numpy as np + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def outlier_candidates(regressor, X, y): # noqa: N803 + # Fit a linear model to X and y to compute MSE + regressor.fit(X, y) + + # Leverage is computed as the diagonal of the projection matrix of X + leverage = (X * np.linalg.pinv(X).T).sum(1) + + # Compute the rank and the degrees of freedom of the OLS model + rank = np.linalg.matrix_rank(X) + df = X.shape[0] - rank + + # Compute the MSE from the residuals + residuals = y - regressor.predict(X) + mse = np.dot(residuals, residuals) / df + + # Compute Cook's distance + residuals_studentized = residuals / np.sqrt(mse) / np.sqrt(1 - leverage) + distance_ = residuals_studentized**2 / X.shape[1] + distance_ *= leverage / (1 - leverage) + + # Compute the influence threshold rule of thumb + influence_threshold_ = 4 / X.shape[0] + outlier_percentage_ = sum(distance_ >= influence_threshold_) / X.shape[0] + outlier_percentage_ *= 100.0 + + distance_dict, count = [], 0 + for d in distance_: + distance_dict.append(d) + count += 1 + if utils.check_against_limit( + count, + "outlier_candidates", + utils.chart_limit, + ): + break + + table = make_table(distance_dict, outlier_percentage_, influence_threshold_) + chart = wandb.visualize("wandb/outliers/v1", table) + + return chart + + +def make_table(distance, outlier_percentage, influence_threshold): + columns = [ + "distance", + "instance_indicies", + "outlier_percentage", + "influence_threshold", + ] + + data = [ + [distance[i], i, utils.round_3(outlier_percentage), influence_threshold] + for i in range(len(distance)) + ] + + table = wandb.Table(columns=columns, data=data) + + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/residuals.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/residuals.py new file mode 100644 index 0000000000000000000000000000000000000000..b45df7b84fa5b3824f99425800bbf0039f7cacd8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/residuals.py @@ -0,0 +1,86 @@ +from warnings import simplefilter + +from sklearn import model_selection + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def residuals(regressor, X, y): # noqa: N803 + # Create the train and test splits + x_train, x_test, y_train, y_test = model_selection.train_test_split( + X, y, test_size=0.2 + ) + + # Store labels and colors for the legend ordered by call + regressor.fit(x_train, y_train) + train_score_ = regressor.score(x_train, y_train) + test_score_ = regressor.score(x_test, y_test) + + y_pred_train = regressor.predict(x_train) + residuals_train = y_pred_train - y_train + + y_pred_test = regressor.predict(x_test) + residuals_test = y_pred_test - y_test + + table = make_table( + y_pred_train, + residuals_train, + y_pred_test, + residuals_test, + train_score_, + test_score_, + ) + chart = wandb.visualize("wandb/residuals_plot/v1", table) + + return chart + + +def make_table( + y_pred_train, + residuals_train, + y_pred_test, + residuals_test, + train_score_, + test_score_, +): + y_pred_column, dataset_column, residuals_column = [], [], [] + + datapoints, max_datapoints_train = 0, 100 + for pred, residual in zip(y_pred_train, residuals_train): + # add class counts from training set + y_pred_column.append(pred) + dataset_column.append("train") + residuals_column.append(residual) + datapoints += 1 + if utils.check_against_limit(datapoints, "residuals", max_datapoints_train): + break + + datapoints = 0 + for pred, residual in zip(y_pred_test, residuals_test): + # add class counts from training set + y_pred_column.append(pred) + dataset_column.append("test") + residuals_column.append(residual) + datapoints += 1 + if utils.check_against_limit(datapoints, "residuals", max_datapoints_train): + break + + columns = ["dataset", "y_pred", "residuals", "train_score", "test_score"] + data = [ + [ + dataset_column[i], + y_pred_column[i], + residuals_column[i], + train_score_, + test_score_, + ] + for i in range(len(y_pred_column)) + ] + + table = wandb.Table(columns=columns, data=data) + + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/silhouette.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/silhouette.py new file mode 100644 index 0000000000000000000000000000000000000000..a71ba88c8f03d99e6908e4f886ba5d0be4f4dd41 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/silhouette.py @@ -0,0 +1,118 @@ +from warnings import simplefilter + +import numpy as np +from sklearn.metrics import silhouette_samples, silhouette_score +from sklearn.preprocessing import LabelEncoder + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def silhouette(clusterer, X, cluster_labels, labels, metric, kmeans): # noqa: N803 + # Run clusterer for n_clusters in range(len(cluster_ranges), get cluster labels + # TODO - keep/delete once we decide if we should train clusterers + # or ask for trained models + # clusterer.set_params(n_clusters=n_clusters, random_state=42) + # cluster_labels = clusterer.fit_predict(X) + cluster_labels = np.asarray(cluster_labels) + labels = np.asarray(labels) + + le = LabelEncoder() + _ = le.fit_transform(cluster_labels) + n_clusters = len(np.unique(cluster_labels)) + + # The silhouette_score gives the average value for all the samples. + # This gives a perspective into the density and separation of the formed + # clusters + silhouette_avg = silhouette_score(X, cluster_labels, metric=metric) + + # Compute the silhouette scores for each sample + sample_silhouette_values = silhouette_samples(X, cluster_labels, metric=metric) + + x_sil, y_sil, color_sil = [], [], [] + + count, y_lower = 0, 10 + for i in range(n_clusters): + # Aggregate the silhouette scores for samples belonging to + # cluster i, and sort them + ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i] + + ith_cluster_silhouette_values.sort() + + size_cluster_i = ith_cluster_silhouette_values.shape[0] + y_upper = y_lower + size_cluster_i + + y_values = np.arange(y_lower, y_upper) + + for j in range(len(y_values)): + y_sil.append(y_values[j]) + x_sil.append(ith_cluster_silhouette_values[j]) + color_sil.append(i) + count += 1 + if utils.check_against_limit(count, "silhouette", utils.chart_limit): + break + + # Compute the new y_lower for next plot + y_lower = y_upper + 10 # 10 for the 0 samples + + if kmeans: + centers = clusterer.cluster_centers_ + centerx = centers[:, 0] + centery = centers[:, 1] + + else: + centerx = [None] * len(color_sil) + centery = [None] * len(color_sil) + + table = make_table( + X[:, 0], + X[:, 1], + cluster_labels, + centerx, + centery, + y_sil, + x_sil, + color_sil, + silhouette_avg, + ) + chart = wandb.visualize("wandb/silhouette_/v1", table) + + return chart + + +def make_table(x, y, colors, centerx, centery, y_sil, x_sil, color_sil, silhouette_avg): + columns = [ + "x", + "y", + "colors", + "centerx", + "centery", + "y_sil", + "x1", + "x2", + "color_sil", + "silhouette_avg", + ] + + data = [ + [ + x[i], + y[i], + colors[i], + centerx[colors[i]], + centery[colors[i]], + y_sil[i], + 0, + x_sil[i], + color_sil[i], + silhouette_avg, + ] + for i in range(len(color_sil)) + ] + + table = wandb.Table(data=data, columns=columns) + + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/summary_metrics.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/summary_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b6f25ead2fd3c94c302db7693ca3fd529812d9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/calculate/summary_metrics.py @@ -0,0 +1,62 @@ +from warnings import simplefilter + +import numpy as np +import sklearn + +import wandb +from wandb.integration.sklearn import utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803 + """Calculate summary metrics for both regressors and classifiers. + + Called by plot_summary_metrics to visualize metrics. Please use the function + plot_summary_metrics() if you wish to visualize your summary metrics. + """ + y, y_test = np.asarray(y), np.asarray(y_test) + metrics = {} + model_name = model.__class__.__name__ + + y_pred = model.predict(X_test) + + if sklearn.base.is_classifier(model): + accuracy_score = sklearn.metrics.accuracy_score(y_test, y_pred) + metrics["accuracy_score"] = accuracy_score + + precision = sklearn.metrics.precision_score(y_test, y_pred, average="weighted") + metrics["precision"] = precision + + recall = sklearn.metrics.recall_score(y_test, y_pred, average="weighted") + metrics["recall"] = recall + + f1_score = sklearn.metrics.f1_score(y_test, y_pred, average="weighted") + metrics["f1_score"] = f1_score + + elif sklearn.base.is_regressor(model): + mae = sklearn.metrics.mean_absolute_error(y_test, y_pred) + metrics["mae"] = mae + + mse = sklearn.metrics.mean_squared_error(y_test, y_pred) + metrics["mse"] = mse + + r2_score = sklearn.metrics.r2_score(y_test, y_pred) + metrics["r2_score"] = r2_score + + metrics = {name: utils.round_2(metric) for name, metric in metrics.items()} + + table = make_table(metrics, model_name) + chart = wandb.visualize("wandb/metrics/v1", table) + + return chart + + +def make_table(metrics, model_name): + columns = ["metric_name", "metric_value", "model_name"] + table_content = [[name, value, model_name] for name, value in metrics.items()] + + table = wandb.Table(columns=columns, data=table_content) + + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..710e24c010e108bb4f25d361092a55182ad26917 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/__init__.py @@ -0,0 +1,35 @@ +"""Create and logs charts introspecting models built with scikit-learn to W&B.""" + +from .classifier import calibration_curve as plot_calibration_curve +from .classifier import class_proportions as plot_class_proportions +from .classifier import classifier as plot_classifier +from .classifier import confusion_matrix as plot_confusion_matrix +from .classifier import feature_importances as plot_feature_importances +from .classifier import precision_recall as plot_precision_recall +from .classifier import roc as plot_roc +from .clusterer import clusterer as plot_clusterer +from .clusterer import elbow_curve as plot_elbow_curve +from .clusterer import silhouette as plot_silhouette +from .regressor import outlier_candidates as plot_outlier_candidates +from .regressor import regressor as plot_regressor +from .regressor import residuals as plot_residuals +from .shared import learning_curve as plot_learning_curve +from .shared import summary_metrics as plot_summary_metrics + +__all__ = [ + "plot_classifier", + "plot_clusterer", + "plot_regressor", + "plot_summary_metrics", + "plot_learning_curve", + "plot_feature_importances", + "plot_class_proportions", + "plot_calibration_curve", + "plot_roc", + "plot_precision_recall", + "plot_confusion_matrix", + "plot_elbow_curve", + "plot_silhouette", + "plot_residuals", + "plot_outlier_candidates", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/classifier.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..e431c8090b565518315f7906c85314f7d4b0ad9b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/classifier.py @@ -0,0 +1,329 @@ +"""Define plots for classification models built with scikit-learn.""" + +from warnings import simplefilter + +import numpy as np +from sklearn import naive_bayes + +import wandb +import wandb.plot +from wandb.integration.sklearn import calculate, utils + +from . import shared + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def classifier( + model, + X_train, # noqa: N803 + X_test, # noqa: N803 + y_train, + y_test, + y_pred, + y_probas, + labels, + is_binary=False, + model_name="Classifier", + feature_names=None, + log_learning_curve=False, +): + """Generate all sklearn classifier plots supported by W&B. + + The following plots are generated: + feature importances, confusion matrix, summary metrics, + class proportions, calibration curve, roc curve, precision-recall curve. + + Should only be called with a fitted classifier (otherwise an error is thrown). + + Args: + model: (classifier) Takes in a fitted classifier. + X_train: (arr) Training set features. + y_train: (arr) Training set labels. + X_test: (arr) Test set features. + y_test: (arr) Test set labels. + y_pred: (arr) Test set predictions by the model passed. + y_probas: (arr) Test set predicted probabilities by the model passed. + labels: (list) Named labels for target variable (y). Makes plots easier to + read by replacing target values with corresponding index. + For example if `labels=['dog', 'cat', 'owl']` all 0s are + replaced by dog, 1s by cat. + is_binary: (bool) Is the model passed a binary classifier? Defaults to False + model_name: (str) Model name. Defaults to 'Classifier' + feature_names: (list) Names for features. Makes plots easier to read by + replacing feature indexes with corresponding names. + log_learning_curve: (bool) Whether or not to log the learning curve. + Defaults to False. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_classifier( + model, + X_train, + X_test, + y_train, + y_test, + y_pred, + y_probas, + ["cat", "dog"], + False, + "RandomForest", + ["barks", "drools", "plays_fetch", "breed"], + ) + ``` + """ + wandb.termlog(f"\nPlotting {model_name}.") + + if not isinstance(model, naive_bayes.MultinomialNB): + feature_importances(model, feature_names) + wandb.termlog("Logged feature importances.") + + if log_learning_curve: + shared.learning_curve(model, X_train, y_train) + wandb.termlog("Logged learning curve.") + + confusion_matrix(y_test, y_pred, labels) + wandb.termlog("Logged confusion matrix.") + + shared.summary_metrics(model, X=X_train, y=y_train, X_test=X_test, y_test=y_test) + wandb.termlog("Logged summary metrics.") + + class_proportions(y_train, y_test, labels) + wandb.termlog("Logged class proportions.") + + if not isinstance(model, naive_bayes.MultinomialNB): + calibration_curve(model, X_train, y_train, model_name) + wandb.termlog("Logged calibration curve.") + + roc(y_test, y_probas, labels) + wandb.termlog("Logged roc curve.") + + precision_recall(y_test, y_probas, labels) + wandb.termlog("Logged precision-recall curve.") + + +def roc( + y_true=None, + y_probas=None, + labels=None, + plot_micro=True, + plot_macro=True, + classes_to_plot=None, +): + """Log the receiver-operating characteristic curve. + + Args: + y_true: (arr) Test set labels. + y_probas: (arr) Test set predicted probabilities. + labels: (list) Named labels for target variable (y). Makes plots easier to + read by replacing target values with corresponding index. + For example if `labels=['dog', 'cat', 'owl']` all 0s are + replaced by dog, 1s by cat. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_roc(y_true, y_probas, labels) + ``` + """ + roc_chart = wandb.plot.roc_curve(y_true, y_probas, labels, classes_to_plot) + wandb.log({"roc": roc_chart}) + + +def confusion_matrix( + y_true=None, + y_pred=None, + labels=None, + true_labels=None, + pred_labels=None, + normalize=False, +): + """Log a confusion matrix to W&B. + + Confusion matrices depict the pattern of misclassifications by a model. + + Args: + y_true: (arr) Test set labels. + y_probas: (arr) Test set predicted probabilities. + labels: (list) Named labels for target variable (y). Makes plots easier to + read by replacing target values with corresponding index. + For example if `labels=['dog', 'cat', 'owl']` all 0s are + replaced by dog, 1s by cat. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_confusion_matrix(y_true, y_probas, labels) + ``` + """ + y_true = np.asarray(y_true) + y_pred = np.asarray(y_pred) + + not_missing = utils.test_missing(y_true=y_true, y_pred=y_pred) + correct_types = utils.test_types(y_true=y_true, y_pred=y_pred) + + if not_missing and correct_types: + confusion_matrix_chart = calculate.confusion_matrix( + y_true, + y_pred, + labels, + true_labels, + pred_labels, + normalize, + ) + + wandb.log({"confusion_matrix": confusion_matrix_chart}) + + +def precision_recall( + y_true=None, y_probas=None, labels=None, plot_micro=True, classes_to_plot=None +): + """Log a precision-recall curve to W&B. + + Precision-recall curves depict the tradeoff between positive predictive value (precision) + and true positive rate (recall) as the threshold of a classifier is shifted. + + Args: + y_true: (arr) Test set labels. + y_probas: (arr) Test set predicted probabilities. + labels: (list) Named labels for target variable (y). Makes plots easier to + read by replacing target values with corresponding index. + For example if `labels=['dog', 'cat', 'owl']` all 0s are + replaced by dog, 1s by cat. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_precision_recall(y_true, y_probas, labels) + ``` + """ + precision_recall_chart = wandb.plot.pr_curve( + y_true, y_probas, labels, classes_to_plot + ) + + wandb.log({"precision_recall": precision_recall_chart}) + + +def feature_importances( + model=None, feature_names=None, title="Feature Importance", max_num_features=50 +): + """Log a plot depicting the relative importance of each feature for a classifier's decisions. + + Should only be called with a fitted classifier (otherwise an error is thrown). + Only works with classifiers that have a feature_importances_ attribute, like trees. + + Args: + model: (clf) Takes in a fitted classifier. + feature_names: (list) Names for features. Makes plots easier to read by + replacing feature indexes with corresponding names. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_feature_importances(model, ["width", "height", "length"]) + ``` + """ + not_missing = utils.test_missing(model=model) + correct_types = utils.test_types(model=model) + model_fitted = utils.test_fitted(model) + + if not_missing and correct_types and model_fitted: + feature_importance_chart = calculate.feature_importances(model, feature_names) + wandb.log({"feature_importances": feature_importance_chart}) + + +def class_proportions(y_train=None, y_test=None, labels=None): + """Plot the distribution of target classes in training and test sets. + + Useful for detecting imbalanced classes. + + Args: + y_train: (arr) Training set labels. + y_test: (arr) Test set labels. + labels: (list) Named labels for target variable (y). Makes plots easier to + read by replacing target values with corresponding index. + For example if `labels=['dog', 'cat', 'owl']` all 0s are + replaced by dog, 1s by cat. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_class_proportions(y_train, y_test, ["dog", "cat", "owl"]) + ``` + """ + not_missing = utils.test_missing(y_train=y_train, y_test=y_test) + correct_types = utils.test_types(y_train=y_train, y_test=y_test) + if not_missing and correct_types: + y_train, y_test = np.array(y_train), np.array(y_test) + class_proportions_chart = calculate.class_proportions(y_train, y_test, labels) + + wandb.log({"class_proportions": class_proportions_chart}) + + +def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"): # noqa: N803 + """Log a plot depicting how well-calibrated the predicted probabilities of a classifier are. + + Also suggests how to calibrate an uncalibrated classifier. Compares estimated predicted + probabilities by a baseline logistic regression model, the model passed as + an argument, and by both its isotonic calibration and sigmoid calibrations. + The closer the calibration curves are to a diagonal the better. + A sine wave like curve represents an overfitted classifier, while a cosine + wave like curve represents an underfitted classifier. + By training isotonic and sigmoid calibrations of the model and comparing + their curves we can figure out whether the model is over or underfitting and + if so which calibration (sigmoid or isotonic) might help fix this. + For more details, see https://scikit-learn.org/stable/auto_examples/calibration/plot_calibration_curve.html. + + Should only be called with a fitted classifier (otherwise an error is thrown). + + Please note this function fits variations of the model on the training set when called. + + Args: + clf: (clf) Takes in a fitted classifier. + X: (arr) Training set features. + y: (arr) Training set labels. + model_name: (str) Model name. Defaults to 'Classifier' + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_calibration_curve(clf, X, y, "RandomForestClassifier") + ``` + """ + not_missing = utils.test_missing(clf=clf, X=X, y=y) + correct_types = utils.test_types(clf=clf, X=X, y=y) + is_fitted = utils.test_fitted(clf) + if not_missing and correct_types and is_fitted: + y = np.asarray(y) + if y.dtype.char == "U" or not ((y == 0) | (y == 1)).all(): + wandb.termwarn( + "This function only supports binary classification at the moment and therefore expects labels to be binary. Skipping calibration curve." + ) + return + + calibration_curve_chart = calculate.calibration_curves(clf, X, y, clf_name) + + wandb.log({"calibration_curve": calibration_curve_chart}) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/clusterer.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/clusterer.py new file mode 100644 index 0000000000000000000000000000000000000000..bced65ae105780961e434b78e6ef316297b5182e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/clusterer.py @@ -0,0 +1,146 @@ +"""Define plots for clustering models built with scikit-learn.""" + +from warnings import simplefilter + +import pandas as pd +import sklearn + +import wandb +from wandb.integration.sklearn import calculate, utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer"): # noqa: N803 + """Generates all sklearn clusterer plots supported by W&B. + + The following plots are generated: + elbow curve, silhouette plot. + + Should only be called with a fitted clusterer (otherwise an error is thrown). + + Args: + model: (clusterer) Takes in a fitted clusterer. + X_train: (arr) Training set features. + cluster_labels: (list) Names for cluster labels. Makes plots easier to read + by replacing cluster indexes with corresponding names. + labels: (list) Named labels for target variable (y). Makes plots easier to + read by replacing target values with corresponding index. + For example if `labels=['dog', 'cat', 'owl']` all 0s are + replaced by dog, 1s by cat. + model_name: (str) Model name. Defaults to 'Clusterer' + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_clusterer(kmeans, X, cluster_labels, labels, "KMeans") + ``` + """ + wandb.termlog(f"\nPlotting {model_name}.") + if isinstance(model, sklearn.cluster.KMeans): + elbow_curve(model, X_train) + wandb.termlog("Logged elbow curve.") + + silhouette(model, X_train, cluster_labels, labels=labels, kmeans=True) + + else: + silhouette(model, X_train, cluster_labels, kmeans=False) + + wandb.termlog("Logged silhouette plot.") + + +def elbow_curve( + clusterer=None, + X=None, # noqa: N803 + cluster_ranges=None, + n_jobs=1, + show_cluster_time=True, +): + """Measures and plots variance explained as a function of the number of clusters. + + Useful in picking the optimal number of clusters. + + Should only be called with a fitted clusterer (otherwise an error is thrown). + + Please note this function fits the model on the training set when called. + + Args: + model: (clusterer) Takes in a fitted clusterer. + X: (arr) Training set features. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_elbow_curve(model, X_train) + ``` + """ + if not hasattr(clusterer, "n_clusters"): + wandb.termlog( + "n_clusters attribute not in classifier. Cannot plot elbow method." + ) + return + + not_missing = utils.test_missing(clusterer=clusterer) + correct_types = utils.test_types + is_fitted = utils.test_fitted(clusterer) + + if not_missing and correct_types and is_fitted: + elbow_curve_chart = calculate.elbow_curve( + clusterer, X, cluster_ranges, n_jobs, show_cluster_time + ) + + wandb.log({"elbow_curve": elbow_curve_chart}) + + +def silhouette( + clusterer=None, + X=None, # noqa: N803 + cluster_labels=None, + labels=None, + metric="euclidean", + kmeans=True, +): + """Measures & plots silhouette coefficients. + + Silhouette coefficients near +1 indicate that the sample is far away from + the neighboring clusters. A value near 0 indicates that the sample is on or + very close to the decision boundary between two neighboring clusters and + negative values indicate that the samples might have been assigned to the wrong cluster. + + Should only be called with a fitted clusterer (otherwise an error is thrown). + + Please note this function fits the model on the training set when called. + + Args: + model: (clusterer) Takes in a fitted clusterer. + X: (arr) Training set features. + cluster_labels: (list) Names for cluster labels. Makes plots easier to read + by replacing cluster indexes with corresponding names. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_silhouette(model, X_train, ["spam", "not spam"]) + ``` + """ + not_missing = utils.test_missing(clusterer=clusterer) + correct_types = utils.test_types(clusterer=clusterer) + is_fitted = utils.test_fitted(clusterer) + + if not_missing and correct_types and is_fitted: + if isinstance(X, (pd.DataFrame)): + X = X.values # noqa: N806 + silhouette_chart = calculate.silhouette( + clusterer, X, cluster_labels, labels, metric, kmeans + ) + wandb.log({"silhouette_plot": silhouette_chart}) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/regressor.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/regressor.py new file mode 100644 index 0000000000000000000000000000000000000000..a2840a06dabfd27cbe0d980af2988f6cc1ef0604 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/regressor.py @@ -0,0 +1,121 @@ +"""Define plots for regression models built with scikit-learn.""" + +from warnings import simplefilter + +import numpy as np + +import wandb +from wandb.integration.sklearn import calculate, utils + +from . import shared + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"): # noqa: N803 + """Generates all sklearn regressor plots supported by W&B. + + The following plots are generated: + learning curve, summary metrics, residuals plot, outlier candidates. + + Should only be called with a fitted regressor (otherwise an error is thrown). + + Args: + model: (regressor) Takes in a fitted regressor. + X_train: (arr) Training set features. + y_train: (arr) Training set labels. + X_test: (arr) Test set features. + y_test: (arr) Test set labels. + model_name: (str) Model name. Defaults to 'Regressor' + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_regressor(reg, X_train, X_test, y_train, y_test, "Ridge") + ``` + """ + wandb.termlog(f"\nPlotting {model_name}.") + + shared.summary_metrics(model, X_train, y_train, X_test, y_test) + wandb.termlog("Logged summary metrics.") + + shared.learning_curve(model, X_train, y_train) + wandb.termlog("Logged learning curve.") + + outlier_candidates(model, X_train, y_train) + wandb.termlog("Logged outlier candidates.") + + residuals(model, X_train, y_train) + wandb.termlog("Logged residuals.") + + +def outlier_candidates(regressor=None, X=None, y=None): # noqa: N803 + """Measures a datapoint's influence on regression model via cook's distance. + + Instances with high influences could potentially be outliers. + + Should only be called with a fitted regressor (otherwise an error is thrown). + + Please note this function fits the model on the training set when called. + + Args: + model: (regressor) Takes in a fitted regressor. + X: (arr) Training set features. + y: (arr) Training set labels. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_outlier_candidates(model, X, y) + ``` + """ + is_missing = utils.test_missing(regressor=regressor, X=X, y=y) + correct_types = utils.test_types(regressor=regressor, X=X, y=y) + is_fitted = utils.test_fitted(regressor) + + if is_missing and correct_types and is_fitted: + y = np.asarray(y) + + outliers_chart = calculate.outlier_candidates(regressor, X, y) + wandb.log({"outlier_candidates": outliers_chart}) + + +def residuals(regressor=None, X=None, y=None): # noqa: N803 + """Measures and plots the regressor's predicted value against the residual. + + The marginal distribution of residuals is also calculated and plotted. + + Should only be called with a fitted regressor (otherwise an error is thrown). + + Please note this function fits variations of the model on the training set when called. + + Args: + regressor: (regressor) Takes in a fitted regressor. + X: (arr) Training set features. + y: (arr) Training set labels. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_residuals(model, X, y) + ``` + """ + not_missing = utils.test_missing(regressor=regressor, X=X, y=y) + correct_types = utils.test_types(regressor=regressor, X=X, y=y) + is_fitted = utils.test_fitted(regressor) + + if not_missing and correct_types and is_fitted: + y = np.asarray(y) + + residuals_chart = calculate.residuals(regressor, X, y) + wandb.log({"residuals": residuals_chart}) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/shared.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..871dbd742c960d681803f240fbccddff02000733 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/plot/shared.py @@ -0,0 +1,91 @@ +"""Define plots used by multiple sklearn model classes.""" + +from warnings import simplefilter + +import numpy as np + +import wandb +from wandb.integration.sklearn import calculate, utils + +# ignore all future warnings +simplefilter(action="ignore", category=FutureWarning) + + +def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # noqa: N803 + """Logs a chart depicting summary metrics for a model. + + Should only be called with a fitted model (otherwise an error is thrown). + + Args: + model: (clf or reg) Takes in a fitted regressor or classifier. + X: (arr) Training set features. + y: (arr) Training set labels. + X_test: (arr) Test set features. + y_test: (arr) Test set labels. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_summary_metrics(model, X_train, y_train, X_test, y_test) + ``` + """ + not_missing = utils.test_missing( + model=model, X=X, y=y, X_test=X_test, y_test=y_test + ) + correct_types = utils.test_types( + model=model, X=X, y=y, X_test=X_test, y_test=y_test + ) + model_fitted = utils.test_fitted(model) + + if not_missing and correct_types and model_fitted: + metrics_chart = calculate.summary_metrics(model, X, y, X_test, y_test) + wandb.log({"summary_metrics": metrics_chart}) + + +def learning_curve( + model=None, + X=None, # noqa: N803 + y=None, + cv=None, + shuffle=False, + random_state=None, + train_sizes=None, + n_jobs=1, + scoring=None, +): + """Logs a plot depicting model performance against dataset size. + + Please note this function fits the model to datasets of varying sizes when called. + + Args: + model: (clf or reg) Takes in a fitted regressor or classifier. + X: (arr) Dataset features. + y: (arr) Dataset labels. + + For details on the other keyword arguments, see the documentation for + `sklearn.model_selection.learning_curve`. + + Returns: + None: To see plots, go to your W&B run page then expand the 'media' tab + under 'auto visualizations'. + + Example: + ```python + wandb.sklearn.plot_learning_curve(model, X, y) + ``` + """ + not_missing = utils.test_missing(model=model, X=X, y=y) + correct_types = utils.test_types(model=model, X=X, y=y) + if not_missing and correct_types: + if train_sizes is None: + train_sizes = np.linspace(0.1, 1.0, 5) + y = np.asarray(y) + + learning_curve_chart = calculate.learning_curve( + model, X, y, cv, shuffle, random_state, train_sizes, n_jobs, scoring + ) + + wandb.log({"learning_curve": learning_curve_chart}) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/utils.py b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b2ef628bbb4165e15d33cd26feb144bd2e40d11 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/sklearn/utils.py @@ -0,0 +1,184 @@ +"""Shared utilities for the modules in wandb.sklearn.""" + +from collections.abc import Iterable, Sequence + +import numpy as np +import pandas as pd +import scipy +import sklearn + +import wandb + +chart_limit = 1000 + + +def check_against_limit(count, chart, limit=None): + if limit is None: + limit = chart_limit + if count > limit: + warn_chart_limit(limit, chart) + return True + else: + return False + + +def warn_chart_limit(limit, chart): + warning = f"using only the first {limit} datapoints to create chart {chart}" + wandb.termwarn(warning) + + +def encode_labels(df): + le = sklearn.preprocessing.LabelEncoder() + # apply le on categorical feature columns + categorical_cols = df.select_dtypes( + exclude=["int", "float", "float64", "float32", "int32", "int64"] + ).columns + df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col)) + + +def test_types(**kwargs): + test_passed = True + for k, v in kwargs.items(): + # check for incorrect types + if ( + (k == "X") + or (k == "X_test") + or (k == "y") + or (k == "y_test") + or (k == "y_true") + or (k == "y_probas") + ): + # FIXME: do this individually + if not isinstance( + v, + ( + Sequence, + Iterable, + np.ndarray, + np.generic, + pd.DataFrame, + pd.Series, + list, + ), + ): + wandb.termerror(f"{k} is not an array. Please try again.") + test_passed = False + # check for classifier types + if k == "model": + if (not sklearn.base.is_classifier(v)) and ( + not sklearn.base.is_regressor(v) + ): + wandb.termerror( + f"{k} is not a classifier or regressor. Please try again." + ) + test_passed = False + elif k == "clf" or k == "binary_clf": + if not (sklearn.base.is_classifier(v)): + wandb.termerror(f"{k} is not a classifier. Please try again.") + test_passed = False + elif k == "regressor": + if not sklearn.base.is_regressor(v): + wandb.termerror(f"{k} is not a regressor. Please try again.") + test_passed = False + elif k == "clusterer": + if not (getattr(v, "_estimator_type", None) == "clusterer"): + wandb.termerror(f"{k} is not a clusterer. Please try again.") + test_passed = False + return test_passed + + +def test_fitted(model): + try: + model.predict(np.zeros((7, 3))) + except sklearn.exceptions.NotFittedError: + wandb.termerror("Please fit the model before passing it in.") + return False + except AttributeError: + # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict`` + try: + sklearn.utils.validation.check_is_fitted( + model, + [ + "coef_", + "estimator_", + "labels_", + "n_clusters_", + "children_", + "components_", + "n_components_", + "n_iter_", + "n_batch_iter_", + "explained_variance_", + "singular_values_", + "mean_", + ], + all_or_any=any, + ) + except sklearn.exceptions.NotFittedError: + wandb.termerror("Please fit the model before passing it in.") + return False + else: + return True + except Exception: + # Assume it's fitted, since ``NotFittedError`` wasn't raised + return True + + +# Test Asummptions for plotting parameters and datasets +def test_missing(**kwargs): + test_passed = True + for k, v in kwargs.items(): + # Missing/empty params/datapoint arrays + if v is None: + wandb.termerror(f"{k} is None. Please try again.") + test_passed = False + if (k == "X") or (k == "X_test"): + if isinstance(v, scipy.sparse.csr.csr_matrix): + v = v.toarray() + elif isinstance(v, (pd.DataFrame, pd.Series)): + v = v.to_numpy() + elif isinstance(v, list): + v = np.asarray(v) + + # Warn the user about missing values + missing = 0 + missing = np.count_nonzero(pd.isnull(v)) + if missing > 0: + wandb.termwarn(f"{k} contains {missing} missing values. ") + test_passed = False + # Ensure the dataset contains only integers + non_nums = 0 + if v.ndim == 1: + non_nums = sum( + 1 + for val in v + if ( + not isinstance(val, (int, float, complex)) + and not isinstance(val, np.number) + ) + ) + else: + non_nums = sum( + 1 + for sl in v + for val in sl + if ( + not isinstance(val, (int, float, complex)) + and not isinstance(val, np.number) + ) + ) + if non_nums > 0: + wandb.termerror( + f"{k} contains values that are not numbers. Please vectorize, label encode or one hot encode {k} " + "and call the plotting function again." + ) + test_passed = False + return test_passed + + +def round_3(n): + return round(n, 3) + + +def round_2(n): + return round(n, 2) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80f60ef258a10fb2e7ad2f81594084ccd6098f6c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/__init__.py @@ -0,0 +1,10 @@ +"""wandb integration tensorboard module.""" + +from .log import _log, log, reset_state, tf_summary_to_dict +from .monkeypatch import patch, unpatch + +__all__ = [ + "patch", + "unpatch", + "log", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/log.py b/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/log.py new file mode 100644 index 0000000000000000000000000000000000000000..5062a9b0d14109aae1901e751e7864d701d2b7da --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/log.py @@ -0,0 +1,351 @@ +import io +import re +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import wandb +import wandb.util +from wandb.sdk.lib import telemetry + +if TYPE_CHECKING: + import numpy as np + + from wandb.sdk.internal.tb_watcher import TBHistory + +# We have at least the default namestep and a global step to track +# TODO: reset this structure on wandb.finish +STEPS: Dict[str, Dict[str, Any]] = { + "": {"step": 0}, + "global": {"step": 0, "last_log": None}, +} +# TODO(cling): Set these when tensorboard behavior is configured. +# We support rate limited logging by setting this to number of seconds, +# can be a floating point. +RATE_LIMIT_SECONDS: Optional[Union[float, int]] = None +IGNORE_KINDS = ["graphs"] +tensor_util = wandb.util.get_module("tensorboard.util.tensor_util") + + +# prefer tensorboard, fallback to protobuf in tensorflow when tboard isn't available +pb = wandb.util.get_module( + "tensorboard.compat.proto.summary_pb2" +) or wandb.util.get_module("tensorflow.core.framework.summary_pb2") + +Summary = pb.Summary if pb else None + + +def make_ndarray(tensor: Any) -> Optional["np.ndarray"]: + if tensor_util: + res = tensor_util.make_ndarray(tensor) + # Tensorboard can log generic objects, and we don't want to save them + if res.dtype == "object": + return None + else: + return res # type: ignore + else: + wandb.termwarn( + "Can't convert tensor summary, upgrade tensorboard with `pip" + " install tensorboard --upgrade`" + ) + return None + + +def namespaced_tag(tag: str, namespace: str = "") -> str: + if not namespace: + return tag + else: + return namespace + "/" + tag + + +def history_image_key(key: str, namespace: str = "") -> str: + """Convert invalid filesystem characters to _ for use in History keys. + + Unfortunately this means currently certain image keys will collide silently. We + implement this mapping up here in the TensorFlow stuff rather than in the History + stuff so that we don't have to store a mapping anywhere from the original keys to + the safe ones. + """ + return namespaced_tag(re.sub(r"[/\\]", "_", key), namespace) + + +def tf_summary_to_dict( # noqa: C901 + tf_summary_str_or_pb: Any, namespace: str = "" +) -> Optional[Dict[str, Any]]: + """Convert a Tensorboard Summary to a dictionary. + + Accepts a tensorflow.summary.Summary, one encoded as a string, + or a list of such encoded as strings. + """ + values = {} + if hasattr(tf_summary_str_or_pb, "summary"): + summary_pb = tf_summary_str_or_pb.summary + values[namespaced_tag("global_step", namespace)] = tf_summary_str_or_pb.step + values["_timestamp"] = tf_summary_str_or_pb.wall_time + elif isinstance(tf_summary_str_or_pb, (str, bytes, bytearray)): + summary_pb = Summary() + summary_pb.ParseFromString(tf_summary_str_or_pb) + elif hasattr(tf_summary_str_or_pb, "__iter__"): + summary_pb = [Summary() for _ in range(len(tf_summary_str_or_pb))] + for i, summary in enumerate(tf_summary_str_or_pb): + summary_pb[i].ParseFromString(summary) + if i > 0: + summary_pb[0].MergeFrom(summary_pb[i]) + summary_pb = summary_pb[0] + else: + summary_pb = tf_summary_str_or_pb + + if not hasattr(summary_pb, "value") or len(summary_pb.value) == 0: + # Ignore these, caller is responsible for handling None + return None + + def encode_images(_img_strs: List[bytes], _value: Any) -> None: + try: + from PIL import Image + except ImportError: + wandb.termwarn( + "Install pillow if you are logging images with Tensorboard. " + "To install, run `pip install pillow`.", + repeat=False, + ) + return None + + if len(_img_strs) == 0: + return None + + images: List[Union[wandb.Video, wandb.Image]] = [] + for _img_str in _img_strs: + # Supports gifs from TensorboardX + if _img_str.startswith(b"GIF"): + images.append(wandb.Video(io.BytesIO(_img_str), format="gif")) + else: + images.append(wandb.Image(Image.open(io.BytesIO(_img_str)))) + tag_idx = _value.tag.rsplit("/", 1) + if len(tag_idx) > 1 and tag_idx[1].isdigit(): + tag, idx = tag_idx + values.setdefault(history_image_key(tag, namespace), []).extend(images) + else: + values[history_image_key(_value.tag, namespace)] = images + + return None + + for value in summary_pb.value: + kind = value.WhichOneof("value") + if kind in IGNORE_KINDS: + continue + if kind == "simple_value": + values[namespaced_tag(value.tag, namespace)] = value.simple_value + elif kind == "tensor": + plugin_name = value.metadata.plugin_data.plugin_name + if plugin_name == "scalars" or plugin_name == "": + values[namespaced_tag(value.tag, namespace)] = make_ndarray( + value.tensor + ) + elif plugin_name == "images": + img_strs = value.tensor.string_val[2:] # First two items are dims. + encode_images(img_strs, value) + elif plugin_name == "histograms": + # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/histogram/summary_v2.py#L15-L26 + ndarray = make_ndarray(value.tensor) + if ndarray is None: + continue + shape = ndarray.shape + counts = [] + bins = [] + if shape[0] > 1: + bins.append(ndarray[0][0]) # Add the left most edge + for v in ndarray: + counts.append(v[2]) + bins.append(v[1]) # Add the right most edges + elif shape[0] == 1: + counts = [ndarray[0][2]] + bins = ndarray[0][:2] + if len(counts) > 0: + try: + # TODO: we should just re-bin if there are too many buckets + values[namespaced_tag(value.tag, namespace)] = wandb.Histogram( + np_histogram=(counts, bins) # type: ignore + ) + except ValueError: + wandb.termwarn( + f'Not logging key "{namespaced_tag(value.tag, namespace)}". ' + f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins", + repeat=False, + ) + elif plugin_name == "pr_curves": + pr_curve_data = make_ndarray(value.tensor) + if pr_curve_data is None: + continue + precision = pr_curve_data[-2, :].tolist() + recall = pr_curve_data[-1, :].tolist() + # TODO: (kdg) implement spec for showing additional info in tool tips + # true_pos = pr_curve_data[1,:] + # false_pos = pr_curve_data[2,:] + # true_neg = pr_curve_data[1,:] + # false_neg = pr_curve_data[1,:] + # threshold = [1.0 / n for n in range(len(true_pos), 0, -1)] + # min of each in case tensorboard ever changes their pr_curve + # to allow for different length outputs + data = [] + for i in range(min(len(precision), len(recall))): + # drop additional threshold values if they exist + if precision[i] != 0 or recall[i] != 0: + data.append((recall[i], precision[i])) + # sort data so custom chart looks the same as tb generated pr curve + # ascending recall, descending precision for the same recall values + data = sorted(data, key=lambda x: (x[0], -x[1])) + data_table = wandb.Table(data=data, columns=["recall", "precision"]) + name = namespaced_tag(value.tag, namespace) + + values[name] = wandb.plot_table( + "wandb/line/v0", + data_table, + {"x": "recall", "y": "precision"}, + {"title": f"{name} Precision v. Recall"}, + ) + elif kind == "image": + img_str = value.image.encoded_image_string + encode_images([img_str], value) + # Coming soon... + # elif kind == "audio": + # audio = wandb.Audio( + # six.BytesIO(value.audio.encoded_audio_string), + # sample_rate=value.audio.sample_rate, + # content_type=value.audio.content_type, + # ) + elif kind == "histo": + tag = namespaced_tag(value.tag, namespace) + if len(value.histo.bucket_limit) >= 3: + first = ( + value.histo.bucket_limit[0] + + value.histo.bucket_limit[0] + - value.histo.bucket_limit[1] + ) + last = ( + value.histo.bucket_limit[-2] + + value.histo.bucket_limit[-2] + - value.histo.bucket_limit[-3] + ) + np_histogram = ( + list(value.histo.bucket), + [first] + value.histo.bucket_limit[:-1] + [last], + ) + try: + # TODO: we should just re-bin if there are too many buckets + values[tag] = wandb.Histogram(np_histogram=np_histogram) # type: ignore + except ValueError: + wandb.termwarn( + f"Not logging key {tag!r}. " + f"Histograms must have fewer than {wandb.Histogram.MAX_LENGTH} bins", + repeat=False, + ) + else: + # TODO: is there a case where we can render this? + wandb.termwarn( + f"Not logging key {tag!r}. Found a histogram with only 2 bins.", + repeat=False, + ) + # TODO(jhr): figure out how to share this between userspace and internal process or dont + # elif value.tag == "_hparams_/session_start_info": + # if wandb.util.get_module("tensorboard.plugins.hparams"): + # from tensorboard.plugins.hparams import plugin_data_pb2 + # + # plugin_data = plugin_data_pb2.HParamsPluginData() # + # plugin_data.ParseFromString(value.metadata.plugin_data.content) + # for key, param in six.iteritems(plugin_data.session_start_info.hparams): + # if not wandb.run.config.get(key): + # wandb.run.config[key] = ( + # param.number_value or param.string_value or param.bool_value + # ) + # else: + # wandb.termerror( + # "Received hparams tf.summary, but could not import " + # "the hparams plugin from tensorboard" + # ) + return values + + +def reset_state() -> None: + """Internal method for resetting state, called by wandb.finish().""" + global STEPS + STEPS = {"": {"step": 0}, "global": {"step": 0, "last_log": None}} + + +def _log( + tf_summary_str_or_pb: Any, + history: Optional["TBHistory"] = None, + step: int = 0, + namespace: str = "", + **kwargs: Any, +) -> None: + """Logs a tfsummary to wandb. + + Can accept a tf summary string or parsed event. Will use wandb.run.history unless a + history object is passed. Can optionally namespace events. Results are committed + when step increases for this namespace. + + NOTE: This assumes that events being passed in are in chronological order + """ + global STEPS + global RATE_LIMIT_SECONDS + # To handle multiple global_steps, we keep track of them here instead + # of the global log + last_step = STEPS.get(namespace, {"step": 0}) + + # Commit our existing data if this namespace increased its step + commit = False + if last_step["step"] < step: + commit = True + + log_dict = tf_summary_to_dict(tf_summary_str_or_pb, namespace) + if log_dict is None: + # not an event, just return + return + + # Pass timestamp to history for loading historic data + timestamp = log_dict.get("_timestamp", time.time()) + # Store our initial timestamp + if STEPS["global"]["last_log"] is None: + STEPS["global"]["last_log"] = timestamp + # Rollup events that share the same step across namespaces + if commit and step == STEPS["global"]["step"]: + commit = False + # Always add the biggest global_step key for non-default namespaces + if step > STEPS["global"]["step"]: + STEPS["global"]["step"] = step + if namespace != "": + log_dict["global_step"] = STEPS["global"]["step"] + + # Keep internal step counter + STEPS[namespace] = {"step": step} + + if commit: + # Only commit our data if we're below the rate limit or don't have one + if ( + RATE_LIMIT_SECONDS is None + or timestamp - STEPS["global"]["last_log"] >= RATE_LIMIT_SECONDS + ): + if history is None: + if wandb.run is not None: + wandb.run._log({}) + else: + history.add({}) + + STEPS["global"]["last_log"] = timestamp + + if history is None: + if wandb.run is not None: + wandb.run._log(log_dict, commit=False) + else: + history._row_update(log_dict) + + +def log(tf_summary_str_or_pb: Any, step: int = 0, namespace: str = "") -> None: + if wandb.run is None: + raise wandb.Error( + "You must call `wandb.init()` before calling `wandb.tensorflow.log`" + ) + + with telemetry.context() as tel: + tel.feature.tensorboard_log = True + + _log(tf_summary_str_or_pb, namespace=namespace, step=step) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/monkeypatch.py b/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/monkeypatch.py new file mode 100644 index 0000000000000000000000000000000000000000..a0db49aaf79bad018642eea9e1d555f65827ba6e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/tensorboard/monkeypatch.py @@ -0,0 +1,186 @@ +"""monkeypatch: patch code to add tensorboard hooks.""" + +import os +import re +import socket +from typing import Any, Optional + +import wandb +import wandb.util + +TENSORBOARD_C_MODULE = "tensorflow.python.ops.gen_summary_ops" +TENSORBOARD_X_MODULE = "tensorboardX.writer" +TENSORFLOW_PY_MODULE = "tensorflow.python.summary.writer.writer" +TENSORBOARD_WRITER_MODULE = "tensorboard.summary.writer.event_file_writer" +TENSORBOARD_PYTORCH_MODULE = "torch.utils.tensorboard.writer" + + +def unpatch() -> None: + for module, method in wandb.patched["tensorboard"]: + writer = wandb.util.get_module(module, lazy=False) + setattr(writer, method, getattr(writer, f"orig_{method}")) + wandb.patched["tensorboard"] = [] + + +def patch( + save: bool = True, + tensorboard_x: Optional[bool] = None, + pytorch: Optional[bool] = None, + root_logdir: str = "", +) -> None: + if len(wandb.patched["tensorboard"]) > 0: + raise ValueError( + "Tensorboard already patched. Call `wandb.tensorboard.unpatch()` first; " + "remove `sync_tensorboard=True` from `wandb.init`; " + "or only call `wandb.tensorboard.patch` once." + ) + + # TODO: Some older versions of tensorflow don't require tensorboard to be present. + # we may want to lift this requirement, but it's safer to have it for now + wandb.util.get_module( + "tensorboard", required="Please install tensorboard package", lazy=False + ) + c_writer = wandb.util.get_module(TENSORBOARD_C_MODULE, lazy=False) + py_writer = wandb.util.get_module(TENSORFLOW_PY_MODULE, lazy=False) + tb_writer = wandb.util.get_module(TENSORBOARD_WRITER_MODULE, lazy=False) + pt_writer = wandb.util.get_module(TENSORBOARD_PYTORCH_MODULE, lazy=False) + tbx_writer = wandb.util.get_module(TENSORBOARD_X_MODULE, lazy=False) + + if not pytorch and not tensorboard_x and c_writer: + _patch_tensorflow2( + writer=c_writer, + module=TENSORBOARD_C_MODULE, + save=save, + root_logdir=root_logdir, + ) + # This is for tensorflow <= 1.15 (tf.compat.v1.summary.FileWriter) + if py_writer: + _patch_file_writer( + writer=py_writer, + module=TENSORFLOW_PY_MODULE, + save=save, + root_logdir=root_logdir, + ) + if tb_writer: + _patch_file_writer( + writer=tb_writer, + module=TENSORBOARD_WRITER_MODULE, + save=save, + root_logdir=root_logdir, + ) + if pt_writer: + _patch_file_writer( + writer=pt_writer, + module=TENSORBOARD_PYTORCH_MODULE, + save=save, + root_logdir=root_logdir, + ) + if tbx_writer: + _patch_file_writer( + writer=tbx_writer, + module=TENSORBOARD_X_MODULE, + save=save, + root_logdir=root_logdir, + ) + if not c_writer and not tb_writer and not tb_writer: + wandb.termerror("Unsupported tensorboard configuration") + + +def _patch_tensorflow2( + writer: Any, + module: Any, + save: bool = True, + root_logdir: str = "", +) -> None: + # This configures TensorFlow 2 style Tensorboard logging + old_csfw_func = writer.create_summary_file_writer + logdir_hist = [] + + def new_csfw_func(*args: Any, **kwargs: Any) -> Any: + logdir = ( + kwargs["logdir"].numpy().decode("utf8") + if hasattr(kwargs["logdir"], "numpy") + else kwargs["logdir"] + ) + logdir_hist.append(logdir) + root_logdir_arg = root_logdir + + if len(set(logdir_hist)) > 1 and root_logdir == "": + wandb.termwarn( + "When using several event log directories, " + 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`' + ) + # if the logdir contains the hostname, the writer was not given a logdir. + # In this case, the generated logdir + # is generated and ends with the hostname, update the root_logdir to match. + hostname = socket.gethostname() + search = re.search(rf"-\d+_{hostname}", logdir) + if search: + root_logdir_arg = logdir[: search.span()[1]] + elif root_logdir is not None and not os.path.abspath(logdir).startswith( + os.path.abspath(root_logdir) + ): + wandb.termwarn( + "Found log directory outside of given root_logdir, " + f"dropping given root_logdir for event file in {logdir}" + ) + root_logdir_arg = "" + + _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg) + return old_csfw_func(*args, **kwargs) + + writer.orig_create_summary_file_writer = old_csfw_func + writer.create_summary_file_writer = new_csfw_func + wandb.patched["tensorboard"].append([module, "create_summary_file_writer"]) + + +def _patch_file_writer( + writer: Any, + module: Any, + save: bool = True, + root_logdir: str = "", +) -> None: + # This configures non-TensorFlow Tensorboard logging, or tensorflow <= 1.15 + logdir_hist = [] + + class TBXEventFileWriter(writer.EventFileWriter): + def __init__(self, logdir: str, *args: Any, **kwargs: Any) -> None: + logdir_hist.append(logdir) + root_logdir_arg = root_logdir + if len(set(logdir_hist)) > 1 and root_logdir == "": + wandb.termwarn( + "When using several event log directories, " + 'please call `wandb.tensorboard.patch(root_logdir="...")` before `wandb.init`' + ) + + # if the logdir contains the hostname, the writer was not given a logdir. + # In this case, the logdir is generated and ends with the hostname, + # update the root_logdir to match. + hostname = socket.gethostname() + search = re.search(rf"-\d+_{hostname}", logdir) + if search: + root_logdir_arg = logdir[: search.span()[1]] + + elif root_logdir is not None and not os.path.abspath(logdir).startswith( + os.path.abspath(root_logdir) + ): + wandb.termwarn( + "Found log directory outside of given root_logdir, " + f"dropping given root_logdir for event file in {logdir}" + ) + root_logdir_arg = "" + + _notify_tensorboard_logdir(logdir, save=save, root_logdir=root_logdir_arg) + + super().__init__(logdir, *args, **kwargs) + + writer.orig_EventFileWriter = writer.EventFileWriter + writer.EventFileWriter = TBXEventFileWriter + wandb.patched["tensorboard"].append([module, "EventFileWriter"]) + + +def _notify_tensorboard_logdir( + logdir: str, save: bool = True, root_logdir: str = "" +) -> None: + if wandb.run is not None: + wandb.run._tensorboard_callback(logdir, save=save, root_logdir=root_logdir) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/tensorflow/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/tensorflow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a5838c7d3eb5e5ea9e50b420bf42a605084dc6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/tensorflow/__init__.py @@ -0,0 +1,5 @@ +"""api.""" + +from wandb.integration.tensorboard import log + +from .estimator_hook import WandbHook diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/tensorflow/estimator_hook.py b/.venv/lib/python3.13/site-packages/wandb/integration/tensorflow/estimator_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..88d58abe3c6ee81759ce01a8e7254e1e324034de --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/tensorflow/estimator_hook.py @@ -0,0 +1,54 @@ +import tensorflow as tf + +import wandb +from wandb.sdk.lib import telemetry + +if hasattr(tf.estimator, "SessionRunHook"): + # In tf 1.14 and beyond, SessionRunHook is in the estimator package. + SessionRunHook = tf.estimator.SessionRunHook + SessionRunArgs = tf.estimator.SessionRunArgs +else: + # In older versions it's in train. + SessionRunHook = tf.train.SessionRunHook + SessionRunArgs = tf.train.SessionRunArgs + +if hasattr(tf.train, "get_global_step"): + get_global_step = tf.train.get_global_step +else: + get_global_step = tf.compat.v1.train.get_global_step + +if hasattr(tf.summary, "merge_all"): + merge_all_summaries = tf.summary.merge_all +else: + merge_all_summaries = tf.compat.v1.summary.merge_all + + +class WandbHook(SessionRunHook): + def __init__(self, summary_op=None, steps_per_log=1000, history=None): + self._summary_op = summary_op + self._steps_per_log = steps_per_log + self._history = history + + with telemetry.context() as tel: + tel.feature.estimator_hook = True + + def begin(self): + if wandb.run is None: + raise wandb.Error("You must call `wandb.init()` before calling `WandbHook`") + if self._summary_op is None: + self._summary_op = merge_all_summaries() + self._step = -1 + + def before_run(self, run_context): + return SessionRunArgs( + {"summary": self._summary_op, "global_step": get_global_step()} + ) + + def after_run(self, run_context, run_values): + step = run_values.results["global_step"] + if step % self._steps_per_log == 0: + wandb.tensorboard._log( + run_values.results["summary"], + history=self._history, + step=step, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/torch/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/torch/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/torch/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f942b8d957cae1fe27506fbd1b2c05e5eb3aa1c Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/torch/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/torch/__pycache__/wandb_torch.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/torch/__pycache__/wandb_torch.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e507c9cb57807f640d78e5d021ee9e23925f48a2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/torch/__pycache__/wandb_torch.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/torch/wandb_torch.py b/.venv/lib/python3.13/site-packages/wandb/integration/torch/wandb_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..28d5971f8fd640ef6315aa22cb795163bc9d0724 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/torch/wandb_torch.py @@ -0,0 +1,554 @@ +"""PyTorch-specific functionality.""" + +import itertools +from functools import reduce +from operator import mul +from typing import TYPE_CHECKING, List + +import wandb +from wandb import util +from wandb.data_types import Node + +torch = None + +if TYPE_CHECKING: + from torch import Tensor + from torch.nn import Module + + +def nested_shape(array_or_tuple, seen=None): + """Figure out the shape of tensors possibly embedded in tuples. + + for example: + - [0,0] returns (2) + - ([0,0], [0,0]) returns (2,2) + - (([0,0], [0,0]),[0,0]) returns ((2,2),2). + """ + if seen is None: + seen = set() + if hasattr(array_or_tuple, "size"): + # pytorch tensors use V.size() to get size of tensor + return list(array_or_tuple.size()) + elif hasattr(array_or_tuple, "get_shape"): + # tensorflow uses V.get_shape() to get size of tensor + return array_or_tuple.get_shape().as_list() + elif hasattr(array_or_tuple, "shape"): + return array_or_tuple.shape + + seen.add(id(array_or_tuple)) + try: + # treat object as iterable + return [ + nested_shape(item, seen) if id(item) not in seen else 0 + for item in list(array_or_tuple) + ] + except TypeError: + # object is not actually iterable + # LB: Maybe we should throw an error? + return [] + + +LOG_TRACK_COUNT, LOG_TRACK_THRESHOLD = range(2) + + +def log_track_init(log_freq: int) -> List[int]: + """Create tracking structure used by log_track_update.""" + log_track = [0, 0] + log_track[LOG_TRACK_THRESHOLD] = log_freq + return log_track + + +def log_track_update(log_track: int) -> bool: + """Count (log_track[0]) up to threshold (log_track[1]), reset count (log_track[0]) and return true when reached.""" + log_track[LOG_TRACK_COUNT] += 1 + if log_track[LOG_TRACK_COUNT] < log_track[LOG_TRACK_THRESHOLD]: + return False + log_track[LOG_TRACK_COUNT] = 0 + return True + + +class TorchHistory: + """History methods specific to PyTorch.""" + + def __init__(self): + global torch + torch = wandb.util.get_module("torch", "Could not import torch") + self._hook_handles = {} + self._num_bins = 64 + self._is_cuda_histc_supported = None + self.hook_torch = TorchGraph.hook_torch + + def add_log_parameters_hook( + self, + module: "Module", + name: str = "", + prefix: str = "", + log_freq: int = 0, + ) -> None: + """This instruments hooks into the pytorch module. + + log parameters after a forward pass + log_freq - log gradients/parameters every N batches. + """ + # if name is not None: + prefix = prefix + name + + if not hasattr(module, "_wandb_hook_names"): + module._wandb_hook_names = [] + + def parameter_log_hook(module, input_, output, log_track): + if not log_track_update(log_track): + return + for name, parameter in module.named_parameters(): + # for pytorch 0.3 Variables + if isinstance(parameter, torch.autograd.Variable): + data = parameter.data + else: + data = parameter + self.log_tensor_stats(data.cpu(), "parameters/" + prefix + name) + + log_track_params = log_track_init(log_freq) + try: + hook = module.register_forward_hook( + lambda mod, inp, outp: parameter_log_hook( + mod, inp, outp, log_track_params + ) + ) + self._hook_handles["parameters/" + prefix] = hook + module._wandb_hook_names.append("parameters/" + prefix) + except RuntimeError as e: + wandb.termwarn( + f"Trying to register forward_hook failed ({e}) - skipping parameter tracking." + ) + + def add_log_gradients_hook( + self, + module: "Module", + name: str = "", + prefix: str = "", + log_freq: int = 0, + ) -> None: + """This instruments hooks into the PyTorch module slog gradients after a backward pass. + + Args: + module: torch.nn.Module - the module to instrument + name: str - the name of the module + prefix: str - the prefix to add to the name + log_freq: log gradients/parameters every N batches + """ + # if name is not None: + prefix = prefix + name + + if not hasattr(module, "_wandb_hook_names"): + module._wandb_hook_names = [] + + for name, parameter in module.named_parameters(): + if parameter.requires_grad: + log_track_grad = log_track_init(log_freq) + module._wandb_hook_names.append("gradients/" + prefix + name) + self._hook_variable_gradient_stats( + parameter, "gradients/" + prefix + name, log_track_grad + ) + + def log_tensor_stats(self, tensor, name): # noqa: C901 + """Add distribution statistics on a tensor's elements to the current History entry.""" + # TODO Handle the case of duplicate names. + if isinstance(tensor, (tuple, list)): + while isinstance(tensor, (tuple, list)) and isinstance( + tensor[0], (tuple, list) + ): + tensor = [item for sublist in tensor for item in sublist] + tensor = torch.cat([t.detach().clone().reshape(-1) for t in tensor]) + + tensor = tensor.detach().clone() + # checking for inheritance from _TensorBase didn't work for some reason + if not hasattr(tensor, "shape"): + cls = type(tensor) + raise TypeError(f"Expected Tensor, not {cls.__module__}.{cls.__name__}") + + # Sparse tensors have a bunch of implicit zeros. In order to histo them correctly, + # we have to count them up and add them to the histo ourselves. + sparse_zeros = None + if tensor.is_sparse: + # Have to call this on a sparse tensor before most other ops. + tensor = tensor.cpu().coalesce() + + backing_values = tensor._values() + sparse_zeros = tensor.numel() - backing_values.numel() + tensor = backing_values + + flat = tensor.reshape(-1) + + if flat.is_cuda: + if self._is_cuda_histc_supported is None: + try: + flat.histc(bins=self._num_bins) + except RuntimeError: + self._is_cuda_histc_supported = False + else: + self._is_cuda_histc_supported = True + + # As of torch 1.0.1.post2+nightly, float16 cuda summary ops are not supported (convert to float32) + if not self._is_cuda_histc_supported: + flat = flat.cpu() + elif not isinstance( + flat, (torch.cuda.FloatTensor, torch.cuda.DoubleTensor) + ): + flat = flat.type(torch.cuda.FloatTensor) + + # Since we use histc, we need to make sure that torch supports the operation on CPU, + # otherwise we'll get a runtime error. Hence, we need to upcast to float32. + if not flat.is_cuda and not isinstance( + flat, (torch.FloatTensor, torch.DoubleTensor) + ): + flat = flat.type(torch.FloatTensor) + + # Skip logging if all values are nan or inf or the tensor is empty. + if self._no_finite_values(flat): + return + + # Remove nans and infs if present. There's no good way to represent that in histograms. + flat = self._remove_infs_nans(flat) + + tmin = flat.min().item() + tmax = flat.max().item() + if sparse_zeros: + # If we've got zeros to add in, make sure zero is in the hist range. + tmin = 0 if tmin > 0 else tmin + tmax = 0 if tmax < 0 else tmax + # Anecdotally, this can somehow happen sometimes. Maybe a precision error + # in min()/max() above. Swap here to prevent a runtime error. + # If all values are equal, just return a single bin. + if tmin > tmax: + tmin, tmax = tmax, tmin + if tmin == tmax: + tensor = torch.Tensor([flat.numel()]) + tensor = tensor.cpu().clone().detach() + bins = torch.Tensor([tmin, tmax]) + else: + tensor = flat.histc(bins=self._num_bins, min=tmin, max=tmax) + tensor = tensor.cpu().detach().clone() + bins = torch.linspace(tmin, tmax, steps=self._num_bins + 1) + + # Add back zeroes from a sparse tensor. + if sparse_zeros: + bins_np = bins.numpy() + tensor_np = tensor.numpy() + bin_idx = 0 + num_buckets = len(bins_np) - 1 + for i in range(num_buckets): + start = bins_np[i] + end = bins_np[i + 1] + # There are 3 cases to consider here, all of which mean we've found the right bucket + # 1. The bucket range contains zero. + # 2. The bucket range lower bound *is* zero. + # 3. This is the last bucket and the bucket range upper bound is zero. + if (start <= 0 and end > 0) or (i == num_buckets - 1 and end == 0): + bin_idx = i + break + + tensor_np[bin_idx] += sparse_zeros + tensor = torch.Tensor(tensor_np) + bins = torch.Tensor(bins_np) + + wandb.run._log( + {name: wandb.Histogram(np_histogram=(tensor.tolist(), bins.tolist()))}, + commit=False, + ) + + def _hook_variable_gradient_stats(self, var, name, log_track): + """Logs a Variable's gradient's distribution statistics next time backward() is called on it.""" + if not isinstance(var, torch.autograd.Variable): + cls = type(var) + raise TypeError( + f"Expected torch.Variable, not {cls.__module__}.{cls.__name__}" + ) + + handle = self._hook_handles.get(name) + if handle is not None and self._torch_hook_handle_is_valid(handle): + raise ValueError(f'A hook has already been set under name "{name}"') + + def _callback(grad, log_track): + if not log_track_update(log_track): + return + self.log_tensor_stats(grad.data, name) + + handle = var.register_hook(lambda grad: _callback(grad, log_track)) + self._hook_handles[name] = handle + return handle + + def unhook_all(self): + for handle in self._hook_handles.values(): + handle.remove() + self._hook_handles = {} + + def unhook(self, name): + handle = self._hook_handles.pop(name) + handle.remove() + + def _torch_hook_handle_is_valid(self, handle): + d = handle.hooks_dict_ref() + if d is None: + return False + else: + return handle.id in d + + def _no_finite_values(self, tensor: "Tensor") -> bool: + return tensor.shape == torch.Size([0]) or (~torch.isfinite(tensor)).all().item() + + def _remove_infs_nans(self, tensor: "Tensor") -> "Tensor": + if not torch.isfinite(tensor).all(): + tensor = tensor[torch.isfinite(tensor)] + + return tensor + + +class TorchGraph(wandb.data_types.Graph): + def __init__(self): + super().__init__("torch") + self._graph_hooks = set() + + @classmethod + def hook_torch(cls, model, criterion=None, graph_idx=0): + wandb.termlog("logging graph, to disable use `wandb.watch(log_graph=False)`") + graph = TorchGraph() + graph.hook_torch_modules(model, criterion, graph_idx=graph_idx) + return graph + + def create_forward_hook(self, name, graph_idx): + graph = self + + def after_forward_hook(module, input, output): + if id(module) not in self._graph_hooks: + # hook already processed -> noop + return + if not isinstance(output, tuple): + output = (output,) + parameters = [ + (pname, list(param.size())) + for pname, param in module.named_parameters() + ] + + node = Node( + id=id(module), + name=name, + class_name=str(module), + output_shape=nested_shape(output), + parameters=parameters, + num_parameters=[reduce(mul, size, 1) for (pname, size) in parameters], + ) + graph.nodes_by_id[id(module)] = node + for param in module.parameters(): + graph.nodes_by_id[id(param)] = node + graph.add_node(node) + if not graph.criterion_passed: + if hasattr(output[0], "grad_fn"): + graph.criterion = output[0].grad_fn + elif ( + isinstance(output[0], list) + and output[0] + and hasattr(output[0][0], "grad_fn") + ): + graph.criterion = output[0][0].grad_fn + + # hook has been processed + self._graph_hooks -= {id(module)} + + if not self._graph_hooks: + # we went through the entire graph + wandb.run.summary[f"graph_{graph_idx}"] = self + + return after_forward_hook + + def hook_torch_modules( + self, module, criterion=None, prefix=None, graph_idx=0, parent=None + ): + torch = util.get_module("torch", "Could not import torch") + layers = 0 + graph = self + if hasattr(module, "_wandb_watch_called") and module._wandb_watch_called: + raise ValueError( + "You can only call `wandb.watch` once per model. Pass a new instance of the model if you need to call wandb.watch again in your code." + ) + module._wandb_watch_called = True + if criterion: + graph.criterion = criterion + graph.criterion_passed = True + + for name, sub_module in module.named_children(): + name = name or str(layers) + if prefix: + name = prefix + "." + name + layers += 1 + if not isinstance(sub_module, torch.nn.Module): + # TODO: Why does this happen? + break + + # Trying to support torch >0.3 making this code complicated + # We want a list of types that we should recurse into + # Torch 0.3 uses containers + # 0.4 has ModuleList + # 0.4.1 has ModuleDict + module_types = [ + getattr(torch.nn, module_classname) + for module_classname in ( + "Container", + "Sequential", + "ModuleList", + "ModuleDict", + ) + if hasattr(torch.nn, module_classname) + ] + if parent is None: + parent = module + + if isinstance(sub_module, tuple(module_types)): + self.hook_torch_modules(sub_module, prefix=name, parent=parent) + else: + self._graph_hooks |= {id(sub_module)} + try: + graph_hook = sub_module.register_forward_hook( + self.create_forward_hook(name, graph_idx) + ) + wandb.run._torch._hook_handles[ + "topology/" + str(id(graph_hook)) + ] = graph_hook + if not hasattr(parent, "_wandb_hook_names"): + # should never happen but let's be extra safe + parent._wandb_hook_names = [] + parent._wandb_hook_names.append("topology/" + str(id(graph_hook))) + except RuntimeError as e: + wandb.termwarn( + f"Trying to register forward_hook failed ({e}) - skipping graph tracking.", + repeat=False, + ) + + @classmethod + def from_torch_layers(cls, module_graph, variable): + """Recover something like neural net layers from PyTorch Module's and the compute graph from a Variable. + + Example output for a multi-layer RNN. We confusingly assign shared embedding values + to the encoder, but ordered next to the decoder. + + rnns.0.linear.module.weight_raw rnns.0 + rnns.0.linear.module.bias rnns.0 + rnns.1.linear.module.weight_raw rnns.1 + rnns.1.linear.module.bias rnns.1 + rnns.2.linear.module.weight_raw rnns.2 + rnns.2.linear.module.bias rnns.2 + rnns.3.linear.module.weight_raw rnns.3 + rnns.3.linear.module.bias rnns.3 + decoder.weight encoder + decoder.bias decoder + """ + # TODO: We're currently not using this, but I left it here in case we want to resurrect! - CVP + torch = util.get_module("torch", "Could not import torch") + + module_nodes_by_hash = {id(n): n for n in module_graph.nodes} + module_parameter_nodes = [ + n for n in module_graph.nodes if isinstance(n.obj, torch.nn.Parameter) + ] + + names_by_pid = {id(n.obj): n.name for n in module_parameter_nodes} + + reachable_param_nodes = module_graph[0].reachable_descendents() + reachable_params = {} + module_reachable_params = {} + names = {} + for pid, reachable_nodes in reachable_param_nodes.items(): + node = module_nodes_by_hash[pid] + if not isinstance(node.obj, torch.nn.Module): + continue + module = node.obj + reachable_params = {} # by object id + module_reachable_params[id(module)] = reachable_params + names[node.name] = set() + for reachable_hash in reachable_nodes: + reachable = module_nodes_by_hash[reachable_hash] + if isinstance(reachable.obj, torch.nn.Parameter): + param = reachable.obj + reachable_params[id(param)] = param + names[node.name].add(names_by_pid[id(param)]) + + # we look for correspondences between sets of parameters used in subtrees of the + # computation graph and sets of parameters contained in subtrees of the module + # graph + node_depths = {id(n): d for n, d in module_graph[0].descendent_bfs()} + parameter_module_names = {} + parameter_modules = {} + for param_node in ( + n for n in module_graph.nodes if isinstance(n.obj, torch.nn.Parameter) + ): + pid = id(param_node.obj) + best_node = None + best_depth = None + best_reachable_params = None + for node in module_graph.nodes: + if not isinstance(node.obj, torch.nn.Module): + continue + module = node.obj + reachable_params = module_reachable_params[id(module)] + if pid in reachable_params: + depth = node_depths[id(node)] + if best_node is None or (len(reachable_params), depth) <= ( + len(best_reachable_params), + best_depth, + ): + best_node = node + best_depth = depth + best_reachable_params = reachable_params + + parameter_modules[pid] = best_node + parameter_module_names[param_node.name] = best_node.name + + # contains all parameters but only a minimal set of modules necessary + # to contain them (and which ideally correspond to conceptual layers) + reduced_module_graph = cls() + rmg_ids = itertools.count() + rmg_root = Node(id=next(rmg_ids), node=module_graph[0]) + reduced_module_graph.add_node(rmg_root) + reduced_module_graph.root = rmg_root + rmg_nodes_by_pid = {} + + module_nodes_by_pid = {id(n.obj): n for n in module_graph.nodes} + + compute_graph, compute_node_vars = cls.from_torch_compute_graph(variable) + for node, _ in reversed(list(compute_graph[0].ancestor_bfs())): + param = compute_node_vars.get(node.id) + pid = id(param) + if not isinstance(param, torch.nn.Parameter): + continue + if pid not in module_nodes_by_pid: + # not all Parameters that occur in the compute graph come from the Module graph + continue + + # add the nodes in the order we want to display them on the frontend + mid = id(parameter_modules[pid].obj) + if mid in rmg_nodes_by_pid: + rmg_module = rmg_nodes_by_pid[mid] + else: + rmg_module = rmg_nodes_by_pid[mid] = Node( + id=next(rmg_ids), node=module_nodes_by_pid[mid] + ) + reduced_module_graph.add_node(rmg_module) + reduced_module_graph.add_edge(rmg_root, rmg_module) + + rmg_param = Node(id=next(rmg_ids), node=module_nodes_by_pid[pid]) + rmg_nodes_by_pid[pid] = rmg_param + reduced_module_graph.add_node(rmg_param) + + reduced_module_graph.add_edge(rmg_module, rmg_param) + return reduced_module_graph + + @classmethod + def node_from_module(cls, nid, module): + numpy = util.get_module("numpy", "Could not import numpy") + + node = wandb.Node() + node.id = nid + node.child_parameters = 0 + for parameter in module.parameters(): + node.child_parameters += numpy.prod(parameter.size()) + node.class_name = type(module).__name__ + + return node diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c25ae320c8c946a25150b870770710f58a8bf713 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/__init__.py @@ -0,0 +1,11 @@ +"""Tools for integrating with [`ultralytics`](https://docs.ultralytics.com/). + +Ultralytics is a computer vision framework for training and deploying YOLOv8 models. +""" + +from wandb.integration.ultralytics.callback import ( + WandBUltralyticsCallback, + add_wandb_callback, +) + +__all__ = ("WandBUltralyticsCallback", "add_wandb_callback") diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/bbox_utils.py b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/bbox_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f87bd167bfb45cb4f9a996df301985fa46a0dc9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/bbox_utils.py @@ -0,0 +1,215 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm +from ultralytics.engine.results import Results +from ultralytics.models.yolo.detect import DetectionPredictor +from ultralytics.utils import ops + +import wandb + + +def scale_bounding_box_to_original_image_shape( + box: torch.Tensor, + resized_image_shape: Tuple, + original_image_shape: Tuple, + ratio_pad: bool, +) -> List[int]: + """YOLOv8 resizes images during training and the label values are normalized based on this resized shape. + + This function rescales the bounding box labels to the original + image shape. + + Reference: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/yolo/utils/callbacks/comet.py#L105 + """ + resized_image_height, resized_image_width = resized_image_shape + # Convert normalized xywh format predictions to xyxy in resized scale format + box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width) + # Scale box predictions from resized image scale back to original image scale + box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad) + # # Convert bounding box format from xyxy to xywh for Comet logging + box = ops.xyxy2xywh(box) + return box.tolist() + + +def get_ground_truth_bbox_annotations( + img_idx: int, image_path: str, batch: Dict, class_name_map: Dict = None +) -> List[Dict[str, Any]]: + """Get ground truth bounding box annotation data in the form required for `wandb.Image` overlay system.""" + indices = batch["batch_idx"] == img_idx + bboxes = batch["bboxes"][indices] + if len(batch["cls"][indices]): + cls_labels = batch["cls"][indices].squeeze(1).tolist() + else: + cls_labels = [] + + class_name_map_reverse = {v: k for k, v in class_name_map.items()} + + if len(bboxes) == 0: + wandb.termwarn( + f"Image: {image_path} has no bounding boxes labels", repeat=False + ) + return None + + if len(batch["cls"][indices]): + cls_labels = batch["cls"][indices].squeeze(1).tolist() + else: + cls_labels = [] + + if class_name_map: + cls_labels = [str(class_name_map[label]) for label in cls_labels] + + original_image_shape = batch["ori_shape"][img_idx] + resized_image_shape = batch["resized_shape"][img_idx] + ratio_pad = batch["ratio_pad"][img_idx] + + data = [] + for box, label in zip(bboxes, cls_labels): + box = scale_bounding_box_to_original_image_shape( + box, resized_image_shape, original_image_shape, ratio_pad + ) + data.append( + { + "position": { + "middle": [int(box[0]), int(box[1])], + "width": int(box[2]), + "height": int(box[3]), + }, + "domain": "pixel", + "class_id": class_name_map_reverse[label], + "box_caption": label, + } + ) + + return data + + +def get_mean_confidence_map( + classes: List, confidence: List, class_id_to_label: Dict +) -> Dict[str, float]: + """Get Mean-confidence map from the predictions to be logged into a `wandb.Table`.""" + confidence_map = {v: [] for _, v in class_id_to_label.items()} + for class_idx, confidence_value in zip(classes, confidence): + confidence_map[class_id_to_label[class_idx]].append(confidence_value) + updated_confidence_map = {} + for label, confidence_list in confidence_map.items(): + if len(confidence_list) > 0: + updated_confidence_map[label] = sum(confidence_list) / len(confidence_list) + else: + updated_confidence_map[label] = 0 + return updated_confidence_map + + +def get_boxes(result: Results) -> Tuple[Dict, Dict]: + """Convert an ultralytics prediction result into metadata for the `wandb.Image` overlay system.""" + boxes = result.boxes.xywh.long().numpy() + classes = result.boxes.cls.long().numpy() + confidence = result.boxes.conf.numpy() + class_id_to_label = {int(k): str(v) for k, v in result.names.items()} + mean_confidence_map = get_mean_confidence_map( + classes, confidence, class_id_to_label + ) + box_data = [] + for idx in range(len(boxes)): + box_data.append( + { + "position": { + "middle": [int(boxes[idx][0]), int(boxes[idx][1])], + "width": int(boxes[idx][2]), + "height": int(boxes[idx][3]), + }, + "domain": "pixel", + "class_id": int(classes[idx]), + "box_caption": class_id_to_label[int(classes[idx])], + "scores": {"confidence": float(confidence[idx])}, + } + ) + boxes = { + "predictions": { + "box_data": box_data, + "class_labels": class_id_to_label, + }, + } + return boxes, mean_confidence_map + + +def plot_bbox_predictions( + result: Results, model_name: str, table: Optional[wandb.Table] = None +) -> Union[wandb.Table, Tuple[wandb.Image, Dict, Dict]]: + """Plot the images with the W&B overlay system. + + The `wandb.Image` is either added to a `wandb.Table` or returned. + """ + result = result.to("cpu") + boxes, mean_confidence_map = get_boxes(result) + image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes) + if table is not None: + table.add_data( + model_name, + image, + len(boxes["predictions"]["box_data"]), + mean_confidence_map, + result.speed, + ) + return table + return image, boxes["predictions"], mean_confidence_map + + +def plot_detection_validation_results( + dataloader: Any, + class_label_map: Dict, + model_name: str, + predictor: DetectionPredictor, + table: wandb.Table, + max_validation_batches: int, + epoch: Optional[int] = None, +) -> wandb.Table: + """Plot validation results in a table.""" + data_idx = 0 + num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size + max_validation_batches = min(max_validation_batches, num_dataloader_batches) + for batch_idx, batch in enumerate(dataloader): + prediction_results = predictor(batch["im_file"]) + progress_bar_result_iterable = tqdm( + enumerate(prediction_results), + total=len(prediction_results), + desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", + ) + for img_idx, prediction_result in progress_bar_result_iterable: + prediction_result = prediction_result.to("cpu") + _, prediction_box_data, mean_confidence_map = plot_bbox_predictions( + prediction_result, model_name + ) + try: + ground_truth_data = get_ground_truth_bbox_annotations( + img_idx, batch["im_file"][img_idx], batch, class_label_map + ) + wandb_image = wandb.Image( + batch["im_file"][img_idx], + boxes={ + "ground-truth": { + "box_data": ground_truth_data, + "class_labels": class_label_map, + }, + "predictions": { + "box_data": prediction_box_data["box_data"], + "class_labels": class_label_map, + }, + }, + ) + table_rows = [ + data_idx, + batch_idx, + wandb_image, + mean_confidence_map, + prediction_result.speed, + ] + table_rows = [epoch] + table_rows if epoch is not None else table_rows + table_rows = [model_name] + table_rows + table.add_data(*table_rows) + data_idx += 1 + except TypeError: + pass + if batch_idx + 1 == max_validation_batches: + break + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/callback.py b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/callback.py new file mode 100644 index 0000000000000000000000000000000000000000..c47605a9579dafece3ef364bd3916f4cfd3c96f9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/callback.py @@ -0,0 +1,528 @@ +import copy +from datetime import datetime +from typing import Callable, Dict, Optional, Union + +from packaging import version + +try: + import dill as pickle +except ImportError: + import pickle + +import wandb +from wandb.sdk.lib import telemetry + +try: + import torch + import ultralytics + from tqdm.auto import tqdm + + if version.parse(ultralytics.__version__) > version.parse("8.0.238"): + wandb.termwarn( + """This integration is tested and supported for ultralytics v8.0.238 and below. + Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`.""", + repeat=False, + ) + + from ultralytics.models import YOLO + from ultralytics.models.sam.predict import Predictor as SAMPredictor + from ultralytics.models.yolo.classify import ( + ClassificationPredictor, + ClassificationTrainer, + ClassificationValidator, + ) + from ultralytics.models.yolo.detect import ( + DetectionPredictor, + DetectionTrainer, + DetectionValidator, + ) + from ultralytics.models.yolo.pose import PosePredictor, PoseTrainer, PoseValidator + from ultralytics.models.yolo.segment import ( + SegmentationPredictor, + SegmentationTrainer, + SegmentationValidator, + ) + from ultralytics.utils.torch_utils import de_parallel + + try: + from ultralytics.yolo.utils import RANK, __version__ + except ModuleNotFoundError: + from ultralytics.utils import RANK, __version__ + + from wandb.integration.ultralytics.bbox_utils import ( + plot_bbox_predictions, + plot_detection_validation_results, + ) + from wandb.integration.ultralytics.classification_utils import ( + plot_classification_predictions, + plot_classification_validation_results, + ) + from wandb.integration.ultralytics.mask_utils import ( + plot_mask_predictions, + plot_sam_predictions, + plot_segmentation_validation_results, + ) + from wandb.integration.ultralytics.pose_utils import ( + plot_pose_predictions, + plot_pose_validation_results, + ) +except Exception as e: + wandb.Error(e) + + +TRAINER_TYPE = Union[ + ClassificationTrainer, DetectionTrainer, SegmentationTrainer, PoseTrainer +] +VALIDATOR_TYPE = Union[ + ClassificationValidator, DetectionValidator, SegmentationValidator, PoseValidator +] +PREDICTOR_TYPE = Union[ + ClassificationPredictor, + DetectionPredictor, + SegmentationPredictor, + PosePredictor, + SAMPredictor, +] + + +class WandBUltralyticsCallback: + """Stateful callback for logging to W&B. + + In particular, it will log model checkpoints, predictions, and + ground-truth annotations with interactive overlays for bounding boxes + to Weights & Biases Tables during training, validation and prediction + for a `ultratytics` workflow. + + Example: + ```python + from ultralytics.yolo.engine.model import YOLO + from wandb.yolov8 import add_wandb_callback + + # initialize YOLO model + model = YOLO("yolov8n.pt") + + # add wandb callback + add_wandb_callback( + model, max_validation_batches=2, enable_model_checkpointing=True + ) + + # train + model.train(data="coco128.yaml", epochs=5, imgsz=640) + + # validate + model.val() + + # perform inference + model(["img1.jpeg", "img2.jpeg"]) + ``` + + Args: + model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type + `ultralytics.yolo.engine.model.YOLO`. + epoch_logging_interval: (int) interval to log the prediction visualizations + during training. + max_validation_batches: (int) maximum number of validation batches to log to + a table per epoch. + enable_model_checkpointing: (bool) enable logging model checkpoints as + artifacts at the end of eveny epoch if set to `True`. + visualize_skeleton: (bool) visualize pose skeleton by drawing lines connecting + keypoints for human pose. + """ + + def __init__( + self, + model: YOLO, + epoch_logging_interval: int = 1, + max_validation_batches: int = 1, + enable_model_checkpointing: bool = False, + visualize_skeleton: bool = False, + ) -> None: + self.epoch_logging_interval = epoch_logging_interval + self.max_validation_batches = max_validation_batches + self.enable_model_checkpointing = enable_model_checkpointing + self.visualize_skeleton = visualize_skeleton + self.task = model.task + self.task_map = model.task_map + self.model_name = ( + model.overrides["model"].split(".")[0] + if "model" in model.overrides + else None + ) + self._make_tables() + self._make_predictor(model) + self.supported_tasks = ["detect", "segment", "pose", "classify"] + self.prompts = None + self.run_id = None + self.train_epoch = None + + def _make_tables(self): + if self.task in ["detect", "segment"]: + validation_columns = [ + "Data-Index", + "Batch-Index", + "Image", + "Mean-Confidence", + "Speed", + ] + train_columns = ["Epoch"] + validation_columns + self.train_validation_table = wandb.Table( + columns=["Model-Name"] + train_columns + ) + self.validation_table = wandb.Table( + columns=["Model-Name"] + validation_columns + ) + self.prediction_table = wandb.Table( + columns=[ + "Model-Name", + "Image", + "Num-Objects", + "Mean-Confidence", + "Speed", + ] + ) + elif self.task == "classify": + classification_columns = [ + "Image", + "Predicted-Category", + "Prediction-Confidence", + "Top-5-Prediction-Categories", + "Top-5-Prediction-Confindence", + "Probabilities", + "Speed", + ] + validation_columns = ["Data-Index", "Batch-Index"] + classification_columns + validation_columns.insert(3, "Ground-Truth-Category") + self.train_validation_table = wandb.Table( + columns=["Model-Name", "Epoch"] + validation_columns + ) + self.validation_table = wandb.Table( + columns=["Model-Name"] + validation_columns + ) + self.prediction_table = wandb.Table( + columns=["Model-Name"] + classification_columns + ) + elif self.task == "pose": + validation_columns = [ + "Data-Index", + "Batch-Index", + "Image-Ground-Truth", + "Image-Prediction", + "Num-Instances", + "Mean-Confidence", + "Speed", + ] + train_columns = ["Epoch"] + validation_columns + self.train_validation_table = wandb.Table( + columns=["Model-Name"] + train_columns + ) + self.validation_table = wandb.Table( + columns=["Model-Name"] + validation_columns + ) + self.prediction_table = wandb.Table( + columns=[ + "Model-Name", + "Image-Prediction", + "Num-Instances", + "Mean-Confidence", + "Speed", + ] + ) + + def _make_predictor(self, model: YOLO): + overrides = copy.deepcopy(model.overrides) + overrides["conf"] = 0.1 + self.predictor = self.task_map[self.task]["predictor"](overrides=overrides) + self.predictor.callbacks = {} + self.predictor.args.save = False + self.predictor.args.save_txt = False + self.predictor.args.save_crop = False + self.predictor.args.verbose = None + + def _save_model(self, trainer: TRAINER_TYPE): + model_checkpoint_artifact = wandb.Artifact(f"run_{wandb.run.id}_model", "model") + checkpoint_dict = { + "epoch": trainer.epoch, + "best_fitness": trainer.best_fitness, + "model": copy.deepcopy(de_parallel(self.model)).half(), + "ema": copy.deepcopy(trainer.ema.ema).half(), + "updates": trainer.ema.updates, + "optimizer": trainer.optimizer.state_dict(), + "train_args": vars(trainer.args), + "date": datetime.now().isoformat(), + "version": __version__, + } + checkpoint_path = trainer.wdir / f"epoch{trainer.epoch}.pt" + torch.save(checkpoint_dict, checkpoint_path, pickle_module=pickle) + model_checkpoint_artifact.add_file(checkpoint_path) + wandb.log_artifact( + model_checkpoint_artifact, aliases=[f"epoch_{trainer.epoch}"] + ) + + def on_train_start(self, trainer: TRAINER_TYPE): + with telemetry.context(run=wandb.run) as tel: + tel.feature.ultralytics_yolov8 = True + wandb.config.train = vars(trainer.args) + self.run_id = wandb.run.id + + @torch.no_grad() + def on_fit_epoch_end(self, trainer: DetectionTrainer): + if self.task in self.supported_tasks and self.train_epoch != trainer.epoch: + self.train_epoch = trainer.epoch + if (self.train_epoch + 1) % self.epoch_logging_interval == 0: + validator = trainer.validator + dataloader = validator.dataloader + class_label_map = validator.names + self.device = next(trainer.model.parameters()).device + if isinstance(trainer.model, torch.nn.parallel.DistributedDataParallel): + model = trainer.model.module + else: + model = trainer.model + self.model = copy.deepcopy(model).eval().to(self.device) + self.predictor.setup_model(model=self.model, verbose=False) + if self.task == "pose": + self.train_validation_table = plot_pose_validation_results( + dataloader=dataloader, + class_label_map=class_label_map, + model_name=self.model_name, + predictor=self.predictor, + visualize_skeleton=self.visualize_skeleton, + table=self.train_validation_table, + max_validation_batches=self.max_validation_batches, + epoch=trainer.epoch, + ) + elif self.task == "segment": + self.train_validation_table = plot_segmentation_validation_results( + dataloader=dataloader, + class_label_map=class_label_map, + model_name=self.model_name, + predictor=self.predictor, + table=self.train_validation_table, + max_validation_batches=self.max_validation_batches, + epoch=trainer.epoch, + ) + elif self.task == "detect": + self.train_validation_table = plot_detection_validation_results( + dataloader=dataloader, + class_label_map=class_label_map, + model_name=self.model_name, + predictor=self.predictor, + table=self.train_validation_table, + max_validation_batches=self.max_validation_batches, + epoch=trainer.epoch, + ) + elif self.task == "classify": + self.train_validation_table = ( + plot_classification_validation_results( + dataloader=dataloader, + model_name=self.model_name, + predictor=self.predictor, + table=self.train_validation_table, + max_validation_batches=self.max_validation_batches, + epoch=trainer.epoch, + ) + ) + if self.enable_model_checkpointing: + self._save_model(trainer) + trainer.model.to(self.device) + + def on_train_end(self, trainer: TRAINER_TYPE): + if self.task in self.supported_tasks: + wandb.log({"Train-Table": self.train_validation_table}, commit=False) + + def on_val_start(self, validator: VALIDATOR_TYPE): + wandb.run or wandb.init( + project=validator.args.project or "YOLOv8", + job_type="validation_" + validator.args.task, + ) + + @torch.no_grad() + def on_val_end(self, trainer: VALIDATOR_TYPE): + if self.task in self.supported_tasks: + validator = trainer + dataloader = validator.dataloader + class_label_map = validator.names + if self.task == "pose": + self.validation_table = plot_pose_validation_results( + dataloader=dataloader, + class_label_map=class_label_map, + model_name=self.model_name, + predictor=self.predictor, + visualize_skeleton=self.visualize_skeleton, + table=self.validation_table, + max_validation_batches=self.max_validation_batches, + ) + elif self.task == "segment": + self.validation_table = plot_segmentation_validation_results( + dataloader=dataloader, + class_label_map=class_label_map, + model_name=self.model_name, + predictor=self.predictor, + table=self.validation_table, + max_validation_batches=self.max_validation_batches, + ) + elif self.task == "detect": + self.validation_table = plot_detection_validation_results( + dataloader=dataloader, + class_label_map=class_label_map, + model_name=self.model_name, + predictor=self.predictor, + table=self.validation_table, + max_validation_batches=self.max_validation_batches, + ) + elif self.task == "classify": + self.validation_table = plot_classification_validation_results( + dataloader=dataloader, + model_name=self.model_name, + predictor=self.predictor, + table=self.validation_table, + max_validation_batches=self.max_validation_batches, + ) + wandb.log({"Validation-Table": self.validation_table}, commit=False) + + def on_predict_start(self, predictor: PREDICTOR_TYPE): + wandb.run or wandb.init( + project=predictor.args.project or "YOLOv8", + config=vars(predictor.args), + job_type="prediction_" + predictor.args.task, + ) + if isinstance(predictor, SAMPredictor): + self.prompts = copy.deepcopy(predictor.prompts) + self.prediction_table = wandb.Table(columns=["Image"]) + + def on_predict_end(self, predictor: PREDICTOR_TYPE): + wandb.config.prediction_configs = vars(predictor.args) + if self.task in self.supported_tasks: + for result in tqdm(predictor.results): + if self.task == "pose": + self.prediction_table = plot_pose_predictions( + result, + self.model_name, + self.visualize_skeleton, + self.prediction_table, + ) + elif self.task == "segment": + if isinstance(predictor, SegmentationPredictor): + self.prediction_table = plot_mask_predictions( + result, self.model_name, self.prediction_table + ) + elif isinstance(predictor, SAMPredictor): + self.prediction_table = plot_sam_predictions( + result, self.prompts, self.prediction_table + ) + elif self.task == "detect": + self.prediction_table = plot_bbox_predictions( + result, self.model_name, self.prediction_table + ) + elif self.task == "classify": + self.prediction_table = plot_classification_predictions( + result, self.model_name, self.prediction_table + ) + + wandb.log({"Prediction-Table": self.prediction_table}, commit=False) + + @property + def callbacks(self) -> Dict[str, Callable]: + """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging.""" + return { + "on_train_start": self.on_train_start, + "on_fit_epoch_end": self.on_fit_epoch_end, + "on_train_end": self.on_train_end, + "on_val_start": self.on_val_start, + "on_val_end": self.on_val_end, + "on_predict_start": self.on_predict_start, + "on_predict_end": self.on_predict_end, + } + + +# TODO: Add epoch interval +def add_wandb_callback( + model: YOLO, + epoch_logging_interval: int = 1, + enable_model_checkpointing: bool = False, + enable_train_validation_logging: bool = True, + enable_validation_logging: bool = True, + enable_prediction_logging: bool = True, + max_validation_batches: Optional[int] = 1, + visualize_skeleton: Optional[bool] = True, +): + """Function to add the `WandBUltralyticsCallback` callback to the `YOLO` model. + + Example: + ```python + from ultralytics.yolo.engine.model import YOLO + from wandb.yolov8 import add_wandb_callback + + # initialize YOLO model + model = YOLO("yolov8n.pt") + + # add wandb callback + add_wandb_callback( + model, max_validation_batches=2, enable_model_checkpointing=True + ) + + # train + model.train(data="coco128.yaml", epochs=5, imgsz=640) + + # validate + model.val() + + # perform inference + model(["img1.jpeg", "img2.jpeg"]) + ``` + + Args: + model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type + `ultralytics.yolo.engine.model.YOLO`. + epoch_logging_interval: (int) interval to log the prediction visualizations + during training. + enable_model_checkpointing: (bool) enable logging model checkpoints as + artifacts at the end of eveny epoch if set to `True`. + enable_train_validation_logging: (bool) enable logging the predictions and + ground-truths as interactive image overlays on the images from + the validation dataloader to a `wandb.Table` along with + mean-confidence of the predictions per-class at the end of each + training epoch. + enable_validation_logging: (bool) enable logging the predictions and + ground-truths as interactive image overlays on the images from the + validation dataloader to a `wandb.Table` along with + mean-confidence of the predictions per-class at the end of + validation. + enable_prediction_logging: (bool) enable logging the predictions and + ground-truths as interactive image overlays on the images from the + validation dataloader to a `wandb.Table` along with mean-confidence + of the predictions per-class at the end of each prediction. + max_validation_batches: (Optional[int]) maximum number of validation batches to log to + a table per epoch. + visualize_skeleton: (Optional[bool]) visualize pose skeleton by drawing lines connecting + keypoints for human pose. + + Returns: + An instance of `ultralytics.yolo.engine.model.YOLO` with the `WandBUltralyticsCallback`. + """ + if RANK in [-1, 0]: + wandb_callback = WandBUltralyticsCallback( + copy.deepcopy(model), + epoch_logging_interval, + max_validation_batches, + enable_model_checkpointing, + visualize_skeleton, + ) + callbacks = wandb_callback.callbacks + if not enable_train_validation_logging: + _ = callbacks.pop("on_fit_epoch_end") + _ = callbacks.pop("on_train_end") + if not enable_validation_logging: + _ = callbacks.pop("on_val_start") + _ = callbacks.pop("on_val_end") + if not enable_prediction_logging: + _ = callbacks.pop("on_predict_start") + _ = callbacks.pop("on_predict_end") + for event, callback_fn in callbacks.items(): + model.add_callback(event, callback_fn) + else: + wandb.termerror( + "The RANK of the process to add the callbacks was neither 0 or " + "-1. No Weights & Biases callbacks were added to this instance " + "of the YOLO model." + ) + return model diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/classification_utils.py b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/classification_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9db6db72768dd2d16b7b44ee1ad0514a6322752f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/classification_utils.py @@ -0,0 +1,83 @@ +from typing import Any, Optional + +import numpy as np +from tqdm.auto import tqdm +from ultralytics.engine.results import Results +from ultralytics.models.yolo.classify import ClassificationPredictor + +import wandb + + +def plot_classification_predictions( + result: Results, + model_name: str, + table: Optional[wandb.Table] = None, + original_image: Optional[np.array] = None, +): + """Plot classification prediction results to a `wandb.Table` if the table is passed otherwise return the data.""" + result = result.to("cpu") + probabilities = result.probs + probabilities_list = probabilities.data.numpy().tolist() + class_id_to_label = {int(k): str(v) for k, v in result.names.items()} + original_image = ( + wandb.Image(original_image) + if original_image is not None + else wandb.Image(result.orig_img) + ) + table_row = [ + model_name, + original_image, + class_id_to_label[int(probabilities.top1)], + probabilities.top1conf, + [class_id_to_label[int(class_idx)] for class_idx in list(probabilities.top5)], + [probabilities_list[int(class_idx)] for class_idx in list(probabilities.top5)], + { + class_id_to_label[int(class_idx)]: probability + for class_idx, probability in enumerate(probabilities_list) + }, + result.speed, + ] + if table is not None: + table.add_data(*table_row) + return table + return class_id_to_label, table_row + + +def plot_classification_validation_results( + dataloader: Any, + model_name: str, + predictor: ClassificationPredictor, + table: wandb.Table, + max_validation_batches: int, + epoch: Optional[int] = None, +) -> wandb.Table: + """Plot classification results to a `wandb.Table`.""" + data_idx = 0 + num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size + max_validation_batches = min(max_validation_batches, num_dataloader_batches) + for batch_idx, batch in enumerate(dataloader): + image_batch = [ + image for image in np.transpose(batch["img"].numpy(), (0, 2, 3, 1)) + ] + ground_truth = batch["cls"].numpy().tolist() + progress_bar_result_iterable = tqdm( + range(max_validation_batches), + desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", + ) + for img_idx in progress_bar_result_iterable: + try: + prediction_result = predictor(image_batch[img_idx])[0] + class_id_to_label, table_row = plot_classification_predictions( + prediction_result, model_name, original_image=image_batch[img_idx] + ) + table_row = [data_idx, batch_idx] + table_row[1:] + table_row.insert(3, class_id_to_label[ground_truth[img_idx]]) + table_row = [epoch] + table_row if epoch is not None else table_row + table_row = [model_name] + table_row + table.add_data(*table_row) + data_idx += 1 + except Exception: + pass + if batch_idx + 1 == max_validation_batches: + break + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/mask_utils.py b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/mask_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5392d964b30ae1330349ce0978ef58eb3e5ee5dd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/mask_utils.py @@ -0,0 +1,202 @@ +from typing import Dict, Optional, Tuple + +import cv2 +import numpy as np +from tqdm.auto import tqdm +from ultralytics.engine.results import Results +from ultralytics.models.yolo.segment import SegmentationPredictor +from ultralytics.utils.ops import scale_image + +import wandb +from wandb.integration.ultralytics.bbox_utils import ( + get_ground_truth_bbox_annotations, + get_mean_confidence_map, +) + + +def instance_mask_to_semantic_mask(instance_mask, class_indices): + height, width, num_instances = instance_mask.shape + semantic_mask = np.zeros((height, width), dtype=np.uint8) + for i in range(num_instances): + instance_map = instance_mask[:, :, i] + class_index = class_indices[i] + semantic_mask[instance_map == 1] = class_index + return semantic_mask + + +def get_boxes_and_masks(result: Results) -> Tuple[Dict, Dict, Dict]: + boxes = result.boxes.xywh.long().numpy() + classes = result.boxes.cls.long().numpy() + confidence = result.boxes.conf.numpy() + class_id_to_label = {int(k): str(v) for k, v in result.names.items()} + class_id_to_label.update({len(result.names.items()): "background"}) + mean_confidence_map = get_mean_confidence_map( + classes, confidence, class_id_to_label + ) + masks = None + if result.masks is not None: + scaled_instance_mask = scale_image( + np.transpose(result.masks.data.numpy(), (1, 2, 0)), + result.orig_img[:, :, ::-1].shape, + ) + scaled_semantic_mask = instance_mask_to_semantic_mask( + scaled_instance_mask, classes.tolist() + ) + scaled_semantic_mask[scaled_semantic_mask == 0] = len(result.names.items()) + masks = { + "predictions": { + "mask_data": scaled_semantic_mask, + "class_labels": class_id_to_label, + } + } + box_data, total_confidence = [], 0.0 + for idx in range(len(boxes)): + box_data.append( + { + "position": { + "middle": [int(boxes[idx][0]), int(boxes[idx][1])], + "width": int(boxes[idx][2]), + "height": int(boxes[idx][3]), + }, + "domain": "pixel", + "class_id": int(classes[idx]), + "box_caption": class_id_to_label[int(classes[idx])], + "scores": {"confidence": float(confidence[idx])}, + } + ) + total_confidence += float(confidence[idx]) + + boxes = { + "predictions": { + "box_data": box_data, + "class_labels": class_id_to_label, + }, + } + return boxes, masks, mean_confidence_map + + +def plot_mask_predictions( + result: Results, model_name: str, table: Optional[wandb.Table] = None +) -> Tuple[wandb.Image, Dict, Dict, Dict]: + result = result.to("cpu") + boxes, masks, mean_confidence_map = get_boxes_and_masks(result) + image = wandb.Image(result.orig_img[:, :, ::-1], boxes=boxes, masks=masks) + if table is not None: + table.add_data( + model_name, + image, + len(boxes["predictions"]["box_data"]), + mean_confidence_map, + result.speed, + ) + return table + return image, masks, boxes["predictions"], mean_confidence_map + + +def structure_prompts_and_image(image: np.array, prompt: Dict) -> Dict: + wb_box_data = [] + if prompt["bboxes"] is not None: + wb_box_data.append( + { + "position": { + "middle": [prompt["bboxes"][0], prompt["bboxes"][1]], + "width": prompt["bboxes"][2], + "height": prompt["bboxes"][3], + }, + "domain": "pixel", + "class_id": 1, + "box_caption": "Prompt-Box", + } + ) + if prompt["points"] is not None: + image = image.copy().astype(np.uint8) + image = cv2.circle( + image, tuple(prompt["points"]), 5, (0, 255, 0), -1, lineType=cv2.LINE_AA + ) + wb_box_data = { + "prompts": { + "box_data": wb_box_data, + "class_labels": {1: "Prompt-Box"}, + } + } + return image, wb_box_data + + +def plot_sam_predictions( + result: Results, prompt: Dict, table: wandb.Table +) -> wandb.Table: + result = result.to("cpu") + image = result.orig_img[:, :, ::-1] + image, wb_box_data = structure_prompts_and_image(image, prompt) + image = wandb.Image( + image, + boxes=wb_box_data, + masks={ + "predictions": { + "mask_data": np.squeeze(result.masks.data.cpu().numpy().astype(int)), + "class_labels": {0: "Background", 1: "Prediction"}, + } + }, + ) + table.add_data(image) + return table + + +def plot_segmentation_validation_results( + dataloader, + class_label_map, + model_name: str, + predictor: SegmentationPredictor, + table: wandb.Table, + max_validation_batches: int, + epoch: Optional[int] = None, +): + data_idx = 0 + num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size + max_validation_batches = min(max_validation_batches, num_dataloader_batches) + for batch_idx, batch in enumerate(dataloader): + prediction_results = predictor(batch["im_file"]) + progress_bar_result_iterable = tqdm( + enumerate(prediction_results), + total=len(prediction_results), + desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", + ) + for img_idx, prediction_result in progress_bar_result_iterable: + prediction_result = prediction_result.to("cpu") + ( + _, + prediction_mask_data, + prediction_box_data, + mean_confidence_map, + ) = plot_mask_predictions(prediction_result, model_name) + try: + ground_truth_data = get_ground_truth_bbox_annotations( + img_idx, batch["im_file"][img_idx], batch, class_label_map + ) + wandb_image = wandb.Image( + batch["im_file"][img_idx], + boxes={ + "ground-truth": { + "box_data": ground_truth_data, + "class_labels": class_label_map, + }, + "predictions": prediction_box_data, + }, + masks=prediction_mask_data, + ) + table_rows = [ + data_idx, + batch_idx, + wandb_image, + mean_confidence_map, + prediction_result.speed, + ] + table_rows = [epoch] + table_rows if epoch is not None else table_rows + table_rows = [model_name] + table_rows + table.add_data(*table_rows) + data_idx += 1 + except TypeError: + pass + if batch_idx + 1 == max_validation_batches: + break + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/pose_utils.py b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/pose_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4cef5a90353d063e281ae54836d76675029aaa4e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/ultralytics/pose_utils.py @@ -0,0 +1,103 @@ +from typing import Any, Optional + +import numpy as np +from PIL import Image +from tqdm.auto import tqdm +from ultralytics.engine.results import Results +from ultralytics.models.yolo.pose import PosePredictor +from ultralytics.utils.plotting import Annotator + +import wandb +from wandb.integration.ultralytics.bbox_utils import ( + get_boxes, + get_ground_truth_bbox_annotations, +) + + +def annotate_keypoint_results(result: Results, visualize_skeleton: bool): + annotator = Annotator(np.ascontiguousarray(result.orig_img[:, :, ::-1])) + key_points = result.keypoints.data.numpy() + for idx in range(key_points.shape[0]): + annotator.kpts(key_points[idx], kpt_line=visualize_skeleton) + return annotator.im + + +def annotate_keypoint_batch(image_path: str, keypoints: Any, visualize_skeleton: bool): + with Image.open(image_path) as original_image: + original_image = np.ascontiguousarray(original_image) + annotator = Annotator(original_image) + annotator.kpts(keypoints.numpy(), kpt_line=visualize_skeleton) + return annotator.im + + +def plot_pose_predictions( + result: Results, + model_name: str, + visualize_skeleton: bool, + table: Optional[wandb.Table] = None, +): + result = result.to("cpu") + boxes, mean_confidence_map = get_boxes(result) + annotated_image = annotate_keypoint_results(result, visualize_skeleton) + prediction_image = wandb.Image(annotated_image, boxes=boxes) + table_row = [ + model_name, + prediction_image, + len(boxes["predictions"]["box_data"]), + mean_confidence_map, + result.speed, + ] + if table is not None: + table.add_data(*table_row) + return table + return table_row + + +def plot_pose_validation_results( + dataloader, + class_label_map, + model_name: str, + predictor: PosePredictor, + visualize_skeleton: bool, + table: wandb.Table, + max_validation_batches: int, + epoch: Optional[int] = None, +) -> wandb.Table: + data_idx = 0 + num_dataloader_batches = len(dataloader.dataset) // dataloader.batch_size + max_validation_batches = min(max_validation_batches, num_dataloader_batches) + for batch_idx, batch in enumerate(dataloader): + prediction_results = predictor(batch["im_file"]) + progress_bar_result_iterable = tqdm( + enumerate(prediction_results), + total=len(prediction_results), + desc=f"Generating Visualizations for batch-{batch_idx + 1}/{max_validation_batches}", + ) + for img_idx, prediction_result in progress_bar_result_iterable: + prediction_result = prediction_result.to("cpu") + table_row = plot_pose_predictions( + prediction_result, model_name, visualize_skeleton + ) + ground_truth_image = wandb.Image( + annotate_keypoint_batch( + batch["im_file"][img_idx], + batch["keypoints"][img_idx], + visualize_skeleton, + ), + boxes={ + "ground-truth": { + "box_data": get_ground_truth_bbox_annotations( + img_idx, batch["im_file"][img_idx], batch, class_label_map + ), + "class_labels": class_label_map, + }, + }, + ) + table_row = [data_idx, batch_idx, ground_truth_image] + table_row[1:] + table_row = [epoch] + table_row if epoch is not None else table_row + table_row = [model_name] + table_row + table.add_data(*table_row) + data_idx += 1 + if batch_idx + 1 == max_validation_batches: + break + return table diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/weave/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d83d8eb03d3a48a867a5d695ccc524b1f5eefad3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__init__.py @@ -0,0 +1,6 @@ +"""Weave integration for W&B.""" + +from .interface import RunPath, active_run_path +from .weave import setup + +__all__ = ("active_run_path", "RunPath", "setup") diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37f1bc77de278bc4840f4400c8c02ff849d8ef8b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/interface.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33562f66b6a581797e03a5b35353fba2010d7573 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/interface.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/weave.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/weave.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1840c81d5946f782a706a4aeaddf02650f021d69 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/integration/weave/__pycache__/weave.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/weave/interface.py b/.venv/lib/python3.13/site-packages/wandb/integration/weave/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1db8ab1d2eb9739e00586bca6f06c7e6f6590b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/weave/interface.py @@ -0,0 +1,49 @@ +"""Internal APIs for integrating with weave. + +The public functions here are intended to be called by weave and care should +be taken to maintain backward compatibility. +""" + +from __future__ import annotations + +import dataclasses + +from wandb.sdk import wandb_setup + + +@dataclasses.dataclass(frozen=True) +class RunPath: + entity: str + """The entity to which the run is logging. Never empty.""" + + project: str + """The project to which the run is logging. Never empty.""" + + run_id: str + """The run's ID. Never empty.""" + + +def active_run_path() -> RunPath | None: + """Returns the path of an initialized, unfinished run. + + Returns None if all initialized runs are finished. If there is + more than one active run, an arbitrary path is returned. + The run may be finished by the time its path is returned. + + Thread-safe. + """ + singleton = wandb_setup.singleton() + + if ( + (run := singleton.most_recent_active_run) + and run.entity + and run.project + and run.id + ): + return RunPath( + entity=run.entity, + project=run.project, + run_id=run.id, + ) + + return None diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/weave/weave.py b/.venv/lib/python3.13/site-packages/wandb/integration/weave/weave.py new file mode 100644 index 0000000000000000000000000000000000000000..dd61aa002b94315218b2cfc97ee50092bca98597 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/weave/weave.py @@ -0,0 +1,121 @@ +"""Integration module for automatic Weave initialization with W&B. + +This module provides automatic initialization of Weave when: +1. Weave is installed +2. A W&B run is active with a project +3. Weave is imported (init-on-import) + +The integration can be disabled by setting the WANDB_DISABLE_WEAVE environment variable. +""" + +from __future__ import annotations + +import importlib.util +import os +import sys +import threading + +import wandb + +_weave_init_lock = threading.Lock() + +_DISABLE_WEAVE = "WANDB_DISABLE_WEAVE" +_WEAVE_PACKAGE_NAME = "weave" + +# This list is adapted from https://github.com/wandb/weave/blob/master/weave/integrations/__init__.py +_AVAILABLE_WEAVE_INTEGRATIONS = [ + "anthropic", + "autogen", + "cohere", + "crewai", + "dspy", + "google.genai", + "groq", + "huggingface_hub.inference", + "instructor", + "langchain", + "litellm", + "llama_index", + "mcp", + "mistral", + "notdiamond", + "openai", + "agents", + "smolagents", + "verdict", + "verifiers", + "vertexai", +] + + +def setup(entity: str | None, project: str | None) -> None: + """Set up automatic Weave initialization for the current W&B run. + + Args: + project: The W&B project name to use for Weave initialization. + """ + # We can't or shouldn't init weave; return + if os.getenv(_DISABLE_WEAVE): + return + if not project: + return + + # Use entity/project when available; otherwise fall back to project only + if entity: + project_path = f"{entity}/{project}" + else: + project_path = project + + # If weave is not yet imported, we can't init it from here. Instead, we'll + # rely on the weave library itself to detect a run and init itself. + if _WEAVE_PACKAGE_NAME not in sys.modules: + _maybe_suggest_weave_installation() + return + + # If weave has already been imported, initialize immediately + wandb.termlog("Initializing weave.") + try: + _weave_init(project_path) + except Exception as e: + wandb.termwarn(f"Failed to automatically initialize weave: {e}") + + +def _maybe_suggest_weave_installation() -> None: + """Suggest Weave installation or import if any target library is imported.""" + imported_libs = [lib for lib in _AVAILABLE_WEAVE_INTEGRATIONS if lib in sys.modules] + if not imported_libs: + return + + weave_spec = importlib.util.find_spec(_WEAVE_PACKAGE_NAME) + if weave_spec is None: + # Weave is not installed + msg = ( + "Use W&B Weave for improved LLM call tracing. Install Weave with " + "`pip install weave` then add `import weave` to the top of your script." + ) + else: + # Weave is installed but not imported + msg = ( + "Use W&B Weave for improved LLM call tracing. Weave is installed " + "but not imported. Add `import weave` to the top of your script." + ) + + wandb.termlog(f"Detected [{', '.join(imported_libs)}] in use.", repeat=False) + wandb.termlog(msg, repeat=False) + wandb.termlog( + "For more information, check out the docs at: https://weave-docs.wandb.ai/", + repeat=False, + ) + + +def _weave_init(project_path: str) -> None: + """Call weave.init(), assuming weave has been imported. + + Patched in tests. + """ + # Lock because weave.init() is not thread-safe. + with _weave_init_lock: + # The import is fast because weave should have been imported. + import weave + + weave.init(project_path) diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/xgboost/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/xgboost/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..052083e0cf7034a469a82582422ff7fee62965a5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/xgboost/__init__.py @@ -0,0 +1,11 @@ +"""W&B callback for xgboost. + +Simple callback to get logging for each tree + +Use the `wandb_callback` to add `wandb` logging to any `XGboost` model. However, it will +be deprecated in favor of WandbCallback. Use it instead for more features. +""" + +from .xgboost import WandbCallback, wandb_callback + +__all__ = ["wandb_callback", "WandbCallback"] diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/xgboost/xgboost.py b/.venv/lib/python3.13/site-packages/wandb/integration/xgboost/xgboost.py new file mode 100644 index 0000000000000000000000000000000000000000..e01f23a087e87b376c2fdb8cf31b65e51f7cd789 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/xgboost/xgboost.py @@ -0,0 +1,205 @@ +"""xgboost init!""" + +from __future__ import annotations + +import json +import warnings +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Dict, List, NamedTuple, Tuple, Union, cast + +import xgboost as xgb +import xgboost.callback +from typing_extensions import TypeAlias, override + +import wandb +from wandb.sdk.lib import telemetry as wb_telemetry + +MINIMIZE_METRICS = [ + "rmse", + "rmsle", + "mae", + "mape", + "mphe", + "logloss", + "error", + "error@t", + "merror", +] + +MAXIMIZE_METRICS = ["auc", "aucpr", "ndcg", "map", "ndcg@n", "map@n"] + + +if TYPE_CHECKING: + + class CallbackEnv(NamedTuple): + evaluation_result_list: list + + # Copied from xgboost's source code. These types are not exported. + _ScoreList = Union[List[float], List[Tuple[float, float]]] + _EvalsLog: TypeAlias = Dict[str, Dict[str, _ScoreList]] + + +def wandb_callback() -> Callable: + """Old style callback that will be deprecated in favor of WandbCallback. Please try the new logger for more features.""" + warnings.warn( + "wandb_callback will be deprecated in favor of WandbCallback. Please use WandbCallback for more features.", + UserWarning, + stacklevel=2, + ) + + with wb_telemetry.context() as tel: + tel.feature.xgboost_old_wandb_callback = True + + def callback(env: CallbackEnv) -> None: + for k, v in env.evaluation_result_list: + wandb.log({k: v}, commit=False) + wandb.log({}) + + return callback + + +class WandbCallback(xgboost.callback.TrainingCallback): + """`WandbCallback` automatically integrates XGBoost with wandb. + + Args: + log_model: (boolean) if True save and upload the model to Weights & Biases Artifacts + log_feature_importance: (boolean) if True log a feature importance bar plot + importance_type: (str) one of {weight, gain, cover, total_gain, total_cover} for tree model. weight for linear model. + define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary`. + + Passing `WandbCallback` to XGBoost will: + + - log the booster model configuration to Weights & Biases + - log evaluation metrics collected by XGBoost, such as rmse, accuracy etc. to Weights & Biases + - log training metric collected by XGBoost (if you provide training data to eval_set) + - log the best score and the best iteration + - save and upload your trained model to Weights & Biases Artifacts (when `log_model = True`) + - log feature importance plot when `log_feature_importance=True` (default). + - Capture the best eval metric in `wandb.summary` when `define_metric=True` (default). + + Example: + ```python + bst_params = dict( + objective="reg:squarederror", + colsample_bytree=0.3, + learning_rate=0.1, + max_depth=5, + alpha=10, + n_estimators=10, + tree_method="hist", + callbacks=[WandbCallback()], + ) + + xg_reg = xgb.XGBRegressor(**bst_params) + xg_reg.fit( + X_train, + y_train, + eval_set=[(X_test, y_test)], + ) + ``` + """ + + def __init__( + self, + log_model: bool = False, + log_feature_importance: bool = True, + importance_type: str = "gain", + define_metric: bool = True, + ): + super().__init__() + + self.log_model: bool = log_model + self.log_feature_importance: bool = log_feature_importance + self.importance_type: str = importance_type + self.define_metric: bool = define_metric + + if wandb.run is None: + raise wandb.Error("You must call wandb.init() before WandbCallback()") + + with wb_telemetry.context() as tel: + tel.feature.xgboost_wandb_callback = True + + @override + def before_training(self, model: xgb.Booster) -> xgb.Booster: + """Run before training is finished.""" + # Update W&B config + config = model.save_config() + wandb.config.update(json.loads(config)) + + return model + + @override + def after_training(self, model: xgb.Booster) -> xgb.Booster: + """Run after training is finished.""" + # Log the booster model as artifacts + if self.log_model: + self._log_model_as_artifact(model) + + # Plot feature importance + if self.log_feature_importance: + self._log_feature_importance(model) + + # Log the best score and best iteration + if model.attr("best_score") is not None: + wandb.log( + { + "best_score": float(cast(str, model.attr("best_score"))), + "best_iteration": int(cast(str, model.attr("best_iteration"))), + } + ) + + return model + + @override + def after_iteration( + self, + model: xgb.Booster, + epoch: int, + evals_log: _EvalsLog, + ) -> bool: + """Run after each iteration. Return True when training should stop.""" + # Log metrics + for data, metric in evals_log.items(): + for metric_name, log in metric.items(): + if self.define_metric: + self._define_metric(data, metric_name) + wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False) + else: + wandb.log({f"{data}-{metric_name}": log[-1]}, commit=False) + + wandb.log({"epoch": epoch}) + + self.define_metric = False + + return False + + def _log_model_as_artifact(self, model: xgb.Booster) -> None: + model_name = f"{wandb.run.id}_model.json" # type: ignore + model_path = Path(wandb.run.dir) / model_name # type: ignore + model.save_model(str(model_path)) + + model_artifact = wandb.Artifact(name=model_name, type="model") + model_artifact.add_file(str(model_path)) + wandb.log_artifact(model_artifact) + + def _log_feature_importance(self, model: xgb.Booster) -> None: + fi = model.get_score(importance_type=self.importance_type) + fi_data = [[k, fi[k]] for k in fi] + table = wandb.Table(data=fi_data, columns=["Feature", "Importance"]) + wandb.log( + { + "Feature Importance": wandb.plot.bar( + table, "Feature", "Importance", title="Feature Importance" + ) + } + ) + + def _define_metric(self, data: str, metric_name: str) -> None: + if "loss" in str.lower(metric_name): + wandb.define_metric(f"{data}-{metric_name}", summary="min") + elif str.lower(metric_name) in MINIMIZE_METRICS: + wandb.define_metric(f"{data}-{metric_name}", summary="min") + elif str.lower(metric_name) in MAXIMIZE_METRICS: + wandb.define_metric(f"{data}-{metric_name}", summary="max") + else: + pass diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/yolov8/__init__.py b/.venv/lib/python3.13/site-packages/wandb/integration/yolov8/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/integration/yolov8/yolov8.py b/.venv/lib/python3.13/site-packages/wandb/integration/yolov8/yolov8.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5dbd0fe8de39f387b6427b4696f28f9c029b44 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/integration/yolov8/yolov8.py @@ -0,0 +1,284 @@ +from typing import Any, Callable, Dict, List, Optional + +from ultralytics.yolo.engine.model import YOLO +from ultralytics.yolo.engine.trainer import BaseTrainer + +try: + from ultralytics.yolo.utils import RANK + from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params +except ModuleNotFoundError: + from ultralytics.utils import RANK + from ultralytics.utils.torch_utils import get_flops, get_num_params +from ultralytics.yolo.v8.classify.train import ClassificationTrainer + +import wandb +from wandb.sdk.lib import telemetry + + +class WandbCallback: + """An internal YOLO model wrapper that tracks metrics, and logs models to Weights & Biases. + + Usage: + ```python + from wandb.integration.yolov8.yolov8 import WandbCallback + + model = YOLO("yolov8n.pt") + wandb_logger = WandbCallback( + model, + ) + for event, callback_fn in wandb_logger.callbacks.items(): + model.add_callback(event, callback_fn) + ``` + """ + + def __init__( + self, + yolo: YOLO, + run_name: Optional[str] = None, + project: Optional[str] = None, + tags: Optional[List[str]] = None, + resume: Optional[str] = None, + **kwargs: Optional[Any], + ) -> None: + """A utility class to manage wandb run and various callbacks for the ultralytics YOLOv8 framework. + + Args: + yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO` + run_name, str: The name of the Weights & Biases run, defaults to an auto generated run_name if `trainer.args.name` is not defined. + project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined. + tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`. + resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`. + **kwargs: Additional arguments to be passed to `wandb.init()`. + """ + self.yolo = yolo + self.run_name = run_name + self.project = project + self.tags = tags + self.resume = resume + self.kwargs = kwargs + + def on_pretrain_routine_start(self, trainer: BaseTrainer) -> None: + """Starts a new wandb run to track the training process and log to Weights & Biases. + + Args: + trainer: A task trainer that's inherited from `:class:ultralytics.yolo.engine.trainer.BaseTrainer` + that contains the model training and optimization routine. + """ + if wandb.run is None: + self.run = wandb.init( + name=self.run_name if self.run_name else trainer.args.name, + project=self.project + if self.project + else trainer.args.project or "YOLOv8", + tags=self.tags if self.tags else ["YOLOv8"], + config=vars(trainer.args), + resume=self.resume if self.resume else None, + **self.kwargs, + ) + else: + self.run = wandb.run + assert self.run is not None + self.run.define_metric("epoch", hidden=True) + self.run.define_metric( + "train/*", step_metric="epoch", step_sync=True, summary="min" + ) + + self.run.define_metric( + "val/*", step_metric="epoch", step_sync=True, summary="min" + ) + + self.run.define_metric( + "metrics/*", step_metric="epoch", step_sync=True, summary="max" + ) + self.run.define_metric( + "lr/*", step_metric="epoch", step_sync=True, summary="last" + ) + + with telemetry.context(run=wandb.run) as tel: + tel.feature.ultralytics_yolov8 = True + + def on_pretrain_routine_end(self, trainer: BaseTrainer) -> None: + assert self.run is not None + self.run.summary.update( + { + "model/parameters": get_num_params(trainer.model), + "model/GFLOPs": round(get_flops(trainer.model), 3), + } + ) + + def on_train_epoch_start(self, trainer: BaseTrainer) -> None: + """On train epoch start we only log epoch number to the Weights & Biases run.""" + # We log the epoch number here to commit the previous step, + assert self.run is not None + self.run.log({"epoch": trainer.epoch + 1}) + + def on_train_epoch_end(self, trainer: BaseTrainer) -> None: + """On train epoch end we log all the metrics to the Weights & Biases run.""" + assert self.run is not None + self.run.log( + { + **trainer.metrics, + **trainer.label_loss_items(trainer.tloss, prefix="train"), + **trainer.lr, + }, + ) + # Currently only the detection and segmentation trainers save images to the save_dir + if not isinstance(trainer, ClassificationTrainer): + self.run.log( + { + "train_batch_images": [ + wandb.Image(str(image_path), caption=image_path.stem) + for image_path in trainer.save_dir.glob("train_batch*.jpg") + ] + } + ) + + def on_fit_epoch_end(self, trainer: BaseTrainer) -> None: + """On fit epoch end we log all the best metrics and model detail to Weights & Biases run summary.""" + assert self.run is not None + if trainer.epoch == 0: + speeds = [ + trainer.validator.speed.get( + key, + ) + for key in (1, "inference") + ] + speed = speeds[0] if speeds[0] else speeds[1] + if speed: + self.run.summary.update( + { + "model/speed(ms/img)": round(speed, 3), + } + ) + if trainer.best_fitness == trainer.fitness: + self.run.summary.update( + { + "best/epoch": trainer.epoch + 1, + **{f"best/{key}": val for key, val in trainer.metrics.items()}, + } + ) + + def on_train_end(self, trainer: BaseTrainer) -> None: + """On train end we log all the media, including plots, images and best model artifact to Weights & Biases.""" + # Currently only the detection and segmentation trainers save images to the save_dir + assert self.run is not None + if not isinstance(trainer, ClassificationTrainer): + assert self.run is not None + self.run.log( + { + "plots": [ + wandb.Image(str(image_path), caption=image_path.stem) + for image_path in trainer.save_dir.glob("*.png") + ], + "val_images": [ + wandb.Image(str(image_path), caption=image_path.stem) + for image_path in trainer.validator.save_dir.glob("val*.jpg") + ], + }, + ) + + if trainer.best.exists(): + assert self.run is not None + self.run.log_artifact( + str(trainer.best), + type="model", + name=f"{self.run.name}_{trainer.args.task}.pt", + aliases=["best", f"epoch_{trainer.epoch + 1}"], + ) + + def on_model_save(self, trainer: BaseTrainer) -> None: + """On model save we log the model as an artifact to Weights & Biases.""" + assert self.run is not None + self.run.log_artifact( + str(trainer.last), + type="model", + name=f"{self.run.name}_{trainer.args.task}.pt", + aliases=["last", f"epoch_{trainer.epoch + 1}"], + ) + + def teardown(self, _trainer: BaseTrainer) -> None: + """On teardown, we finish the Weights & Biases run and set it to None.""" + assert self.run is not None + self.run.finish() + self.run = None + + @property + def callbacks( + self, + ) -> Dict[str, Callable]: + """Property contains all the relevant callbacks to add to the YOLO model for the Weights & Biases logging.""" + return { + "on_pretrain_routine_start": self.on_pretrain_routine_start, + "on_pretrain_routine_end": self.on_pretrain_routine_end, + "on_train_epoch_start": self.on_train_epoch_start, + "on_train_epoch_end": self.on_train_epoch_end, + "on_fit_epoch_end": self.on_fit_epoch_end, + "on_train_end": self.on_train_end, + "on_model_save": self.on_model_save, + "teardown": self.teardown, + } + + +def add_callbacks( + yolo: YOLO, + run_name: Optional[str] = None, + project: Optional[str] = None, + tags: Optional[List[str]] = None, + resume: Optional[str] = None, + **kwargs: Optional[Any], +) -> YOLO: + """A YOLO model wrapper that tracks metrics, and logs models to Weights & Biases. + + Args: + yolo: A YOLOv8 model that's inherited from `:class:ultralytics.yolo.engine.model.YOLO` + run_name, str: The name of the Weights & Biases run, defaults to an auto generated name if `trainer.args.name` is not defined. + project, str: The name of the Weights & Biases project, defaults to `"YOLOv8"` if `trainer.args.project` is not defined. + tags, List[str]: A list of tags to be added to the Weights & Biases run, defaults to `["YOLOv8"]`. + resume, str: Whether to resume a previous run on Weights & Biases, defaults to `None`. + **kwargs: Additional arguments to be passed to `wandb.init()`. + + Usage: + ```python + from wandb.integration.yolov8 import add_callbacks as add_wandb_callbacks + + model = YOLO("yolov8n.pt") + add_wandb_callbacks( + model, + ) + model.train( + data="coco128.yaml", + epochs=3, + imgsz=640, + ) + ``` + """ + wandb.termwarn( + """The wandb callback is currently in beta and is subject to change based on updates to `ultralytics yolov8`. + The callback is tested and supported for ultralytics v8.0.43 and above. + Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`. + """, + repeat=False, + ) + wandb.termwarn( + """This wandb callback is no longer functional and would be deprecated in the near future. + We recommend you to use the updated callback using `from wandb.integration.ultralytics import add_wandb_callback`. + The updated callback is tested and supported for ultralytics 8.0.167 and above. + You can refer to https://docs.wandb.ai/guides/integrations/ultralytics for the updated documentation. + Please report any issues to https://github.com/wandb/wandb/issues with the tag `yolov8`. + """, + repeat=False, + ) + + if RANK in [-1, 0]: + wandb_logger = WandbCallback( + yolo, run_name=run_name, project=project, tags=tags, resume=resume, **kwargs + ) + for event, callback_fn in wandb_logger.callbacks.items(): + yolo.add_callback(event, callback_fn) + return yolo + else: + wandb.termerror( + "The RANK of the process to add the callbacks was neither 0 or -1." + "No Weights & Biases callbacks were added to this instance of the YOLO model." + ) + return yolo diff --git a/.venv/lib/python3.13/site-packages/wandb/mpmain/__init__.py b/.venv/lib/python3.13/site-packages/wandb/mpmain/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/mpmain/__main__.py b/.venv/lib/python3.13/site-packages/wandb/mpmain/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..02e8a24d9fbce482e31cde126d8407586a61b3fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/mpmain/__main__.py @@ -0,0 +1 @@ +# This module is initialized after multiprocessing spawn diff --git a/.venv/lib/python3.13/site-packages/wandb/old/__init__.py b/.venv/lib/python3.13/site-packages/wandb/old/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/old/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/old/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dd804df35c002630dd8ca8859dd2f7aecd260c4 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/old/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/old/__pycache__/core.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/old/__pycache__/core.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3595a00199e7ae8ac542497e9c88b2d082b2ac6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/old/__pycache__/core.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/old/core.py b/.venv/lib/python3.13/site-packages/wandb/old/core.py new file mode 100644 index 0000000000000000000000000000000000000000..15d6d5a410a9a2d843616d09b8cfdc0571b20fc0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/old/core.py @@ -0,0 +1,53 @@ +"""Core variables, functions, and classes that we want in the wandb +module but are also used in modules that import the wandb module. + +The purpose of this module is to break circular imports. +""" + +import os +import tempfile +import time + +import wandb +from wandb import env + +# We use the hidden version if it already exists, otherwise non-hidden. +if os.path.exists(os.path.join(env.get_dir(os.getcwd()), ".wandb")): + __stage_dir__ = ".wandb" + os.sep +elif os.path.exists(os.path.join(env.get_dir(os.getcwd()), "wandb")): + __stage_dir__ = "wandb" + os.sep +else: + __stage_dir__ = None + +wandb.START_TIME = time.time() + + +def wandb_dir(root_dir=None): + if root_dir is None or root_dir == "": + try: + cwd = os.getcwd() + except OSError: + wandb.termwarn("os.getcwd() no longer exists, using system temp directory") + cwd = tempfile.gettempdir() + root_dir = env.get_dir(cwd) + path = os.path.join(root_dir, __stage_dir__ or ("wandb" + os.sep)) + if not os.access(root_dir, os.W_OK): + wandb.termwarn( + f"Path {path} wasn't writable, using system temp directory", repeat=False + ) + path = os.path.join(tempfile.gettempdir(), __stage_dir__ or ("wandb" + os.sep)) + return path + + +def _set_stage_dir(stage_dir): + # Used when initing a new project with "wandb init" + global __stage_dir__ + __stage_dir__ = stage_dir + + +__all__ = [ + "__stage_dir__", + "START_TIME", + "wandb_dir", + "_set_stage_dir", +] diff --git a/.venv/lib/python3.13/site-packages/wandb/old/summary.py b/.venv/lib/python3.13/site-packages/wandb/old/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5eba49534dc705846f6dd766a7c720c475d57b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/old/summary.py @@ -0,0 +1,438 @@ +import json +import os +import time + +from wandb_gql import gql + +import wandb +from wandb import util +from wandb.apis.internal import Api +from wandb.sdk.data_types.utils import val_to_json +from wandb.sdk.lib import filenames + +DEEP_SUMMARY_FNAME = "wandb.h5" +H5_TYPES = ("numpy.ndarray", "tensorflow.Tensor", "torch.Tensor") +h5py = util.get_module("h5py") +np = util.get_module("numpy") + + +class SummarySubDict: + """Nested dict-like object that proxies read and write operations through a root object. + + This lets us do synchronous serialization and lazy loading of large values. + """ + + def __init__(self, root=None, path=()): + self._path = tuple(path) + if root is None: + self._root = self + self._json_dict = {} + else: + self._root = root + json_dict = root._json_dict + for k in path: + json_dict = json_dict.get(k, {}) + + self._json_dict = json_dict + self._dict = {} + + # We use this to track which keys the user has set explicitly + # so that we don't automatically overwrite them when we update + # the summary from the history. + self._locked_keys = set() + + def __setattr__(self, k, v): + k = k.strip() + if k.startswith("_"): + object.__setattr__(self, k, v) + else: + self[k] = v + + def __getattr__(self, k): + k = k.strip() + if k.startswith("_"): + return object.__getattribute__(self, k) + else: + return self[k] + + def _root_get(self, path, child_dict): + """Load a value at a particular path from the root. + + This should only be implemented by the "_root" child class. + + We pass the child_dict so the item can be set on it or not as + appropriate. Returning None for a nonexistent path wouldn't be + distinguishable from that path being set to the value None. + """ + raise NotImplementedError + + def _root_set(self, path, new_keys_values): + """Set a value at a particular path in the root. + + This should only be implemented by the "_root" child class. + """ + raise NotImplementedError + + def _root_del(self, path): + """Delete a value at a particular path in the root. + + This should only be implemented by the "_root" child class. + """ + raise NotImplementedError + + def _write(self, commit=False): + # should only be implemented on the root summary + raise NotImplementedError + + def keys(self): + # _json_dict has the full set of keys, including those for h5 objects + # that may not have been loaded yet + return self._json_dict.keys() + + def get(self, k, default=None): + if isinstance(k, str): + k = k.strip() + if k not in self._dict: + self._root._root_get(self._path + (k,), self._dict) + return self._dict.get(k, default) + + def items(self): + # not all items may be loaded into self._dict, so we + # have to build the sequence of items from scratch + for k in self.keys(): + yield k, self[k] + + def __getitem__(self, k): + if isinstance(k, str): + k = k.strip() + + self.get(k) # load the value into _dict if it should be there + res = self._dict[k] + + return res + + def __contains__(self, k): + if isinstance(k, str): + k = k.strip() + + return k in self._json_dict + + def __setitem__(self, k, v): + if isinstance(k, str): + k = k.strip() + + path = self._path + + if isinstance(v, dict): + self._dict[k] = SummarySubDict(self._root, path + (k,)) + self._root._root_set(path, [(k, {})]) + self._dict[k].update(v) + else: + self._dict[k] = v + self._root._root_set(path, [(k, v)]) + + self._locked_keys.add(k) + + self._root._write() + + return v + + def __delitem__(self, k): + k = k.strip() + del self._dict[k] + self._root._root_del(self._path + (k,)) + + self._root._write() + + def __repr__(self): + # use a copy of _dict, except add placeholders for h5 objects, etc. + # that haven't been loaded yet + repr_dict = dict(self._dict) + for k in self._json_dict: + v = self._json_dict[k] + if ( + k not in repr_dict + and isinstance(v, dict) + and v.get("_type") in H5_TYPES + ): + # unloaded h5 objects may be very large. use a placeholder for them + # if we haven't already loaded them + repr_dict[k] = "..." + else: + repr_dict[k] = self[k] + + return repr(repr_dict) + + def update(self, key_vals=None, overwrite=True): + """Locked keys will be overwritten unless overwrite=False. + + Otherwise, written keys will be added to the "locked" list. + """ + if key_vals: + write_items = self._update(key_vals, overwrite) + self._root._root_set(self._path, write_items) + self._root._write(commit=True) + + def _update(self, key_vals, overwrite): + if not key_vals: + return + key_vals = {k.strip(): v for k, v in key_vals.items()} + if overwrite: + write_items = list(key_vals.items()) + self._locked_keys.update(key_vals.keys()) + else: + write_keys = set(key_vals.keys()) - self._locked_keys + write_items = [(k, key_vals[k]) for k in write_keys] + + for key, value in write_items: + if isinstance(value, dict): + self._dict[key] = SummarySubDict(self._root, self._path + (key,)) + self._dict[key]._update(value, overwrite) + else: + self._dict[key] = value + + return write_items + + +class Summary(SummarySubDict): + """Store summary metrics (eg. accuracy) during and after a run. + + You can manipulate this as if it's a Python dictionary but the keys + get mangled. .strip() is called on them, so spaces at the beginning + and end are removed. + """ + + def __init__(self, run, summary=None): + super().__init__() + self._run = run + self._h5_path = os.path.join(self._run.dir, DEEP_SUMMARY_FNAME) + # Lazy load the h5 file + self._h5 = None + + # Mirrored version of self._dict with versions of values that get written + # to JSON kept up to date by self._root_set() and self._root_del(). + self._json_dict = {} + + if summary is not None: + self._json_dict = summary + + def _json_get(self, path): + pass + + def _root_get(self, path, child_dict): + json_dict = self._json_dict + for key in path[:-1]: + json_dict = json_dict[key] + + key = path[-1] + if key in json_dict: + child_dict[key] = self._decode(path, json_dict[key]) + + def _root_del(self, path): + json_dict = self._json_dict + for key in path[:-1]: + json_dict = json_dict[key] + + val = json_dict[path[-1]] + del json_dict[path[-1]] + if isinstance(val, dict) and val.get("_type") in H5_TYPES: + if not h5py: + wandb.termerror("Deleting tensors in summary requires h5py") + else: + self.open_h5() + h5_key = "summary/" + ".".join(path) + del self._h5[h5_key] + self._h5.flush() + + def _root_set(self, path, new_keys_values): + json_dict = self._json_dict + for key in path: + json_dict = json_dict[key] + + for new_key, new_value in new_keys_values: + json_dict[new_key] = self._encode(new_value, path + (new_key,)) + + def write_h5(self, path, val): + # ensure the file is open + self.open_h5() + + if not self._h5: + wandb.termerror("Storing tensors in summary requires h5py") + else: + try: + del self._h5["summary/" + ".".join(path)] + except KeyError: + pass + self._h5["summary/" + ".".join(path)] = val + self._h5.flush() + + def read_h5(self, path, val=None): + # ensure the file is open + self.open_h5() + + if not self._h5: + wandb.termerror("Reading tensors from summary requires h5py") + else: + return self._h5.get("summary/" + ".".join(path), val) + + def open_h5(self): + if not self._h5 and h5py: + self._h5 = h5py.File(self._h5_path, "a", libver="latest") + + def _decode(self, path, json_value): + """Decode a `dict` encoded by `Summary._encode()`, loading h5 objects. + + h5 objects may be very large, so we won't have loaded them automatically. + """ + if isinstance(json_value, dict): + if json_value.get("_type") in H5_TYPES: + return self.read_h5(path, json_value) + elif json_value.get("_type") == "data-frame": + wandb.termerror( + "This data frame was saved via the wandb data API. Contact support@wandb.com for help." + ) + return None + # TODO: transform wandb objects and plots + else: + return SummarySubDict(self, path) + else: + return json_value + + def _encode(self, value, path_from_root): + """Normalize, compress, and encode sub-objects for backend storage. + + value: Object to encode. + path_from_root: `tuple` of key strings from the top-level summary to the + current `value`. + + Returns: + A new tree of dict's with large objects replaced with dictionaries + with "_type" entries that say which type the original data was. + """ + + # Constructs a new `dict` tree in `json_value` that discards and/or + # encodes objects that aren't JSON serializable. + + if isinstance(value, dict): + json_value = {} + for key, value in value.items(): + json_value[key] = self._encode(value, path_from_root + (key,)) + return json_value + else: + path = ".".join(path_from_root) + friendly_value, converted = util.json_friendly( + val_to_json(self._run, path, value, namespace="summary") + ) + json_value, compressed = util.maybe_compress_summary( + friendly_value, util.get_h5_typename(value) + ) + if compressed: + self.write_h5(path_from_root, friendly_value) + + return json_value + + +def download_h5(run_id, entity=None, project=None, out_dir=None): + api = Api() + meta = api.download_url( + project or api.settings("project"), + DEEP_SUMMARY_FNAME, + entity=entity or api.settings("entity"), + run=run_id, + ) + if meta and "md5" in meta and meta["md5"] is not None: + # TODO: make this non-blocking + wandb.termlog("Downloading summary data...") + path, res = api.download_write_file(meta, out_dir=out_dir) + return path + + +def upload_h5(file, run_id, entity=None, project=None): + api = Api() + wandb.termlog("Uploading summary data...") + with open(file, "rb") as f: + api.push( + {os.path.basename(file): f}, run=run_id, project=project, entity=entity + ) + + +class FileSummary(Summary): + def __init__(self, run): + super().__init__(run) + self._fname = os.path.join(run.dir, filenames.SUMMARY_FNAME) + self.load() + + def load(self): + try: + with open(self._fname) as f: + self._json_dict = json.load(f) + except (OSError, ValueError): + self._json_dict = {} + + def _write(self, commit=False): + # TODO: we just ignore commit to ensure backward capability + with open(self._fname, "w") as f: + f.write(util.json_dumps_safer(self._json_dict)) + f.write("\n") + f.flush() + os.fsync(f.fileno()) + if self._h5: + self._h5.close() + self._h5 = None + + +class HTTPSummary(Summary): + def __init__(self, run, client, summary=None): + super().__init__(run, summary=summary) + self._run = run + self._client = client + self._started = time.time() + + def __delitem__(self, key): + if key not in self._json_dict: + raise KeyError(key) + del self._json_dict[key] + + def load(self): + pass + + def open_h5(self): + if not self._h5 and h5py: + download_h5( + self._run.id, + entity=self._run.entity, + project=self._run.project, + out_dir=self._run.dir, + ) + super().open_h5() + + def _write(self, commit=False): + mutation = gql( + """ + mutation UpsertBucket( $id: String, $summaryMetrics: JSONString) { + upsertBucket(input: { id: $id, summaryMetrics: $summaryMetrics}) { + bucket { id } + } + } + """ + ) + if commit: + if self._h5: + self._h5.close() + self._h5 = None + res = self._client.execute( + mutation, + variable_values={ + "id": self._run.storage_id, + "summaryMetrics": util.json_dumps_safer(self._json_dict), + }, + ) + assert res["upsertBucket"]["bucket"]["id"] + entity, project, run = self._run.path + if ( + os.path.exists(self._h5_path) + and os.path.getmtime(self._h5_path) >= self._started + ): + upload_h5(self._h5_path, run, entity=entity, project=project) + else: + return False diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__init__.py b/.venv/lib/python3.13/site-packages/wandb/plot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c49e40ac3ce6a7ffccd97c553a013bdd97fd305b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/__init__.py @@ -0,0 +1,30 @@ +"""Chart Visualization Utilities + +This module offers a collection of predefined chart types, along with functionality +for creating custom charts, enabling flexible visualization of your data beyond the +built-in options. +""" + +__all__ = [ + "line", + "histogram", + "scatter", + "bar", + "roc_curve", + "pr_curve", + "confusion_matrix", + "line_series", + "plot_table", + "visualize", # doc:exclude +] + +from wandb.plot.bar import bar +from wandb.plot.confusion_matrix import confusion_matrix +from wandb.plot.custom_chart import CustomChart, plot_table +from wandb.plot.histogram import histogram +from wandb.plot.line import line +from wandb.plot.line_series import line_series +from wandb.plot.pr_curve import pr_curve +from wandb.plot.roc_curve import roc_curve +from wandb.plot.scatter import scatter +from wandb.plot.viz import Visualize, visualize diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a05f2861dfbcd8b97e189f6cc4d7b3a2fe6e8cb6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/bar.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/bar.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed9e446b2e8c0f5bbb73d08107ef9b9a2bbe5f57 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/bar.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/confusion_matrix.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/confusion_matrix.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ee15476e0a57da658cd3bda3bb02da2471e10ea Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/confusion_matrix.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/custom_chart.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/custom_chart.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6b85b85b5c486c1ec7519978c187a228a87f867 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/custom_chart.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/histogram.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/histogram.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84745efb42f30900e6a7cb91ffc65fc708a0071 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/histogram.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/line.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/line.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7b34377aa5219a9311ee3fdb8854a48c99743ec Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/line.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/line_series.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/line_series.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f899d7846d4850171b4b81dd0bd5a8b3a0decf85 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/line_series.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/pr_curve.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/pr_curve.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc45d4cd4da53c641b0b6db55473a423d1f43c34 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/pr_curve.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/roc_curve.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/roc_curve.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3633f569183a988ec10738f84d5a6757ae004d7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/roc_curve.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/scatter.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/scatter.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32c75f94f0d370af168ce463ee85bb127f9c7fe6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/scatter.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/utils.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/utils.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a9e67a89dee3c01e3f951251269c505c42a10dc Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/utils.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/viz.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/viz.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07361b78d4a7e393894022e5d4c6c9e35315a66b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/plot/__pycache__/viz.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/bar.py b/.venv/lib/python3.13/site-packages/wandb/plot/bar.py new file mode 100644 index 0000000000000000000000000000000000000000..21e6496df9e46a00204b78abb4c599ca0430a3c1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/bar.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from wandb.plot.custom_chart import plot_table + +if TYPE_CHECKING: + import wandb + from wandb.plot.custom_chart import CustomChart + + +def bar( + table: wandb.Table, + label: str, + value: str, + title: str = "", + split_table: bool = False, +) -> CustomChart: + """Constructs a bar chart from a wandb.Table of data. + + Args: + table: A table containing the data for the bar chart. + label: The name of the column to use for the labels of each bar. + value: The name of the column to use for the values of each bar. + title: The title of the bar chart. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Example: + + ```python + import random + import wandb + + # Generate random data for the table + data = [ + ["car", random.uniform(0, 1)], + ["bus", random.uniform(0, 1)], + ["road", random.uniform(0, 1)], + ["person", random.uniform(0, 1)], + ] + + # Create a table with the data + table = wandb.Table(data=data, columns=["class", "accuracy"]) + + # Initialize a W&B run and log the bar plot + with wandb.init(project="bar_chart") as run: + # Create a bar plot from the table + bar_plot = wandb.plot.bar( + table=table, + label="class", + value="accuracy", + title="Object Classification Accuracy", + ) + + # Log the bar chart to W&B + run.log({"bar_plot": bar_plot}) + ``` + """ + return plot_table( + data_table=table, + vega_spec_name="wandb/bar/v0", + fields={"label": label, "value": value}, + string_fields={"title": title}, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/confusion_matrix.py b/.venv/lib/python3.13/site-packages/wandb/plot/confusion_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..a469632fd9efe8a8fa2e2d17a0d46ddbbfbc9736 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/confusion_matrix.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence, TypeVar + +import wandb +from wandb import util +from wandb.plot.custom_chart import plot_table + +if TYPE_CHECKING: + from wandb.plot.custom_chart import CustomChart + +T = TypeVar("T") + + +def confusion_matrix( + probs: Sequence[Sequence[float]] | None = None, + y_true: Sequence[T] | None = None, + preds: Sequence[T] | None = None, + class_names: Sequence[str] | None = None, + title: str = "Confusion Matrix Curve", + split_table: bool = False, +) -> CustomChart: + """Constructs a confusion matrix from a sequence of probabilities or predictions. + + Args: + probs: A sequence of predicted probabilities for each + class. The sequence shape should be (N, K) where N is the number of samples + and K is the number of classes. If provided, `preds` should not be provided. + y_true: A sequence of true labels. + preds: A sequence of predicted class labels. If provided, + `probs` should not be provided. + class_names: Sequence of class names. If not + provided, class names will be defined as "Class_1", "Class_2", etc. + title: Title of the confusion matrix chart. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Raises: + ValueError: If both `probs` and `preds` are provided or if the number of + predictions and true labels are not equal. If the number of unique + predicted classes exceeds the number of class names or if the number of + unique true labels exceeds the number of class names. + wandb.Error: If numpy is not installed. + + Examples: + Logging a confusion matrix with random probabilities for wildlife + classification: + + ```python + import numpy as np + import wandb + + # Define class names for wildlife + wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"] + + # Generate random true labels (0 to 3 for 10 samples) + wildlife_y_true = np.random.randint(0, 4, size=10) + + # Generate random probabilities for each class (10 samples x 4 classes) + wildlife_probs = np.random.rand(10, 4) + wildlife_probs = np.exp(wildlife_probs) / np.sum( + np.exp(wildlife_probs), + axis=1, + keepdims=True, + ) + + # Initialize W&B run and log confusion matrix + with wandb.init(project="wildlife_classification") as run: + confusion_matrix = wandb.plot.confusion_matrix( + probs=wildlife_probs, + y_true=wildlife_y_true, + class_names=wildlife_class_names, + title="Wildlife Classification Confusion Matrix", + ) + run.log({"wildlife_confusion_matrix": confusion_matrix}) + ``` + + In this example, random probabilities are used to generate a confusion + matrix. + + Logging a confusion matrix with simulated model predictions and 85% + accuracy: + + ```python + import numpy as np + import wandb + + # Define class names for wildlife + wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"] + + # Simulate true labels for 200 animal images (imbalanced distribution) + wildlife_y_true = np.random.choice( + [0, 1, 2, 3], + size=200, + p=[0.2, 0.3, 0.25, 0.25], + ) + + # Simulate model predictions with 85% accuracy + wildlife_preds = [ + y_t + if np.random.rand() < 0.85 + else np.random.choice([x for x in range(4) if x != y_t]) + for y_t in wildlife_y_true + ] + + # Initialize W&B run and log confusion matrix + with wandb.init(project="wildlife_classification") as run: + confusion_matrix = wandb.plot.confusion_matrix( + preds=wildlife_preds, + y_true=wildlife_y_true, + class_names=wildlife_class_names, + title="Simulated Wildlife Classification Confusion Matrix", + ) + run.log({"wildlife_confusion_matrix": confusion_matrix}) + ``` + + In this example, predictions are simulated with 85% accuracy to generate a + confusion matrix. + """ + np = util.get_module( + "numpy", + required=( + "numpy is required to use wandb.plot.confusion_matrix, " + "install with `pip install numpy`", + ), + ) + + if probs is not None and preds is not None: + raise ValueError("Only one of `probs` or `preds` should be provided, not both.") + + if probs is not None: + preds = np.argmax(probs, axis=1).tolist() + + if len(preds) != len(y_true): + raise ValueError("The number of predictions and true labels must be equal.") + + if class_names is not None: + n_classes = len(class_names) + class_idx = list(range(n_classes)) + if len(set(preds)) > len(class_names): + raise ValueError( + "The number of unique predicted classes exceeds the number of class names." + ) + + if len(set(y_true)) > len(class_names): + raise ValueError( + "The number of unique true labels exceeds the number of class names." + ) + else: + class_idx = set(preds).union(set(y_true)) + n_classes = len(class_idx) + class_names = [f"Class_{i + 1}" for i in range(n_classes)] + + # Create a mapping from class name to index + class_mapping = {val: i for i, val in enumerate(sorted(list(class_idx)))} + + counts = np.zeros((n_classes, n_classes)) + for i in range(len(preds)): + counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1 + + data = [ + [class_names[i], class_names[j], counts[i, j]] + for i in range(n_classes) + for j in range(n_classes) + ] + + return plot_table( + data_table=wandb.Table( + columns=["Actual", "Predicted", "nPredictions"], + data=data, + ), + vega_spec_name="wandb/confusion_matrix/v1", + fields={ + "Actual": "Actual", + "Predicted": "Predicted", + "nPredictions": "nPredictions", + }, + string_fields={"title": title}, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/custom_chart.py b/.venv/lib/python3.13/site-packages/wandb/plot/custom_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..62cd8784cdc63b2af264cd51ae7ebb80e7bddbbc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/custom_chart.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import wandb + + +@dataclass +class CustomChartSpec: + spec_name: str + fields: dict[str, Any] + string_fields: dict[str, Any] + key: str = "" + panel_type: str = "Vega2" + split_table: bool = False + + @property + def table_key(self) -> str: + if not self.key: + raise wandb.Error("Key for the custom chart spec is not set.") + if self.split_table: + return f"Custom Chart Tables/{self.key}_table" + return f"{self.key}_table" + + @property + def config_value(self) -> dict[str, Any]: + return { + "panel_type": self.panel_type, + "panel_config": { + "panelDefId": self.spec_name, + "fieldSettings": self.fields, + "stringSettings": self.string_fields, + "transform": {"name": "tableWithLeafColNames"}, + "userQuery": { + "queryFields": [ + { + "name": "runSets", + "args": [{"name": "runSets", "value": "${runSets}"}], + "fields": [ + {"name": "id", "fields": []}, + {"name": "name", "fields": []}, + {"name": "_defaultColorIndex", "fields": []}, + { + "name": "summaryTable", + "args": [ + { + "name": "tableKey", + "value": self.table_key, + } + ], + "fields": [], + }, + ], + } + ], + }, + }, + } + + @property + def config_key(self) -> tuple[str, str, str]: + return ("_wandb", "visualize", self.key) + + +@dataclass +class CustomChart: + table: wandb.Table + spec: CustomChartSpec + + def set_key(self, key: str): + """Sets the key for the spec and updates dependent configurations.""" + self.spec.key = key + + +def plot_table( + vega_spec_name: str, + data_table: wandb.Table, + fields: dict[str, Any], + string_fields: dict[str, Any] | None = None, + split_table: bool = False, +) -> CustomChart: + """Creates a custom charts using a Vega-Lite specification and a `wandb.Table`. + + This function creates a custom chart based on a Vega-Lite specification and + a data table represented by a `wandb.Table` object. The specification needs + to be predefined and stored in the W&B backend. The function returns a custom + chart object that can be logged to W&B using `wandb.Run.log()`. + + Args: + vega_spec_name: The name or identifier of the Vega-Lite spec + that defines the visualization structure. + data_table: A `wandb.Table` object containing the data to be + visualized. + fields: A mapping between the fields in the Vega-Lite spec and the + corresponding columns in the data table to be visualized. + string_fields: A dictionary for providing values for any string constants + required by the custom visualization. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass the chart object as argument to `wandb.Run.log()`. + + Raises: + wandb.Error: If `data_table` is not a `wandb.Table` object. + + Example: + ```python + # Create a custom chart using a Vega-Lite spec and the data table. + import wandb + + data = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]] + table = wandb.Table(data=data, columns=["x", "y"]) + fields = {"x": "x", "y": "y", "title": "MY TITLE"} + + with wandb.init() as run: + # Training code goes here + + # Create a custom title with `string_fields`. + my_custom_chart = wandb.plot_table( + vega_spec_name="wandb/line/v0", + data_table=table, + fields=fields, + string_fields={"title": "Title"}, + ) + + run.log({"custom_chart": my_custom_chart}) + ``` + """ + + if not isinstance(data_table, wandb.Table): + raise wandb.Error( + f"Expected `data_table` to be `wandb.Table` type, instead got {type(data_table).__name__}" + ) + + return CustomChart( + table=data_table, + spec=CustomChartSpec( + spec_name=vega_spec_name, + fields=fields, + string_fields=string_fields or {}, + split_table=split_table, + ), + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/histogram.py b/.venv/lib/python3.13/site-packages/wandb/plot/histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..b62cb1142d94cbe59101d6ded8cc6a523c05d769 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/histogram.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from wandb.plot.custom_chart import plot_table + +if TYPE_CHECKING: + import wandb + from wandb.plot.custom_chart import CustomChart + + +def histogram( + table: wandb.Table, + value: str, + title: str = "", + split_table: bool = False, +) -> CustomChart: + """Constructs a histogram chart from a W&B Table. + + Args: + table: The W&B Table containing the data for the histogram. + value: The label for the bin axis (x-axis). + title: The title of the histogram plot. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Example: + + ```python + import math + import random + import wandb + + # Generate random data + data = [[i, random.random() + math.sin(i / 10)] for i in range(100)] + + # Create a W&B Table + table = wandb.Table( + data=data, + columns=["step", "height"], + ) + + # Create a histogram plot + histogram = wandb.plot.histogram( + table, + value="height", + title="My Histogram", + ) + + # Log the histogram plot to W&B + with wandb.init(...) as run: + run.log({"histogram-plot1": histogram}) + ``` + """ + return plot_table( + data_table=table, + vega_spec_name="wandb/histogram/v0", + fields={"value": value}, + string_fields={"title": title}, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/line.py b/.venv/lib/python3.13/site-packages/wandb/plot/line.py new file mode 100644 index 0000000000000000000000000000000000000000..34857a4053ac3d699f447f927b7d6d0ea8d6ca14 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/line.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from wandb.plot.custom_chart import plot_table + +if TYPE_CHECKING: + import wandb + from wandb.plot.custom_chart import CustomChart + + +def line( + table: wandb.Table, + x: str, + y: str, + stroke: str | None = None, + title: str = "", + split_table: bool = False, +) -> CustomChart: + """Constructs a customizable line chart. + + Args: + table: The table containing data for the chart. + x: Column name for the x-axis values. + y: Column name for the y-axis values. + stroke: Column name to differentiate line strokes (e.g., for + grouping lines). + title: Title of the chart. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Example: + + ```python + import math + import random + import wandb + + # Create multiple series of data with different patterns + data = [] + for i in range(100): + # Series 1: Sinusoidal pattern with random noise + data.append([i, math.sin(i / 10) + random.uniform(-0.1, 0.1), "series_1"]) + # Series 2: Cosine pattern with random noise + data.append([i, math.cos(i / 10) + random.uniform(-0.1, 0.1), "series_2"]) + # Series 3: Linear increase with random noise + data.append([i, i / 10 + random.uniform(-0.5, 0.5), "series_3"]) + + # Define the columns for the table + table = wandb.Table(data=data, columns=["step", "value", "series"]) + + # Initialize wandb run and log the line chart + with wandb.init(project="line_chart_example") as run: + line_chart = wandb.plot.line( + table=table, + x="step", + y="value", + stroke="series", # Group by the "series" column + title="Multi-Series Line Plot", + ) + run.log({"line-chart": line_chart}) + ``` + """ + return plot_table( + data_table=table, + vega_spec_name="wandb/line/v0", + fields={"x": x, "y": y, "stroke": stroke}, + string_fields={"title": title}, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/line_series.py b/.venv/lib/python3.13/site-packages/wandb/plot/line_series.py new file mode 100644 index 0000000000000000000000000000000000000000..de43caa5495d4699d900521497498bd9cdfce259 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/line_series.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable + +import wandb +from wandb.plot.custom_chart import plot_table + +if TYPE_CHECKING: + from wandb.plot.custom_chart import CustomChart + + +def line_series( + xs: Iterable[Iterable[Any]] | Iterable[Any], + ys: Iterable[Iterable[Any]], + keys: Iterable[str] | None = None, + title: str = "", + xname: str = "x", + split_table: bool = False, +) -> CustomChart: + """Constructs a line series chart. + + Args: + xs: Sequence of x values. If a singular + array is provided, all y values are plotted against that x array. If + an array of arrays is provided, each y value is plotted against the + corresponding x array. + ys: Sequence of y values, where each iterable represents + a separate line series. + keys: Sequence of keys for labeling each line series. If + not provided, keys will be automatically generated as "line_1", + "line_2", etc. + title: Title of the chart. + xname: Label for the x-axis. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Examples: + Logging a single x array where all y series are plotted against the same x values: + + ```python + import wandb + + # Initialize W&B run + with wandb.init(project="line_series_example") as run: + # x values shared across all y series + xs = list(range(10)) + + # Multiple y series to plot + ys = [ + [i for i in range(10)], # y = x + [i**2 for i in range(10)], # y = x^2 + [i**3 for i in range(10)], # y = x^3 + ] + + # Generate and log the line series chart + line_series_chart = wandb.plot.line_series( + xs, + ys, + title="title", + xname="step", + ) + run.log({"line-series-single-x": line_series_chart}) + ``` + + In this example, a single `xs` series (shared x-values) is used for all + `ys` series. This results in each y-series being plotted against the + same x-values (0-9). + + Logging multiple x arrays where each y series is plotted against its corresponding x array: + + ```python + import wandb + + # Initialize W&B run + with wandb.init(project="line_series_example") as run: + # Separate x values for each y series + xs = [ + [i for i in range(10)], # x for first series + [2 * i for i in range(10)], # x for second series (stretched) + [3 * i for i in range(10)], # x for third series (stretched more) + ] + + # Corresponding y series + ys = [ + [i for i in range(10)], # y = x + [i**2 for i in range(10)], # y = x^2 + [i**3 for i in range(10)], # y = x^3 + ] + + # Generate and log the line series chart + line_series_chart = wandb.plot.line_series( + xs, ys, title="Multiple X Arrays Example", xname="Step" + ) + run.log({"line-series-multiple-x": line_series_chart}) + ``` + + In this example, each y series is plotted against its own unique x series. + This allows for more flexibility when the x values are not uniform across + the data series. + + Customizing line labels using `keys`: + + ```python + import wandb + + # Initialize W&B run + with wandb.init(project="line_series_example") as run: + xs = list(range(10)) # Single x array + ys = [ + [i for i in range(10)], # y = x + [i**2 for i in range(10)], # y = x^2 + [i**3 for i in range(10)], # y = x^3 + ] + + # Custom labels for each line + keys = ["Linear", "Quadratic", "Cubic"] + + # Generate and log the line series chart + line_series_chart = wandb.plot.line_series( + xs, + ys, + keys=keys, # Custom keys (line labels) + title="Custom Line Labels Example", + xname="Step", + ) + run.log({"line-series-custom-keys": line_series_chart}) + ``` + + This example shows how to provide custom labels for the lines using + the `keys` argument. The keys will appear in the legend as "Linear", + "Quadratic", and "Cubic". + """ + # If xs is a single array, repeat it for each y in ys + if not isinstance(xs[0], Iterable) or isinstance(xs[0], (str, bytes)): + xs = [xs] * len(ys) + + if len(xs) != len(ys): + msg = f"Number of x-series ({len(xs)}) must match y-series ({len(ys)})." + raise ValueError(msg) + + if keys is None: + keys = [f"line_{i}" for i in range(len(ys))] + + if len(keys) != len(ys): + msg = f"Number of keys ({len(keys)}) must match y-series ({len(ys)})." + raise ValueError(msg) + + data = [ + [x, keys[i], y] + for i, (xx, yy) in enumerate(zip(xs, ys)) + for x, y in zip(xx, yy) + ] + table = wandb.Table( + data=data, + columns=["step", "lineKey", "lineVal"], + ) + + return plot_table( + data_table=table, + vega_spec_name="wandb/lineseries/v0", + fields={ + "step": "step", + "lineKey": "lineKey", + "lineVal": "lineVal", + }, + string_fields={"title": title, "xname": xname}, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/pr_curve.py b/.venv/lib/python3.13/site-packages/wandb/plot/pr_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ade060c3dab946002f48d3b3164ed901b8efe3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/pr_curve.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +import numbers +from typing import TYPE_CHECKING, Iterable, TypeVar + +import wandb +from wandb import util +from wandb.plot.custom_chart import plot_table +from wandb.plot.utils import test_missing, test_types + +if TYPE_CHECKING: + from wandb.plot.custom_chart import CustomChart + + +T = TypeVar("T") + + +def pr_curve( + y_true: Iterable[T] | None = None, + y_probas: Iterable[numbers.Number] | None = None, + labels: list[str] | None = None, + classes_to_plot: list[T] | None = None, + interp_size: int = 21, + title: str = "Precision-Recall Curve", + split_table: bool = False, +) -> CustomChart: + """Constructs a Precision-Recall (PR) curve. + + The Precision-Recall curve is particularly useful for evaluating classifiers + on imbalanced datasets. A high area under the PR curve signifies both high + precision (a low false positive rate) and high recall (a low false negative + rate). The curve provides insights into the balance between false positives + and false negatives at various threshold levels, aiding in the assessment of + a model's performance. + + Args: + y_true: True binary labels. The shape should be (`num_samples`,). + y_probas: Predicted scores or probabilities for each class. + These can be probability estimates, confidence scores, or non-thresholded + decision values. The shape should be (`num_samples`, `num_classes`). + labels: Optional list of class names to replace + numeric values in `y_true` for easier plot interpretation. + For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with + 'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided, + numeric values from `y_true` will be used. + classes_to_plot: Optional list of unique class values from + y_true to be included in the plot. If not specified, all unique + classes in y_true will be plotted. + interp_size: Number of points to interpolate recall values. The + recall values will be fixed to `interp_size` uniformly distributed + points in the range [0, 1], and the precision will be interpolated + accordingly. + title: Title of the plot. Defaults to "Precision-Recall Curve". + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Raises: + wandb.Error: If NumPy, pandas, or scikit-learn is not installed. + + + Example: + + ```python + import wandb + + # Example for spam detection (binary classification) + y_true = [0, 1, 1, 0, 1] # 0 = not spam, 1 = spam + y_probas = [ + [0.9, 0.1], # Predicted probabilities for the first sample (not spam) + [0.2, 0.8], # Second sample (spam), and so on + [0.1, 0.9], + [0.8, 0.2], + [0.3, 0.7], + ] + + labels = ["not spam", "spam"] # Optional class names for readability + + with wandb.init(project="spam-detection") as run: + pr_curve = wandb.plot.pr_curve( + y_true=y_true, + y_probas=y_probas, + labels=labels, + title="Precision-Recall Curve for Spam Detection", + ) + run.log({"pr-curve": pr_curve}) + ``` + """ + np = util.get_module( + "numpy", + required="roc requires the numpy library, install with `pip install numpy`", + ) + pd = util.get_module( + "pandas", + required="roc requires the pandas library, install with `pip install pandas`", + ) + sklearn_metrics = util.get_module( + "sklearn.metrics", + "roc requires the scikit library, install with `pip install scikit-learn`", + ) + sklearn_utils = util.get_module( + "sklearn.utils", + "roc requires the scikit library, install with `pip install scikit-learn`", + ) + + def _step(x): + y = np.array(x) + for i in range(1, len(y)): + y[i] = max(y[i], y[i - 1]) + return y + + y_true = np.array(y_true) + y_probas = np.array(y_probas) + + if not test_missing(y_true=y_true, y_probas=y_probas): + return + if not test_types(y_true=y_true, y_probas=y_probas): + return + + classes = np.unique(y_true) + if classes_to_plot is None: + classes_to_plot = classes + + precision = {} + interp_recall = np.linspace(0, 1, interp_size)[::-1] + indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0] + for i in indices_to_plot: + if labels is not None and ( + isinstance(classes[i], int) or isinstance(classes[0], np.integer) + ): + class_label = labels[classes[i]] + else: + class_label = classes[i] + + cur_precision, cur_recall, _ = sklearn_metrics.precision_recall_curve( + y_true, y_probas[:, i], pos_label=classes[i] + ) + # smooth the precision (monotonically increasing) + cur_precision = _step(cur_precision) + + # reverse order so that recall in ascending + cur_precision = cur_precision[::-1] + cur_recall = cur_recall[::-1] + indices = np.searchsorted(cur_recall, interp_recall, side="left") + precision[class_label] = cur_precision[indices] + + df = pd.DataFrame( + { + "class": np.hstack([[k] * len(v) for k, v in precision.items()]), + "precision": np.hstack(list(precision.values())), + "recall": np.tile(interp_recall, len(precision)), + } + ).round(3) + + if len(df) > wandb.Table.MAX_ROWS: + wandb.termwarn( + f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit." + ) + # different sampling could be applied, possibly to ensure endpoints are kept + df = sklearn_utils.resample( + df, + replace=False, + n_samples=wandb.Table.MAX_ROWS, + random_state=42, + stratify=df["class"], + ).sort_values(["precision", "recall", "class"]) + + return plot_table( + data_table=wandb.Table(dataframe=df), + vega_spec_name="wandb/area-under-curve/v0", + fields={ + "x": "recall", + "y": "precision", + "class": "class", + }, + string_fields={ + "title": title, + "x-axis-title": "Recall", + "y-axis-title": "Precision", + }, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/roc_curve.py b/.venv/lib/python3.13/site-packages/wandb/plot/roc_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..ce42232a9d0a3a91ed0017a3c126b4acc7063402 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/roc_curve.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import numbers +from typing import TYPE_CHECKING, Sequence + +import wandb +from wandb import util +from wandb.plot.custom_chart import plot_table +from wandb.plot.utils import test_missing, test_types + +if TYPE_CHECKING: + from wandb.plot.custom_chart import CustomChart + + +def roc_curve( + y_true: Sequence[numbers.Number], + y_probas: Sequence[Sequence[float]] | None = None, + labels: list[str] | None = None, + classes_to_plot: list[numbers.Number] | None = None, + title: str = "ROC Curve", + split_table: bool = False, +) -> CustomChart: + """Constructs Receiver Operating Characteristic (ROC) curve chart. + + Args: + y_true: The true class labels (ground truth) + for the target variable. Shape should be (num_samples,). + y_probas: The predicted probabilities or + decision scores for each class. Shape should be (num_samples, num_classes). + labels: Human-readable labels corresponding to the class + indices in `y_true`. For example, if `labels=['dog', 'cat']`, + class 0 will be displayed as 'dog' and class 1 as 'cat' in the plot. + If None, the raw class indices from `y_true` will be used. + Default is None. + classes_to_plot: A subset of unique class labels + to include in the ROC curve. If None, all classes in `y_true` will + be plotted. Default is None. + title: Title of the ROC curve plot. Default is "ROC Curve". + split_table: Whether the table should be split into a separate + section in the W&B UI. If `True`, the table will be displayed in a + section named "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + + Raises: + wandb.Error: If numpy, pandas, or scikit-learn are not found. + + Example: + ```python + import numpy as np + import wandb + + # Simulate a medical diagnosis classification problem with three diseases + n_samples = 200 + n_classes = 3 + + # True labels: assign "Diabetes", "Hypertension", or "Heart Disease" to + # each sample + disease_labels = ["Diabetes", "Hypertension", "Heart Disease"] + # 0: Diabetes, 1: Hypertension, 2: Heart Disease + y_true = np.random.choice([0, 1, 2], size=n_samples) + + # Predicted probabilities: simulate predictions, ensuring they sum to 1 + # for each sample + y_probas = np.random.dirichlet(np.ones(n_classes), size=n_samples) + + # Specify classes to plot (plotting all three diseases) + classes_to_plot = [0, 1, 2] + + # Initialize a W&B run and log a ROC curve plot for disease classification + with wandb.init(project="medical_diagnosis") as run: + roc_plot = wandb.plot.roc_curve( + y_true=y_true, + y_probas=y_probas, + labels=disease_labels, + classes_to_plot=classes_to_plot, + title="ROC Curve for Disease Classification", + ) + run.log({"roc-curve": roc_plot}) + ``` + """ + np = util.get_module( + "numpy", + required="roc requires the numpy library, install with `pip install numpy`", + ) + pd = util.get_module( + "pandas", + required="roc requires the pandas library, install with `pip install pandas`", + ) + sklearn_metrics = util.get_module( + "sklearn.metrics", + "roc requires the scikit library, install with `pip install scikit-learn`", + ) + sklearn_utils = util.get_module( + "sklearn.utils", + "roc requires the scikit library, install with `pip install scikit-learn`", + ) + + y_true = np.array(y_true) + y_probas = np.array(y_probas) + + if not test_missing(y_true=y_true, y_probas=y_probas): + return + if not test_types(y_true=y_true, y_probas=y_probas): + return + + classes = np.unique(y_true) + if classes_to_plot is None: + classes_to_plot = classes + + fpr = {} + tpr = {} + indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0] + for i in indices_to_plot: + if labels is not None and ( + isinstance(classes[i], int) or isinstance(classes[0], np.integer) + ): + class_label = labels[classes[i]] + else: + class_label = classes[i] + + fpr[class_label], tpr[class_label], _ = sklearn_metrics.roc_curve( + y_true, y_probas[..., i], pos_label=classes[i] + ) + + df = pd.DataFrame( + { + "class": np.hstack([[k] * len(v) for k, v in fpr.items()]), + "fpr": np.hstack(list(fpr.values())), + "tpr": np.hstack(list(tpr.values())), + } + ).round(3) + + if len(df) > wandb.Table.MAX_ROWS: + wandb.termwarn( + f"wandb uses only {wandb.Table.MAX_ROWS} data points to create the plots." + ) + # different sampling could be applied, possibly to ensure endpoints are kept + df = sklearn_utils.resample( + df, + replace=False, + n_samples=wandb.Table.MAX_ROWS, + random_state=42, + stratify=df["class"], + ).sort_values(["fpr", "tpr", "class"]) + + return plot_table( + data_table=wandb.Table(dataframe=df), + vega_spec_name="wandb/area-under-curve/v0", + fields={ + "x": "fpr", + "y": "tpr", + "class": "class", + }, + string_fields={ + "title": title, + "x-axis-title": "False positive rate", + "y-axis-title": "True positive rate", + }, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/scatter.py b/.venv/lib/python3.13/site-packages/wandb/plot/scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..40e3d8033e2d6249554e3de56904add3ef30ffbd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/scatter.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from wandb.plot.custom_chart import plot_table + +if TYPE_CHECKING: + import wandb + from wandb.plot.custom_chart import CustomChart + + +def scatter( + table: wandb.Table, + x: str, + y: str, + title: str = "", + split_table: bool = False, +) -> CustomChart: + """Constructs a scatter plot from a wandb.Table of data. + + Args: + table: The W&B Table containing the data to visualize. + x: The name of the column used for the x-axis. + y: The name of the column used for the y-axis. + title: The title of the scatter chart. + split_table: Whether the table should be split into a separate section + in the W&B UI. If `True`, the table will be displayed in a section named + "Custom Chart Tables". Default is `False`. + + Returns: + CustomChart: A custom chart object that can be logged to W&B. To log the + chart, pass it to `wandb.log()`. + Example: + ```python + import math + import random + import wandb + + # Simulate temperature variations at different altitudes over time + data = [ + [i, random.uniform(-10, 20) - 0.005 * i + 5 * math.sin(i / 50)] + for i in range(300) + ] + + # Create W&B table with altitude (m) and temperature (°C) columns + table = wandb.Table(data=data, columns=["altitude (m)", "temperature (°C)"]) + + # Initialize W&B run and log the scatter plot + with wandb.init(project="temperature-altitude-scatter") as run: + # Create and log the scatter plot + scatter_plot = wandb.plot.scatter( + table=table, + x="altitude (m)", + y="temperature (°C)", + title="Altitude vs Temperature", + ) + run.log({"altitude-temperature-scatter": scatter_plot}) + ``` + """ + return plot_table( + data_table=table, + vega_spec_name="wandb/scatter/v0", + fields={"x": x, "y": y}, + string_fields={"title": title}, + split_table=split_table, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/utils.py b/.venv/lib/python3.13/site-packages/wandb/plot/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9436dce8e3394acd731d3efffed23af8188490e6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/utils.py @@ -0,0 +1,184 @@ +from typing import Iterable, Sequence + +import wandb +from wandb import util + + +def test_missing(**kwargs): + np = util.get_module("numpy", required="Logging plots requires numpy") + pd = util.get_module("pandas", required="Logging dataframes requires pandas") + scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy") + + test_passed = True + for k, v in kwargs.items(): + # Missing/empty params/datapoint arrays + if v is None: + wandb.termerror(f"{k} is None. Please try again.") + test_passed = False + if (k == "X") or (k == "X_test"): + if isinstance(v, scipy.sparse.csr.csr_matrix): + v = v.toarray() + elif isinstance(v, (pd.DataFrame, pd.Series)): + v = v.to_numpy() + elif isinstance(v, list): + v = np.asarray(v) + + # Warn the user about missing values + missing = 0 + missing = np.count_nonzero(pd.isnull(v)) + if missing > 0: + wandb.termwarn("%s contains %d missing values. " % (k, missing)) + test_passed = False + # Ensure the dataset contains only integers + non_nums = 0 + if v.ndim == 1: + non_nums = sum( + 1 + for val in v + if ( + not isinstance(val, (int, float, complex)) + and not isinstance(val, np.number) + ) + ) + else: + non_nums = sum( + 1 + for sl in v + for val in sl + if ( + not isinstance(val, (int, float, complex)) + and not isinstance(val, np.number) + ) + ) + if non_nums > 0: + wandb.termerror( + f"{k} contains values that are not numbers. Please vectorize, " + f"label encode or one hot encode {k} and call the plotting function again." + ) + test_passed = False + return test_passed + + +def test_fitted(model): + np = util.get_module("numpy", required="Logging plots requires numpy") + _ = util.get_module("pandas", required="Logging dataframes requires pandas") + _ = util.get_module("scipy", required="Logging scipy matrices requires scipy") + scikit_utils = util.get_module( + "sklearn.utils", + required="roc requires the scikit utils submodule, install with `pip install scikit-learn`", + ) + scikit_exceptions = util.get_module( + "sklearn.exceptions", + "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`", + ) + + try: + model.predict(np.zeros((7, 3))) + except scikit_exceptions.NotFittedError: + wandb.termerror("Please fit the model before passing it in.") + return False + except AttributeError: + # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict`` + try: + scikit_utils.validation.check_is_fitted( + model, + [ + "coef_", + "estimator_", + "labels_", + "n_clusters_", + "children_", + "components_", + "n_components_", + "n_iter_", + "n_batch_iter_", + "explained_variance_", + "singular_values_", + "mean_", + ], + all_or_any=any, + ) + except scikit_exceptions.NotFittedError: + wandb.termerror("Please fit the model before passing it in.") + return False + else: + return True + except Exception: + # Assume it's fitted, since ``NotFittedError`` wasn't raised + return True + + +def encode_labels(df): + _ = util.get_module("pandas", required="Logging dataframes requires pandas") + preprocessing = util.get_module( + "sklearn.preprocessing", + "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`", + ) + + le = preprocessing.LabelEncoder() + # apply le on categorical feature columns + categorical_cols = df.select_dtypes( + exclude=["int", "float", "float64", "float32", "int32", "int64"] + ).columns + df[categorical_cols] = df[categorical_cols].apply(lambda col: le.fit_transform(col)) + + +def test_types(**kwargs): + np = util.get_module("numpy", required="Logging plots requires numpy") + pd = util.get_module("pandas", required="Logging dataframes requires pandas") + _ = util.get_module("scipy", required="Logging scipy matrices requires scipy") + + base = util.get_module( + "sklearn.base", + "roc requires the scikit base submodule, install with `pip install scikit-learn`", + ) + + test_passed = True + for k, v in kwargs.items(): + # check for incorrect types + if ( + (k == "X") + or (k == "X_test") + or (k == "y") + or (k == "y_test") + or (k == "y_true") + or (k == "y_probas") + or (k == "x_labels") + or (k == "y_labels") + or (k == "matrix_values") + ): + # FIXME: do this individually + if not isinstance( + v, + ( + Sequence, + Iterable, + np.ndarray, + np.generic, + pd.DataFrame, + pd.Series, + list, + ), + ): + wandb.termerror(f"{k} is not an array. Please try again.") + test_passed = False + # check for classifier types + if k == "model": + if (not base.is_classifier(v)) and (not base.is_regressor(v)): + wandb.termerror( + f"{k} is not a classifier or regressor. Please try again." + ) + test_passed = False + elif k == "clf" or k == "binary_clf": + if not (base.is_classifier(v)): + wandb.termerror(f"{k} is not a classifier. Please try again.") + test_passed = False + elif k == "regressor": + if not base.is_regressor(v): + wandb.termerror(f"{k} is not a regressor. Please try again.") + test_passed = False + elif k == "clusterer": + if not (getattr(v, "_estimator_type", None) == "clusterer"): + wandb.termerror(f"{k} is not a clusterer. Please try again.") + test_passed = False + return test_passed diff --git a/.venv/lib/python3.13/site-packages/wandb/plot/viz.py b/.venv/lib/python3.13/site-packages/wandb/plot/viz.py new file mode 100644 index 0000000000000000000000000000000000000000..ea6be2f8805891b5d0c77b557c040a4377e9f4f9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/plot/viz.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from wandb.data_types import Table +from wandb.errors import Error + + +@dataclass +class VisualizeSpec: + name: str + key: str = "" + + @property + def config_value(self) -> dict[str, Any]: + return { + "id": self.name, + "historyFieldSettings": {"x-axis": "_step", "key": self.key}, + } + + @property + def config_key(self) -> tuple[str, str, str]: + return ("_wandb", "viz", self.key) + + +@dataclass +class Visualize: + table: Table + spec: VisualizeSpec + + def set_key(self, key: str) -> None: + self.spec.key = key + + +def visualize(id: str, value: Table) -> Visualize: + if not isinstance(value, Table): + raise Error( + f"Expected `value` to be `wandb.Table` type, instead got {type(value).__name__}" + ) + return Visualize(table=value, spec=VisualizeSpec(name=id)) diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/__init__.py b/.venv/lib/python3.13/site-packages/wandb/proto/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_api_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_api_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..40b7e8a2d01173ec5433da8361fd9ae280956df5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_api_pb2.py @@ -0,0 +1,18 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_api_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_api_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_api_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_api_pb2 import * +else: + raise ImportError( + "Failed to import protobufs for protobuf version" + f" {google.protobuf.__version__}. `wandb` only works with major" + " versions 3, 4, 5, and 6 of the protobuf package.", + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_base_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_base_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2ae6b240436e0da0d94ffda7c4a150d24dc54f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_base_pb2.py @@ -0,0 +1,12 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_base_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_base_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_base_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_base_pb2 import * diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_generate_proto.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_generate_proto.py new file mode 100644 index 0000000000000000000000000000000000000000..71ac23beca489552886b4f6e0feecbd29b00bff3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_generate_proto.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +import importlib.metadata +import os +import pathlib + +import grpc_tools # type: ignore +from grpc_tools import protoc # type: ignore +from packaging import version + + +def get_pip_package_version(package_name: str) -> str: + try: + return importlib.metadata.version(package_name) + except importlib.metadata.PackageNotFoundError: + raise ValueError(f"Package `{package_name}` not found") + +protobuf_version = version.Version(get_pip_package_version("protobuf")) + +proto_root = os.path.join(os.path.dirname(grpc_tools.__file__), "_proto") +tmp_out: pathlib.Path = pathlib.Path(f"wandb/proto/v{protobuf_version.major}/") + +os.chdir("../..") +for proto_file in [ + "wandb_base.proto", + "wandb_internal.proto", + "wandb_settings.proto", + "wandb_telemetry.proto", + "wandb_server.proto", + "wandb_sync.proto", + "wandb_api.proto", +]: + ret = protoc.main( + ( + "", + "-I", + proto_root, + "-I", + ".", + f"--python_out={tmp_out}", + f"--mypy_out={tmp_out}", + f"wandb/proto/{proto_file}", + ) + ) + assert not ret + +# clean up tmp dirs +for p in (tmp_out / "wandb" / "proto").glob("*pb2*"): + p.rename(tmp_out / p.name) +os.rmdir(tmp_out / "wandb" / "proto") +os.rmdir(tmp_out / "wandb") diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_internal_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_internal_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..97cbe72e2148955f2aaa357053905e81ccfc5825 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_internal_pb2.py @@ -0,0 +1,18 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_internal_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_internal_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_internal_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_internal_pb2 import * +else: + raise ImportError( + "Failed to import protobufs for protobuf version" + f" {google.protobuf.__version__}. `wandb` only works with major" + " versions 3, 4, 5, and 6 of the protobuf package.", + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_server_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_server_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..d30616eb0902a99ce4180d7760a975b3e4383bc7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_server_pb2.py @@ -0,0 +1,12 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_server_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_server_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_server_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_server_pb2 import * diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_settings_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_settings_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..0600cc522c3e7dae1222d19779652da44010a100 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_settings_pb2.py @@ -0,0 +1,12 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_settings_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_settings_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_settings_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_settings_pb2 import * diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_sync_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_sync_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..119b214a76c18efe2e96d4b4e23c2a500a17b94f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_sync_pb2.py @@ -0,0 +1,12 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_sync_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_sync_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_sync_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_sync_pb2 import * diff --git a/.venv/lib/python3.13/site-packages/wandb/proto/wandb_telemetry_pb2.py b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_telemetry_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..19c81ce3f03a8906b4108f3f9f5d602d2b84f792 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/proto/wandb_telemetry_pb2.py @@ -0,0 +1,12 @@ +import google.protobuf + +protobuf_version = google.protobuf.__version__[0] + +if protobuf_version == "3": + from wandb.proto.v3.wandb_telemetry_pb2 import * +elif protobuf_version == "4": + from wandb.proto.v4.wandb_telemetry_pb2 import * +elif protobuf_version == "5": + from wandb.proto.v5.wandb_telemetry_pb2 import * +elif protobuf_version == "6": + from wandb.proto.v6.wandb_telemetry_pb2 import * diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__init__.py b/.venv/lib/python3.13/site-packages/wandb/sdk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aca29299d736832f386133698d85885bdbe36c38 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/__init__.py @@ -0,0 +1,36 @@ +"""W&B SDK module.""" + +__all__ = ( + "Config", + "Settings", + "Summary", + "Artifact", + "AlertLevel", + "init", + "setup", + "_attach", + "_sync", + "login", + "require", + "finish", + "teardown", + "_watch", + "_unwatch", + "sweep", + "controller", + "helper", +) + +from . import wandb_helper as helper +from .artifacts.artifact import Artifact +from .wandb_alerts import AlertLevel +from .wandb_config import Config +from .wandb_init import _attach, init +from .wandb_login import login +from .wandb_require import require +from .wandb_run import finish +from .wandb_settings import Settings +from .wandb_setup import setup, teardown +from .wandb_summary import Summary +from .wandb_sweep import controller, sweep +from .wandb_watch import _unwatch, _watch diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aacc06b2ce156c6d03e1e314f5b80af4c1d750d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_alerts.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_alerts.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc8fe70d3ec4375dd5aa83e26de4463f088d5495 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_alerts.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_config.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a9a060bb16e44f9653f25d1a783134dc846c457 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_config.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_helper.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_helper.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..741fb40ba0d04b794550ef2a0fd7e19f28b85e3b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_helper.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_init.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_init.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..532754320321a39a19aa935c4fcc819c545e4663 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_init.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_login.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_login.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f97cff3166b1936d34862ffe09d1af6c9dea2d23 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_login.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_metric.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_metric.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe43cd7ddd95b3b45ad761739c3e4010d86abe5f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_metric.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_require.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_require.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45f845697d7840d7e37a79b917e510986d23caaf Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_require.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_settings.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_settings.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a6a4ffbe6dc56e3765910f2c5d20a69088bd458 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_settings.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_setup.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_setup.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f790cba5297b679097daa56ff71f9cb4eebafbb9 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_setup.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_summary.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_summary.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffbe9733677b729e85c0ea8871c8be2c32c0f430 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_summary.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_sweep.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_sweep.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b316730793f0de75dc281787eb246015a3d3be98 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_sweep.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_watch.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_watch.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba933224da56dc018880e579acd23fd7fcd44dd Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/__pycache__/wandb_watch.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_factories.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_factories.py new file mode 100644 index 0000000000000000000000000000000000000000..00c02027d776fcb88f86b2adb2722c0332b30107 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_factories.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._models.storage import StoragePolicyConfig +from .storage_policies import WandbStoragePolicy + +if TYPE_CHECKING: + from .storage_policy import StoragePolicy + + +def make_storage_policy(region: str | None = None) -> StoragePolicy: + """Returns the default `StoragePolicy` for the current environment.""" + return WandbStoragePolicy(config=StoragePolicyConfig.from_env(region=region)) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/add_aliases.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/add_aliases.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd56f6dcf18158d6433058edf1181175db555e4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/add_aliases.py @@ -0,0 +1,19 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + + +class AddAliases(GQLResult): + result: Optional[AddAliasesResult] + + +class AddAliasesResult(GQLResult): + success: bool + + +AddAliases.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/artifact_created_by.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/artifact_created_by.py new file mode 100644 index 0000000000000000000000000000000000000000..e61416ff64da094796bb2351f7016aa0dd604d5a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/artifact_created_by.py @@ -0,0 +1,34 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional, Union + +from pydantic import Field +from typing_extensions import Annotated, Literal + +from wandb._pydantic import GQLResult, Typename + +from .fragments import RunInfoFragment + + +class ArtifactCreatedBy(GQLResult): + artifact: Optional[ArtifactCreatedByArtifact] + + +class ArtifactCreatedByArtifact(GQLResult): + created_by: Optional[ + Annotated[ + Union[RunInfoFragment, ArtifactCreatedByArtifactCreatedByUser], + Field(discriminator="typename__"), + ] + ] = Field(alias="createdBy") + + +class ArtifactCreatedByArtifactCreatedByUser(GQLResult): + typename__: Typename[Literal["User"]] + + +ArtifactCreatedBy.model_rebuild() +ArtifactCreatedByArtifact.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/artifact_type.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/artifact_type.py new file mode 100644 index 0000000000000000000000000000000000000000..94bd303c5e30f506c07dc3f8fb9289e9fb833d41 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/artifact_type.py @@ -0,0 +1,31 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + + +class ArtifactType(GQLResult): + project: Optional[ArtifactTypeProject] + + +class ArtifactTypeProject(GQLResult): + artifact: Optional[ArtifactTypeProjectArtifact] + + +class ArtifactTypeProjectArtifact(GQLResult): + artifact_type: ArtifactTypeProjectArtifactArtifactType = Field(alias="artifactType") + + +class ArtifactTypeProjectArtifactArtifactType(GQLResult): + name: str + + +ArtifactType.model_rebuild() +ArtifactTypeProject.model_rebuild() +ArtifactTypeProjectArtifact.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa3c4d6f8ef1884de6103f0c138118b550ba293 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py @@ -0,0 +1,34 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field +from typing_extensions import Literal + +from wandb._pydantic import GQLResult, Typename + +from .enums import ArtifactCollectionState + + +class DeleteArtifactPortfolio(GQLResult): + result: Optional[DeleteArtifactPortfolioResult] + + +class DeleteArtifactPortfolioResult(GQLResult): + artifact_collection: DeleteArtifactPortfolioResultArtifactCollection = Field( + alias="artifactCollection" + ) + + +class DeleteArtifactPortfolioResultArtifactCollection(GQLResult): + typename__: Typename[ + Literal["ArtifactCollection", "ArtifactPortfolio", "ArtifactSequence"] + ] + state: ArtifactCollectionState + + +DeleteArtifactPortfolio.model_rebuild() +DeleteArtifactPortfolioResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_registry.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..679e61f0be859d80a595bcb474b99f147e28447c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_registry.py @@ -0,0 +1,21 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + + +class DeleteRegistry(GQLResult): + delete_model: Optional[DeleteRegistryDeleteModel] = Field(alias="deleteModel") + + +class DeleteRegistryDeleteModel(GQLResult): + success: Optional[bool] + + +DeleteRegistry.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_registry_members.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_registry_members.py new file mode 100644 index 0000000000000000000000000000000000000000..05ca21157f4f310d31acb8064c993fe1bfcf5313 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/delete_registry_members.py @@ -0,0 +1,19 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + + +class DeleteRegistryMembers(GQLResult): + result: Optional[DeleteRegistryMembersResult] + + +class DeleteRegistryMembersResult(GQLResult): + success: bool + + +DeleteRegistryMembers.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/enums.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..c4131d8362f5b1e0d8a2d0fba641ce1affe713d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/enums.py @@ -0,0 +1,22 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations + +from enum import Enum + + +class ArtifactCollectionType(str, Enum): + SEQUENCE = "SEQUENCE" + PORTFOLIO = "PORTFOLIO" + + +class ArtifactState(str, Enum): + PENDING = "PENDING" + COMMITTED = "COMMITTED" + DELETED = "DELETED" + + +class ArtifactCollectionState(str, Enum): + READY = "READY" + DELETED = "DELETED" diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b98e48527b4cd35ee6114bd7d5acbc50c5bb93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py @@ -0,0 +1,26 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import DeferredManifestFragment + + +class FetchArtifactManifest(GQLResult): + artifact: Optional[FetchArtifactManifestArtifact] + + +class FetchArtifactManifestArtifact(GQLResult): + current_manifest: Optional[DeferredManifestFragment] = Field( + alias="currentManifest" + ) + + +FetchArtifactManifest.model_rebuild() +FetchArtifactManifestArtifact.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/fragments.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/fragments.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5920cd4f8d7361947f9251b8ade5ff1153fba8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/fragments.py @@ -0,0 +1,372 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field +from typing_extensions import Literal + +from wandb._pydantic import GQLId, GQLResult, Typename + +from .enums import ArtifactState + + +class ArtifactAliasFragment(GQLResult): + typename__: Typename[Literal["ArtifactAlias"]] = "ArtifactAlias" + id: GQLId + alias: str + + +class ProjectInfoFragment(GQLResult): + name: str + entity: ProjectInfoFragmentEntity + + +class ProjectInfoFragmentEntity(GQLResult): + name: str + + +class TagFragment(GQLResult): + typename__: Typename[Literal["Tag"]] = "Tag" + id: GQLId + name: str + + +class ArtifactCollectionFragment(GQLResult): + typename__: Typename[Literal["ArtifactSequence", "ArtifactPortfolio"]] + id: GQLId + name: str + description: Optional[str] + created_at: str = Field(alias="createdAt") + project: Optional[ProjectInfoFragment] + type: ArtifactCollectionFragmentType + tags: ArtifactCollectionFragmentTags + + +class ArtifactCollectionFragmentType(GQLResult): + name: str + + +class ArtifactCollectionFragmentTags(GQLResult): + edges: List[ArtifactCollectionFragmentTagsEdges] + + +class ArtifactCollectionFragmentTagsEdges(GQLResult): + node: TagFragment + + +class CollectionInfoFragment(GQLResult): + typename__: Typename[Literal["ArtifactSequence", "ArtifactPortfolio"]] + name: str + project: Optional[ProjectInfoFragment] + + +class SourceCollectionInfoFragment(GQLResult): + typename__: Typename[Literal["ArtifactSequence"]] = "ArtifactSequence" + name: str + project: Optional[ProjectInfoFragment] + + +class ArtifactFragment(GQLResult): + typename__: Typename[Literal["Artifact"]] = "Artifact" + id: GQLId + artifact_sequence: SourceCollectionInfoFragment = Field(alias="artifactSequence") + version_index: Optional[int] = Field(alias="versionIndex") + artifact_type: ArtifactFragmentArtifactType = Field(alias="artifactType") + description: Optional[str] + metadata: Optional[str] + ttl_duration_seconds: int = Field(alias="ttlDurationSeconds") + ttl_is_inherited: bool = Field(alias="ttlIsInherited") + tags: List[TagFragment] + history_step: Optional[int] = Field(alias="historyStep") + state: ArtifactState + size: int + digest: str + commit_hash: Optional[str] = Field(alias="commitHash") + file_count: int = Field(alias="fileCount") + created_at: str = Field(alias="createdAt") + updated_at: Optional[str] = Field(alias="updatedAt") + aliases: Optional[List[ArtifactFragmentAliases]] = None + + +class ArtifactFragmentArtifactType(GQLResult): + name: str + + +class ArtifactFragmentAliases(ArtifactAliasFragment): + artifact_collection: Optional[CollectionInfoFragment] = Field( + alias="artifactCollection" + ) + + +class ArtifactMembershipFragment(GQLResult): + typename__: Typename[Literal["ArtifactCollectionMembership"]] = ( + "ArtifactCollectionMembership" + ) + id: GQLId + version_index: Optional[int] = Field(alias="versionIndex") + aliases: List[ArtifactAliasFragment] + artifact_collection: Optional[CollectionInfoFragment] = Field( + alias="artifactCollection" + ) + artifact: Optional[ArtifactFragment] + + +class ArtifactPortfolioTypeFields(GQLResult): + typename__: Typename[Literal["ArtifactPortfolio"]] = "ArtifactPortfolio" + id: GQLId + name: str + + +class ArtifactSequenceTypeFields(GQLResult): + typename__: Typename[Literal["ArtifactSequence"]] = "ArtifactSequence" + id: GQLId + name: str + + +class ArtifactTypeFragment(GQLResult): + typename__: Typename[Literal["ArtifactType"]] = "ArtifactType" + id: GQLId + name: str + description: Optional[str] + created_at: str = Field(alias="createdAt") + + +class DeferredManifestFragment(GQLResult): + file: DeferredManifestFragmentFile + + +class DeferredManifestFragmentFile(GQLResult): + direct_url: str = Field(alias="directUrl") + + +class FileFragment(GQLResult): + typename__: Typename[Literal["File"]] = "File" + id: GQLId + name: str + url: Optional[str] + size_bytes: int = Field(alias="sizeBytes") + storage_path: Optional[str] = Field(alias="storagePath") + mimetype: Optional[str] + updated_at: Optional[str] = Field(alias="updatedAt") + digest: Optional[str] + md_5: Optional[str] = Field(alias="md5") + direct_url: str = Field(alias="directUrl") + + +class FileWithUrlFragment(GQLResult): + typename__: Typename[Literal["File"]] = "File" + name: str + direct_url: str = Field(alias="directUrl") + + +class OrgInfoFragment(GQLResult): + name: str + org_entity: Optional[OrgInfoFragmentOrgEntity] = Field(alias="orgEntity") + + +class OrgInfoFragmentOrgEntity(GQLResult): + name: str + + +class PageInfoFragment(GQLResult): + typename__: Typename[Literal["PageInfo"]] = "PageInfo" + end_cursor: Optional[str] = Field(alias="endCursor") + has_next_page: bool = Field(alias="hasNextPage") + + +class RegistryCollectionFragment(GQLResult): + typename__: Typename[Literal["ArtifactSequence", "ArtifactPortfolio"]] + id: GQLId + name: str + description: Optional[str] + created_at: str = Field(alias="createdAt") + project: Optional[ProjectInfoFragment] + type: RegistryCollectionFragmentType + tags: RegistryCollectionFragmentTags + + +class RegistryCollectionFragmentType(GQLResult): + name: str + + +class RegistryCollectionFragmentTags(GQLResult): + edges: List[RegistryCollectionFragmentTagsEdges] + + +class RegistryCollectionFragmentTagsEdges(GQLResult): + node: TagFragment + + +class RegistryFragment(GQLResult): + typename__: Typename[Literal["Project"]] = "Project" + id: GQLId + name: str + entity: RegistryFragmentEntity + description: Optional[str] + created_at: str = Field(alias="createdAt") + updated_at: Optional[str] = Field(alias="updatedAt") + access: Optional[str] + allow_all_artifact_types: bool = Field(alias="allowAllArtifactTypes") + artifact_types: RegistryFragmentArtifactTypes = Field(alias="artifactTypes") + + +class RegistryFragmentEntity(GQLResult): + name: str + organization: Optional[RegistryFragmentEntityOrganization] + + +class RegistryFragmentEntityOrganization(GQLResult): + name: str + + +class RegistryFragmentArtifactTypes(GQLResult): + edges: List[RegistryFragmentArtifactTypesEdges] + + +class RegistryFragmentArtifactTypesEdges(GQLResult): + node: Optional[RegistryFragmentArtifactTypesEdgesNode] + + +class RegistryFragmentArtifactTypesEdgesNode(GQLResult): + name: str + + +class RegistryRoleFragment(GQLResult): + name: str + + +class RunInfoFragment(GQLResult): + typename__: Typename[Literal["Run"]] = "Run" + id: GQLId + name: str + project: Optional[ProjectInfoFragment] + + +class TeamMemberFragment(GQLResult): + typename__: Typename[Literal["Member"]] = "Member" + id: Optional[str] + role: Optional[str] + pending: Optional[bool] + email: Optional[str] + username: Optional[str] + name: str + photo_url: Optional[str] = Field(alias="photoUrl") + account_type: Optional[str] = Field(alias="accountType") + api_key: Optional[str] = Field(alias="apiKey") + + +class TeamFragment(GQLResult): + typename__: Typename[Literal["Entity"]] = "Entity" + id: GQLId + name: str + available: Optional[bool] + photo_url: Optional[str] = Field(alias="photoUrl") + read_only: Optional[bool] = Field(alias="readOnly") + read_only_admin: bool = Field(alias="readOnlyAdmin") + is_team: bool = Field(alias="isTeam") + private_only: bool = Field(alias="privateOnly") + storage_bytes: int = Field(alias="storageBytes") + code_saving_enabled: bool = Field(alias="codeSavingEnabled") + default_access: str = Field(alias="defaultAccess") + is_paid: Optional[bool] = Field(alias="isPaid") + members: List[TeamMemberFragment] + + +class TeamRegistryMemberFragment(GQLResult): + team: TeamFragment + role: RegistryRoleFragment + + +class TypeInfoFragment(GQLResult): + name: Optional[str] + fields: Optional[List[TypeInfoFragmentFields]] + input_fields: Optional[List[TypeInfoFragmentInputFields]] = Field( + alias="inputFields" + ) + + +class TypeInfoFragmentFields(GQLResult): + name: str + args: List[TypeInfoFragmentFieldsArgs] + + +class TypeInfoFragmentFieldsArgs(GQLResult): + name: str + + +class TypeInfoFragmentInputFields(GQLResult): + name: str + + +class UserRegistryMemberFragment(GQLResult): + id: GQLId + name: Optional[str] + username: Optional[str] + email: Optional[str] + role: RegistryRoleFragment + + +ArtifactAliasFragment.model_rebuild() +ProjectInfoFragment.model_rebuild() +ProjectInfoFragmentEntity.model_rebuild() +TagFragment.model_rebuild() +ArtifactCollectionFragment.model_rebuild() +ProjectInfoFragment.model_rebuild() +ArtifactCollectionFragmentType.model_rebuild() +ArtifactCollectionFragmentTags.model_rebuild() +ArtifactCollectionFragmentTagsEdges.model_rebuild() +TagFragment.model_rebuild() +CollectionInfoFragment.model_rebuild() +ProjectInfoFragment.model_rebuild() +SourceCollectionInfoFragment.model_rebuild() +ProjectInfoFragment.model_rebuild() +ArtifactFragment.model_rebuild() +SourceCollectionInfoFragment.model_rebuild() +ArtifactFragmentArtifactType.model_rebuild() +TagFragment.model_rebuild() +ArtifactFragmentAliases.model_rebuild() +CollectionInfoFragment.model_rebuild() +ArtifactMembershipFragment.model_rebuild() +ArtifactAliasFragment.model_rebuild() +CollectionInfoFragment.model_rebuild() +ArtifactFragment.model_rebuild() +ArtifactPortfolioTypeFields.model_rebuild() +ArtifactSequenceTypeFields.model_rebuild() +ArtifactTypeFragment.model_rebuild() +DeferredManifestFragment.model_rebuild() +DeferredManifestFragmentFile.model_rebuild() +FileFragment.model_rebuild() +FileWithUrlFragment.model_rebuild() +OrgInfoFragment.model_rebuild() +OrgInfoFragmentOrgEntity.model_rebuild() +PageInfoFragment.model_rebuild() +RegistryCollectionFragment.model_rebuild() +ProjectInfoFragment.model_rebuild() +RegistryCollectionFragmentType.model_rebuild() +RegistryCollectionFragmentTags.model_rebuild() +RegistryCollectionFragmentTagsEdges.model_rebuild() +TagFragment.model_rebuild() +RegistryFragment.model_rebuild() +RegistryFragmentEntity.model_rebuild() +RegistryFragmentEntityOrganization.model_rebuild() +RegistryFragmentArtifactTypes.model_rebuild() +RegistryFragmentArtifactTypesEdges.model_rebuild() +RegistryFragmentArtifactTypesEdgesNode.model_rebuild() +RegistryRoleFragment.model_rebuild() +RunInfoFragment.model_rebuild() +ProjectInfoFragment.model_rebuild() +TeamMemberFragment.model_rebuild() +TeamFragment.model_rebuild() +TeamMemberFragment.model_rebuild() +TeamRegistryMemberFragment.model_rebuild() +TeamFragment.model_rebuild() +RegistryRoleFragment.model_rebuild() +TypeInfoFragment.model_rebuild() +TypeInfoFragmentFields.model_rebuild() +TypeInfoFragmentFieldsArgs.model_rebuild() +TypeInfoFragmentInputFields.model_rebuild() +UserRegistryMemberFragment.model_rebuild() +RegistryRoleFragment.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/input_types.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/input_types.py new file mode 100644 index 0000000000000000000000000000000000000000..3cee7676c28467209ce11efe10cd46c2f26c3fc8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/input_types.py @@ -0,0 +1,207 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLId, GQLInput + + +class UpsertModelInput(GQLInput): + name: Optional[str] = Field(default=None, max_length=128) + description: Optional[str] = None + id: Optional[str] = None + framework: Optional[str] = None + entity_name: Optional[str] = Field(alias="entityName", default=None) + docker_image: Optional[str] = Field( + alias="dockerImage", default=None, max_length=512 + ) + repo: Optional[str] = Field(default=None, max_length=256) + access: Optional[str] = None + views: Optional[str] = None + is_benchmark: Optional[bool] = Field(alias="isBenchmark", default=None) + linked_benchmark: Optional[GQLId] = Field(alias="linkedBenchmark", default=None) + is_published: Optional[bool] = Field(alias="isPublished", default=None) + owner: Optional[GQLId] = None + allow_all_artifact_types_in_registry: Optional[bool] = Field( + alias="allowAllArtifactTypesInRegistry", default=None + ) + rate_limits: Optional[RateLimitsInput] = Field(alias="rateLimits", default=None) + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + artifact_types: Optional[List[ArtifactTypeInput]] = Field( + alias="artifactTypes", default=None + ) + + +class RenameProjectInput(GQLInput): + entity_name: str = Field(alias="entityName") + old_project_name: str = Field(alias="oldProjectName") + new_project_name: str = Field(alias="newProjectName") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class RateLimitsInput(GQLInput): + graphql: Optional[int] = None + sdk_graphql: Optional[int] = Field(alias="sdkGraphql", default=None) + filestream_count: Optional[int] = Field(alias="filestreamCount", default=None) + filestream_size: Optional[int] = Field(alias="filestreamSize", default=None) + sdk_graphql_query_seconds: Optional[float] = Field( + alias="sdkGraphqlQuerySeconds", default=None + ) + + +class ArtifactTypeInput(GQLInput): + name: str = Field(max_length=128, pattern="^[-\\w]+([ ]*[-.\\w]+)*$") + description: Optional[str] = None + + +class UpdateArtifactSequenceInput(GQLInput): + artifact_sequence_id: GQLId = Field(alias="artifactSequenceID") + name: Optional[str] = Field(default=None, max_length=128) + description: Optional[str] = None + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class MoveArtifactSequenceInput(GQLInput): + artifact_sequence_id: GQLId = Field(alias="artifactSequenceID") + destination_artifact_type_name: str = Field(alias="destinationArtifactTypeName") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class UpdateArtifactPortfolioInput(GQLInput): + artifact_portfolio_id: GQLId = Field(alias="artifactPortfolioID") + name: Optional[str] = Field(default=None, max_length=128) + description: Optional[str] = None + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class ArtifactAliasInput(GQLInput): + artifact_collection_name: str = Field(alias="artifactCollectionName") + alias: str = Field(max_length=128) + + +class UpdateArtifactInput(GQLInput): + artifact_id: GQLId = Field(alias="artifactID") + description: Optional[str] = None + labels: Optional[str] = None + aliases: Optional[List[ArtifactAliasInput]] = None + tags_to_add: Optional[List[TagInput]] = Field(alias="tagsToAdd", default=None) + tags_to_delete: Optional[List[TagInput]] = Field(alias="tagsToDelete", default=None) + metadata: Optional[str] = None + ttl_duration_seconds: Optional[int] = Field( + alias="ttlDurationSeconds", default=None + ) + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class DeleteArtifactInput(GQLInput): + artifact_id: GQLId = Field(alias="artifactID") + delete_aliases: Optional[bool] = Field(alias="deleteAliases", default=False) + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class LinkArtifactInput(GQLInput): + artifact_id: Optional[GQLId] = Field(alias="artifactID", default=None) + artifact_portfolio_id: Optional[GQLId] = Field( + alias="artifactPortfolioID", default=None + ) + artifact_portfolio_name: Optional[str] = Field( + alias="artifactPortfolioName", default=None + ) + entity_name: Optional[str] = Field(alias="entityName", default=None) + project_name: Optional[str] = Field(alias="projectName", default=None) + aliases: Optional[List[ArtifactAliasInput]] = None + client_id: Optional[GQLId] = Field(alias="clientID", default=None) + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class UnlinkArtifactInput(GQLInput): + artifact_id: GQLId = Field(alias="artifactID") + artifact_portfolio_id: GQLId = Field(alias="artifactPortfolioID") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class ArtifactCollectionAliasInput(GQLInput): + alias: str = Field(max_length=128) + entity_name: str = Field(alias="entityName") + project_name: str = Field(alias="projectName") + artifact_collection_name: str = Field(alias="artifactCollectionName") + + +class AddAliasesInput(GQLInput): + aliases: List[ArtifactCollectionAliasInput] + artifact_id: GQLId = Field(alias="artifactID") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class DeleteAliasesInput(GQLInput): + aliases: List[ArtifactCollectionAliasInput] + artifact_id: GQLId = Field(alias="artifactID") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class TagInput(GQLInput): + tag_category_name: Optional[str] = Field( + alias="tagCategoryName", + default=None, + max_length=128, + pattern="^[-\\w]+([ ]+[-\\w]+)*$", + ) + tag_name: str = Field( + alias="tagName", max_length=128, pattern="^[-\\w]+([ ]+[-\\w]+)*$" + ) + attributes: Optional[str] = None + + +class CreateArtifactCollectionTagAssignmentsInput(GQLInput): + entity_name: str = Field(alias="entityName") + project_name: str = Field(alias="projectName") + artifact_collection_name: str = Field(alias="artifactCollectionName") + tags: List[TagInput] = Field(max_length=20) + client_mutation_id: Optional[str] = Field(alias="clientMutationID", default=None) + + +class DeleteArtifactCollectionTagAssignmentsInput(GQLInput): + entity_name: str = Field(alias="entityName") + project_name: str = Field(alias="projectName") + artifact_collection_name: str = Field(alias="artifactCollectionName") + tags: List[TagInput] = Field(max_length=20) + client_mutation_id: Optional[str] = Field(alias="clientMutationID", default=None) + + +class CreateProjectMembersInput(GQLInput): + user_ids: Optional[List[GQLId]] = Field(alias="userIds", default=None) + team_ids: Optional[List[GQLId]] = Field(alias="teamIds", default=None) + project_id: GQLId = Field(alias="projectId") + + +class DeleteProjectMembersInput(GQLInput): + user_ids: Optional[List[GQLId]] = Field(alias="userIds", default=None) + team_ids: Optional[List[GQLId]] = Field(alias="teamIds", default=None) + project_id: GQLId = Field(alias="projectId") + + +class UpdateProjectMemberInput(GQLInput): + user_id: GQLId = Field(alias="userId") + project_id: GQLId = Field(alias="projectId") + user_project_role: str = Field(alias="userProjectRole") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +class UpdateProjectTeamMemberInput(GQLInput): + team_id: GQLId = Field(alias="teamId") + project_id: GQLId = Field(alias="projectId") + team_project_role: str = Field(alias="teamProjectRole") + client_mutation_id: Optional[str] = Field(alias="clientMutationId", default=None) + + +UpsertModelInput.model_rebuild() +UpdateArtifactInput.model_rebuild() +LinkArtifactInput.model_rebuild() +AddAliasesInput.model_rebuild() +DeleteAliasesInput.model_rebuild() +CreateArtifactCollectionTagAssignmentsInput.model_rebuild() +DeleteArtifactCollectionTagAssignmentsInput.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/link_artifact.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/link_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b06d0ccf171a9ae7491807c6f0dd4ddb6f8c86 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/link_artifact.py @@ -0,0 +1,27 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ArtifactMembershipFragment + + +class LinkArtifact(GQLResult): + result: Optional[LinkArtifactResult] + + +class LinkArtifactResult(GQLResult): + version_index: Optional[int] = Field(alias="versionIndex") + artifact_membership: Optional[ArtifactMembershipFragment] = Field( + alias="artifactMembership", default=None + ) + + +LinkArtifact.model_rebuild() +LinkArtifactResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/operations.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..4e2a4074945032612211fe4b534dd0c8af0c37d1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/operations.py @@ -0,0 +1,1946 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +__all__ = [ + "ADD_ALIASES_GQL", + "ADD_ARTIFACT_COLLECTION_TAGS_GQL", + "ARTIFACT_BY_ID_GQL", + "ARTIFACT_BY_NAME_GQL", + "ARTIFACT_COLLECTION_ALIASES_GQL", + "ARTIFACT_CREATED_BY_GQL", + "ARTIFACT_MEMBERSHIP_BY_NAME_GQL", + "ARTIFACT_TYPE_GQL", + "ARTIFACT_USED_BY_GQL", + "CREATE_REGISTRY_MEMBERS_GQL", + "DELETE_ALIASES_GQL", + "DELETE_ARTIFACT_COLLECTION_TAGS_GQL", + "DELETE_ARTIFACT_GQL", + "DELETE_ARTIFACT_PORTFOLIO_GQL", + "DELETE_ARTIFACT_SEQUENCE_GQL", + "DELETE_REGISTRY_GQL", + "DELETE_REGISTRY_MEMBERS_GQL", + "FETCH_ARTIFACT_MANIFEST_GQL", + "FETCH_LINKED_ARTIFACTS_GQL", + "FETCH_ORG_ENTITY_FROM_ORGANIZATION_GQL", + "FETCH_ORG_INFO_FROM_ENTITY_GQL", + "FETCH_REGISTRIES_GQL", + "FETCH_REGISTRY_GQL", + "GET_ARTIFACT_FILES_GQL", + "GET_ARTIFACT_FILE_URLS_GQL", + "GET_ARTIFACT_MEMBERSHIP_FILES_GQL", + "GET_ARTIFACT_MEMBERSHIP_FILE_URLS_GQL", + "LINK_ARTIFACT_GQL", + "PROJECT_ARTIFACTS_GQL", + "PROJECT_ARTIFACT_COLLECTIONS_GQL", + "PROJECT_ARTIFACT_COLLECTION_GQL", + "PROJECT_ARTIFACT_TYPES_GQL", + "PROJECT_ARTIFACT_TYPE_GQL", + "REGISTRY_COLLECTIONS_GQL", + "REGISTRY_TEAM_MEMBERS_GQL", + "REGISTRY_USER_MEMBERS_GQL", + "REGISTRY_VERSIONS_GQL", + "RENAME_REGISTRY_GQL", + "RUN_INPUT_ARTIFACTS_GQL", + "RUN_OUTPUT_ARTIFACTS_GQL", + "TYPE_INFO_GQL", + "UNLINK_ARTIFACT_GQL", + "UPDATE_ARTIFACT_GQL", + "UPDATE_ARTIFACT_PORTFOLIO_GQL", + "UPDATE_ARTIFACT_SEQUENCE_GQL", + "UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL", + "UPDATE_TEAM_REGISTRY_ROLE_GQL", + "UPDATE_USER_REGISTRY_ROLE_GQL", + "UPSERT_REGISTRY_GQL", +] + +DELETE_ARTIFACT_SEQUENCE_GQL = """ +mutation DeleteArtifactSequence($id: ID!) { + result: deleteArtifactSequence(input: {artifactSequenceID: $id}) { + artifactCollection { + __typename + state + } + } +} +""" + +DELETE_ARTIFACT_PORTFOLIO_GQL = """ +mutation DeleteArtifactPortfolio($id: ID!) { + result: deleteArtifactPortfolio(input: {artifactPortfolioID: $id}) { + artifactCollection { + __typename + state + } + } +} +""" + +UPDATE_ARTIFACT_SEQUENCE_GQL = """ +mutation UpdateArtifactSequence($input: UpdateArtifactSequenceInput!) { + result: updateArtifactSequence(input: $input) { + artifactCollection { + __typename + ...ArtifactCollectionFragment + } + } +} + +fragment ArtifactCollectionFragment on ArtifactCollection { + __typename + id + name + description + createdAt + project { + ...ProjectInfoFragment + } + type: defaultArtifactType { + name + } + tags { + edges { + node { + ...TagFragment + } + } + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +UPDATE_ARTIFACT_PORTFOLIO_GQL = """ +mutation UpdateArtifactPortfolio($input: UpdateArtifactPortfolioInput!) { + result: updateArtifactPortfolio(input: $input) { + artifactCollection { + __typename + ...ArtifactCollectionFragment + } + } +} + +fragment ArtifactCollectionFragment on ArtifactCollection { + __typename + id + name + description + createdAt + project { + ...ProjectInfoFragment + } + type: defaultArtifactType { + name + } + tags { + edges { + node { + ...TagFragment + } + } + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +UPDATE_ARTIFACT_SEQUENCE_TYPE_GQL = """ +mutation UpdateArtifactSequenceType($input: MoveArtifactSequenceInput!) { + result: moveArtifactSequence(input: $input) { + artifactCollection { + __typename + ...ArtifactCollectionFragment + } + } +} + +fragment ArtifactCollectionFragment on ArtifactCollection { + __typename + id + name + description + createdAt + project { + ...ProjectInfoFragment + } + type: defaultArtifactType { + name + } + tags { + edges { + node { + ...TagFragment + } + } + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +ADD_ARTIFACT_COLLECTION_TAGS_GQL = """ +mutation AddArtifactCollectionTags($input: CreateArtifactCollectionTagAssignmentsInput!) { + result: createArtifactCollectionTagAssignments(input: $input) { + tags { + ...TagFragment + } + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +DELETE_ARTIFACT_COLLECTION_TAGS_GQL = """ +mutation DeleteArtifactCollectionTags($input: DeleteArtifactCollectionTagAssignmentsInput!) { + result: deleteArtifactCollectionTagAssignments(input: $input) { + success + } +} +""" + +PROJECT_ARTIFACT_COLLECTIONS_GQL = """ +query ProjectArtifactCollections($entity: String!, $project: String!, $type: String!, $cursor: String, $perPage: Int) { + project(entityName: $entity, name: $project) { + artifactType(name: $type) { + artifactCollections(after: $cursor, first: $perPage) { + totalCount + pageInfo { + ...PageInfoFragment + } + edges { + node { + __typename + ...ArtifactCollectionFragment + } + } + } + } + } +} + +fragment ArtifactCollectionFragment on ArtifactCollection { + __typename + id + name + description + createdAt + project { + ...ProjectInfoFragment + } + type: defaultArtifactType { + name + } + tags { + edges { + node { + ...TagFragment + } + } + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +PROJECT_ARTIFACT_COLLECTION_GQL = """ +query ProjectArtifactCollection($entity: String!, $project: String!, $type: String!, $name: String!) { + project(entityName: $entity, name: $project) { + artifactType(name: $type) { + artifactCollection(name: $name) { + __typename + ...ArtifactCollectionFragment + } + } + } +} + +fragment ArtifactCollectionFragment on ArtifactCollection { + __typename + id + name + description + createdAt + project { + ...ProjectInfoFragment + } + type: defaultArtifactType { + name + } + tags { + edges { + node { + ...TagFragment + } + } + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +ARTIFACT_COLLECTION_ALIASES_GQL = """ +query ArtifactCollectionAliases($id: ID!, $cursor: String, $perPage: Int = 1000) { + artifactCollection(id: $id) { + __typename + aliases(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...ArtifactAliasFragment + } + } + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} +""" + +GET_ARTIFACT_FILES_GQL = """ +query GetArtifactFiles($entity: String!, $project: String!, $type: String!, $name: String!, $fileNames: [String!], $cursor: String, $perPage: Int = 50) { + project(name: $project, entityName: $entity) { + artifactType(name: $type) { + artifact(name: $name) { + files(names: $fileNames, after: $cursor, first: $perPage) { + totalCount @include(if: true) + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...FileFragment + } + } + } + } + } + } +} + +fragment FileFragment on File { + __typename + id + name: displayName + url + sizeBytes + storagePath + mimetype + updatedAt + digest + md5 + directUrl +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} +""" + +GET_ARTIFACT_MEMBERSHIP_FILES_GQL = """ +query GetArtifactMembershipFiles($entity: String!, $project: String!, $collection: String!, $alias: String!, $fileNames: [String!], $cursor: String, $perPage: Int = 50) { + project(name: $project, entityName: $entity) { + artifactCollection(name: $collection) { + __typename + artifactMembership(aliasName: $alias) { + files(names: $fileNames, after: $cursor, first: $perPage) { + totalCount @include(if: true) + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...FileFragment + } + } + } + } + } + } +} + +fragment FileFragment on File { + __typename + id + name: displayName + url + sizeBytes + storagePath + mimetype + updatedAt + digest + md5 + directUrl +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} +""" + +GET_ARTIFACT_FILE_URLS_GQL = """ +query GetArtifactFileUrls($id: ID!, $cursor: String, $perPage: Int) { + artifact(id: $id) { + files(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...FileWithUrlFragment + } + } + } + } +} + +fragment FileWithUrlFragment on File { + __typename + name + directUrl +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} +""" + +GET_ARTIFACT_MEMBERSHIP_FILE_URLS_GQL = """ +query GetArtifactMembershipFileUrls($entity: String!, $project: String!, $collection: String!, $alias: String!, $cursor: String, $perPage: Int) { + project(name: $project, entityName: $entity) { + artifactCollection(name: $collection) { + __typename + artifactMembership(aliasName: $alias) { + files(after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...FileWithUrlFragment + } + } + } + } + } + } +} + +fragment FileWithUrlFragment on File { + __typename + name + directUrl +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} +""" + +PROJECT_ARTIFACT_TYPES_GQL = """ +query ProjectArtifactTypes($entity: String!, $project: String!, $cursor: String, $perPage: Int) { + project(name: $project, entityName: $entity) { + artifactTypes(after: $cursor, first: $perPage) { + edges { + node { + ...ArtifactTypeFragment + } + } + pageInfo { + ...PageInfoFragment + } + } + } +} + +fragment ArtifactTypeFragment on ArtifactType { + __typename + id + name + description + createdAt +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} +""" + +PROJECT_ARTIFACT_TYPE_GQL = """ +query ProjectArtifactType($entity: String!, $project: String!, $type: String!) { + project(entityName: $entity, name: $project) { + artifactType(name: $type) { + ...ArtifactTypeFragment + } + } +} + +fragment ArtifactTypeFragment on ArtifactType { + __typename + id + name + description + createdAt +} +""" + +PROJECT_ARTIFACTS_GQL = """ +query ProjectArtifacts($entity: String!, $project: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString, $includeAliases: Boolean = true) { + project(entityName: $entity, name: $project) { + artifactType(name: $type) { + artifactCollection(name: $collection) { + __typename + artifacts(after: $cursor, first: $perPage, order: $order, filters: $filters) { + totalCount + pageInfo { + ...PageInfoFragment + } + edges { + version + node { + ...ArtifactFragment + } + } + } + } + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +RUN_OUTPUT_ARTIFACTS_GQL = """ +query RunOutputArtifacts($entity: String!, $project: String!, $run: String!, $cursor: String, $perPage: Int, $includeAliases: Boolean = true) { + project(entityName: $entity, name: $project) { + run(name: $run) { + artifacts: outputArtifacts(after: $cursor, first: $perPage) { + totalCount + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...ArtifactFragment + } + } + } + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +RUN_INPUT_ARTIFACTS_GQL = """ +query RunInputArtifacts($entity: String!, $project: String!, $run: String!, $cursor: String, $perPage: Int, $includeAliases: Boolean = true) { + project(entityName: $entity, name: $project) { + run(name: $run) { + artifacts: inputArtifacts(after: $cursor, first: $perPage) { + totalCount + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...ArtifactFragment + } + } + } + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +FETCH_LINKED_ARTIFACTS_GQL = """ +query FetchLinkedArtifacts($artifactID: ID!) { + artifact(id: $artifactID) { + artifactMemberships { + edges { + node { + versionIndex + aliases { + ...ArtifactAliasFragment + } + artifactCollection { + __typename + ...CollectionInfoFragment + } + } + } + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} +""" + +FETCH_ARTIFACT_MANIFEST_GQL = """ +query FetchArtifactManifest($id: ID!) { + artifact(id: $id) { + currentManifest { + ...DeferredManifestFragment + } + } +} + +fragment DeferredManifestFragment on ArtifactManifest { + file { + directUrl + } +} +""" + +ARTIFACT_BY_ID_GQL = """ +query ArtifactByID($id: ID!, $includeAliases: Boolean = true) { + artifact(id: $id) { + ...ArtifactFragment + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +ARTIFACT_BY_NAME_GQL = """ +query ArtifactByName($entity: String!, $project: String!, $name: String!, $enableTracking: Boolean, $includeAliases: Boolean = true) { + project(name: $project, entityName: $entity) { + artifact(name: $name, enableTracking: $enableTracking) { + ...ArtifactFragment + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +ARTIFACT_MEMBERSHIP_BY_NAME_GQL = """ +query ArtifactMembershipByName($entity: String!, $project: String!, $name: String!, $includeAliases: Boolean = false) { + project(name: $project, entityName: $entity) { + artifactCollectionMembership(name: $name) { + ...ArtifactMembershipFragment + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment ArtifactMembershipFragment on ArtifactCollectionMembership { + __typename + id + versionIndex + aliases { + ...ArtifactAliasFragment + } + artifactCollection { + ...CollectionInfoFragment + } + artifact { + ...ArtifactFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +ARTIFACT_USED_BY_GQL = """ +query ArtifactUsedBy($id: ID!) { + artifact(id: $id) { + usedBy { + edges { + node { + ...RunInfoFragment + } + } + } + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment RunInfoFragment on Run { + __typename + id + name + project { + ...ProjectInfoFragment + } +} +""" + +ARTIFACT_CREATED_BY_GQL = """ +query ArtifactCreatedBy($id: ID!) { + artifact(id: $id) { + createdBy { + __typename + ... on Run { + ...RunInfoFragment + } + } + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment RunInfoFragment on Run { + __typename + id + name + project { + ...ProjectInfoFragment + } +} +""" + +ARTIFACT_TYPE_GQL = """ +query ArtifactType($entity: String, $project: String, $name: String!) { + project(name: $project, entityName: $entity) { + artifact(name: $name) { + artifactType { + name + } + } + } +} +""" + +ADD_ALIASES_GQL = """ +mutation AddAliases($input: AddAliasesInput!) { + result: addAliases(input: $input) { + success + } +} +""" + +DELETE_ALIASES_GQL = """ +mutation DeleteAliases($input: DeleteAliasesInput!) { + result: deleteAliases(input: $input) { + success + } +} +""" + +UPDATE_ARTIFACT_GQL = """ +mutation UpdateArtifact($input: UpdateArtifactInput!, $includeAliases: Boolean = true) { + result: updateArtifact(input: $input) { + artifact { + ...ArtifactFragment + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +DELETE_ARTIFACT_GQL = """ +mutation DeleteArtifact($input: DeleteArtifactInput!) { + result: deleteArtifact(input: $input) { + artifact { + id + } + } +} +""" + +LINK_ARTIFACT_GQL = """ +mutation LinkArtifact($input: LinkArtifactInput!, $includeAliases: Boolean = true) { + result: linkArtifact(input: $input) { + versionIndex + artifactMembership @include(if: true) { + ...ArtifactMembershipFragment + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment ArtifactMembershipFragment on ArtifactCollectionMembership { + __typename + id + versionIndex + aliases { + ...ArtifactAliasFragment + } + artifactCollection { + ...CollectionInfoFragment + } + artifact { + ...ArtifactFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +UNLINK_ARTIFACT_GQL = """ +mutation UnlinkArtifact($input: UnlinkArtifactInput!) { + result: unlinkArtifact(input: $input) { + success + } +} +""" + +TYPE_INFO_GQL = """ +query TypeInfo($name: String!) { + __type(name: $name) { + ...TypeInfoFragment + } +} + +fragment TypeInfoFragment on __Type { + name + fields { + name + args { + name + } + } + inputFields { + name + } +} +""" + +FETCH_ORG_INFO_FROM_ENTITY_GQL = """ +query FetchOrgInfoFromEntity($entity: String!) { + entity(name: $entity) { + organization { + ...OrgInfoFragment + } + user { + organizations { + ...OrgInfoFragment + } + } + } +} + +fragment OrgInfoFragment on Organization { + name + orgEntity { + name + } +} +""" + +FETCH_ORG_ENTITY_FROM_ORGANIZATION_GQL = """ +query FetchOrgEntityFromOrganization($organization: String!) { + organization(name: $organization) { + orgEntity { + name + } + } +} +""" + +REGISTRY_VERSIONS_GQL = """ +query RegistryVersions($organization: String!, $registryFilter: JSONString, $collectionFilter: JSONString, $artifactFilter: JSONString, $cursor: String, $perPage: Int, $includeAliases: Boolean = false) { + organization(name: $organization) { + orgEntity { + name + artifactMemberships( + projectFilters: $registryFilter + collectionFilters: $collectionFilter + filters: $artifactFilter + after: $cursor + first: $perPage + ) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...ArtifactMembershipFragment + } + } + } + } + } +} + +fragment ArtifactAliasFragment on ArtifactAlias { + __typename + id + alias +} + +fragment ArtifactFragment on Artifact { + __typename + id + artifactSequence { + ...SourceCollectionInfoFragment + } + versionIndex + artifactType { + name + } + description + metadata + ttlDurationSeconds + ttlIsInherited + tags { + ...TagFragment + } + historyStep + state + size + digest + commitHash + fileCount + createdAt + updatedAt + aliases @include(if: $includeAliases) { + artifactCollection { + ...CollectionInfoFragment + } + ...ArtifactAliasFragment + } +} + +fragment ArtifactMembershipFragment on ArtifactCollectionMembership { + __typename + id + versionIndex + aliases { + ...ArtifactAliasFragment + } + artifactCollection { + ...CollectionInfoFragment + } + artifact { + ...ArtifactFragment + } +} + +fragment CollectionInfoFragment on ArtifactCollection { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment SourceCollectionInfoFragment on ArtifactSequence { + __typename + name + project { + ...ProjectInfoFragment + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +REGISTRY_COLLECTIONS_GQL = """ +query RegistryCollections($organization: String!, $registryFilter: JSONString, $collectionFilter: JSONString, $collectionTypes: [ArtifactCollectionType!] = [PORTFOLIO], $cursor: String, $perPage: Int) { + organization(name: $organization) { + orgEntity { + name + artifactCollections( + projectFilters: $registryFilter + filters: $collectionFilter + collectionTypes: $collectionTypes + after: $cursor + first: $perPage + ) { + totalCount + pageInfo { + ...PageInfoFragment + } + edges { + node { + __typename + ...RegistryCollectionFragment + } + } + } + } + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment ProjectInfoFragment on Project { + name + entity { + name + } +} + +fragment RegistryCollectionFragment on ArtifactCollection { + __typename + id + name + description + createdAt + project { + ...ProjectInfoFragment + } + type: defaultArtifactType { + name + } + tags { + edges { + node { + ...TagFragment + } + } + } +} + +fragment TagFragment on Tag { + __typename + id + name +} +""" + +FETCH_REGISTRY_GQL = """ +query FetchRegistry($name: String, $entity: String) { + entity(name: $entity) { + project(name: $name) { + ...RegistryFragment + } + } +} + +fragment RegistryFragment on Project { + __typename + id + name + entity { + name + organization { + name + } + } + description + createdAt + updatedAt + access + allowAllArtifactTypes: allowAllArtifactTypesInRegistry + artifactTypes(includeAll: true) { + edges { + node { + name + } + } + } +} +""" + +FETCH_REGISTRIES_GQL = """ +query FetchRegistries($organization: String!, $filters: JSONString, $cursor: String, $perPage: Int) { + organization(name: $organization) { + orgEntity { + projects(filters: $filters, after: $cursor, first: $perPage) { + pageInfo { + ...PageInfoFragment + } + edges { + node { + ...RegistryFragment + } + } + } + } + } +} + +fragment PageInfoFragment on PageInfo { + __typename + endCursor + hasNextPage +} + +fragment RegistryFragment on Project { + __typename + id + name + entity { + name + organization { + name + } + } + description + createdAt + updatedAt + access + allowAllArtifactTypes: allowAllArtifactTypesInRegistry + artifactTypes(includeAll: true) { + edges { + node { + name + } + } + } +} +""" + +RENAME_REGISTRY_GQL = """ +mutation RenameRegistry($input: RenameProjectInput!) { + renameProject(input: $input) { + inserted + project { + ...RegistryFragment + } + } +} + +fragment RegistryFragment on Project { + __typename + id + name + entity { + name + organization { + name + } + } + description + createdAt + updatedAt + access + allowAllArtifactTypes: allowAllArtifactTypesInRegistry + artifactTypes(includeAll: true) { + edges { + node { + name + } + } + } +} +""" + +UPSERT_REGISTRY_GQL = """ +mutation UpsertRegistry($input: UpsertModelInput!) { + upsertModel(input: $input) { + inserted + project { + ...RegistryFragment + } + } +} + +fragment RegistryFragment on Project { + __typename + id + name + entity { + name + organization { + name + } + } + description + createdAt + updatedAt + access + allowAllArtifactTypes: allowAllArtifactTypesInRegistry + artifactTypes(includeAll: true) { + edges { + node { + name + } + } + } +} +""" + +DELETE_REGISTRY_GQL = """ +mutation DeleteRegistry($id: String!) { + deleteModel(input: {id: $id}) { + success + } +} +""" + +REGISTRY_USER_MEMBERS_GQL = """ +query RegistryUserMembers($project: String!, $entity: String!) { + project(name: $project, entityName: $entity) { + members { + ...UserRegistryMemberFragment + } + } +} + +fragment RegistryRoleFragment on Role { + name +} + +fragment UserRegistryMemberFragment on ProjectMember { + id + name + username + email + role { + ...RegistryRoleFragment + } +} +""" + +REGISTRY_TEAM_MEMBERS_GQL = """ +query RegistryTeamMembers($project: String!, $entity: String!) { + project(name: $project, entityName: $entity) { + teamMembers { + ...TeamRegistryMemberFragment + } + } +} + +fragment RegistryRoleFragment on Role { + name +} + +fragment TeamFragment on Entity { + __typename + id + name + available + photoUrl + readOnly + readOnlyAdmin + isTeam + privateOnly + storageBytes + codeSavingEnabled + defaultAccess + isPaid + members { + ...TeamMemberFragment + } +} + +fragment TeamMemberFragment on Member { + __typename + id + role + pending + email + username + name + photoUrl + accountType + apiKey +} + +fragment TeamRegistryMemberFragment on ProjectTeamMember { + team { + ...TeamFragment + } + role { + ...RegistryRoleFragment + } +} +""" + +CREATE_REGISTRY_MEMBERS_GQL = """ +mutation CreateRegistryMembers($input: CreateProjectMembersInput!) { + result: createProjectMembers(input: $input) { + success + } +} +""" + +DELETE_REGISTRY_MEMBERS_GQL = """ +mutation DeleteRegistryMembers($input: DeleteProjectMembersInput!) { + result: deleteProjectMembers(input: $input) { + success + } +} +""" + +UPDATE_USER_REGISTRY_ROLE_GQL = """ +mutation UpdateUserRegistryRole($input: UpdateProjectMemberInput!) { + result: updateProjectMember(input: $input) { + success + } +} +""" + +UPDATE_TEAM_REGISTRY_ROLE_GQL = """ +mutation UpdateTeamRegistryRole($input: UpdateProjectTeamMemberInput!) { + result: updateProjectTeamMember(input: $input) { + success + } +} +""" diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/project_artifact_collection.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/project_artifact_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..56026db741a09d8f47f108782c7017b1d2102faf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/project_artifact_collection.py @@ -0,0 +1,33 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ArtifactCollectionFragment + + +class ProjectArtifactCollection(GQLResult): + project: Optional[ProjectArtifactCollectionProject] + + +class ProjectArtifactCollectionProject(GQLResult): + artifact_type: Optional[ProjectArtifactCollectionProjectArtifactType] = Field( + alias="artifactType" + ) + + +class ProjectArtifactCollectionProjectArtifactType(GQLResult): + artifact_collection: Optional[ArtifactCollectionFragment] = Field( + alias="artifactCollection" + ) + + +ProjectArtifactCollection.model_rebuild() +ProjectArtifactCollectionProject.model_rebuild() +ProjectArtifactCollectionProjectArtifactType.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/registry_team_members.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/registry_team_members.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd7f5d3705cb0cb15e6e8534ea9aeab8118d3e2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/registry_team_members.py @@ -0,0 +1,24 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import TeamRegistryMemberFragment + + +class RegistryTeamMembers(GQLResult): + project: Optional[RegistryTeamMembersProject] + + +class RegistryTeamMembersProject(GQLResult): + team_members: List[TeamRegistryMemberFragment] = Field(alias="teamMembers") + + +RegistryTeamMembers.model_rebuild() +RegistryTeamMembersProject.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/registry_versions.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/registry_versions.py new file mode 100644 index 0000000000000000000000000000000000000000..de1f11d7d91f26eaab08cdd7cc24da73091ff879 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/registry_versions.py @@ -0,0 +1,45 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ArtifactMembershipFragment, PageInfoFragment + + +class RegistryVersions(GQLResult): + organization: Optional[RegistryVersionsOrganization] + + +class RegistryVersionsOrganization(GQLResult): + org_entity: Optional[RegistryVersionsOrganizationOrgEntity] = Field( + alias="orgEntity" + ) + + +class RegistryVersionsOrganizationOrgEntity(GQLResult): + name: str + artifact_memberships: Optional[ + RegistryVersionsOrganizationOrgEntityArtifactMemberships + ] = Field(alias="artifactMemberships") + + +class RegistryVersionsOrganizationOrgEntityArtifactMemberships(GQLResult): + page_info: PageInfoFragment = Field(alias="pageInfo") + edges: List[RegistryVersionsOrganizationOrgEntityArtifactMembershipsEdges] + + +class RegistryVersionsOrganizationOrgEntityArtifactMembershipsEdges(GQLResult): + node: Optional[ArtifactMembershipFragment] + + +RegistryVersions.model_rebuild() +RegistryVersionsOrganization.model_rebuild() +RegistryVersionsOrganizationOrgEntity.model_rebuild() +RegistryVersionsOrganizationOrgEntityArtifactMemberships.model_rebuild() +RegistryVersionsOrganizationOrgEntityArtifactMembershipsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/rename_registry.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/rename_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..73c55860a559883b4959fd87f1ad43189eb8aa8a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/rename_registry.py @@ -0,0 +1,25 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import RegistryFragment + + +class RenameRegistry(GQLResult): + rename_project: Optional[RenameRegistryRenameProject] = Field(alias="renameProject") + + +class RenameRegistryRenameProject(GQLResult): + inserted: Optional[bool] + project: Optional[RegistryFragment] + + +RenameRegistry.model_rebuild() +RenameRegistryRenameProject.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/run_output_artifacts.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/run_output_artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..974cc30cbd9bd3908f86460bfb61820de9147d83 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/run_output_artifacts.py @@ -0,0 +1,41 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ArtifactFragment, PageInfoFragment + + +class RunOutputArtifacts(GQLResult): + project: Optional[RunOutputArtifactsProject] + + +class RunOutputArtifactsProject(GQLResult): + run: Optional[RunOutputArtifactsProjectRun] + + +class RunOutputArtifactsProjectRun(GQLResult): + artifacts: Optional[RunOutputArtifactsProjectRunArtifacts] + + +class RunOutputArtifactsProjectRunArtifacts(GQLResult): + total_count: int = Field(alias="totalCount") + page_info: PageInfoFragment = Field(alias="pageInfo") + edges: List[RunOutputArtifactsProjectRunArtifactsEdges] + + +class RunOutputArtifactsProjectRunArtifactsEdges(GQLResult): + node: Optional[ArtifactFragment] + + +RunOutputArtifacts.model_rebuild() +RunOutputArtifactsProject.model_rebuild() +RunOutputArtifactsProjectRun.model_rebuild() +RunOutputArtifactsProjectRunArtifacts.model_rebuild() +RunOutputArtifactsProjectRunArtifactsEdges.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/unlink_artifact.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/unlink_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..340482eb130662b076957b574b21b3f9cd5608b1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/unlink_artifact.py @@ -0,0 +1,19 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from wandb._pydantic import GQLResult + + +class UnlinkArtifact(GQLResult): + result: Optional[UnlinkArtifactResult] + + +class UnlinkArtifactResult(GQLResult): + success: bool + + +UnlinkArtifact.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/update_artifact_portfolio.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/update_artifact_portfolio.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3c5c3782eb8d6d99229a66f4d46c941e4d6488 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/update_artifact_portfolio.py @@ -0,0 +1,24 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ArtifactCollectionFragment + + +class UpdateArtifactPortfolio(GQLResult): + result: Optional[UpdateArtifactPortfolioResult] + + +class UpdateArtifactPortfolioResult(GQLResult): + artifact_collection: ArtifactCollectionFragment = Field(alias="artifactCollection") + + +UpdateArtifactPortfolio.model_rebuild() +UpdateArtifactPortfolioResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/update_artifact_sequence.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/update_artifact_sequence.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5a57c1d5e5399837dd4742e1e3dcd04f1285b7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/update_artifact_sequence.py @@ -0,0 +1,24 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import ArtifactCollectionFragment + + +class UpdateArtifactSequence(GQLResult): + result: Optional[UpdateArtifactSequenceResult] + + +class UpdateArtifactSequenceResult(GQLResult): + artifact_collection: ArtifactCollectionFragment = Field(alias="artifactCollection") + + +UpdateArtifactSequence.model_rebuild() +UpdateArtifactSequenceResult.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/upsert_registry.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/upsert_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..a1114f6c8b775bdbfb259c469d64f45537e85602 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_generated/upsert_registry.py @@ -0,0 +1,25 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/artifacts/ + +from __future__ import annotations + +from typing import Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + +from .fragments import RegistryFragment + + +class UpsertRegistry(GQLResult): + upsert_model: Optional[UpsertRegistryUpsertModel] = Field(alias="upsertModel") + + +class UpsertRegistryUpsertModel(GQLResult): + inserted: Optional[bool] + project: Optional[RegistryFragment] + + +UpsertRegistry.model_rebuild() +UpsertRegistryUpsertModel.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_gqlutils.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_gqlutils.py new file mode 100644 index 0000000000000000000000000000000000000000..016f4bb23dc79ddbcffdce0a17396b8601d8f039 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_gqlutils.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from contextlib import suppress +from dataclasses import dataclass +from functools import lru_cache +from typing import TYPE_CHECKING + +from wandb_gql import gql + +from wandb._iterutils import one +from wandb.proto.wandb_internal_pb2 import ServerFeature +from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery + +if TYPE_CHECKING: + from wandb.apis.public import RetryingClient + from wandb.sdk.artifacts._generated import TypeInfoFragment + from wandb.sdk.artifacts._generated.fetch_org_info_from_entity import ( + FetchOrgInfoFromEntityEntity, + ) + + +@lru_cache(maxsize=16) +def type_info(client: RetryingClient, typename: str) -> TypeInfoFragment | None: + """Returns the type info for a given GraphQL type.""" + from ._generated import TYPE_INFO_GQL, TypeInfo + + data = client.execute(gql(TYPE_INFO_GQL), variable_values={"name": typename}) + return TypeInfo.model_validate(data).type + + +@lru_cache(maxsize=16) +def org_info_from_entity( + client: RetryingClient, entity: str +) -> FetchOrgInfoFromEntityEntity | None: + """Returns the organization info for a given entity.""" + from ._generated import FETCH_ORG_INFO_FROM_ENTITY_GQL, FetchOrgInfoFromEntity + + gql_op = gql(FETCH_ORG_INFO_FROM_ENTITY_GQL) + data = client.execute(gql_op, variable_values={"entity": entity}) + return FetchOrgInfoFromEntity.model_validate(data).entity + + +@lru_cache(maxsize=16) +def server_features(client: RetryingClient) -> dict[str, bool]: + """Returns a mapping of `{server_feature_name (str) -> is_enabled (bool)}`. + + Results are cached per client instance. + """ + try: + response = client.execute(gql(SERVER_FEATURES_QUERY_GQL)) + except Exception as e: + # Unfortunately we currently have to match on the text of the error message, + # as the `gql` client raises `Exception` rather than a more specific error. + if 'Cannot query field "features" on type "ServerInfo".' in str(e): + return {} + raise + + result = ServerFeaturesQuery.model_validate(response) + if (server_info := result.server_info) and (features := server_info.features): + return {feat.name: feat.is_enabled for feat in features if feat} + return {} + + +def server_supports(client: RetryingClient, feature: str | int) -> bool: + """Return whether the current server supports the given feature. + + Good to use for features that have a fallback mechanism for older servers. + """ + # If we're given the protobuf enum value, convert to a string name. + # NOTE: We deliberately use names (str) instead of enum values (int) + # as the keys here, since: + # - the server identifies features by their name, rather than (client-side) enum value + # - the defined list of client-side flags may be behind the server-side list of flags + try: + name = ServerFeature.Name(feature) if isinstance(feature, int) else feature + except ValueError: + return False # Invalid int-like value, assume unsupported + return server_features(client).get(name) or False + + +def allowed_fields(client: RetryingClient, typename: str) -> set[str]: + """Returns the allowed field names for a given GraphQL type.""" + typ = type_info(client, typename) + return {f.name for f in typ.fields} if (typ and typ.fields) else set() + + +@dataclass(frozen=True) +class OrgInfo: + org_name: str + entity_name: str + + def __contains__(self, other: str) -> bool: + return other in {self.org_name, self.entity_name} + + +def resolve_org_entity_name( + client: RetryingClient, + non_org_entity: str | None, + org_or_entity: str | None = None, +) -> str: + # Resolve the portfolio's org entity name. + # + # The `org_or_org_entity` parameter may be empty, an org display name, or an + # org entity name. + # + # If the server cannot fetch the portfolio's org name, return the provided + # value or raise an error if it is empty. Otherwise, return the fetched + # value after validating that the given organization, if provided, matches + # either the display or entity name. + if not non_org_entity: + raise ValueError("Entity name is required to resolve org entity name.") + + # Fetch candidate orgs to verify or identify the correct orgEntity name. + entity = org_info_from_entity(client, non_org_entity) + + # Parse possible organization(s) from the response... + # ---------------------------------------------------------------------------- + # If a team entity was provided, a single organization should exist under + # the team/org entity type. + if entity and (org := entity.organization) and (org_entity := org.org_entity): + # Ensure the provided name, if given, matches the org or org entity name before + # returning the org entity. + org_info = OrgInfo(org_name=org.name, entity_name=org_entity.name) + if (not org_or_entity) or (org_or_entity in org_info): + return org_entity.name + + # ---------------------------------------------------------------------------- + # If a personal entity was provided, the user may belong to multiple + # organizations. + if entity and (user := entity.user) and (orgs := user.organizations): + org_infos = [ + OrgInfo(org_name=org.name, entity_name=org_entity.name) + for org in orgs + if (org_entity := org.org_entity) + ] + if org_or_entity: + with suppress(StopIteration): + return next( + info.entity_name for info in org_infos if (org_or_entity in info) + ) + + if len(org_infos) == 1: + raise ValueError( + f"Expecting the organization name or entity name to match {org_infos[0].org_name!r} " + f"and cannot be linked/fetched with {org_or_entity!r}. " + "Please update the target path with the correct organization name." + ) + else: + raise ValueError( + "Personal entity belongs to multiple organizations " + f"and cannot be linked/fetched with {org_or_entity!r}. " + "Please update the target path with the correct organization name " + "or use a team entity in the entity settings." + ) + + else: + # If no input organization provided, error if entity belongs to: + # - multiple orgs, because we cannot determine which one to use. + # - no orgs, because there's nothing to use. + return one( + (org.entity_name for org in org_infos), + too_short=ValueError( + f"Unable to resolve an organization associated with personal entity: {non_org_entity!r}. " + "This could be because its a personal entity that doesn't belong to any organizations. " + "Please specify the organization in the Registry path or use a team entity in the entity settings." + ), + too_long=ValueError( + f"Personal entity {non_org_entity!r} belongs to multiple organizations " + "and cannot be used without specifying the organization name. " + "Please specify the organization in the Registry path or use a team entity in the entity settings." + ), + ) + + raise ValueError(f"Unable to find organization for entity {non_org_entity!r}.") diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_internal_artifact.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_internal_artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..fabd09c35c21adaa9b9f185964fee2a9dd409f51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_internal_artifact.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import re +from base64 import urlsafe_b64encode +from typing import Any, Final +from zlib import crc32 + +from wandb.sdk.artifacts.artifact import Artifact + +PLACEHOLDER: Final[str] = "PLACEHOLDER" + + +def sanitize_artifact_name(name: str) -> str: + """Sanitize the string to satisfy constraints on artifact names.""" + # If the name is already sanitized, don't change it. + if (sanitized := re.sub(r"[^a-zA-Z0-9_\-.]+", "", name)) == name: + return name + + # Append a short alphanumeric suffix to maintain uniqueness. + # Yes, CRC is meant for checksums and not as a general hash function, but + # a 32-bit CRC hash, encoded as (url-safe) base64, is fairly short while + # providing 4B+ possible values, which should be good enough for the corner + # case names this function is meant to address. + # + # As implemented, the final suffix should be 6 characters. + crc: int = crc32(name.encode("utf-8")) & 0xFFFFFFFF # Ensure it's unsigned + crc_bytes = crc.to_bytes(4, byteorder="big") + suffix = urlsafe_b64encode(crc_bytes).rstrip(b"=").decode("ascii") + + return f"{sanitized}-{suffix}" + + +class InternalArtifact(Artifact): + """An Artifact intended for internal use only. + + Includes artifacts of type `job`, `code` (with a `source-` collection name + prefix), `run_table` (with a `run-` collection name prefix), and any type that starts + with `wandb-`. Users should not use this class directly. + """ + + def __init__( + self, + name: str, + type: str, + description: str | None = None, + metadata: dict[str, Any] | None = None, + incremental: bool = False, + use_as: str | None = None, + ) -> None: + sanitized_name = sanitize_artifact_name(name) + super().__init__( + sanitized_name, PLACEHOLDER, description, metadata, incremental, use_as + ) + self._type = type diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_validators.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_validators.py new file mode 100644 index 0000000000000000000000000000000000000000..54657ff2e94c8971619000f1ee14a3a050bf1981 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/_validators.py @@ -0,0 +1,348 @@ +"""Internal validation utilities that are specific to artifacts.""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass, field, replace +from functools import singledispatch, wraps +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, TypeVar + +from pydantic.dataclasses import dataclass as pydantic_dataclass +from typing_extensions import Concatenate, ParamSpec, Self + +from wandb._iterutils import always_list, unique_list +from wandb._pydantic import from_json +from wandb._strutils import nameof, removeprefix +from wandb.util import json_friendly_val + +from .exceptions import ArtifactFinalizedError, ArtifactNotLoggedError + +if TYPE_CHECKING: + from typing import Final, Iterable + + from wandb.sdk.artifacts.artifact import Artifact + +ArtifactT = TypeVar("ArtifactT", bound="Artifact") +SelfT = TypeVar("SelfT") +R = TypeVar("R") +P = ParamSpec("P") + +REGISTRY_PREFIX: Final[str] = "wandb-registry-" +MAX_ARTIFACT_METADATA_KEYS: Final[int] = 100 + +NAME_MAXLEN: Final[int] = 128 + +INVALID_ARTIFACT_NAME_CHARS: Final[frozenset[str]] = frozenset("/") +INVALID_URL_CHARS: Final[frozenset[str]] = frozenset("/\\#?%:\r\n") +ARTIFACT_SEP_CHARS: Final[frozenset[str]] = frozenset("/:") + + +@dataclass +class LinkArtifactFields: + """Keep this list updated with fields where linked and source artifacts differ.""" + + entity_name: str + project_name: str + name: str + version: str + aliases: list[str] + + # These fields shouldn't be user-editable, linked artifacts always have these values + _is_link: Literal[True] = field(init=False, default=True) + _linked_artifacts: list[Artifact] = field(init=False, default_factory=list) + + @property + def is_link(self) -> bool: + return self._is_link + + @property + def linked_artifacts(self) -> list[Artifact]: + return self._linked_artifacts + + +def validate_artifact_name(name: str) -> str: + """Validate the artifact name, returning it if successful. + + Raises: + ValueError: If the artifact name is invalid. + """ + if len(name) > NAME_MAXLEN: + trunc_name = f"{name[:NAME_MAXLEN]} ..." + raise ValueError( + f"Artifact name is longer than {NAME_MAXLEN!r} characters: {trunc_name!r}" + ) + + if INVALID_ARTIFACT_NAME_CHARS.intersection(name): + raise ValueError( + "Artifact names must not contain any of the following characters: " + f"{', '.join(sorted(INVALID_ARTIFACT_NAME_CHARS))}. Got: {name!r}" + ) + + return name + + +def validate_project_name(name: str) -> str: + """Validate a project name according to W&B rules. + + Return the original name if successful. + + Args: + name: The project name string. + + Raises: + ValueError: If the name is invalid (too long or contains invalid characters). + """ + if not name: + raise ValueError("Project name cannot be empty") + if not (registry_name := removeprefix(name, REGISTRY_PREFIX)): + raise ValueError("Registry name cannot be empty") + + if len(name) > NAME_MAXLEN: + if registry_name != name: + msg = f"Invalid registry name {registry_name!r}, must be {NAME_MAXLEN - len(REGISTRY_PREFIX)!r} characters or less" + else: + msg = f"Invalid project name {name!r}, must be {NAME_MAXLEN!r} characters or less" + raise ValueError(msg) + + # Find the first occurrence of any invalid character + if invalid_chars := set(INVALID_URL_CHARS).intersection(name): + error_name = registry_name or name + invalid_chars_repr = ", ".join(sorted(map(repr, invalid_chars))) + raise ValueError( + f"Invalid project/registry name {error_name!r}, cannot contain characters: {invalid_chars_repr!s}" + ) + return name + + +def validate_aliases(aliases: Iterable[str] | str) -> list[str]: + """Validate the artifact aliases and return them as a list. + + Raises: + ValueError: If any of the aliases contain invalid characters. + """ + aliases_list = always_list(aliases) + if any(ARTIFACT_SEP_CHARS.intersection(name) for name in aliases_list): + invalid_chars = ", ".join(sorted(map(repr, ARTIFACT_SEP_CHARS))) + raise ValueError( + f"Aliases must not contain any of the following characters: {invalid_chars}" + ) + return aliases_list + + +def validate_artifact_types(types: Iterable[str] | str) -> list[str]: + """Validate the artifact type names and return them as a list.""" + types_list = always_list(types) + if any(ARTIFACT_SEP_CHARS.intersection(name) for name in types_list): + invalid_chars = ", ".join(sorted(map(repr, ARTIFACT_SEP_CHARS))) + raise ValueError( + f"Artifact types must not contain any of the following characters: {invalid_chars}" + ) + if any(len(name) > NAME_MAXLEN for name in types_list): + raise ValueError( + f"Artifact types must be less than or equal to {NAME_MAXLEN!r} characters" + ) + return types_list + + +TAG_REGEX: re.Pattern[str] = re.compile(r"^[-\w]+( +[-\w]+)*$") +"""Regex pattern for valid tag names.""" + + +def validate_tags(tags: Iterable[str] | str) -> list[str]: + """Validate artifact tag names and return them as a deduped list. + + In the case of duplicates, keep the first tag and maintain the order of + appearance. + + Raises: + ValueError: If any of the tags contain invalid characters. + """ + tags_list = unique_list(always_list(tags)) + if any(not TAG_REGEX.match(tag) for tag in tags_list): + raise ValueError( + "Invalid tag(s). " + "Tags must only contain alphanumeric characters separated by hyphens, underscores, and/or spaces." + ) + return tags_list + + +RESERVED_ARTIFACT_TYPE_PREFIX: Final[str] = "wandb-" +"""Internal, reserved artifact type prefix.""" + +RESERVED_ARTIFACT_NAME_PREFIX_BY_TYPE: Final[dict[str, str]] = { + "job": "", # Empty prefix means ALL artifact names are reserved for this artifact type + "run_table": "run-", + "code": "source-", +} +"""Lookup of internal, reserved `Artifact.name` prefixes by `Artifact.type`.""" + + +def validate_artifact_type(typ: str, name: str) -> str: + """Validate the artifact type and return it as a string.""" + if ( + # Check if the artifact name is disallowed, based on the artifact type + ( + # This check MUST be against `None`, since "" disallows ALL artifact names + (bad_prefix := RESERVED_ARTIFACT_NAME_PREFIX_BY_TYPE.get(typ)) is not None + and name.startswith(bad_prefix) + ) + or + # Check if the artifact type is disallowed + typ.startswith(RESERVED_ARTIFACT_TYPE_PREFIX) + ): + raise ValueError( + f"Artifact type {typ!r} is reserved for internal use. " + "Please use a different type." + ) + return typ + + +@singledispatch +def validate_metadata(metadata: dict[str, Any] | str | None) -> dict[str, Any]: + """Validate the artifact metadata and return it as a dict.""" + raise TypeError(f"Cannot parse {type(metadata)} as artifact metadata") + + +@validate_metadata.register(type(None)) +@validate_metadata.register(str) +def _(metadata: str | None) -> dict[str, Any]: + return validate_metadata(from_json(metadata)) if metadata else {} + + +@validate_metadata.register(dict) +def _(metadata: dict[str, Any]) -> dict[str, Any]: + # NOTE: The backend doesn't currently allow JS-compatible `+/-Infinity` values. + # Forbid them here to avoid surprises, but revisit if we add future backend support. + # Note that prior behavior already converts `NaN` values to `None` (client-side). + metadata = from_json(json.dumps(json_friendly_val(metadata), allow_nan=False)) + if len(metadata) > MAX_ARTIFACT_METADATA_KEYS: + raise ValueError( + f"Artifact must not have more than {MAX_ARTIFACT_METADATA_KEYS!r} metadata keys." + ) + return metadata + + +def validate_ttl_duration_seconds(ttl_duration_seconds: int) -> int | None: + """Validate the `ttlDurationSeconds` value from a GraphQL response. + + A non-positive value indicates that TTL is DISABLED (-2), which we + convert to `None`. + """ + return ttl_duration_seconds if ttl_duration_seconds > 0 else None + + +# ---------------------------------------------------------------------------- +MethodT = Callable[Concatenate[SelfT, P], R] +"""Generic type hint for an instance method, e.g. for use with decorators.""" + + +def ensure_logged(method: MethodT[ArtifactT, P, R]) -> MethodT[ArtifactT, P, R]: + """Ensure an artifact method runs only if the artifact has been logged. + + If the method is called on an artifact that's not logged, `ArtifactNotLoggedError` + is raised. + """ + # For clarity, use the qualified (full) name of the method + method_fullname = nameof(method) + + @wraps(method) + def wrapper(self: ArtifactT, *args: P.args, **kwargs: P.kwargs) -> R: + if self.is_draft(): + raise ArtifactNotLoggedError(fullname=method_fullname, obj=self) + return method(self, *args, **kwargs) + + return wrapper + + +def ensure_not_finalized(method: MethodT[ArtifactT, P, R]) -> MethodT[ArtifactT, P, R]: + """Ensure an `Artifact` method runs only if the artifact is not finalized. + + If the method is called on an artifact that's not logged, `ArtifactFinalizedError` + is raised. + """ + # For clarity, use the qualified (full) name of the method + method_fullname = nameof(method) + + @wraps(method) + def wrapper(self: ArtifactT, *args: P.args, **kwargs: P.kwargs) -> R: + if self._final: + raise ArtifactFinalizedError(fullname=method_fullname, obj=self) + return method(self, *args, **kwargs) + + return wrapper + + +def is_artifact_registry_project(project: str) -> bool: + return project.startswith(REGISTRY_PREFIX) + + +def remove_registry_prefix(project: str) -> str: + if not is_artifact_registry_project(project): + raise ValueError( + f"Project {project!r} is not a registry project. Must start with: {REGISTRY_PREFIX!r}" + ) + return removeprefix(project, REGISTRY_PREFIX) + + +@pydantic_dataclass +class ArtifactPath: + name: str + """The collection or artifact version name.""" + project: Optional[str] = None # noqa: UP045 + """The project name.""" + prefix: Optional[str] = None # noqa: UP045 + """Typically the entity or org name.""" + + @classmethod + def from_str(cls, path: str) -> Self: + """Instantiate by parsing a string artifact path. + + Raises: + ValueError: If the string is not a valid artifact path. + """ + # Separate the alias first, which may itself contain slashes. + # If there's no alias, note that both sep and alias will be empty. + collection_path, sep, alias = path.partition(":") + + prefix, project = None, None # defaults, if missing + if len(parts := collection_path.split("/")) == 1: + name = parts[0] + elif len(parts) == 2: + project, name = parts + elif len(parts) == 3: + prefix, project, name = parts + else: + raise ValueError(f"Invalid artifact path: {path!r}") + return cls(prefix=prefix, project=project, name=f"{name}{sep}{alias}") + + def to_str(self) -> str: + """Returns the slash-separated string representation of the path.""" + ordered_parts = (self.prefix, self.project, self.name) + return "/".join(part for part in ordered_parts if part) + + def with_defaults( + self, + *, + prefix: str | None = None, + project: str | None = None, + ) -> Self: + """Returns a copy of this path with missing values set to the given defaults.""" + return replace( + self, + prefix=self.prefix or prefix, + project=self.project or project, + ) + + def is_registry_path(self) -> bool: + """Returns True if this path appears to be a registry path.""" + return bool((p := self.project) and is_artifact_registry_project(p)) + + +@pydantic_dataclass +class FullArtifactPath(ArtifactPath): + """Same as ArtifactPath, but with all parts required.""" + + name: str + project: str + prefix: str diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact.py new file mode 100644 index 0000000000000000000000000000000000000000..add4191677a8a6d926aa923f1b77f647ea3238cb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact.py @@ -0,0 +1,2722 @@ +"""Artifact class.""" + +from __future__ import annotations + +import atexit +import contextlib +import json +import logging +import multiprocessing.dummy +import os +import re +import shutil +import stat +import tempfile +import time +from collections import deque +from concurrent.futures import Executor, ThreadPoolExecutor, as_completed +from copy import copy +from dataclasses import asdict, replace +from datetime import timedelta +from itertools import filterfalse +from pathlib import Path, PurePosixPath +from typing import ( + IO, + TYPE_CHECKING, + Any, + Callable, + Final, + Iterator, + Literal, + Sequence, + Type, +) +from urllib.parse import quote, urljoin, urlparse + +from pydantic import NonNegativeInt + +import wandb +from wandb import data_types, env +from wandb._iterutils import one, unique_list +from wandb._pydantic import from_json +from wandb._strutils import nameof +from wandb.apis.normalize import normalize_exceptions +from wandb.apis.public import ArtifactCollection, ArtifactFiles, Run +from wandb.apis.public.utils import gql_compat +from wandb.data_types import WBValue +from wandb.errors import CommError +from wandb.errors.errors import UnsupportedError +from wandb.errors.term import termerror, termlog, termwarn +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk import wandb_setup +from wandb.sdk.data_types._dtypes import Type as WBType +from wandb.sdk.data_types._dtypes import TypeRegistry +from wandb.sdk.lib import retry, telemetry +from wandb.sdk.lib.deprecation import warn_and_record_deprecation +from wandb.sdk.lib.filesystem import check_exists, system_preferred_path +from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64 +from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr +from wandb.sdk.lib.runid import generate_fast_id, generate_id +from wandb.sdk.mailbox import MailboxHandle +from wandb.util import ( + alias_is_version_index, + artifact_to_json, + fsync_open, + json_dumps_safer, + uri_from_path, + vendor_setup, +) + +from ._factories import make_storage_policy +from ._gqlutils import org_info_from_entity, resolve_org_entity_name, server_supports +from ._validators import ensure_logged, ensure_not_finalized +from .artifact_download_logger import ArtifactDownloadLogger +from .artifact_instance_cache import ( + artifact_instance_cache, + artifact_instance_cache_by_client_id, +) +from .artifact_manifest import ArtifactManifest +from .artifact_manifest_entry import ArtifactManifestEntry +from .artifact_manifests.artifact_manifest_v1 import ArtifactManifestV1 +from .artifact_state import ArtifactState +from .artifact_ttl import ArtifactTTL +from .exceptions import ( + ArtifactNotLoggedError, + TooFewItemsError, + TooManyItemsError, + WaitTimeoutError, +) +from .staging import get_staging_dir +from .storage_handlers.gcs_handler import _GCSIsADirectoryError +from .storage_policies._factories import make_http_session +from .storage_policies._multipart import should_multipart_download + +reset_path = vendor_setup() + +from wandb_gql import gql # noqa: E402 + +reset_path() + +if TYPE_CHECKING: + from typing import Iterable + + from wandb.apis.public import RetryingClient + + from ._generated import ArtifactFragment, ArtifactMembershipFragment + from ._models.pagination import FileWithUrlConnection + from ._validators import FullArtifactPath, LinkArtifactFields + +logger = logging.getLogger(__name__) + + +_MB: Final[int] = 1024 * 1024 + + +class Artifact: + """Flexible and lightweight building block for dataset and model versioning. + + Construct an empty W&B Artifact. Populate an artifacts contents with methods that + begin with `add`. Once the artifact has all the desired files, you can call + `run.log_artifact()` to log it. + + Args: + name (str): A human-readable name for the artifact. Use the name to identify + a specific artifact in the W&B App UI or programmatically. You can + interactively reference an artifact with the `use_artifact` Public API. + A name can contain letters, numbers, underscores, hyphens, and dots. + The name must be unique across a project. + type (str): The artifact's type. Use the type of an artifact to both organize + and differentiate artifacts. You can use any string that contains letters, + numbers, underscores, hyphens, and dots. Common types include `dataset` or + `model`. Include `model` within your type string if you want to link the + artifact to the W&B Model Registry. + Note that some types reserved for internal use and cannot be set by users. + Such types include `job` and types that start with `wandb-`. + description (str | None) = None: A description of the artifact. For Model or + Dataset Artifacts, add documentation for your standardized team model or + dataset card. View an artifact's description programmatically with the + `Artifact.description` attribute or programmatically with the W&B App UI. + W&B renders the description as markdown in the W&B App. + metadata (dict[str, Any] | None) = None: Additional information about an artifact. + Specify metadata as a dictionary of key-value pairs. You can specify no more + than 100 total keys. + incremental: Use `Artifact.new_draft()` method instead to modify an + existing artifact. + use_as: Deprecated. + + Returns: + An `Artifact` object. + """ + + _TMP_DIR = tempfile.TemporaryDirectory("wandb-artifacts") + atexit.register(_TMP_DIR.cleanup) + + def __init__( + self, + name: str, + type: str, + description: str | None = None, + metadata: dict[str, Any] | None = None, + incremental: bool = False, + use_as: str | None = None, + storage_region: str | None = None, + ) -> None: + from wandb.sdk.artifacts._internal_artifact import InternalArtifact + + from ._validators import ( + validate_artifact_name, + validate_artifact_type, + validate_metadata, + ) + + if not re.match(r"^[a-zA-Z0-9_\-.]+$", name): + raise ValueError( + f"Artifact name may only contain alphanumeric characters, dashes, " + f"underscores, and dots. Invalid name: {name!r}" + ) + + if incremental and not isinstance(self, InternalArtifact): + termwarn("Using experimental arg `incremental`") + + # Internal. + self._client: RetryingClient | None = None + + self._tmp_dir: tempfile.TemporaryDirectory | None = None + self._added_objs: dict[int, tuple[WBValue, ArtifactManifestEntry]] = {} + self._added_local_paths: dict[str, ArtifactManifestEntry] = {} + self._save_handle: MailboxHandle[pb.Result] | None = None + self._download_roots: set[str] = set() + # Set by new_draft(), otherwise the latest artifact will be used as the base. + self._base_id: str | None = None + # Properties. + self._id: str | None = None + + # Client IDs don't need cryptographic strength, so use a faster implementation. + self._client_id: str = generate_fast_id(128) + self._sequence_client_id: str = generate_fast_id(128) + + self._entity: str | None = None + self._project: str | None = None + self._name: str = validate_artifact_name(name) # includes version after saving + self._version: str | None = None + self._source_entity: str | None = None + self._source_project: str | None = None + self._source_name: str = name # includes version after saving + self._source_version: str | None = None + self._source_artifact: Artifact | None = None + self._is_link: bool = False + self._type: str = validate_artifact_type(type, name) + self._description: str | None = description + self._metadata: dict[str, Any] = validate_metadata(metadata) + self._ttl_duration_seconds: int | None = None + self._ttl_is_inherited: bool = True + self._ttl_changed: bool = False + self._aliases: list[str] = [] + self._saved_aliases: list[str] = [] + self._tags: list[str] = [] + self._saved_tags: list[str] = [] + self._distributed_id: str | None = None + self._incremental: bool = incremental + if use_as is not None: + warn_and_record_deprecation( + feature=Deprecated(artifact__init_use_as=True), + message=( + "`use_as` argument is deprecated and does not affect the behaviour of `wandb.Artifact()`" + ), + ) + self._use_as: str | None = None + self._state: ArtifactState = ArtifactState.PENDING + + # NOTE: These fields only reflect the last fetched response from the + # server, if any. If the ArtifactManifest has already been fetched and/or + # populated locally, it should take priority when determining these values. + self._size: NonNegativeInt | None = None + self._digest: str | None = None + + self._manifest: ArtifactManifest | None = ArtifactManifestV1( + storage_policy=make_storage_policy(region=storage_region) + ) + + self._commit_hash: str | None = None + self._file_count: int | None = None + self._created_at: str | None = None + self._updated_at: str | None = None + self._final: bool = False + self._history_step: int | None = None + self._linked_artifacts: list[Artifact] = [] + + self._fetch_file_urls_decorated: Callable[..., Any] | None = None + + # Cache. + artifact_instance_cache_by_client_id[self._client_id] = self + + def __repr__(self) -> str: + return f"" + + @classmethod + def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None: + from ._generated import ARTIFACT_BY_ID_GQL, ArtifactByID + from ._validators import FullArtifactPath + + if cached_artifact := artifact_instance_cache.get(artifact_id): + return cached_artifact + + gql_op = gql(ARTIFACT_BY_ID_GQL) + + data = client.execute(gql_op, variable_values={"id": artifact_id}) + result = ArtifactByID.model_validate(data) + + if (artifact := result.artifact) is None: + return None + + src_collection = artifact.artifact_sequence + src_project = src_collection.project + + entity_name = src_project.entity.name if src_project else "" + project_name = src_project.name if src_project else "" + + name = f"{src_collection.name}:v{artifact.version_index}" + + path = FullArtifactPath(prefix=entity_name, project=project_name, name=name) + return cls._from_attrs(path, artifact, client) + + @classmethod + def _membership_from_name( + cls, *, path: FullArtifactPath, client: RetryingClient + ) -> Artifact: + from ._generated import ( + ARTIFACT_MEMBERSHIP_BY_NAME_GQL, + ArtifactMembershipByName, + ) + + if not server_supports(client, pb.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP): + raise UnsupportedError( + "Querying for the artifact collection membership is not supported " + "by this version of wandb server. Consider updating to the latest version." + ) + + gql_op = gql(ARTIFACT_MEMBERSHIP_BY_NAME_GQL) + gql_vars = {"entity": path.prefix, "project": path.project, "name": path.name} + data = client.execute(gql_op, variable_values=gql_vars) + result = ArtifactMembershipByName.model_validate(data) + + if not (project := result.project): + msg = f"project {path.project!r} not found under entity {path.prefix!r}" + raise ValueError(msg) + + if not (membership := project.artifact_collection_membership): + entity_project = f"{path.prefix}/{path.project}" + msg = f"artifact membership {path.name!r} not found in {entity_project!r}" + raise ValueError(msg) + + return cls._from_membership(membership, target=path, client=client) + + @classmethod + def _from_name( + cls, + *, + path: FullArtifactPath, + client: RetryingClient, + enable_tracking: bool = False, + ) -> Artifact: + from ._generated import ARTIFACT_BY_NAME_GQL, ArtifactByName + + if server_supports(client, pb.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP): + return cls._membership_from_name(path=path, client=client) + + gql_vars = { + "entity": path.prefix, + "project": path.project, + "name": path.name, + "enableTracking": enable_tracking, + } + gql_op = gql(ARTIFACT_BY_NAME_GQL) + data = client.execute(gql_op, variable_values=gql_vars) + result = ArtifactByName.model_validate(data) + + if not (project := result.project): + msg = f"project {path.project!r} not found in entity {path.prefix!r}" + raise ValueError(msg) + + if not (artifact := project.artifact): + entity_project = f"{path.prefix}/{path.project}" + msg = f"artifact {path.name!r} not found in {entity_project!r}" + raise ValueError(msg) + + return cls._from_attrs(path, artifact, client) + + @classmethod + def _from_membership( + cls, + membership: ArtifactMembershipFragment, + target: FullArtifactPath, + client: RetryingClient, + ) -> Artifact: + from ._validators import is_artifact_registry_project + + if not ( + (collection := membership.artifact_collection) + and (name := collection.name) + and (proj := collection.project) + ): + raise ValueError("Missing artifact collection project in GraphQL response") + + if is_artifact_registry_project(proj.name) and ( + target.project == "model-registry" + ): + wandb.termwarn( + "This model registry has been migrated and will be discontinued. " + f"Your request was redirected to the corresponding artifact {name!r} in the new registry. " + f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'." + ) + + # Update the target path to use the actual project/entity names returned in the + # response, in case they differ from the original target path + # E.g. uppercase vs lowercase, migrated legacy model registry, etc. + new_target = replace(target, prefix=proj.entity.name, project=proj.name) + + if not (artifact := membership.artifact): + raise ValueError(f"Artifact {target.to_str()!r} not found in response") + + return cls._from_attrs(new_target, artifact, client, membership=membership) + + @classmethod + def _from_attrs( + cls, + path: FullArtifactPath, + src_art: ArtifactFragment, + client: RetryingClient, + *, + # aliases/version_index are taken from the membership, if given + membership: ArtifactMembershipFragment | None = None, + ) -> Artifact: + # Placeholder is required to skip validation. + artifact = cls("placeholder", type="placeholder") + artifact._client = client + artifact._entity = path.prefix + artifact._project = path.project + artifact._name = path.name + + artifact._assign_attrs(src_art, membership=membership) + + artifact.finalize() + + # Cache. + assert artifact.id is not None + artifact_instance_cache[artifact.id] = artifact + return artifact + + # TODO: Eventually factor out is_link. Have to currently use it since some forms of fetching the artifact + # doesn't make it clear if the artifact is a link or not and have to manually set it. + def _assign_attrs( + self, + src_art: ArtifactFragment, + *, + # aliases/version_index are taken from the membership, if given + membership: ArtifactMembershipFragment | None = None, + is_link: bool | None = None, + ) -> None: + """Update this Artifact's attributes using the server response.""" + from ._validators import validate_metadata, validate_ttl_duration_seconds + + self._id = src_art.id + + src_collection = src_art.artifact_sequence + src_project = src_collection.project + + self._source_entity = src_project.entity.name if src_project else "" + self._source_project = src_project.name if src_project else "" + self._source_name = f"{src_collection.name}:v{src_art.version_index}" + self._source_version = f"v{src_art.version_index}" + + self._entity = self._entity or self._source_entity + self._project = self._project or self._source_project + self._name = self._name or self._source_name + + # TODO: Refactor artifact query to fetch artifact via membership instead + # and get the collection type + if is_link is None: + self._is_link = ( + self._entity != self._source_entity + or self._project != self._source_project + or self._name.split(":")[0] != self._source_name.split(":")[0] + ) + else: + self._is_link = is_link + + self._type = src_art.artifact_type.name + self._description = src_art.description + + # The future of aliases is to move all alias fetches to the membership level + # so we don't have to do the collection fetches below + if membership: + aliases = [a.alias for a in membership.aliases] + elif src_art.aliases: + entity = self._entity + project = self._project + collection = self._name.split(":")[0] + aliases = [ + a.alias + for a in src_art.aliases + if ( + (alias_coll := a.artifact_collection) + and (alias_proj := alias_coll.project) + and alias_proj.entity.name == entity + and alias_proj.name == project + and alias_coll.name == collection + ) + ] + else: + aliases = [] + + version_aliases = list(filter(alias_is_version_index, aliases)) + other_aliases = list(filterfalse(alias_is_version_index, aliases)) + + try: + version = one( + version_aliases, too_short=TooFewItemsError, too_long=TooManyItemsError + ) + except TooFewItemsError: + # default to the membership version if passed to this method, + # otherwise fallback to the source version + if membership and (m_version_index := membership.version_index) is not None: + version = f"v{m_version_index}" + else: + version = f"v{src_art.version_index}" + except TooManyItemsError: + msg = f"Expected at most one version alias, got {len(version_aliases)}: {version_aliases!r}" + raise ValueError(msg) from None + + self._version = version + self._name = self._name if (":" in self._name) else f"{self._name}:{version}" + + self._aliases = copy(other_aliases) + self._saved_aliases = copy(other_aliases) + + self._tags = [tag.name for tag in src_art.tags] + self._saved_tags = copy(self._tags) + + self._metadata = validate_metadata(src_art.metadata) + + self._ttl_duration_seconds = validate_ttl_duration_seconds( + src_art.ttl_duration_seconds + ) + self._ttl_is_inherited = src_art.ttl_is_inherited + + self._state = ArtifactState(src_art.state) + self._size = src_art.size + self._digest = src_art.digest + + self._manifest = None + + self._commit_hash = src_art.commit_hash + self._file_count = src_art.file_count + self._created_at = src_art.created_at + self._updated_at = src_art.updated_at + self._history_step = src_art.history_step + + @ensure_logged + def new_draft(self) -> Artifact: + """Create a new draft artifact with the same content as this committed artifact. + + Modifying an existing artifact creates a new artifact version known + as an "incremental artifact". The artifact returned can be extended or + modified and logged as a new version. + + Returns: + An `Artifact` object. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + # Name, _entity and _project are set to the *source* name/entity/project: + # if this artifact is saved it must be saved to the source sequence. + artifact = Artifact(self.source_name.split(":")[0], self.type) + artifact._entity = self._source_entity + artifact._project = self._source_project + artifact._source_entity = self._source_entity + artifact._source_project = self._source_project + + # This artifact's parent is the one we are making a draft from. + artifact._base_id = self.id + + # We can reuse the client, and copy over all the attributes that aren't + # version-dependent and don't depend on having been logged. + artifact._client = self._client + artifact._description = self.description + artifact._metadata = self.metadata + artifact._manifest = ArtifactManifest.from_manifest_json( + self.manifest.to_manifest_json() + ) + return artifact + + # Properties (Python Class managed attributes). + + @property + def id(self) -> str | None: + """The artifact's ID.""" + if self.is_draft(): + return None + assert self._id is not None + return self._id + + @property + @ensure_logged + def entity(self) -> str: + """The name of the entity that the artifact collection belongs to. + + If the artifact is a link, the entity will be the entity of the linked artifact. + """ + assert self._entity is not None + return self._entity + + @property + @ensure_logged + def project(self) -> str: + """The name of the project that the artifact collection belongs to. + + If the artifact is a link, the project will be the project of the linked artifact. + """ + assert self._project is not None + return self._project + + @property + def name(self) -> str: + """The artifact name and version of the artifact. + + A string with the format `{collection}:{alias}`. If fetched before an artifact is + logged/saved, the name won't contain the alias. + If the artifact is a link, the name will be the name of the linked artifact. + """ + return self._name + + @property + def qualified_name(self) -> str: + """The entity/project/name of the artifact. + + If the artifact is a link, the qualified name will be the qualified name of the + linked artifact path. + """ + return f"{self.entity}/{self.project}/{self.name}" + + @property + @ensure_logged + def version(self) -> str: + """The artifact's version. + + A string with the format `v{number}`. + If this is a link artifact, the version will be from the linked collection. + """ + assert self._version is not None + return self._version + + @property + @ensure_logged + def collection(self) -> ArtifactCollection: + """The collection this artifact is retrieved from. + + A collection is an ordered group of artifact versions. + If this artifact is retrieved from a collection that it is linked to, + return that collection. Otherwise, return the collection + that the artifact version originates from. + + The collection that an artifact originates from is known as + the source sequence. + """ + if (client := self._client) is None: + raise RuntimeError("Client not initialized") + base_name = self.name.split(":")[0] + return ArtifactCollection( + client, self.entity, self.project, base_name, self.type + ) + + @property + @ensure_logged + def source_entity(self) -> str: + """The name of the entity of the source artifact.""" + assert self._source_entity is not None + return self._source_entity + + @property + @ensure_logged + def source_project(self) -> str: + """The name of the project of the source artifact.""" + assert self._source_project is not None + return self._source_project + + @property + def source_name(self) -> str: + """The artifact name and version of the source artifact. + + A string with the format `{source_collection}:{alias}`. Before the artifact + is saved, contains only the name since the version is not yet known. + """ + return self._source_name + + @property + def source_qualified_name(self) -> str: + """The source_entity/source_project/source_name of the source artifact.""" + return f"{self.source_entity}/{self.source_project}/{self.source_name}" + + @property + @ensure_logged + def source_version(self) -> str: + """The source artifact's version. + + A string with the format `v{number}`. + """ + assert self._source_version is not None + return self._source_version + + @property + @ensure_logged + def source_collection(self) -> ArtifactCollection: + """The artifact's source collection. + + The source collection is the collection that the artifact was logged from. + """ + if (client := self._client) is None: + raise RuntimeError("Client not initialized") + base_name = self.source_name.split(":")[0] + return ArtifactCollection( + client, self.source_entity, self.source_project, base_name, self.type + ) + + @property + def is_link(self) -> bool: + """Boolean flag indicating if the artifact is a link artifact. + + True: The artifact is a link artifact to a source artifact. + False: The artifact is a source artifact. + """ + return self._is_link + + @property + @ensure_logged + def linked_artifacts(self) -> list[Artifact]: + """Returns a list of all the linked artifacts of a source artifact. + + If this artifact is a link artifact (`artifact.is_link == True`), + it will return an empty list. + + Limited to 500 results. + """ + if not self.is_link: + self._linked_artifacts = self._fetch_linked_artifacts() + return self._linked_artifacts + + @property + @ensure_logged + def source_artifact(self) -> Artifact: + """Returns the source artifact, which is the original logged artifact. + + If this artifact is a source artifact (`artifact.is_link == False`), + it will return itself. + """ + from ._validators import FullArtifactPath + + if not self.is_link: + return self + if self._source_artifact is None: + if (client := self._client) is None: + raise ValueError("Client is not initialized") + + try: + path = FullArtifactPath( + prefix=self.source_entity, + project=self.source_project, + name=self.source_name, + ) + self._source_artifact = self._from_name(path=path, client=client) + except Exception as e: + raise ValueError( + f"Unable to fetch source artifact for linked artifact {self.name}" + ) from e + return self._source_artifact + + @property + def type(self) -> str: + """The artifact's type. Common types include `dataset` or `model`.""" + return self._type + + @property + @ensure_logged + def url(self) -> str: + """ + Constructs the URL of the artifact. + + Returns: + str: The URL of the artifact. + """ + from ._validators import is_artifact_registry_project + + try: + base_url = self._client.app_url # type: ignore[union-attr] + except AttributeError: + return "" + + if not self.is_link: + return self._construct_standard_url(base_url) + if is_artifact_registry_project(self.project): + return self._construct_registry_url(base_url) + if self._type == "model" or self.project == "model-registry": + return self._construct_model_registry_url(base_url) + return self._construct_standard_url(base_url) + + def _construct_standard_url(self, base_url: str) -> str: + if not all( + [ + base_url, + self.entity, + self.project, + self._type, + self.collection.name, + self._version, + ] + ): + return "" + return urljoin( + base_url, + f"{self.entity}/{self.project}/artifacts/{quote(self._type)}/{quote(self.collection.name)}/{self._version}", + ) + + def _construct_registry_url(self, base_url: str) -> str: + from ._validators import remove_registry_prefix + + if not all( + [ + base_url, + self.entity, + self.project, + self.collection.name, + self._version, + ] + ): + return "" + + try: + org_name = org_info_from_entity(self._client, self.entity).organization.name # type: ignore[union-attr] + except (AttributeError, ValueError): + return "" + + selection_path = quote( + f"{self.entity}/{self.project}/{self.collection.name}", safe="" + ) + return urljoin( + base_url, + f"orgs/{org_name}/registry/{remove_registry_prefix(self.project)}?selectionPath={selection_path}&view=membership&version={self.version}", + ) + + def _construct_model_registry_url(self, base_url: str) -> str: + if not all( + [ + base_url, + self.entity, + self.project, + self.collection.name, + self._version, + ] + ): + return "" + selection_path = quote( + f"{self.entity}/{self.project}/{self.collection.name}", safe="" + ) + return urljoin( + base_url, + f"{self.entity}/registry/model?selectionPath={selection_path}&view=membership&version={self._version}", + ) + + @property + def description(self) -> str | None: + """A description of the artifact.""" + return self._description + + @description.setter + def description(self, description: str | None) -> None: + """Set the description of the artifact. + + For model or dataset Artifacts, add documentation for your + standardized team model or dataset card. In the W&B UI the + description is rendered as markdown. + + Editing the description will apply the changes to the source artifact + and all linked artifacts associated with it. + + Args: + description: Free text that offers a description of the artifact. + """ + if self.is_link: + wandb.termwarn( + "Editing the description of this linked artifact will edit the description for the source artifact and it's linked artifacts as well." + ) + self._description = description + + @property + def metadata(self) -> dict: + """User-defined artifact metadata. + + Structured data associated with the artifact. + """ + return self._metadata + + @metadata.setter + def metadata(self, metadata: dict) -> None: + """User-defined artifact metadata. + + Metadata set this way will eventually be queryable and plottable in the UI; e.g. + the class distribution of a dataset. + + Note: There is currently a limit of 100 total keys. + Editing the metadata will apply the changes to the source artifact + and all linked artifacts associated with it. + + Args: + metadata: Structured data associated with the artifact. + """ + from ._validators import validate_metadata + + if self.is_link: + wandb.termwarn( + "Editing the metadata of this linked artifact will edit the metadata for the source artifact and it's linked artifacts as well." + ) + self._metadata = validate_metadata(metadata) + + @property + def ttl(self) -> timedelta | None: + """The time-to-live (TTL) policy of an artifact. + + Artifacts are deleted shortly after a TTL policy's duration passes. + If set to `None`, the artifact deactivates TTL policies and will be not + scheduled for deletion, even if there is a team default TTL. + An artifact inherits a TTL policy from + the team default if the team administrator defines a default + TTL and there is no custom policy set on an artifact. + + Raises: + ArtifactNotLoggedError: Unable to fetch inherited TTL if the + artifact has not been logged or saved. + """ + if self._ttl_is_inherited and (self.is_draft() or self._ttl_changed): + raise ArtifactNotLoggedError(f"{nameof(type(self))}.ttl", self) + if self._ttl_duration_seconds is None: + return None + return timedelta(seconds=self._ttl_duration_seconds) + + @ttl.setter + def ttl(self, ttl: timedelta | ArtifactTTL | None) -> None: + """The time-to-live (TTL) policy of an artifact. + + Artifacts are deleted shortly after a TTL policy's duration passes. + If set to `None`, the artifact has no TTL policy set and it is not + scheduled for deletion. An artifact inherits a TTL policy from + the team default if the team administrator defines a default + TTL and there is no custom policy set on an artifact. + + Args: + ttl: The duration as a positive `datetime.timedelta` that represents + how long the artifact will remain active from its creation. + + """ + if self.type == "wandb-history": + raise ValueError("Cannot set artifact TTL for type wandb-history") + + if self.is_link: + raise ValueError( + "Cannot set TTL for link artifact. " + "Unlink the artifact first then set the TTL for the source artifact" + ) + + self._ttl_changed = True + if isinstance(ttl, ArtifactTTL): + if ttl == ArtifactTTL.INHERIT: + self._ttl_is_inherited = True + else: + raise ValueError(f"Unhandled ArtifactTTL enum {ttl}") + else: + self._ttl_is_inherited = False + if ttl is None: + self._ttl_duration_seconds = None + else: + if ttl.total_seconds() <= 0: + raise ValueError( + f"Artifact TTL Duration has to be positive. ttl: {ttl.total_seconds()}" + ) + self._ttl_duration_seconds = int(ttl.total_seconds()) + + @property + @ensure_logged + def aliases(self) -> list[str]: + """List of one or more semantically-friendly references or + + identifying "nicknames" assigned to an artifact version. + + Aliases are mutable references that you can programmatically reference. + Change an artifact's alias with the W&B App UI or programmatically. + See [Create new artifact versions](https://docs.wandb.ai/guides/artifacts/create-a-new-artifact-version) + for more information. + """ + return self._aliases + + @aliases.setter + @ensure_logged + def aliases(self, aliases: list[str]) -> None: + """Set the aliases associated with this artifact.""" + from ._validators import validate_aliases + + self._aliases = validate_aliases(aliases) + + @property + @ensure_logged + def tags(self) -> list[str]: + """List of one or more tags assigned to this artifact version.""" + return self._tags + + @tags.setter + @ensure_logged + def tags(self, tags: list[str]) -> None: + """Set the tags associated with this artifact. + + Editing tags will apply the changes to the source artifact + and all linked artifacts associated with it. + """ + from ._validators import validate_tags + + if self.is_link: + wandb.termwarn( + "Editing tags will apply the changes to the source artifact and all linked artifacts associated with it." + ) + self._tags = validate_tags(tags) + + @property + def distributed_id(self) -> str | None: + """The distributed ID of the artifact. + + + """ + return self._distributed_id + + @distributed_id.setter + def distributed_id(self, distributed_id: str | None) -> None: + self._distributed_id = distributed_id + + @property + def incremental(self) -> bool: + """Boolean flag indicating if the artifact is an incremental artifact. + + + """ + return self._incremental + + @property + def use_as(self) -> str | None: + """Deprecated.""" + warn_and_record_deprecation( + feature=Deprecated(artifact__use_as=True), + message=("The use_as property of Artifact is deprecated."), + ) + return self._use_as + + @property + def state(self) -> str: + """The status of the artifact. One of: "PENDING", "COMMITTED", or "DELETED".""" + return self._state.value + + @property + def manifest(self) -> ArtifactManifest: + """The artifact's manifest. + + The manifest lists all of its contents, and can't be changed once the artifact + has been logged. + """ + if self._manifest is None: + self._manifest = self._fetch_manifest() + return self._manifest + + def _fetch_manifest(self) -> ArtifactManifest: + """Fetch, parse, and load the full ArtifactManifest.""" + from ._generated import FETCH_ARTIFACT_MANIFEST_GQL, FetchArtifactManifest + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact queries") + + # From the GraphQL API, get the (expiring) directUrl for downloading the manifest. + gql_op = gql(FETCH_ARTIFACT_MANIFEST_GQL) + gql_vars = {"id": self.id} + data = client.execute(gql_op, variable_values=gql_vars) + result = FetchArtifactManifest.model_validate(data) + + # Now fetch the actual manifest contents from the directUrl. + if (artifact := result.artifact) and (manifest := artifact.current_manifest): + # Create a short lived session instead of using requests.get() + # because make_http_session() adds http headers from env vars. + # Artifact manifest json is also downloaded from object storage + # using presigned urls like artifact files, which requires adding + # extra http headers when user specifies them in env vars. + # + # FIXME: For successive/repeated calls to `manifest`, figure out + # how to reuse a single `requests.Session` within the constraints + # of the current API. Creating a new session for _each_ fetch is + # wasteful and introduces noticeable perf overhead when e.g. + # downloading many artifacts sequentially or concurrently. The + # storage policy's session is also not reused across different + # artifacts. + with make_http_session() as session: + response = session.get(manifest.file.direct_url) + return ArtifactManifest.from_manifest_json(from_json(response.content)) + + raise ValueError("Failed to fetch artifact manifest") + + @property + def digest(self) -> str: + """The logical digest of the artifact. + + The digest is the checksum of the artifact's contents. If an artifact has the + same digest as the current `latest` version, then `log_artifact` is a no-op. + """ + # Use the last fetched value of `Artifact.digest` ONLY if present AND the manifest + # has not been fetched and/or populated locally. + # Otherwise, use the manifest directly to recalculate the digest, as its contents + # may have been locally modified. + return ( + self._digest + if (self._manifest is None) and (self._digest is not None) + else self.manifest.digest() + ) + + @property + def size(self) -> int: + """The total size of the artifact in bytes. + + Includes any references tracked by this artifact. + """ + # Use the last fetched value of `Artifact.size` ONLY if present AND the manifest + # has not been fetched and/or populated locally. + # Otherwise, use the manifest directly to recalculate the size, as its contents + # may have been locally modified. + # + # NOTE on choice of GQL field: `Artifact.size` counts references, while + # `Artifact.storageBytes` does not. + return ( + self._size + if (self._manifest is None) and (self._size is not None) + else self.manifest.size() + ) + + @property + @ensure_logged + def commit_hash(self) -> str: + """The hash returned when this artifact was committed.""" + assert self._commit_hash is not None + return self._commit_hash + + @property + @ensure_logged + def file_count(self) -> int: + """The number of files (including references).""" + assert self._file_count is not None + return self._file_count + + @property + @ensure_logged + def created_at(self) -> str: + """Timestamp when the artifact was created.""" + assert self._created_at is not None + return self._created_at + + @property + @ensure_logged + def updated_at(self) -> str: + """The time when the artifact was last updated.""" + assert self._created_at is not None + return self._updated_at or self._created_at + + @property + @ensure_logged + def history_step(self) -> int | None: + """The nearest step which logged history metrics for this artifact's source run. + + Examples: + ```python + run = artifact.logged_by() + if run and (artifact.history_step is not None): + history = run.sample_history( + min_step=artifact.history_step, + max_step=artifact.history_step + 1, + keys=["my_metric"], + ) + ``` + """ + if self._history_step is None: + return None + return max(0, self._history_step - 1) + + # State management. + + def finalize(self) -> None: + """Finalize the artifact version. + + You cannot modify an artifact version once it is finalized because the artifact + is logged as a specific artifact version. Create a new artifact version + to log more data to an artifact. An artifact is automatically finalized + when you log the artifact with `log_artifact`. + """ + self._final = True + + def is_draft(self) -> bool: + """Check if artifact is not saved. + + Returns: + Boolean. `False` if artifact is saved. `True` if artifact is not saved. + """ + return self._state is ArtifactState.PENDING + + def _is_draft_save_started(self) -> bool: + return self._save_handle is not None + + def save( + self, + project: str | None = None, + settings: wandb.Settings | None = None, + ) -> None: + """Persist any changes made to the artifact. + + If currently in a run, that run will log this artifact. If not currently in a + run, a run of type "auto" is created to track this artifact. + + Args: + project: A project to use for the artifact in the case that a run is not + already in context. + settings: A settings object to use when initializing an automatic run. Most + commonly used in testing harness. + """ + if self._state is not ArtifactState.PENDING: + return self._update() + + if self._incremental: + with telemetry.context() as tel: + tel.feature.artifact_incremental = True + + if run := wandb_setup.singleton().most_recent_active_run: + # TODO: Deprecate and encourage explicit log_artifact(). + run.log_artifact(self) + else: + if settings is None: + settings = wandb.Settings(silent="true") + with wandb.init( # type: ignore + entity=self._source_entity, + project=project or self._source_project, + job_type="auto", + settings=settings, + ) as run: + # redoing this here because in this branch we know we didn't + # have the run at the beginning of the method + if self._incremental: + with telemetry.context(run=run) as tel: + tel.feature.artifact_incremental = True + run.log_artifact(self) + + def _set_save_handle( + self, + save_handle: MailboxHandle[pb.Result], + client: RetryingClient, + ) -> None: + self._save_handle = save_handle + self._client = client + + def wait(self, timeout: int | None = None) -> Artifact: + """If needed, wait for this artifact to finish logging. + + Args: + timeout: The time, in seconds, to wait. + + Returns: + An `Artifact` object. + """ + if self.is_draft(): + if self._save_handle is None: + raise ArtifactNotLoggedError(nameof(self.wait), self) + + try: + result = self._save_handle.wait_or(timeout=timeout) + except TimeoutError as e: + raise WaitTimeoutError( + "Artifact upload wait timed out, failed to fetch Artifact response" + ) from e + + response = result.response.log_artifact_response + if response.error_message: + raise ValueError(response.error_message) + self._populate_after_save(response.artifact_id) + return self + + def _populate_after_save(self, artifact_id: str) -> None: + from ._generated import ARTIFACT_BY_ID_GQL, ArtifactByID + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact queries") + + gql_op = gql(ARTIFACT_BY_ID_GQL) + data = client.execute(gql_op, variable_values={"id": artifact_id}) + result = ArtifactByID.model_validate(data) + + if not (artifact := result.artifact): + raise ValueError(f"Unable to fetch artifact with id: {artifact_id!r}") + + # _populate_after_save is only called on source artifacts, not linked artifacts + # We have to manually set is_link because we aren't fetching the collection + # the artifact. That requires greater refactoring for commitArtifact to return + # the artifact collection type. + self._assign_attrs(artifact, is_link=False) + + @normalize_exceptions + def _update(self) -> None: + """Persists artifact changes to the wandb backend.""" + from ._generated import UPDATE_ARTIFACT_GQL, UpdateArtifact, UpdateArtifactInput + from ._validators import FullArtifactPath, validate_tags + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact mutations") + + entity, project, collection = self.entity, self.project, self.name.split(":")[0] + + old_aliases, new_aliases = set(self._saved_aliases), set(self.aliases) + target = FullArtifactPath(prefix=entity, project=project, name=collection) + self._add_aliases(new_aliases - old_aliases, target=target) + self._delete_aliases(old_aliases - new_aliases, target=target) + self._saved_aliases = copy(self.aliases) + + old_tags, new_tags = set(self._saved_tags), set(self.tags) + + gql_op = gql(UPDATE_ARTIFACT_GQL) + gql_input = UpdateArtifactInput( + artifact_id=self.id, + description=self.description, + metadata=json_dumps_safer(self.metadata), + ttl_duration_seconds=self._ttl_duration_seconds_to_gql(), + tags_to_add=[{"tagName": t} for t in validate_tags(new_tags - old_tags)], + tags_to_delete=[{"tagName": t} for t in validate_tags(old_tags - new_tags)], + ) + gql_vars = {"input": gql_input.model_dump()} + data = client.execute(gql_op, variable_values=gql_vars) + + result = UpdateArtifact.model_validate(data).result + if not (result and (artifact := result.artifact)): + raise ValueError("Unable to parse updateArtifact response") + self._assign_attrs(artifact) + + self._ttl_changed = False # Reset after updating artifact + + def _add_aliases(self, alias_names: set[str], target: FullArtifactPath) -> None: + from ._generated import ADD_ALIASES_GQL, AddAliasesInput + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact mutations") + + # If there aren't any aliases to add, we can skip the GraphQL call. + if alias_names: + target_props = { + "entityName": target.prefix, + "projectName": target.project, + "artifactCollectionName": target.name, + } + alias_inputs = [{**target_props, "alias": name} for name in alias_names] + gql_op = gql(ADD_ALIASES_GQL) + gql_input = AddAliasesInput(artifact_id=self.id, aliases=alias_inputs) + gql_vars = {"input": gql_input.model_dump()} + try: + client.execute(gql_op, variable_values=gql_vars) + except CommError as e: + msg = ( + "You do not have permission to add" + f" {'at least one of the following aliases' if len(alias_names) > 1 else 'the following alias'}" + f" to this artifact: {alias_names!r}" + ) + raise CommError(msg) from e + + def _delete_aliases(self, alias_names: set[str], target: FullArtifactPath) -> None: + from ._generated import DELETE_ALIASES_GQL, DeleteAliasesInput + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact mutations") + + # If there aren't any aliases to delete, we can skip the GraphQL call. + if alias_names: + target_props = { + "entityName": target.prefix, + "projectName": target.project, + "artifactCollectionName": target.name, + } + alias_inputs = [{**target_props, "alias": name} for name in alias_names] + gql_op = gql(DELETE_ALIASES_GQL) + gql_input = DeleteAliasesInput(artifact_id=self.id, aliases=alias_inputs) + gql_vars = {"input": gql_input.model_dump()} + try: + client.execute(gql_op, variable_values=gql_vars) + except CommError as e: + msg = ( + f"You do not have permission to delete" + f" {'at least one of the following aliases' if len(alias_names) > 1 else 'the following alias'}" + f" from this artifact: {alias_names!r}" + ) + raise CommError(msg) from e + + # Adding, removing, getting entries. + + def __getitem__(self, name: str) -> WBValue | None: + """Get the WBValue object located at the artifact relative `name`. + + Args: + name: The artifact relative name to get. + + Returns: + W&B object that can be logged with `run.log()` and visualized in the W&B UI. + + Raises: + ArtifactNotLoggedError: If the artifact isn't logged or the run is offline. + """ + return self.get(name) + + def __setitem__(self, name: str, item: WBValue) -> ArtifactManifestEntry: + """Add `item` to the artifact at path `name`. + + Args: + name: The path within the artifact to add the object. + item: The object to add. + + Returns: + The added manifest entry + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + """ + return self.add(item, name) + + @contextlib.contextmanager + @ensure_not_finalized + def new_file( + self, name: str, mode: str = "x", encoding: str | None = None + ) -> Iterator[IO]: + """Open a new temporary file and add it to the artifact. + + Args: + name: The name of the new file to add to the artifact. + mode: The file access mode to use to open the new file. + encoding: The encoding used to open the new file. + + Returns: + A new file object that can be written to. Upon closing, the file + is automatically added to the artifact. + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + """ + overwrite: bool = "x" not in mode + + if self._tmp_dir is None: + self._tmp_dir = tempfile.TemporaryDirectory() + path = os.path.join(self._tmp_dir.name, name.lstrip("/")) + + Path(path).parent.mkdir(parents=True, exist_ok=True) + try: + with fsync_open(path, mode, encoding) as f: + yield f + except FileExistsError: + raise ValueError(f"File with name {name!r} already exists at {path!r}") + except UnicodeEncodeError as e: + termerror( + f"Failed to open the provided file ({nameof(type(e))}: {e}). Please " + f"provide the proper encoding." + ) + raise + + self.add_file( + path, name=name, policy="immutable", skip_cache=True, overwrite=overwrite + ) + + @ensure_not_finalized + def add_file( + self, + local_path: str, + name: str | None = None, + is_tmp: bool | None = False, + skip_cache: bool | None = False, + policy: Literal["mutable", "immutable"] | None = "mutable", + overwrite: bool = False, + ) -> ArtifactManifestEntry: + """Add a local file to the artifact. + + Args: + local_path: The path to the file being added. + name: The path within the artifact to use for the file being added. + Defaults to the basename of the file. + is_tmp: If true, then the file is renamed deterministically to avoid + collisions. + skip_cache: If `True`, do not copy files to the cache + after uploading. + policy: By default, set to "mutable". If set to "mutable", + create a temporary copy of the file to prevent corruption + during upload. If set to "immutable", disable + protection and rely on the user not to delete or change the + file. + overwrite: If `True`, overwrite the file if it already exists. + + Returns: + The added manifest entry. + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + ValueError: Policy must be "mutable" or "immutable" + """ + if not os.path.isfile(local_path): + raise ValueError(f"Path is not a file: {local_path!r}") + + name = LogicalPath(name or os.path.basename(local_path)) + digest = md5_file_b64(local_path) + + if is_tmp: + file_path, file_name = os.path.split(name) + file_name_parts = file_name.split(".") + file_name_parts[0] = b64_to_hex_id(digest)[:20] + name = os.path.join(file_path, ".".join(file_name_parts)) + + return self._add_local_file( + name, + local_path, + digest=digest, + skip_cache=skip_cache, + policy=policy, + overwrite=overwrite, + ) + + @ensure_not_finalized + def add_dir( + self, + local_path: str, + name: str | None = None, + skip_cache: bool | None = False, + policy: Literal["mutable", "immutable"] | None = "mutable", + merge: bool = False, + ) -> None: + """Add a local directory to the artifact. + + Args: + local_path: The path of the local directory. + name: The subdirectory name within an artifact. The name you + specify appears in the W&B App UI nested by artifact's `type`. + Defaults to the root of the artifact. + skip_cache: If set to `True`, W&B will not copy/move files to + the cache while uploading + policy: By default, "mutable". + - mutable: Create a temporary copy of the file to prevent + corruption during upload. + - immutable: Disable protection, rely on the user not to delete + or change the file. + merge: If `False` (default), throws ValueError if a file was already added + in a previous add_dir call and its content has changed. If `True`, + overwrites existing files with changed content. Always adds new files + and never removes files. To replace an entire directory, pass a name + when adding the directory using `add_dir(local_path, name=my_prefix)` + and call `remove(my_prefix)` to remove the directory, then add it again. + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + ValueError: Policy must be "mutable" or "immutable" + """ + if not os.path.isdir(local_path): + raise ValueError(f"Path is not a directory: {local_path!r}") + + termlog( + f"Adding directory to artifact ({Path('.', local_path)})... ", + newline=False, + ) + start_time = time.monotonic() + + paths: deque[tuple[str, str]] = deque() + logical_root = name or "" # shared prefix, if any, for logical paths + for dirpath, _, filenames in os.walk(local_path, followlinks=True): + for fname in filenames: + physical_path = os.path.join(dirpath, fname) + logical_path = os.path.relpath(physical_path, start=local_path) + logical_path = os.path.join(logical_root, logical_path) + paths.append((logical_path, physical_path)) + + def add_manifest_file(logical_pth: str, physical_pth: str) -> None: + self._add_local_file( + name=logical_pth, + path=physical_pth, + skip_cache=skip_cache, + policy=policy, + overwrite=merge, + ) + + num_threads = 8 + pool = multiprocessing.dummy.Pool(num_threads) + pool.starmap(add_manifest_file, paths) + pool.close() + pool.join() + + termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False) + + @ensure_not_finalized + def add_reference( + self, + uri: ArtifactManifestEntry | str, + name: StrPath | None = None, + checksum: bool = True, + max_objects: int | None = None, + ) -> Sequence[ArtifactManifestEntry]: + """Add a reference denoted by a URI to the artifact. + + Unlike files or directories that you add to an artifact, references are not + uploaded to W&B. For more information, + see [Track external files](https://docs.wandb.ai/guides/artifacts/track-external-files). + + By default, the following schemes are supported: + + - http(s): The size and digest of the file will be inferred by the + `Content-Length` and the `ETag` response headers returned by the server. + - s3: The checksum and size are pulled from the object metadata. + If bucket versioning is enabled, then the version ID is also tracked. + - gs: The checksum and size are pulled from the object metadata. If bucket + versioning is enabled, then the version ID is also tracked. + - https, domain matching `*.blob.core.windows.net` + - Azure: The checksum and size are be pulled from the blob metadata. + If storage account versioning is enabled, then the version ID is + also tracked. + - file: The checksum and size are pulled from the file system. This scheme + is useful if you have an NFS share or other externally mounted volume + containing files you wish to track but not necessarily upload. + + For any other scheme, the digest is just a hash of the URI and the size is left + blank. + + Args: + uri: The URI path of the reference to add. The URI path can be an object + returned from `Artifact.get_entry` to store a reference to another + artifact's entry. + name: The path within the artifact to place the contents of this reference. + checksum: Whether or not to checksum the resource(s) located at the + reference URI. Checksumming is strongly recommended as it enables + automatic integrity validation. Disabling checksumming will speed up + artifact creation but reference directories will not iterated through so + the objects in the directory will not be saved to the artifact. + We recommend setting `checksum=False` when adding reference objects, + in which case a new version will only be created if the reference URI + changes. + max_objects: The maximum number of objects to consider when adding a + reference that points to directory or bucket store prefix. + By default, the maximum number of objects allowed for Amazon S3, + GCS, Azure, and local files is 10,000,000. Other URI schemas + do not have a maximum. + + Returns: + The added manifest entries. + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + """ + if name is not None: + name = LogicalPath(name) + + # This is a bit of a hack, we want to check if the uri is a of the type + # ArtifactManifestEntry. If so, then recover the reference URL. + if isinstance(uri, ArtifactManifestEntry): + uri_str = uri.ref_url() + elif isinstance(uri, str): + uri_str = uri + url = urlparse(str(uri_str)) + if not url.scheme: + raise ValueError( + "References must be URIs. To reference a local file, use file://" + ) + + manifest_entries = self.manifest.storage_policy.store_reference( + self, + URIStr(uri_str), + name=name, + checksum=checksum, + max_objects=max_objects, + ) + for entry in manifest_entries: + self.manifest.add_entry(entry) + + return manifest_entries + + @ensure_not_finalized + def add( + self, obj: WBValue, name: StrPath, overwrite: bool = False + ) -> ArtifactManifestEntry: + """Add wandb.WBValue `obj` to the artifact. + + Args: + obj: The object to add. Currently support one of Bokeh, JoinedTable, + PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D, + Audio, Image, Video, Html, Object3D + name: The path within the artifact to add the object. + overwrite: If True, overwrite existing objects with the same file + path if applicable. + + Returns: + The added manifest entry + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + """ + name = LogicalPath(name) + + # This is a "hack" to automatically rename tables added to + # the wandb /media/tables directory to their sha-based name. + # TODO: figure out a more appropriate convention. + is_tmp_name = name.startswith("media/tables") + + # Validate that the object is one of the correct wandb.Media types + # TODO: move this to checking subclass of wandb.Media once all are + # generally supported + allowed_types = ( + data_types.Bokeh, + data_types.JoinedTable, + data_types.PartitionedTable, + data_types.Table, + data_types.Classes, + data_types.ImageMask, + data_types.BoundingBoxes2D, + data_types.Audio, + data_types.Image, + data_types.Video, + data_types.Html, + data_types.Object3D, + data_types.Molecule, + data_types._SavedModel, + ) + if not isinstance(obj, allowed_types): + raise TypeError( + f"Found object of type {obj.__class__}, expected one of:" + f" {allowed_types}" + ) + + obj_id = id(obj) + if obj_id in self._added_objs: + return self._added_objs[obj_id][1] + + # If the object is coming from another artifact, save it as a reference + ref_path = obj._get_artifact_entry_ref_url() + if ref_path is not None: + return self.add_reference(ref_path, type(obj).with_suffix(name))[0] + + val = obj.to_json(self) + name = obj.with_suffix(name) + entry = self.manifest.get_entry_by_path(name) + if (not overwrite) and (entry is not None): + return entry + + if is_tmp_name: + file_path = os.path.join(self._TMP_DIR.name, str(id(self)), name) + folder_path, _ = os.path.split(file_path) + os.makedirs(folder_path, exist_ok=True) + with open(file_path, "w", encoding="utf-8") as tmp_f: + json.dump(val, tmp_f, sort_keys=True) + else: + filemode = "w" if overwrite else "x" + with self.new_file(name, mode=filemode, encoding="utf-8") as f: + json.dump(val, f, sort_keys=True) + file_path = f.name + + # Note, we add the file from our temp directory. + # It will be added again later on finalize, but succeed since + # the checksum should match + entry = self.add_file(file_path, name, is_tmp_name) + # We store a reference to the obj so that its id doesn't get reused. + self._added_objs[obj_id] = (obj, entry) + if obj._artifact_target is None: + obj._set_artifact_target(self, entry.path) + + if is_tmp_name: + with contextlib.suppress(FileNotFoundError): + os.remove(file_path) + + return entry + + def _add_local_file( + self, + name: StrPath, + path: StrPath, + digest: B64MD5 | None = None, + skip_cache: bool | None = False, + policy: Literal["mutable", "immutable"] | None = "mutable", + overwrite: bool = False, + ) -> ArtifactManifestEntry: + policy = policy or "mutable" + if policy not in ["mutable", "immutable"]: + raise ValueError( + f"Invalid policy {policy!r}. Policy may only be `mutable` or `immutable`." + ) + upload_path = path + if policy == "mutable": + with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f: + staging_path = f.name + shutil.copyfile(path, staging_path) + # Set as read-only to prevent changes to the file during upload process + os.chmod(staging_path, stat.S_IRUSR) + upload_path = staging_path + + entry = ArtifactManifestEntry( + path=name, + digest=digest or md5_file_b64(upload_path), + size=os.path.getsize(upload_path), + local_path=upload_path, + skip_cache=skip_cache, + ) + self.manifest.add_entry(entry, overwrite=overwrite) + self._added_local_paths[os.fspath(path)] = entry + return entry + + @ensure_not_finalized + def remove(self, item: StrPath | ArtifactManifestEntry) -> None: + """Remove an item from the artifact. + + Args: + item: The item to remove. Can be a specific manifest entry + or the name of an artifact-relative path. If the item + matches a directory all items in that directory will be removed. + + Raises: + ArtifactFinalizedError: You cannot make changes to the current + artifact version because it is finalized. Log a new artifact + version instead. + FileNotFoundError: If the item isn't found in the artifact. + """ + if isinstance(item, ArtifactManifestEntry): + self.manifest.remove_entry(item) + return + + path = str(PurePosixPath(item)) + if entry := self.manifest.get_entry_by_path(path): + return self.manifest.remove_entry(entry) + + entries = self.manifest.get_entries_in_directory(path) + if not entries: + raise FileNotFoundError(f"No such file or directory: {path}") + for entry in entries: + self.manifest.remove_entry(entry) + + def get_path(self, name: StrPath) -> ArtifactManifestEntry: + """Deprecated. Use `get_entry(name)`.""" + warn_and_record_deprecation( + feature=Deprecated(artifact__get_path=True), + message="Artifact.get_path(name) is deprecated, use Artifact.get_entry(name) instead.", + ) + return self.get_entry(name) + + @ensure_logged + def get_entry(self, name: StrPath) -> ArtifactManifestEntry: + """Get the entry with the given name. + + Args: + name: The artifact relative name to get + + Returns: + A `W&B` object. + + Raises: + ArtifactNotLoggedError: if the artifact isn't logged or the run is offline. + KeyError: if the artifact doesn't contain an entry with the given name. + """ + name = LogicalPath(name) + entry = self.manifest.entries.get(name) or self._get_obj_entry(name)[0] + if entry is None: + raise KeyError(f"Path not contained in artifact: {name}") + entry._parent_artifact = self + return entry + + @ensure_logged + def get(self, name: str) -> WBValue | None: + """Get the WBValue object located at the artifact relative `name`. + + Args: + name: The artifact relative name to retrieve. + + Returns: + W&B object that can be logged with `run.log()` and + visualized in the W&B UI. + + Raises: + ArtifactNotLoggedError: if the artifact isn't logged or the + run is offline. + """ + entry, wb_class = self._get_obj_entry(name) + if entry is None or wb_class is None: + return None + + # If the entry is a reference from another artifact, then get it directly from + # that artifact. + if referenced_id := entry._referenced_artifact_id(): + assert self._client is not None + artifact = self._from_id(referenced_id, client=self._client) + assert artifact is not None + return artifact.get(uri_from_path(entry.ref)) + + # Special case for wandb.Table. This is intended to be a short term + # optimization. Since tables are likely to download many other assets in + # artifact(s), we eagerly download the artifact using the parallelized + # `artifact.download`. In the future, we should refactor the deserialization + # pattern such that this special case is not needed. + if wb_class == wandb.Table: + self.download() + + # Get the ArtifactManifestEntry + item = self.get_entry(entry.path) + item_path = item.download() + + # Load the object from the JSON blob + with open(item_path) as file: + json_obj = json.load(file) + + result = wb_class.from_json(json_obj, self) + result._set_artifact_source(self, name) + return result + + def get_added_local_path_name(self, local_path: str) -> str | None: + """Get the artifact relative name of a file added by a local filesystem path. + + Args: + local_path: The local path to resolve into an artifact relative name. + + Returns: + The artifact relative name. + """ + if entry := self._added_local_paths.get(local_path): + return entry.path + return None + + def _get_obj_entry( + self, name: str + ) -> tuple[ArtifactManifestEntry, Type[WBValue]] | tuple[None, None]: # noqa: UP006 # `type` shadows `Artifact.type` + """Return an object entry by name, handling any type suffixes. + + When objects are added with `.add(obj, name)`, the name is typically changed to + include the suffix of the object type when serializing to JSON. So we need to be + able to resolve a name, without tasking the user with appending .THING.json. + This method returns an entry if it exists by a suffixed name. + + Args: + name: name used when adding + """ + for wb_class in WBValue.type_mapping().values(): + wandb_file_name = wb_class.with_suffix(name) + if entry := self.manifest.entries.get(wandb_file_name): + return entry, wb_class + return None, None + + # Downloading. + + @ensure_logged + def download( + self, + root: StrPath | None = None, + allow_missing_references: bool = False, + skip_cache: bool | None = None, + path_prefix: StrPath | None = None, + multipart: bool | None = None, + ) -> FilePathStr: + """Download the contents of the artifact to the specified root directory. + + Existing files located within `root` are not modified. Explicitly delete `root` + before you call `download` if you want the contents of `root` to exactly match + the artifact. + + Args: + root: The directory W&B stores the artifact's files. + allow_missing_references: If set to `True`, any invalid reference paths + will be ignored while downloading referenced files. + skip_cache: If set to `True`, the artifact cache will be skipped when + downloading and W&B will download each file into the default root or + specified download directory. + path_prefix: If specified, only files with a path that starts with the given + prefix will be downloaded. Uses unix format (forward slashes). + multipart: If set to `None` (default), the artifact will be downloaded + in parallel using multipart download if individual file size is greater + than 2GB. If set to `True` or `False`, the artifact will be downloaded in + parallel or serially regardless of the file size. + + Returns: + The path to the downloaded contents. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + root = self._add_download_root(root) + + # TODO: download artifacts using core when implemented + # if is_require_core(): + # return self._download_using_core( + # root=root, + # allow_missing_references=allow_missing_references, + # skip_cache=bool(skip_cache), + # path_prefix=path_prefix, + # ) + return self._download( + root=root, + allow_missing_references=allow_missing_references, + skip_cache=skip_cache, + path_prefix=path_prefix, + multipart=multipart, + ) + + def _download_using_core( + self, + root: str, + allow_missing_references: bool = False, + skip_cache: bool = False, + path_prefix: StrPath | None = None, + ) -> FilePathStr: + import pathlib + + from wandb.sdk.backend.backend import Backend + + # TODO: Create a special stream instead of relying on an existing run. + if wandb.run is None: + wl = wandb_setup.singleton() + + stream_id = generate_id() + + settings = wl.settings.to_proto() + # TODO: remove this + tmp_dir = pathlib.Path(tempfile.mkdtemp()) + + settings.sync_dir.value = str(tmp_dir) + settings.sync_file.value = str(tmp_dir / f"{stream_id}.wandb") + settings.files_dir.value = str(tmp_dir / "files") + settings.run_id.value = stream_id + + service = wl.ensure_service() + service.inform_init(settings=settings, run_id=stream_id) + + backend = Backend(settings=wl.settings, service=service) + backend.ensure_launched() + + assert backend.interface + backend.interface._stream_id = stream_id # type: ignore + else: + assert wandb.run._backend + backend = wandb.run._backend + + assert backend.interface + handle = backend.interface.deliver_download_artifact( + self.id, # type: ignore + root, + allow_missing_references, + skip_cache, + path_prefix, # type: ignore + ) + # TODO: Start the download process in the user process too, to handle reference downloads + self._download( + root=root, + allow_missing_references=allow_missing_references, + skip_cache=skip_cache, + path_prefix=path_prefix, + ) + result = handle.wait_or(timeout=None) + + response = result.response.download_artifact_response + if response.error_message: + raise ValueError(f"Error downloading artifact: {response.error_message}") + + return FilePathStr(root) + + def _download( + self, + root: str, + allow_missing_references: bool = False, + skip_cache: bool | None = None, + path_prefix: StrPath | None = None, + multipart: bool | None = None, + ) -> FilePathStr: + nfiles = len(self.manifest.entries) + size_mb = self.size / _MB + + if log := (nfiles > 5000 or size_mb > 50): + termlog( + f"Downloading large artifact {self.name!r}, {size_mb:.2f}MB. {nfiles!r} files...", + ) + start_time = time.monotonic() + + download_logger = ArtifactDownloadLogger(nfiles=nfiles) + + def _download_entry(entry: ArtifactManifestEntry, executor: Executor) -> None: + multipart_executor = ( + executor + if should_multipart_download(entry.size, override=multipart) + else None + ) + try: + entry.download(root, skip_cache=skip_cache, executor=multipart_executor) + except FileNotFoundError as e: + if allow_missing_references: + wandb.termwarn(str(e)) + return + raise + except _GCSIsADirectoryError as e: + logger.debug(str(e)) + return + except IsADirectoryError: + wandb.termwarn( + f"Unable to download file {entry.path!r} as there is a directory with the same path, skipping." + ) + return + except NotADirectoryError: + wandb.termwarn( + f"Unable to download file {entry.path!r} as there is a file with the same path as a directory this file is expected to be in, skipping." + ) + return + download_logger.notify_downloaded() + + with ThreadPoolExecutor(max_workers=64) as executor: + batch_size = env.get_artifact_fetch_file_url_batch_size() + + active_futures = set() + cursor, has_more = None, True + while has_more: + files_page = self._fetch_file_urls(cursor=cursor, per_page=batch_size) + + has_more = files_page.page_info.has_next_page + cursor = files_page.page_info.end_cursor + + # `File` nodes are formally nullable, so filter them out just in case. + file_nodes = (e.node for e in files_page.edges if e.node) + for node in file_nodes: + entry = self.get_entry(node.name) + # TODO: uncomment once artifact downloads are supported in core + # if require_core and entry.ref is None: + # # Handled by core + # continue + entry._download_url = node.direct_url + if (not path_prefix) or entry.path.startswith(str(path_prefix)): + active_futures.add( + executor.submit(_download_entry, entry, executor=executor) + ) + + # Wait for download threads to catch up. + # + # Extra context and observations (tonyyli): + # - Even though the ThreadPoolExecutor limits the number of + # concurrently-executed tasks, its internal task queue is unbounded. + # The code below seems intended to ensure that at most `batch_size` + # "backlogged" futures are held in memory at any given time. This seems + # like a reasonable safeguard against unbounded memory consumption. + # + # - We should probably use a builtin bounded Queue or Semaphore instead. + # Consider this for a future change, or (depending on appetite for risk) + # managing this logic via asyncio instead, if viable. + if len(active_futures) > batch_size: + for future in as_completed(active_futures): + future.result() # check for errors + active_futures.remove(future) + if len(active_futures) <= batch_size: + break + + # Check for errors. + for future in as_completed(active_futures): + future.result() + + if log: + # If you're wondering if we can display a `timedelta`, note that it + # doesn't really support custom string format specifiers (compared to + # e.g. `datetime` objs). To truncate the number of decimal places for + # the seconds part, we manually convert/format each part below. + dt_secs = abs(time.monotonic() - start_time) + hrs, mins = divmod(dt_secs, 3600) + mins, secs = divmod(mins, 60) + termlog( + f"Done. {int(hrs):02d}:{int(mins):02d}:{secs:04.1f} ({size_mb / dt_secs:.1f}MB/s)", + prefix=False, + ) + return FilePathStr(root) + + def _build_fetch_file_urls_wrapper(self) -> Callable[..., Any]: + import requests + + @retry.retriable( + retry_timedelta=timedelta(minutes=3), + retryable_exceptions=(requests.RequestException), + ) + def _impl(cursor: str | None, per_page: int = 5000) -> FileWithUrlConnection: + from ._generated import ( + GET_ARTIFACT_FILE_URLS_GQL, + GET_ARTIFACT_MEMBERSHIP_FILE_URLS_GQL, + GetArtifactFileUrls, + GetArtifactMembershipFileUrls, + ) + from ._models.pagination import FileWithUrlConnection + + if self._client is None: + raise RuntimeError("Client not initialized") + + if server_supports(self._client, pb.ARTIFACT_COLLECTION_MEMBERSHIP_FILES): + query = gql(GET_ARTIFACT_MEMBERSHIP_FILE_URLS_GQL) + gql_vars = { + "entity": self.entity, + "project": self.project, + "collection": self.name.split(":")[0], + "alias": self.version, + "cursor": cursor, + "perPage": per_page, + } + data = self._client.execute(query, variable_values=gql_vars, timeout=60) + result = GetArtifactMembershipFileUrls.model_validate(data) + + if not ( + (project := result.project) + and (collection := project.artifact_collection) + and (membership := collection.artifact_membership) + and (files := membership.files) + ): + raise ValueError( + f"Unable to fetch files for artifact: {self.name!r}" + ) + return FileWithUrlConnection.model_validate(files) + else: + query = gql(GET_ARTIFACT_FILE_URLS_GQL) + gql_vars = {"id": self.id, "cursor": cursor, "perPage": per_page} + data = self._client.execute(query, variable_values=gql_vars, timeout=60) + result = GetArtifactFileUrls.model_validate(data) + + if not ((artifact := result.artifact) and (files := artifact.files)): + raise ValueError( + f"Unable to fetch files for artifact: {self.name!r}" + ) + return FileWithUrlConnection.model_validate(files) + + return _impl + + def _fetch_file_urls( + self, cursor: str | None, per_page: int = 5000 + ) -> FileWithUrlConnection: + if self._fetch_file_urls_decorated is None: + self._fetch_file_urls_decorated = self._build_fetch_file_urls_wrapper() + return self._fetch_file_urls_decorated(cursor, per_page) + + @ensure_logged + def checkout(self, root: str | None = None) -> str: + """Replace the specified root directory with the contents of the artifact. + + WARNING: This will delete all files in `root` that are not included in the + artifact. + + Args: + root: The directory to replace with this artifact's files. + + Returns: + The path of the checked out contents. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + root = root or self._default_root(include_version=False) + + for dirpath, _, files in os.walk(root): + for file in files: + full_path = os.path.join(dirpath, file) + artifact_path = os.path.relpath(full_path, start=root) + try: + self.get_entry(artifact_path) + except KeyError: + # File is not part of the artifact, remove it. + os.remove(full_path) + + return self.download(root=root) + + @ensure_logged + def verify(self, root: str | None = None) -> None: + """Verify that the contents of an artifact match the manifest. + + All files in the directory are checksummed and the checksums are then + cross-referenced against the artifact's manifest. References are not verified. + + Args: + root: The directory to verify. If None artifact will be downloaded to + './artifacts/self.name/'. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + ValueError: If the verification fails. + """ + root = root or self._default_root() + + for dirpath, _, files in os.walk(root): + for file in files: + full_path = os.path.join(dirpath, file) + artifact_path = os.path.relpath(full_path, start=root) + try: + self.get_entry(artifact_path) + except KeyError: + raise ValueError( + f"Found file {full_path} which is not a member of artifact {self.name}" + ) + + ref_count = 0 + for entry in self.manifest.entries.values(): + if entry.ref is None: + if md5_file_b64(os.path.join(root, entry.path)) != entry.digest: + raise ValueError(f"Digest mismatch for file: {entry.path}") + else: + ref_count += 1 + if ref_count > 0: + termwarn(f"skipped verification of {ref_count} refs") + + @ensure_logged + def file(self, root: str | None = None) -> StrPath: + """Download a single file artifact to the directory you specify with `root`. + + Args: + root: The root directory to store the file. Defaults to + `./artifacts/self.name/`. + + Returns: + The full path of the downloaded file. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + ValueError: If the artifact contains more than one file. + """ + if root is None: + root = os.path.join(".", "artifacts", self.name) + + if len(self.manifest.entries) > 1: + raise ValueError( + "This artifact contains more than one file, call `.download()` to get " + 'all files or call .get_entry("filename").download()' + ) + + return self.get_entry(list(self.manifest.entries)[0]).download(root) + + @ensure_logged + def files( + self, names: list[str] | None = None, per_page: int = 50 + ) -> ArtifactFiles: + """Iterate over all files stored in this artifact. + + Args: + names: The filename paths relative to the root of the artifact you wish to + list. + per_page: The number of files to return per request. + + Returns: + An iterator containing `File` objects. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + if (client := self._client) is None: + raise RuntimeError("Client not initialized") + return ArtifactFiles(client, self, names, per_page) + + def _default_root(self, include_version: bool = True) -> FilePathStr: + name = self.source_name if include_version else self.source_name.split(":")[0] + root = os.path.join(env.get_artifact_dir(), name) + # In case we're on a system where the artifact dir has a name corresponding to + # an unexpected filesystem, we'll check for alternate roots. If one exists we'll + # use that, otherwise we'll fall back to the system-preferred path. + return FilePathStr(check_exists(root) or system_preferred_path(root)) + + def _add_download_root(self, dir_path: StrPath | None) -> FilePathStr: + root = str(dir_path or self._default_root()) + self._download_roots.add(os.path.abspath(root)) + return root + + def _local_path_to_name(self, file_path: str) -> str | None: + """Convert a local file path to a path entry in the artifact.""" + abs_file_path = os.path.abspath(file_path) + abs_file_parts = abs_file_path.split(os.sep) + for i in range(len(abs_file_parts) + 1): + if os.path.join(os.sep, *abs_file_parts[:i]) in self._download_roots: + return os.path.join(*abs_file_parts[i:]) + return None + + # Others. + + @ensure_logged + def delete(self, delete_aliases: bool = False) -> None: + """Delete an artifact and its files. + + If called on a linked artifact, only the link is deleted, and the + source artifact is unaffected. + + Use `Artifact.unlink()` instead of `Artifact.delete()` to remove a + link between a source artifact and a collection. + + Args: + delete_aliases: If set to `True`, delete all aliases associated + with the artifact. If `False`, raise an exception if + the artifact has existing aliases. This parameter is ignored + if the artifact is retrieved from a collection it is linked to. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + if self.is_link: + wandb.termwarn( + "Deleting a link artifact will only unlink the artifact from the source artifact and not delete the source artifact and the data of the source artifact." + ) + self._unlink() + else: + self._delete(delete_aliases) + + @normalize_exceptions + def _delete(self, delete_aliases: bool = False) -> None: + from ._generated import DELETE_ARTIFACT_GQL, DeleteArtifactInput + + if self._client is None: + raise RuntimeError("Client not initialized for artifact mutations") + + gql_op = gql(DELETE_ARTIFACT_GQL) + gql_input = DeleteArtifactInput( + artifact_id=self.id, + delete_aliases=delete_aliases, + ) + self._client.execute(gql_op, variable_values={"input": gql_input.model_dump()}) + + @normalize_exceptions + def link(self, target_path: str, aliases: Iterable[str] | None = None) -> Artifact: + """Link this artifact to a collection. + + Args: + target_path: The path of the collection. Path consists of the prefix + "wandb-registry-" along with the registry name and the + collection name `wandb-registry-{REGISTRY_NAME}/{COLLECTION_NAME}`. + aliases: Add one or more aliases to the linked artifact. The + "latest" alias is automatically applied to the most recent artifact + you link. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + + Returns: + The linked artifact. + """ + from wandb import Api + from wandb.sdk.internal.internal_api import Api as InternalApi + + from ._generated import LINK_ARTIFACT_GQL, LinkArtifact, LinkArtifactInput + from ._validators import ArtifactPath, FullArtifactPath, validate_aliases + + if self.is_link: + wandb.termwarn( + "Linking to a link artifact will result in directly linking to the source artifact of that link artifact." + ) + + # Save the artifact first if necessary + if self.is_draft(): + if not self._is_draft_save_started(): + # Avoiding public `.source_project` property here, + # as it requires the artifact is logged first. + self.save(project=self._source_project) + + # Wait until the artifact is committed before trying to link it. + self.wait() + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact mutations") + + # FIXME: Find a way to avoid using InternalApi here, due to the perf overhead + settings = InternalApi().settings() + + target = ArtifactPath.from_str(target_path).with_defaults( + project=settings.get("project") or "uncategorized", + ) + + # Parse the entity (first part of the path) appropriately, + # depending on whether we're linking to a registry + if target.is_registry_path(): + # In a Registry linking, the entity is used to fetch the organization of the + # artifact, therefore the source artifact's entity is passed to the backend + org = target.prefix or settings.get("organization") or None + target.prefix = resolve_org_entity_name(client, self.source_entity, org) + else: + target = target.with_defaults(prefix=self.source_entity) + + # Explicitly convert to FullArtifactPath to ensure all fields are present + target = FullArtifactPath(**asdict(target)) + + # Prepare the validated GQL input, send it + alias_inputs = [ + {"artifactCollectionName": target.name, "alias": a} + for a in validate_aliases(aliases or []) + ] + gql_input = LinkArtifactInput( + artifact_id=self.id, + artifact_portfolio_name=target.name, + entity_name=target.prefix, + project_name=target.project, + aliases=alias_inputs, + ) + gql_vars = {"input": gql_input.model_dump()} + + # Newer server versions can return `artifactMembership` directly in the response, + # avoiding the need to re-fetch the linked artifact at the end. + omit_variables = omit_fields = None + if not server_supports( + client, pb.ARTIFACT_MEMBERSHIP_IN_LINK_ARTIFACT_RESPONSE + ): + omit_variables = {"includeAliases"} + omit_fields = {"artifactMembership"} + + gql_op = gql_compat( + LINK_ARTIFACT_GQL, omit_variables=omit_variables, omit_fields=omit_fields + ) + data = client.execute(gql_op, variable_values=gql_vars) + result = LinkArtifact.model_validate(data).result + + # Newer server versions can return artifactMembership directly in the response + if result and (membership := result.artifact_membership): + return self._from_membership(membership, target=target, client=client) + + # Old behavior, which requires re-fetching the linked artifact to return it + if not (result and (version_idx := result.version_index) is not None): + raise ValueError("Unable to parse linked artifact version from response") + + link_name = f"{target.to_str()}:v{version_idx}" + return Api(overrides={"entity": self.source_entity})._artifact(link_name) + + @ensure_logged + def unlink(self) -> None: + """Unlink this artifact if it is a linked member of an artifact collection. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + ValueError: If the artifact is not linked to any collection. + """ + # Fail early if this isn't a linked artifact to begin with + if not self.is_link: + raise ValueError( + f"Artifact {self.qualified_name!r} is not a linked artifact and cannot be unlinked. " + f"To delete it, use {nameof(self.delete)!r} instead." + ) + + self._unlink() + + @normalize_exceptions + def _unlink(self) -> None: + from ._generated import UNLINK_ARTIFACT_GQL, UnlinkArtifactInput + + if self._client is None: + raise RuntimeError("Client not initialized for artifact mutations") + + mutation = gql(UNLINK_ARTIFACT_GQL) + gql_input = UnlinkArtifactInput( + artifact_id=self.id, + artifact_portfolio_id=self.collection.id, + ) + gql_vars = {"input": gql_input.model_dump()} + try: + self._client.execute(mutation, variable_values=gql_vars) + except CommError as e: + raise CommError( + f"You do not have permission to unlink the artifact {self.qualified_name!r}" + ) from e + + @ensure_logged + def used_by(self) -> list[Run]: + """Get a list of the runs that have used this artifact and its linked artifacts. + + Returns: + A list of `Run` objects. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + from ._generated import ARTIFACT_USED_BY_GQL, ArtifactUsedBy + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact queries") + + query = gql(ARTIFACT_USED_BY_GQL) + gql_vars = {"id": self.id} + data = client.execute(query, variable_values=gql_vars) + result = ArtifactUsedBy.model_validate(data) + + if ( + (artifact := result.artifact) + and (used_by := artifact.used_by) + and (edges := used_by.edges) + ): + run_nodes = (e.node for e in edges) + return [ + Run(client, proj.entity.name, proj.name, run.name) + for run in run_nodes + if (proj := run.project) + ] + return [] + + @ensure_logged + def logged_by(self) -> Run | None: + """Get the W&B run that originally logged the artifact. + + Returns: + The name of the W&B run that originally logged the artifact. + + Raises: + ArtifactNotLoggedError: If the artifact is not logged. + """ + from ._generated import ARTIFACT_CREATED_BY_GQL, ArtifactCreatedBy + + if (client := self._client) is None: + raise RuntimeError("Client not initialized for artifact queries") + + gql_op = gql(ARTIFACT_CREATED_BY_GQL) + gql_vars = {"id": self.id} + data = client.execute(gql_op, variable_values=gql_vars) + result = ArtifactCreatedBy.model_validate(data) + + if ( + (artifact := result.artifact) + and (creator := artifact.created_by) + and (name := creator.name) + and (project := creator.project) + ): + return Run(client, project.entity.name, project.name, name) + return None + + @ensure_logged + def json_encode(self) -> dict[str, Any]: + """Returns the artifact encoded to the JSON format. + + Returns: + A `dict` with `string` keys representing attributes of the artifact. + """ + return artifact_to_json(self) + + @staticmethod + def _expected_type( + entity_name: str, project_name: str, name: str, client: RetryingClient + ) -> str | None: + """Returns the expected type for a given artifact name and project.""" + from ._generated import ARTIFACT_TYPE_GQL, ArtifactType + + name = name if (":" in name) else f"{name}:latest" + + gql_op = gql(ARTIFACT_TYPE_GQL) + gql_vars = {"entity": entity_name, "project": project_name, "name": name} + data = client.execute(gql_op, variable_values=gql_vars) + result = ArtifactType.model_validate(data) + if (project := result.project) and (artifact := project.artifact): + return artifact.artifact_type.name + return None + + def _ttl_duration_seconds_to_gql(self) -> int | None: + # Set the artifact TTL to `ttl_duration_seconds` if the user provided a value. + # Otherwise, use `ttl_status` to indicate backend values INHERIT (-1) or + # DISABLED (-2) when the TTL is None. + # When `ttl_change is None`, nothing changed and this is a no-op. + INHERIT = -1 # noqa: N806 + DISABLED = -2 # noqa: N806 + + if not self._ttl_changed: + return None + if self._ttl_is_inherited: + return INHERIT + return self._ttl_duration_seconds or DISABLED + + def _fetch_linked_artifacts(self) -> list[Artifact]: + """Fetches all linked artifacts from the server.""" + from wandb._pydantic import gql_typename + + from ._generated import ( + FETCH_LINKED_ARTIFACTS_GQL, + ArtifactPortfolioTypeFields, + FetchLinkedArtifacts, + ) + from ._validators import LinkArtifactFields + + if self.id is None: + raise ValueError( + "Unable to find any artifact memberships for artifact without an ID" + ) + + if (client := self._client) is None: + raise ValueError("Client is not initialized") + + gql_op = gql_compat(FETCH_LINKED_ARTIFACTS_GQL) + data = client.execute(gql_op, variable_values={"artifactID": self.id}) + result = FetchLinkedArtifacts.model_validate(data) + + if not ( + (artifact := result.artifact) + and (memberships := artifact.artifact_memberships) + and (membership_edges := memberships.edges) + ): + raise ValueError("Unable to find any artifact memberships for artifact") + + linked_artifacts: deque[Artifact] = deque() + linked_nodes = ( + node + for edge in membership_edges + if ( + (node := edge.node) + and (col := node.artifact_collection) + and (col.typename__ == gql_typename(ArtifactPortfolioTypeFields)) + ) + ) + for node in linked_nodes: + alias_names = unique_list(a.alias for a in node.aliases) + version = f"v{node.version_index}" + aliases = ( + [*alias_names, version] + if version not in alias_names + else [*alias_names] + ) + + if not ( + node + and (col := node.artifact_collection) + and (proj := col.project) + and (proj.entity.name and proj.name) + ): + raise ValueError("Unable to fetch fields for linked artifact") + + link_fields = LinkArtifactFields( + entity_name=proj.entity.name, + project_name=proj.name, + name=f"{col.name}:{version}", + version=version, + aliases=aliases, + ) + link = self._create_linked_artifact_using_source_artifact(link_fields) + linked_artifacts.append(link) + return list(linked_artifacts) + + def _create_linked_artifact_using_source_artifact( + self, + link_fields: LinkArtifactFields, + ) -> Artifact: + """Copies the source artifact to a linked artifact.""" + linked_artifact = copy(self) + linked_artifact._version = link_fields.version + linked_artifact._aliases = link_fields.aliases + linked_artifact._saved_aliases = copy(link_fields.aliases) + linked_artifact._name = link_fields.name + linked_artifact._entity = link_fields.entity_name + linked_artifact._project = link_fields.project_name + linked_artifact._is_link = link_fields.is_link + linked_artifact._linked_artifacts = link_fields.linked_artifacts + return linked_artifact + + +class _ArtifactVersionType(WBType): + name = "artifactVersion" + types = [Artifact] + + +TypeRegistry.add(_ArtifactVersionType) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_download_logger.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_download_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..f4370d68bc74b9252827b6b18aeac87a4b06ef55 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_download_logger.py @@ -0,0 +1,45 @@ +"""Artifact download logger.""" + +from __future__ import annotations + +import multiprocessing.dummy +import time +from typing import Callable + +from wandb.errors.term import termlog + + +class ArtifactDownloadLogger: + def __init__( + self, + nfiles: int, + clock_for_testing: Callable[[], float] = time.monotonic, + termlog_for_testing: Callable[..., None] = termlog, + ) -> None: + self._nfiles = nfiles + self._clock = clock_for_testing + self._termlog = termlog_for_testing + + self._n_files_downloaded = 0 + self._spinner_index = 0 + self._last_log_time = self._clock() + self._lock = multiprocessing.dummy.Lock() + + def notify_downloaded(self) -> None: + with self._lock: + self._n_files_downloaded += 1 + if self._n_files_downloaded == self._nfiles: + self._termlog( + f" {self._nfiles} of {self._nfiles} files downloaded. ", + # ^ trailing spaces to wipe out ellipsis from previous logs + newline=True, + ) + self._last_log_time = self._clock() + elif self._clock() - self._last_log_time > 0.1: + self._spinner_index += 1 + spinner = r"-\|/"[self._spinner_index % 4] + self._termlog( + f"{spinner} {self._n_files_downloaded} of {self._nfiles} files downloaded...\r", + newline=False, + ) + self._last_log_time = self._clock() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_file_cache.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_file_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..8f06fbce71f91e6231d01e29a84861705f3698d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_file_cache.py @@ -0,0 +1,255 @@ +"""Artifact cache.""" + +from __future__ import annotations + +import contextlib +import errno +import hashlib +import os +import shutil +import subprocess +import sys +from functools import lru_cache +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import IO, ContextManager, Iterator, Protocol + +import wandb +from wandb import env, util +from wandb.sdk.lib.filesystem import files_in +from wandb.sdk.lib.hashutil import B64MD5, ETag, b64_to_hex_id +from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr + + +class Opener(Protocol): + def __call__(self, mode: str = ...) -> ContextManager[IO]: ... + + +def artifacts_cache_dir() -> Path: + """Get the artifacts cache directory.""" + return env.get_cache_dir() / "artifacts" + + +def _get_sys_umask_threadsafe() -> int: + # Workaround to get the current system umask, since + # - `os.umask()` isn't thread-safe + # - we don't want to inadvertently change the umask of the current process + # See: https://stackoverflow.com/questions/53227072/reading-umask-thread-safe + umask_cmd = (sys.executable, "-c", "import os; print(os.umask(22))") + return int(subprocess.check_output(umask_cmd)) + + +class ArtifactFileCache: + def __init__(self, cache_dir: StrPath) -> None: + self._cache_dir = Path(cache_dir) + self._obj_dir = self._cache_dir / "obj" + self._temp_dir = self._cache_dir / "tmp" + self._ensure_write_permissions() + + # NamedTemporaryFile sets the file mode to 600 [1], we reset to the default. + # [1] https://stackoverflow.com/questions/10541760/can-i-set-the-umask-for-tempfile-namedtemporaryfile-in-python + self._sys_umask = _get_sys_umask_threadsafe() + + self._override_cache_path: StrPath | None = None + + def check_md5_obj_path( + self, b64_md5: B64MD5, size: int + ) -> tuple[FilePathStr, bool, Opener]: + # Check if we're using vs skipping the cache + if self._override_cache_path is not None: + skip_cache = True + path = Path(self._override_cache_path) + else: + skip_cache = False + hex_md5 = b64_to_hex_id(b64_md5) + path = self._obj_dir / "md5" / hex_md5[:2] / hex_md5[2:] + return self._check_or_create(path, size, skip_cache=skip_cache) + + # TODO(spencerpearson): this method at least needs its signature changed. + # An ETag is not (necessarily) a checksum. + def check_etag_obj_path( + self, + url: URIStr, + etag: ETag, + size: int, + ) -> tuple[FilePathStr, bool, Opener]: + # Check if we're using vs skipping the cache + if self._override_cache_path is not None: + skip_cache = True + path = Path(self._override_cache_path) + else: + skip_cache = False + hexhash = hashlib.sha256( + hashlib.sha256(url.encode("utf-8")).digest() + + hashlib.sha256(etag.encode("utf-8")).digest() + ).hexdigest() + path = self._obj_dir / "etag" / hexhash[:2] / hexhash[2:] + return self._check_or_create(path, size, skip_cache=skip_cache) + + def _check_or_create( + self, path: Path, size: int, skip_cache: bool = False + ) -> tuple[FilePathStr, bool, Opener]: + opener = self._opener(path, size, skip_cache=skip_cache) + hit = path.is_file() and path.stat().st_size == size + return FilePathStr(path), hit, opener + + def cleanup( + self, + target_size: int | None = None, + remove_temp: bool = False, + target_fraction: float | None = None, + ) -> int: + """Clean up the cache, removing the least recently used files first. + + Args: + target_size: The target size of the cache in bytes. If the cache is larger + than this, we will remove the least recently used files until the cache + is smaller than this size. + remove_temp: Whether to remove temporary files. Temporary files are files + that are currently being written to the cache. If remove_temp is True, + all temp files will be removed, regardless of the target_size or + target_fraction. + target_fraction: The target fraction of the cache to reclaim. If the cache + is larger than this, we will remove the least recently used files until + the cache is smaller than this fraction of its current size. It is an + error to specify both target_size and target_fraction. + + Returns: + The number of bytes reclaimed. + """ + if target_size is None and target_fraction is None: + # Default to clearing the entire cache. + target_size = 0 + if target_size is not None and target_fraction is not None: + raise ValueError("Cannot specify both target_size and target_fraction") + if target_size is not None and target_size < 0: + raise ValueError("target_size must be non-negative") + if target_fraction is not None and (target_fraction < 0 or target_fraction > 1): + raise ValueError("target_fraction must be between 0 and 1") + + bytes_reclaimed = 0 + total_size = 0 + temp_size = 0 + + # Remove all temporary files if requested. Otherwise sum their size. + for entry in files_in(self._temp_dir): + size = entry.stat().st_size + total_size += size + if remove_temp: + try: + os.remove(entry.path) + bytes_reclaimed += size + except OSError: + pass + else: + temp_size += size + if temp_size: + wandb.termwarn( + f"Cache contains {util.to_human_size(temp_size)} of temporary files. " + "Run `wandb artifact cache cleanup --remove-temp` to remove them." + ) + + entries = [] + for file_entry in files_in(self._obj_dir): + total_size += file_entry.stat().st_size + entries.append(file_entry) + + if target_fraction is not None: + target_size = int(total_size * target_fraction) + assert target_size is not None + + for entry in sorted(entries, key=lambda x: x.stat().st_atime): + if total_size <= target_size: + return bytes_reclaimed + try: + os.remove(entry.path) + except OSError: + pass + total_size -= entry.stat().st_size + bytes_reclaimed += entry.stat().st_size + + if total_size > target_size: + wandb.termerror( + f"Failed to reclaim enough space in {self._cache_dir}. Try running" + " `wandb artifact cache cleanup --remove-temp` to remove temporary files." + ) + + return bytes_reclaimed + + def _free_space(self) -> int: + """Return the number of bytes of free space in the cache directory.""" + return shutil.disk_usage(self._cache_dir)[2] + + def _reserve_space(self, size: int) -> None: + """If a `size` write would exceed disk space, remove cached items to make space. + + Raises: + OSError: If there is not enough space to write `size` bytes, even after + removing cached items. + """ + if size <= self._free_space(): + return + + wandb.termwarn("Cache size exceeded. Attempting to reclaim space...") + self.cleanup(target_fraction=0.5) + if size <= self._free_space(): + return + + self.cleanup(target_size=0) + if size > self._free_space(): + raise OSError(errno.ENOSPC, f"Insufficient free space in {self._cache_dir}") + + def _opener(self, path: Path, size: int, skip_cache: bool = False) -> Opener: + @contextlib.contextmanager + def atomic_open(mode: str = "w") -> Iterator[IO]: + if "a" in mode: + raise ValueError("Appending to cache files is not supported") + + if skip_cache: + # Skip the cache but still use an intermediate temporary file to + # ensure atomicity. Place the temp file in the same root as the + # destination file to avoid cross-filesystem move/copy operations. + temp_dir = path.parent + else: + self._reserve_space(size) + temp_dir = self._temp_dir + + temp_dir.mkdir(parents=True, exist_ok=True) + temp_file = NamedTemporaryFile(dir=temp_dir, mode=mode, delete=False) + try: + yield temp_file + temp_file.close() + os.chmod(temp_file.name, 0o666 & ~self._sys_umask) + path.parent.mkdir(parents=True, exist_ok=True) + os.replace(temp_file.name, path) + except Exception: + os.remove(temp_file.name) + raise + + return atomic_open + + def _ensure_write_permissions(self) -> None: + """Raise an error if we cannot write to the cache directory.""" + try: + self._temp_dir.mkdir(parents=True, exist_ok=True) + with NamedTemporaryFile(dir=self._temp_dir) as f: + f.write(b"wandb") + except PermissionError as e: + raise PermissionError( + f"Unable to write to {self._cache_dir}. " + "Ensure that the current user has write permissions." + ) from e + + +# Memo `ArtifactFileCache` instances while avoiding reliance on global +# variable(s). Notes: +# - @lru_cache should be thread-safe. +# - We don't memoize `get_artifact_file_cache` directly, as the cache_dir +# may change at runtime. This is likely rare in practice, though. +@lru_cache(maxsize=1) +def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache: + return ArtifactFileCache(cache_dir) + + +def get_artifact_file_cache() -> ArtifactFileCache: + return _build_artifact_file_cache(artifacts_cache_dir()) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_manifest_entry.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_manifest_entry.py new file mode 100644 index 0000000000000000000000000000000000000000..0a5ed2098ac2e04ee0b6c0512b35591ff159175d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_manifest_entry.py @@ -0,0 +1,265 @@ +"""Artifact manifest entry.""" + +# Older-style type annotations required for Pydantic v1 / python 3.8 compatibility. +# ruff: noqa: UP006, UP007, UP045 + +from __future__ import annotations + +import concurrent.futures +import hashlib +import logging +import os +from contextlib import suppress +from os.path import getsize +from typing import TYPE_CHECKING, Any, Dict, Final, Optional, Union +from urllib.parse import urlparse + +from pydantic import Field, NonNegativeInt +from typing_extensions import Annotated, Self + +from wandb._pydantic import field_validator, model_validator +from wandb._strutils import nameof +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk.lib.deprecation import warn_and_record_deprecation +from wandb.sdk.lib.filesystem import copy_or_overwrite_changed +from wandb.sdk.lib.hashutil import ( + B64MD5, + ETag, + b64_to_hex_id, + hex_to_b64_id, + md5_file_b64, +) +from wandb.sdk.lib.paths import FilePathStr, LogicalPath, URIStr + +from ._models.base_model import ArtifactsBase + +if TYPE_CHECKING: + from .artifact import Artifact + + +logger = logging.getLogger(__name__) + + +_WB_ARTIFACT_SCHEME: Final[str] = "wandb-artifact" + + +def _checksum_cache_path(file_path: str) -> str: + """Get path for checksum in central cache directory.""" + from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir + + # Create a unique cache key based on the file's absolute path + abs_path = os.path.abspath(file_path) + path_hash = hashlib.sha256(abs_path.encode()).hexdigest() + + # Store in wandb cache directory under checksums subdirectory + cache_dir = artifacts_cache_dir() / "checksums" + cache_dir.mkdir(parents=True, exist_ok=True) + + return str(cache_dir / f"{path_hash}.checksum") + + +def _read_cached_checksum(file_path: str) -> str | None: + """Read checksum from cache if it exists and is valid.""" + checksum_path = _checksum_cache_path(file_path) + + try: + with open(file_path) as f, open(checksum_path) as f_checksum: + if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name): + # File was modified after checksum was written + return None + # Read and return the cached checksum + return f_checksum.read().strip() + except OSError: + # File doesn't exist or couldn't be opened + return None + + +def _write_cached_checksum(file_path: str, checksum: str) -> None: + """Write checksum to cache directory.""" + checksum_path = _checksum_cache_path(file_path) + try: + with open(checksum_path, "w") as f: + f.write(checksum) + except OSError: + # Non-critical failure, just log it + logger.debug(f"Failed to write checksum cache for {file_path!r}") + + +class ArtifactManifestEntry(ArtifactsBase): + """A single entry in an artifact manifest. + + External code should avoid instantiating this class directly. + """ + + path: LogicalPath + + digest: Union[B64MD5, ETag, URIStr, FilePathStr] + ref: Union[URIStr, FilePathStr, None] = None + birth_artifact_id: Annotated[Optional[str], Field(alias="birthArtifactID")] = None + size: Optional[NonNegativeInt] = None + extra: Dict[str, Any] = Field(default_factory=dict) + local_path: Optional[str] = None + + skip_cache: bool = False + + # Note: Pydantic treats these as private attributes, omitting them from + # validation and comparison logic. + _parent_artifact: Optional[Artifact] = None + _download_url: Optional[str] = None + + @field_validator("path", mode="before") + def _validate_path(cls, v: Any) -> LogicalPath: + """Coerce `path` to a LogicalPath. + + LogicalPath does not implement its own pydantic validator, and adding + one for both pydantic V1 and V2 would add excessive boilerplate. Until + we drop V1 support, coerce to LogicalPath in this field validator. + """ + return LogicalPath(v) + + @field_validator("local_path", mode="before") + def _validate_local_path(cls, v: Any) -> str | None: + """Coerce `local_path` to a str. Necessary if the input is a `PosixPath`.""" + return str(v) if v else None + + @model_validator(mode="after") + def _infer_size_from_local_path(self) -> Self: + """If `size` isn't set, try to infer it from `local_path`.""" + if (self.size is None) and self.local_path: + self.size = getsize(self.local_path) + return self + + def __repr__(self) -> str: + # For compatibility with prior behavior, don't display `extra` if it's empty + exclude = None if self.extra else {"extra"} + repr_dict = self.model_dump(by_alias=False, exclude_none=True, exclude=exclude) + return f"{nameof(type(self))}({', '.join(f'{k}={v!r}' for k, v in repr_dict.items())})" + + @property + def name(self) -> LogicalPath: + """Deprecated; use `path` instead.""" + warn_and_record_deprecation( + feature=Deprecated(artifactmanifestentry__name=True), + message="ArtifactManifestEntry.name is deprecated, use .path instead.", + ) + return self.path + + def parent_artifact(self) -> Artifact: + """Get the artifact to which this artifact entry belongs. + + Returns: + (PublicArtifact): The parent artifact + """ + if self._parent_artifact is None: + raise NotImplementedError + return self._parent_artifact + + def download( + self, + root: str | None = None, + skip_cache: bool | None = None, + executor: concurrent.futures.Executor | None = None, + ) -> FilePathStr: + """Download this artifact entry to the specified root path. + + Args: + root: (str, optional) The root path in which to download this + artifact entry. Defaults to the artifact's root. + + Returns: + (str): The path of the downloaded artifact entry. + """ + artifact = self.parent_artifact() + rootdir = artifact._add_download_root(root) + dest_path = os.path.join(rootdir, self.path) + + # Skip checking the cache (and possibly downloading) if the file already exists + # and has the digest we're expecting. + + # Fast integrity check using cached checksum from persistent cache + with suppress(OSError): + if self.digest == _read_cached_checksum(dest_path): + return FilePathStr(dest_path) + + # Fallback to computing/caching the checksum hash + try: + md5_hash = md5_file_b64(dest_path) + except (FileNotFoundError, IsADirectoryError): + logger.debug(f"unable to find {dest_path!r}, skip searching for file") + else: + _write_cached_checksum(dest_path, md5_hash) + if self.digest == md5_hash: + return FilePathStr(dest_path) + + # Override the target cache path IF we're skipping the cache. + # Note that `override_cache_path is None` <=> `skip_cache is False`. + override_cache_path = FilePathStr(dest_path) if skip_cache else None + storage_policy = artifact.manifest.storage_policy + if self.ref is not None: + cache_path = storage_policy.load_reference( + self, local=True, dest_path=override_cache_path + ) + else: + cache_path = storage_policy.load_file( + artifact, self, dest_path=override_cache_path, executor=executor + ) + + # Determine the final path + final_path = FilePathStr( + override_cache_path or copy_or_overwrite_changed(cache_path, dest_path) + ) + + # Cache the checksum for future downloads + _write_cached_checksum(final_path, self.digest) + + return final_path + + def ref_target(self) -> FilePathStr | URIStr: + """Get the reference URL that is targeted by this artifact entry. + + Returns: + (str): The reference URL of this artifact entry. + + Raises: + ValueError: If this artifact entry was not a reference. + """ + if self.ref is None: + raise ValueError("Only reference entries support ref_target().") + if (parent_artifact := self._parent_artifact) is None: + return self.ref + return parent_artifact.manifest.storage_policy.load_reference(self, local=False) + + def ref_url(self) -> str: + """Get a URL to this artifact entry. + + These URLs can be referenced by another artifact. + + Returns: + (str): A URL representing this artifact entry. + + Examples: + Basic usage + ``` + ref_url = source_artifact.get_entry("file.txt").ref_url() + derived_artifact.add_reference(ref_url) + ``` + """ + if (parent_artifact := self.parent_artifact()) is None: + raise ValueError("Parent artifact is not set") + elif (parent_id := parent_artifact.id) is None: + raise ValueError("Parent artifact ID is not set") + return f"{_WB_ARTIFACT_SCHEME}://{b64_to_hex_id(parent_id)}/{self.path}" + + def to_json(self) -> dict[str, Any]: + # NOTE: The method name `to_json` is a bit misleading, as this returns a + # python dict, NOT a JSON string. The historical name is kept for continuity, + # but consider deprecating this in favor of `BaseModel.model_dump()`. + return self.model_dump(exclude_none=True) # type: ignore[return-value] + + def _is_artifact_reference(self) -> bool: + return self.ref is not None and urlparse(self.ref).scheme == _WB_ARTIFACT_SCHEME + + def _referenced_artifact_id(self) -> str | None: + if not self._is_artifact_reference(): + return None + return hex_to_b64_id(urlparse(self.ref).netloc) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_saver.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..f007287627665a416b9bd6cb4cf89c457448e2bd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_saver.py @@ -0,0 +1,275 @@ +"""Artifact saver.""" + +from __future__ import annotations + +import concurrent.futures +import json +import os +import tempfile +from typing import TYPE_CHECKING, Awaitable, Sequence + +import wandb +import wandb.filesync.step_prepare +from wandb import util +from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest +from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64 +from wandb.sdk.lib.paths import URIStr + +if TYPE_CHECKING: + from typing import Protocol + + from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry + from wandb.sdk.internal.file_pusher import FilePusher + from wandb.sdk.internal.internal_api import Api as InternalApi + from wandb.sdk.internal.progress import ProgressFn + + class SaveFn(Protocol): + def __call__( + self, entry: ArtifactManifestEntry, progress_callback: ProgressFn + ) -> bool: + pass + + class SaveFnAsync(Protocol): + def __call__( + self, entry: ArtifactManifestEntry, progress_callback: ProgressFn + ) -> Awaitable[bool]: + pass + + +class ArtifactSaver: + _server_artifact: dict | None # TODO better define this dict + + def __init__( + self, + api: InternalApi, + digest: str, + manifest_json: dict, + file_pusher: FilePusher, + is_user_created: bool = False, + ) -> None: + self._api = api + self._file_pusher = file_pusher + self._digest = digest + self._manifest = ArtifactManifest.from_manifest_json(manifest_json) + self._manifest.storage_policy._api = self._api + self._is_user_created = is_user_created + self._server_artifact = None + + def save( + self, + entity: str, + project: str, + type: str, + name: str, + client_id: str, + sequence_client_id: str, + distributed_id: str | None = None, + finalize: bool = True, + metadata: dict | None = None, + ttl_duration_seconds: int | None = None, + description: str | None = None, + aliases: Sequence[str] | None = None, + tags: Sequence[str] | None = None, + use_after_commit: bool = False, + incremental: bool = False, + history_step: int | None = None, + base_id: str | None = None, + ) -> dict | None: + return self._save_internal( + entity, + project, + type, + name, + client_id, + sequence_client_id, + distributed_id, + finalize, + metadata, + ttl_duration_seconds, + description, + aliases, + tags, + use_after_commit, + incremental, + history_step, + base_id, + ) + + def _save_internal( + self, + entity: str, + project: str, + type: str, + name: str, + client_id: str, + sequence_client_id: str, + distributed_id: str | None = None, + finalize: bool = True, + metadata: dict | None = None, + ttl_duration_seconds: int | None = None, + description: str | None = None, + aliases: Sequence[str] | None = None, + tags: Sequence[str] | None = None, + use_after_commit: bool = False, + incremental: bool = False, + history_step: int | None = None, + base_id: str | None = None, + ) -> dict | None: + alias_specs = [] + for alias in aliases or []: + alias_specs.append({"artifactCollectionName": name, "alias": alias}) + + tag_specs = [{"tagName": tag} for tag in tags or []] + + """Returns the server artifact.""" + self._server_artifact, latest = self._api.create_artifact( + type, + name, + self._digest, + metadata=metadata, + ttl_duration_seconds=ttl_duration_seconds, + aliases=alias_specs, + tags=tag_specs, + description=description, + is_user_created=self._is_user_created, + distributed_id=distributed_id, + client_id=client_id, + sequence_client_id=sequence_client_id, + history_step=history_step, + ) + + assert self._server_artifact is not None # mypy optionality unwrapper + artifact_id = self._server_artifact["id"] + if base_id is None and latest: + base_id = latest["id"] + if self._server_artifact["state"] == "COMMITTED": + if use_after_commit: + self._api.use_artifact( + artifact_id, + artifact_entity_name=entity, + artifact_project_name=project, + ) + return self._server_artifact + if ( + self._server_artifact["state"] != "PENDING" + # For old servers, see https://github.com/wandb/wandb/pull/6190 + and self._server_artifact["state"] != "DELETED" + ): + raise Exception( + 'Unknown artifact state "{}"'.format(self._server_artifact["state"]) + ) + + manifest_type = "FULL" + manifest_filename = "wandb_manifest.json" + if incremental: + manifest_type = "INCREMENTAL" + manifest_filename = "wandb_manifest.incremental.json" + elif distributed_id: + manifest_type = "PATCH" + manifest_filename = "wandb_manifest.patch.json" + artifact_manifest_id, _ = self._api.create_artifact_manifest( + manifest_filename, + "", + artifact_id, + base_artifact_id=base_id, + include_upload=False, + type=manifest_type, + ) + + step_prepare = wandb.filesync.step_prepare.StepPrepare( + self._api, 0.1, 0.01, 1000 + ) # TODO: params + step_prepare.start() + + # Upload Artifact "L1" files, the actual artifact contents + self._file_pusher.store_manifest_files( + self._manifest, + artifact_id, + lambda entry, progress_callback: self._manifest.storage_policy.store_file( + artifact_id, + artifact_manifest_id, + entry, + step_prepare, + progress_callback=progress_callback, + ), + ) + + def before_commit() -> None: + self._resolve_client_id_manifest_references() + with tempfile.NamedTemporaryFile("w+", suffix=".json", delete=False) as fp: + path = os.path.abspath(fp.name) + json.dump(self._manifest.to_manifest_json(), fp, indent=4) + digest = md5_file_b64(path) + if distributed_id or incremental: + # If we're in the distributed flow, we want to update the + # patch manifest we created with our finalized digest. + _, resp = self._api.update_artifact_manifest( + artifact_manifest_id, + digest=digest, + ) + else: + # In the regular flow, we can recreate the full manifest with the + # updated digest. + # + # NOTE: We do this for backwards compatibility with older backends + # that don't support the 'updateArtifactManifest' API. + _, resp = self._api.create_artifact_manifest( + manifest_filename, + digest, + artifact_id, + base_artifact_id=base_id, + ) + + # We're duplicating the file upload logic a little, which isn't great. + upload_url = resp["uploadUrl"] + upload_headers = resp["uploadHeaders"] + extra_headers = {} + for upload_header in upload_headers: + key, val = upload_header.split(":", 1) + extra_headers[key] = val + with open(path, "rb") as fp2: + self._api.upload_file_retry( + upload_url, + fp2, + extra_headers=extra_headers, + ) + + commit_result: concurrent.futures.Future[None] = concurrent.futures.Future() + + # Queue the commit. It will only happen after all file uploads finish. + self._file_pusher.commit_artifact( + artifact_id, + finalize=finalize, + before_commit=before_commit, + result_future=commit_result, + ) + + # Block until all artifact files are uploaded and the + # artifact is committed. + try: + commit_result.result() + finally: + step_prepare.shutdown() + + if finalize and use_after_commit: + self._api.use_artifact( + artifact_id, + artifact_entity_name=entity, + artifact_project_name=project, + ) + + return self._server_artifact + + def _resolve_client_id_manifest_references(self) -> None: + for entry_path in self._manifest.entries: + entry = self._manifest.entries[entry_path] + if entry.ref is not None: + if entry.ref.startswith("wandb-client-artifact:"): + client_id = util.host_from_path(entry.ref) + artifact_file_path = util.uri_from_path(entry.ref) + artifact_id = self._api._resolve_client_id(client_id) + if artifact_id is None: + raise RuntimeError(f"Could not resolve client id {client_id}") + entry.ref = URIStr( + f"wandb-artifact://{b64_to_hex_id(B64MD5(artifact_id))}/{artifact_file_path}" + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_state.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_state.py new file mode 100644 index 0000000000000000000000000000000000000000..4df836c8b60e482e3de8a103b4b0e646800922a7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_state.py @@ -0,0 +1,13 @@ +"""Artifact state.""" + +from __future__ import annotations + +from enum import Enum + + +class ArtifactState(Enum): + PENDING = "PENDING" + COMMITTED = "COMMITTED" + DELETED = "DELETED" + GARBAGE_COLLECTED = "GARBAGE_COLLECTED" + PENDING_DELETION = "PENDING_DELETION" diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_ttl.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_ttl.py new file mode 100644 index 0000000000000000000000000000000000000000..e63c69031923aa36a2214526443861497f3575d6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/artifact_ttl.py @@ -0,0 +1,9 @@ +"""Artifact TTL.""" + +from __future__ import annotations + +from enum import Enum + + +class ArtifactTTL(Enum): + INHERIT = 0 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/exceptions.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..464d6a7b026171f534d3aa992b859f57434dccaa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/exceptions.py @@ -0,0 +1,72 @@ +"""Artifact exceptions.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + +from wandb import errors +from wandb._strutils import nameof + +if TYPE_CHECKING: + from wandb.sdk.artifacts.artifact import Artifact + + ArtifactT = TypeVar("ArtifactT", bound=Artifact) + + +class ArtifactStatusError(AttributeError): + """Raised when an artifact is in an invalid state for the requested operation.""" + + def __init__( + self, + msg: str = "Artifact is in an invalid state for the requested operation.", + name: str | None = None, + obj: ArtifactT | None = None, + ): + # Follow AttributeError (Python 3.10+) by exposing `name` and `obj`. + # See: https://docs.python.org/3/library/exceptions.html#AttributeError + try: + super().__init__(msg, name=name, obj=obj) + except TypeError: + # The `name`/`obj` keywords were only added in Python >= 3.10. + super().__init__(msg) + self.name = name or "" + self.obj = obj + + +class ArtifactNotLoggedError(ArtifactStatusError): + """Raised for Artifact methods or attributes only available after logging.""" + + def __init__(self, fullname: str, obj: ArtifactT): + *_, name = fullname.split(".") + msg = ( + f"{fullname!r} used prior to logging artifact or while in offline mode. " + f"Call {nameof(obj.wait)}() before accessing logged artifact properties." + ) + super().__init__(msg=msg, name=name, obj=obj) + + +class ArtifactFinalizedError(ArtifactStatusError): + """Raised for Artifact methods or attributes that can't be changed after logging.""" + + def __init__(self, fullname: str, obj: ArtifactT): + *_, name = fullname.split(".") + msg = f"{fullname!r} used on logged artifact. Can't modify finalized artifact." + super().__init__(msg=msg, name=name, obj=obj) + + +class WaitTimeoutError(errors.Error): + """Raised when wait() timeout occurs before process is finished.""" + + +class TooFewItemsError(ValueError): + """Raised when there are fewer items than expected in a collection. + + Intended for internal use only. + """ + + +class TooManyItemsError(ValueError): + """Raised when there are more items than expected in a collection. + + Intended for internal use only. + """ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/staging.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/staging.py new file mode 100644 index 0000000000000000000000000000000000000000..24a6313ca0507d1f4ee110570c80062644c08cb9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/staging.py @@ -0,0 +1,27 @@ +"""Manages artifact file staging. + +Artifact files are copied to the staging area as soon as they are added to an artifact +in order to avoid file changes corrupting the artifact. Once the upload is complete, the +file should be moved to the artifact cache. +""" + +from __future__ import annotations + +import os + +from wandb import env +from wandb.sdk.lib.filesystem import mkdir_exists_ok +from wandb.sdk.lib.paths import FilePathStr + + +def get_staging_dir() -> FilePathStr: + path = os.path.join(env.get_data_dir(), "artifacts", "staging") + try: + mkdir_exists_ok(path) + except OSError as e: + raise PermissionError( + f"Unable to write staging files to {path}. To fix this problem, please set " + f"{env.DATA_DIR} to a directory where you have the necessary write access." + ) from e + + return FilePathStr(os.path.abspath(os.path.expanduser(path))) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_handler.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c1883e5144cf8bf602b11a7012c0a85cb1025979 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_handler.py @@ -0,0 +1,68 @@ +"""Storage handler.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Final + +from wandb.sdk.lib.paths import FilePathStr, URIStr + +if TYPE_CHECKING: + from urllib.parse import ParseResult + + from wandb.sdk.artifacts.artifact import Artifact + from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry + +DEFAULT_MAX_OBJECTS: Final[int] = 10_000_000 # 10**7 + + +class _BaseStorageHandler(ABC): + @abstractmethod + def load_path( + self, + manifest_entry: ArtifactManifestEntry, + local: bool = False, + ) -> URIStr | FilePathStr: + """Load a file or directory given the corresponding index entry. + + Args: + manifest_entry: The index entry to load + local: Whether to load the file locally or not + + Returns: + A path to the file represented by `index_entry` + """ + raise NotImplementedError + + @abstractmethod + def store_path( + self, + artifact: Artifact, + path: URIStr | FilePathStr, + name: str | None = None, + checksum: bool = True, + max_objects: int | None = None, + ) -> list[ArtifactManifestEntry]: + """Store the file or directory at the given path to the specified artifact. + + Args: + path: The path to store + name: If specified, the logical name that should map to `path` + checksum: Whether to compute the checksum of the file + max_objects: The maximum number of objects to store + + Returns: + A list of manifest entries to store within the artifact + """ + raise NotImplementedError + + +class StorageHandler(_BaseStorageHandler, ABC): # Handles a single storage protocol + @abstractmethod + def can_handle(self, parsed_url: ParseResult) -> bool: + """Checks whether this handler can handle the given url. + + Returns: + Whether this handler can handle the given url. + """ + raise NotImplementedError diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_layout.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_layout.py new file mode 100644 index 0000000000000000000000000000000000000000..368b0cc1126e55e3fe48688d3dcdc1b826a8df20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_layout.py @@ -0,0 +1,16 @@ +"""Storage layout.""" + +from __future__ import annotations + +from enum import Enum + + +class StorageLayout(str, Enum): + V1 = "V1" + V2 = "V2" + + @classmethod + def from_env(cls) -> StorageLayout: + from wandb.env import get_use_v1_artifacts + + return cls.V1 if get_use_v1_artifacts() else cls.V2 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_policy.py b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb2a230c5e57d804f61fb0c1ab104eb3d924b87 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/artifacts/storage_policy.py @@ -0,0 +1,89 @@ +"""Storage policy.""" + +from __future__ import annotations + +import concurrent.futures +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +from wandb.sdk.internal.internal_api import Api as InternalApi +from wandb.sdk.lib.paths import FilePathStr, URIStr + +if TYPE_CHECKING: + from wandb.filesync.step_prepare import StepPrepare + from wandb.sdk.artifacts._models.storage import StoragePolicyConfig + from wandb.sdk.artifacts.artifact import Artifact + from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry + from wandb.sdk.internal.progress import ProgressFn + + +_POLICY_REGISTRY: dict[str, type[StoragePolicy]] = {} + + +class StoragePolicy(ABC): + _api: InternalApi | None = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + _POLICY_REGISTRY[cls.name()] = cls + + @classmethod + def lookup_by_name(cls, name: str) -> type[StoragePolicy]: + if policy := _POLICY_REGISTRY.get(name): + return policy + raise ValueError(f"Failed to find storage policy {name!r}") + + @classmethod + @abstractmethod + def name(cls) -> str: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: StoragePolicyConfig) -> StoragePolicy: + raise NotImplementedError + + @abstractmethod + def config(self) -> dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def load_file( + self, + artifact: Artifact, + manifest_entry: ArtifactManifestEntry, + dest_path: str | None = None, + executor: concurrent.futures.Executor | None = None, + ) -> FilePathStr: + raise NotImplementedError + + @abstractmethod + def store_file( + self, + artifact_id: str, + artifact_manifest_id: str, + entry: ArtifactManifestEntry, + preparer: StepPrepare, + progress_callback: ProgressFn | None = None, + ) -> bool: + raise NotImplementedError + + @abstractmethod + def store_reference( + self, + artifact: Artifact, + path: URIStr | FilePathStr, + name: str | None = None, + checksum: bool = True, + max_objects: int | None = None, + ) -> list[ArtifactManifestEntry]: + raise NotImplementedError + + @abstractmethod + def load_reference( + self, + manifest_entry: ArtifactManifestEntry, + local: bool = False, + dest_path: str | None = None, + ) -> FilePathStr | URIStr: + raise NotImplementedError diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/__init__.py b/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/auto_logging.py b/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/auto_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3d36b9ce29299fdf2cc3d72b8878ab2e4db9ba --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/auto_logging.py @@ -0,0 +1,232 @@ +import asyncio +import functools +import inspect +import logging +from typing import Any, Dict, Optional, Protocol, Sequence, TypeVar + +import wandb.sdk +import wandb.util +from wandb.sdk.lib import telemetry as wb_telemetry +from wandb.sdk.lib.timer import Timer + +logger = logging.getLogger(__name__) + + +AutologInitArgs = Optional[Dict[str, Any]] + + +K = TypeVar("K", bound=str) +V = TypeVar("V") + + +class Response(Protocol[K, V]): + def __getitem__(self, key: K) -> V: ... # pragma: no cover + + def get( + self, key: K, default: Optional[V] = None + ) -> Optional[V]: ... # pragma: no cover + + +class ArgumentResponseResolver(Protocol): + def __call__( + self, + args: Sequence[Any], + kwargs: Dict[str, Any], + response: Response, + start_time: float, + time_elapsed: float, + ) -> Optional[Dict[str, Any]]: ... # pragma: no cover + + +class PatchAPI: + def __init__( + self, + name: str, + symbols: Sequence[str], + resolver: ArgumentResponseResolver, + ) -> None: + """Patches the API to log wandb Media or metrics.""" + # name of the LLM provider, e.g. "Cohere" or "OpenAI" or package name like "Transformers" + self.name = name + # api library name, e.g. "cohere" or "openai" or "transformers" + self._api = None + # dictionary of original methods + self.original_methods: Dict[str, Any] = {} + # list of symbols to patch, e.g. ["Client.generate", "Edit.create"] or ["Pipeline.__call__"] + self.symbols = symbols + # resolver callable to convert args/response into a dictionary of wandb media objects or metrics + self.resolver = resolver + + @property + def set_api(self) -> Any: + """Returns the API module.""" + lib_name = self.name.lower() + if self._api is None: + self._api = wandb.util.get_module( + name=lib_name, + required=f"To use the W&B {self.name} Autolog, " + f"you need to have the `{lib_name}` python " + f"package installed. Please install it with `pip install {lib_name}`.", + lazy=False, + ) + return self._api + + def patch(self, run: "wandb.Run") -> None: + """Patches the API to log media or metrics to W&B.""" + for symbol in self.symbols: + # split on dots, e.g. "Client.generate" -> ["Client", "generate"] + symbol_parts = symbol.split(".") + # and get the attribute from the module + original = functools.reduce(getattr, symbol_parts, self.set_api) + + def method_factory(original_method: Any): + async def async_method(*args, **kwargs): + future = asyncio.Future() + + async def callback(coro): + try: + result = await coro + loggable_dict = self.resolver( + args, kwargs, result, timer.start_time, timer.elapsed + ) + if loggable_dict is not None: + run.log(loggable_dict) + future.set_result(result) + except Exception as e: + logger.warning(e) + + with Timer() as timer: + coro = original_method(*args, **kwargs) + asyncio.ensure_future(callback(coro)) + + return await future + + def sync_method(*args, **kwargs): + with Timer() as timer: + result = original_method(*args, **kwargs) + try: + loggable_dict = self.resolver( + args, kwargs, result, timer.start_time, timer.elapsed + ) + if loggable_dict is not None: + run.log(loggable_dict) + except Exception as e: + logger.warning(e) + return result + + if inspect.iscoroutinefunction(original_method): + return functools.wraps(original_method)(async_method) + else: + return functools.wraps(original_method)(sync_method) + + # save original method + self.original_methods[symbol] = original + # monkey patch the method + if len(symbol_parts) == 1: + setattr(self.set_api, symbol_parts[0], method_factory(original)) + else: + setattr( + functools.reduce(getattr, symbol_parts[:-1], self.set_api), + symbol_parts[-1], + method_factory(original), + ) + + def unpatch(self) -> None: + """Unpatches the API.""" + for symbol, original in self.original_methods.items(): + # split on dots, e.g. "Client.generate" -> ["Client", "generate"] + symbol_parts = symbol.split(".") + # unpatch the method + if len(symbol_parts) == 1: + setattr(self.set_api, symbol_parts[0], original) + else: + setattr( + functools.reduce(getattr, symbol_parts[:-1], self.set_api), + symbol_parts[-1], + original, + ) + + +class AutologAPI: + def __init__( + self, + name: str, + symbols: Sequence[str], + resolver: ArgumentResponseResolver, + telemetry_feature: Optional[str] = None, + ) -> None: + """Autolog API calls to W&B.""" + self._telemetry_feature = telemetry_feature + self._patch_api = PatchAPI( + name=name, + symbols=symbols, + resolver=resolver, + ) + self._name = self._patch_api.name + self._run: Optional[wandb.Run] = None + self.__run_created_by_autolog: bool = False + + @property + def _is_enabled(self) -> bool: + """Returns whether autologging is enabled.""" + return self._run is not None + + def __call__(self, init: AutologInitArgs = None) -> None: + """Enable autologging.""" + self.enable(init=init) + + def _run_init(self, init: AutologInitArgs = None) -> None: + """Handle wandb run initialization.""" + # - autolog(init: dict = {...}) calls wandb.init(**{...}) + # regardless of whether there is a wandb.run or not, + # we only track if the run was created by autolog + # - todo: autolog(init: dict | run = run) would use the user-provided run + # - autolog() uses the wandb.run if there is one, otherwise it calls wandb.init() + if init: + _wandb_run = wandb.run + # we delegate dealing with the init dict to wandb.init() + self._run = wandb.init(**init) + if _wandb_run != self._run: + self.__run_created_by_autolog = True + elif wandb.run is None: + self._run = wandb.init() + self.__run_created_by_autolog = True + else: + self._run = wandb.run + + def enable(self, init: AutologInitArgs = None) -> None: + """Enable autologging. + + Args: + init: Optional dictionary of arguments to pass to wandb.init(). + + """ + if self._is_enabled: + logger.info( + f"{self._name} autologging is already enabled, disabling and re-enabling." + ) + self.disable() + + logger.info(f"Enabling {self._name} autologging.") + self._run_init(init=init) + + self._patch_api.patch(self._run) + + if self._telemetry_feature: + with wb_telemetry.context(self._run) as tel: + setattr(tel.feature, self._telemetry_feature, True) + + def disable(self) -> None: + """Disable autologging.""" + if self._run is None: + return + + logger.info(f"Disabling {self._name} autologging.") + + if self.__run_created_by_autolog: + self._run.finish() + self.__run_created_by_autolog = False + + self._run = None + + self._patch_api.unpatch() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/data_logging.py b/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/data_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b971cc17211228132698f4279b7b7db1a81bed --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/integration_utils/data_logging.py @@ -0,0 +1,475 @@ +# wandb.integrations.data_logging.py +# +# Contains common utility functions that enable +# logging datasets and predictions to wandb. +import sys +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + +import wandb + +if TYPE_CHECKING: + from wandb.data_types import _TableIndex + +CAN_INFER_IMAGE_AND_VIDEO = sys.version_info.major == 3 and sys.version_info.minor >= 5 + + +class ValidationDataLogger: + """Logs validation data as a wandb.Table. + + ValidationDataLogger is intended to be used inside of library integrations + in order to facilitate the process of optionally building a validation dataset + and logging periodic predictions against such validation data using WandB best + practices. + """ + + validation_inputs: Union[Sequence, Dict[str, Sequence]] + validation_targets: Optional[Union[Sequence, Dict[str, Sequence]]] + validation_indexes: List["_TableIndex"] + prediction_row_processor: Optional[Callable] + class_labels_table: Optional["wandb.Table"] + infer_missing_processors: bool + + def __init__( + self, + inputs: Union[Sequence, Dict[str, Sequence]], + targets: Optional[Union[Sequence, Dict[str, Sequence]]] = None, + indexes: Optional[List["_TableIndex"]] = None, + validation_row_processor: Optional[Callable] = None, + prediction_row_processor: Optional[Callable] = None, + input_col_name: str = "input", + target_col_name: str = "target", + table_name: str = "wb_validation_data", + artifact_type: str = "validation_dataset", + class_labels: Optional[List[str]] = None, + infer_missing_processors: bool = True, + ) -> None: + """Initialize a new ValidationDataLogger. + + Args: + inputs: A list of input vectors or dictionary of lists of input vectors + (used if the model has multiple named inputs) + targets: A list of target vectors or dictionary of lists of target vectors + (used if the model has multiple named targets/putputs). Defaults to `None`. + `targets` and `indexes` cannot both be `None`. + indexes: An ordered list of `wandb.data_types._TableIndex` mapping the + input items to their source table. This is most commonly retrieved by using + `indexes = my_data_table.get_index()`. Defaults to `None`. `targets` + and `indexes` cannot both be `None`. + validation_row_processor: A function to apply to the validation data, + commonly used to visualize the data. The function will receive an `ndx` (`int`) + and a `row` (`dict`). If `inputs` is a list, then `row["input"]` will be the input + data for the row. Else, it will be keyed based on the name of the input slot + (corresponding to `inputs`). If `targets` is a list, then + `row["target"]` will be the target data for the row. Else, it will + be keyed based on `targets`. For example, if your input data is a + single ndarray, but you wish to visualize the data as an image, + then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}` + as the processor. If `None`, we will try to guess the appropriate processor. + Ignored if `log_evaluation` is `False` or `val_keys` are present. Defaults to `None`. + prediction_row_processor: Same as validation_row_processor, but applied to the + model's output. `row["output"]` will contain the results of the model output. + Defaults to `None`. + input_col_name: The name to use for the input column. + Defaults to `"input"`. + target_col_name: The name to use for the target column. + Defaults to `"target"`. + table_name: The name to use for the validation table. + Defaults to `"wb_validation_data"`. + artifact_type: The artifact type to use for the validation data. + Defaults to `"validation_dataset"`. + class_labels: Optional list of labels to use in the inferred + processors. If the model's `target` or `output` is inferred to be a class, + we will attempt to map the class to these labels. Defaults to `None`. + infer_missing_processors: Determines if processors are inferred if + they are missing. Defaults to True. + """ + class_labels_table: Optional[wandb.Table] + if isinstance(class_labels, list) and len(class_labels) > 0: + class_labels_table = wandb.Table( + columns=["label"], data=[[label] for label in class_labels] + ) + else: + class_labels_table = None + + if indexes is None: + assert targets is not None + local_validation_table = wandb.Table(columns=[], data=[]) + + if isinstance(targets, dict): + for col_name in targets: + local_validation_table.add_column(col_name, targets[col_name]) + else: + local_validation_table.add_column(target_col_name, targets) + + if isinstance(inputs, dict): + for col_name in inputs: + local_validation_table.add_column(col_name, inputs[col_name]) + else: + local_validation_table.add_column(input_col_name, inputs) + + if validation_row_processor is None and infer_missing_processors: + example_input = _make_example(inputs) + example_target = _make_example(targets) + if example_input is not None and example_target is not None: + validation_row_processor = _infer_validation_row_processor( + example_input, + example_target, + class_labels_table, + input_col_name, + target_col_name, + ) + + if validation_row_processor is not None: + local_validation_table.add_computed_columns(validation_row_processor) + + local_validation_artifact = wandb.Artifact(table_name, artifact_type) + local_validation_artifact.add(local_validation_table, "validation_data") + if wandb.run: + wandb.run.use_artifact(local_validation_artifact) + indexes = local_validation_table.get_index() + else: + local_validation_artifact = None + + self.class_labels_table = class_labels_table + self.validation_inputs = inputs + self.validation_targets = targets + self.validation_indexes = indexes + self.prediction_row_processor = prediction_row_processor + self.infer_missing_processors = infer_missing_processors + self.local_validation_artifact = local_validation_artifact + self.input_col_name = input_col_name + + def make_predictions( + self, predict_fn: Callable + ) -> Union[Sequence, Dict[str, Sequence]]: + """Produce predictions by passing `validation_inputs` to `predict_fn`. + + Args: + predict_fn (Callable): Any function which can accept `validation_inputs` and produce + a list of vectors or dictionary of lists of vectors + + Returns: + (Sequence | Dict[str, Sequence]): The returned value of predict_fn + """ + return predict_fn(self.validation_inputs) + + def log_predictions( + self, + predictions: Union[Sequence, Dict[str, Sequence]], + prediction_col_name: str = "output", + val_ndx_col_name: str = "val_row", + table_name: str = "validation_predictions", + commit: bool = True, + ) -> wandb.data_types.Table: + """Log a set of predictions. + + Intended usage: + + vl.log_predictions(vl.make_predictions(self.model.predict)) + + Args: + predictions (Sequence | Dict[str, Sequence]): A list of prediction vectors or dictionary + of lists of prediction vectors + prediction_col_name (str, optional): the name of the prediction column. Defaults to "output". + val_ndx_col_name (str, optional): The name of the column linking prediction table + to the validation ata table. Defaults to "val_row". + table_name (str, optional): name of the prediction table. Defaults to "validation_predictions". + commit (bool, optional): determines if commit should be called on the logged data. Defaults to False. + """ + pred_table = wandb.Table(columns=[], data=[]) + if isinstance(predictions, dict): + for col_name in predictions: + pred_table.add_column(col_name, predictions[col_name]) + else: + pred_table.add_column(prediction_col_name, predictions) + pred_table.add_column(val_ndx_col_name, self.validation_indexes) + + if self.prediction_row_processor is None and self.infer_missing_processors: + example_prediction = _make_example(predictions) + example_input = _make_example(self.validation_inputs) + if example_prediction is not None and example_input is not None: + self.prediction_row_processor = _infer_prediction_row_processor( + example_prediction, + example_input, + self.class_labels_table, + self.input_col_name, + prediction_col_name, + ) + + if self.prediction_row_processor is not None: + pred_table.add_computed_columns(self.prediction_row_processor) + + wandb.log({table_name: pred_table}, commit=commit) + return pred_table + + +def _make_example(data: Any) -> Optional[Union[Dict, Sequence, Any]]: + """Used to make an example input, target, or output.""" + example: Optional[Union[Dict, Sequence, Any]] + + if isinstance(data, dict): + example = {} + for key in data: + example[key] = data[key][0] + elif hasattr(data, "__len__"): + example = data[0] + else: + example = None + + return example + + +def _get_example_shape(example: Union[Sequence, Any]): + """Get the shape of an object if applicable.""" + shape = [] + if not isinstance(example, str) and hasattr(example, "__len__"): + length = len(example) + shape = [length] + if length > 0: + shape += _get_example_shape(example[0]) + return shape + + +def _bind(lambda_fn: Callable, **closure_kwargs: Any) -> Callable: + """Create a closure around a lambda function by binding `closure_kwargs` to the function.""" + + def closure(*args: Any, **kwargs: Any) -> Any: + _k = {} + _k.update(kwargs) + _k.update(closure_kwargs) + return lambda_fn(*args, **_k) + + return closure + + +def _infer_single_example_keyed_processor( + example: Union[Sequence, Any], + class_labels_table: Optional["wandb.Table"] = None, + possible_base_example: Optional[Union[Sequence, Any]] = None, +) -> Dict[str, Callable]: + """Infers a processor from a single example. + + Infers a processor from a single example with optional class_labels_table + and base_example. Base example is useful for cases such as segmentation masks + """ + shape = _get_example_shape(example) + processors: Dict[str, Callable] = {} + if ( + class_labels_table is not None + and len(shape) == 1 + and shape[0] == len(class_labels_table.data) + ): + np = wandb.util.get_module( + "numpy", + required="Inferring processors require numpy", + ) + # Assume these are logits + class_names = class_labels_table.get_column("label") + + processors["max_class"] = lambda n, d, p: class_labels_table.index_ref( # type: ignore + np.argmax(d) + ) + # TODO: Consider adding back if users ask + # processors["min_class"] = lambda n, d, p: class_labels_table.index_ref( # type: ignore + # np.argmin(d) + # ) + + values = np.unique(example) + is_one_hot = len(values) == 2 and set(values) == {0, 1} + if not is_one_hot: + processors["score"] = lambda n, d, p: { + class_names[i]: d[i] for i in range(shape[0]) + } + elif ( + len(shape) == 1 + and shape[0] == 1 + and ( + isinstance(example[0], int) + or (hasattr(example, "tolist") and isinstance(example.tolist()[0], int)) # type: ignore + ) + ): + # assume this is a class + if class_labels_table is not None: + processors["class"] = ( + lambda n, d, p: class_labels_table.index_ref(d[0]) + if d[0] < len(class_labels_table.data) + else d[0] + ) # type: ignore + else: + processors["val"] = lambda n, d, p: d[0] + elif len(shape) == 1: + np = wandb.util.get_module( + "numpy", + required="Inferring processors require numpy", + ) + # This could be anything + if shape[0] <= 10: + # if less than 10, fan out the results + # processors["node"] = lambda n, d, p: {i: d[i] for i in range(shape[0])} + processors["node"] = lambda n, d, p: [ + d[i].tolist() if hasattr(d[i], "tolist") else d[i] + for i in range(shape[0]) + ] + # just report the argmax and argmin + processors["argmax"] = lambda n, d, p: np.argmax(d) + + values = np.unique(example) + is_one_hot = len(values) == 2 and set(values) == {0, 1} + if not is_one_hot: + processors["argmin"] = lambda n, d, p: np.argmin(d) + elif len(shape) == 2 and CAN_INFER_IMAGE_AND_VIDEO: + if ( + class_labels_table is not None + and possible_base_example is not None + and shape == _get_example_shape(possible_base_example) + ): + # consider this a segmentation mask + processors["image"] = lambda n, d, p: wandb.Image( + p, + masks={ + "masks": { + "mask_data": d, + "class_labels": class_labels_table.get_column("label"), # type: ignore + } + }, + ) + else: + # consider this a 2d image + processors["image"] = lambda n, d, p: wandb.Image(d) + elif len(shape) == 3 and CAN_INFER_IMAGE_AND_VIDEO: + # consider this an image + processors["image"] = lambda n, d, p: wandb.Image(d) + elif len(shape) == 4 and CAN_INFER_IMAGE_AND_VIDEO: + # consider this a video + processors["video"] = lambda n, d, p: wandb.Video(d) + + return processors + + +def _infer_validation_row_processor( + example_input: Union[Dict, Sequence], + example_target: Union[Dict, Sequence, Any], + class_labels_table: Optional["wandb.Table"] = None, + input_col_name: str = "input", + target_col_name: str = "target", +) -> Callable: + """Infers the composite processor for the validation data.""" + single_processors = {} + if isinstance(example_input, dict): + for key in example_input: + key_processors = _infer_single_example_keyed_processor(example_input[key]) + for p_key in key_processors: + single_processors[f"{key}:{p_key}"] = _bind( + lambda ndx, row, key_processor, key: key_processor( + ndx, + row[key], + None, + ), + key_processor=key_processors[p_key], + key=key, + ) + else: + key = input_col_name + key_processors = _infer_single_example_keyed_processor(example_input) + for p_key in key_processors: + single_processors[f"{key}:{p_key}"] = _bind( + lambda ndx, row, key_processor, key: key_processor( + ndx, + row[key], + None, + ), + key_processor=key_processors[p_key], + key=key, + ) + + if isinstance(example_target, dict): + for key in example_target: + key_processors = _infer_single_example_keyed_processor( + example_target[key], class_labels_table + ) + for p_key in key_processors: + single_processors[f"{key}:{p_key}"] = _bind( + lambda ndx, row, key_processor, key: key_processor( + ndx, + row[key], + None, + ), + key_processor=key_processors[p_key], + key=key, + ) + else: + key = target_col_name + key_processors = _infer_single_example_keyed_processor( + example_target, + class_labels_table, + example_input if not isinstance(example_input, dict) else None, + ) + for p_key in key_processors: + single_processors[f"{key}:{p_key}"] = _bind( + lambda ndx, row, key_processor, key: key_processor( + ndx, + row[key], + row[input_col_name] + if not isinstance(example_input, dict) + else None, + ), + key_processor=key_processors[p_key], + key=key, + ) + + def processor(ndx, row): + return {key: single_processors[key](ndx, row) for key in single_processors} + + return processor + + +def _infer_prediction_row_processor( + example_prediction: Union[Dict, Sequence], + example_input: Union[Dict, Sequence], + class_labels_table: Optional["wandb.Table"] = None, + input_col_name: str = "input", + output_col_name: str = "output", +) -> Callable: + """Infers the composite processor for the prediction output data.""" + single_processors = {} + + if isinstance(example_prediction, dict): + for key in example_prediction: + key_processors = _infer_single_example_keyed_processor( + example_prediction[key], class_labels_table + ) + for p_key in key_processors: + single_processors[f"{key}:{p_key}"] = _bind( + lambda ndx, row, key_processor, key: key_processor( + ndx, + row[key], + None, + ), + key_processor=key_processors[p_key], + key=key, + ) + else: + key = output_col_name + key_processors = _infer_single_example_keyed_processor( + example_prediction, + class_labels_table, + example_input if not isinstance(example_input, dict) else None, + ) + for p_key in key_processors: + single_processors[f"{key}:{p_key}"] = _bind( + lambda ndx, row, key_processor, key: key_processor( + ndx, + row[key], + ndx.get_row().get("val_row").get_row().get(input_col_name) + if not isinstance(example_input, dict) + else None, + ), + key_processor=key_processors[p_key], + key=key, + ) + + def processor(ndx, row): + return {key: single_processors[key](ndx, row) for key in single_processors} + + return processor diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__init__.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14bcea792ed68afc9a3432846fa14b084e9b001b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7ce63bdfc40ccae83ad684fbb25119f9447a0f5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_queue.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_queue.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd3588718844e918cdb53c0d1ccd8648eaafe219 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_queue.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_shared.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_shared.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1246cd327e521001f966e956f400633f2ccd24ff Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_shared.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_sock.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_sock.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44e42ce026c05924da08bac334f5cc9546f51043 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/interface_sock.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/summary_record.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/summary_record.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15d80336fc3d8a017f539c4907f97d33490d4c9d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/__pycache__/summary_record.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/constants.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..09fe1e26f0ce72e23df8b627ff15e332f6ebc4e6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/constants.py @@ -0,0 +1,4 @@ +# +NOTIFY_PROCESS = 1 +NOTIFY_SHUTDOWN = 2 +NOTIFY_REQUEST = 3 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..90d60c92cdad26fd527cfb9e4c9620843bc107e7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface.py @@ -0,0 +1,1086 @@ +from __future__ import annotations + +import abc +import gzip +import logging +import time +from pathlib import Path +from secrets import token_hex +from typing import TYPE_CHECKING, Any, Iterable + +from wandb import termwarn +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto import wandb_telemetry_pb2 as tpb +from wandb.sdk.lib import json_util as json +from wandb.sdk.lib.filesystem import FilesDict, PolicyName +from wandb.sdk.mailbox import HandleAbandonedError, MailboxHandle +from wandb.util import ( + WandBJSONEncoderOld, + get_h5_typename, + json_dumps_safer, + json_dumps_safer_history, + json_friendly, + json_friendly_val, + maybe_compress_summary, +) + +from ..data_types.utils import history_dict_to_json, val_to_json +from . import summary_record as sr + +MANIFEST_FILE_SIZE_THRESHOLD = 100_000 + +if TYPE_CHECKING: + from wandb.sdk.artifacts.artifact import Artifact + from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest + + from ..wandb_run import Run + + +logger = logging.getLogger("wandb") + + +def file_policy_to_enum(policy: PolicyName) -> pb.FilesItem.PolicyType.V: + if policy == "now": + enum = pb.FilesItem.PolicyType.NOW + elif policy == "end": + enum = pb.FilesItem.PolicyType.END + elif policy == "live": + enum = pb.FilesItem.PolicyType.LIVE + return enum + + +def file_enum_to_policy(enum: pb.FilesItem.PolicyType.V) -> PolicyName: + if enum == pb.FilesItem.PolicyType.NOW: + policy: PolicyName = "now" + elif enum == pb.FilesItem.PolicyType.END: + policy = "end" + elif enum == pb.FilesItem.PolicyType.LIVE: + policy = "live" + return policy + + +class InterfaceBase(abc.ABC): + """Methods for sending run messages (Records) to the service. + + None of the methods may be called from an asyncio context other than + deliver_async() or those with a `nowait=True` argument. + """ + + _drop: bool + + def __init__(self) -> None: + self._drop = False + + @abc.abstractmethod + async def deliver_async( + self, + record: pb.Record, + ) -> MailboxHandle[pb.Result]: + """Send a record and create a handle to wait for the response. + + The synchronous publish and deliver methods on this class cannot be + called in the asyncio thread because they block. Instead of having + an async copy of every method, this is a general method for sending + any kind of record in the asyncio thread. + + Args: + record: The record to send. This method takes ownership of the + record and it must not be used afterward. + + Returns: + A handle to wait for a response to the record. + """ + + def publish_header(self) -> None: + header = pb.HeaderRecord() + self._publish_header(header) + + @abc.abstractmethod + def _publish_header(self, header: pb.HeaderRecord) -> None: + raise NotImplementedError + + def deliver_status(self) -> MailboxHandle[pb.Result]: + return self._deliver_status(pb.StatusRequest()) + + @abc.abstractmethod + def _deliver_status( + self, + status: pb.StatusRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def _make_config( + self, + data: dict | None = None, + key: tuple[str, ...] | str | None = None, + val: Any | None = None, + obj: pb.ConfigRecord | None = None, + ) -> pb.ConfigRecord: + config = obj or pb.ConfigRecord() + if data: + for k, v in data.items(): + update = config.update.add() + update.key = k + update.value_json = json_dumps_safer(json_friendly(v)[0]) + if key: + update = config.update.add() + if isinstance(key, tuple): + for k in key: + update.nested_key.append(k) + else: + update.key = key + update.value_json = json_dumps_safer(json_friendly(val)[0]) + return config + + def _make_run(self, run: Run) -> pb.RunRecord: # noqa: C901 + proto_run = pb.RunRecord() + if run._settings.entity is not None: + proto_run.entity = run._settings.entity + if run._settings.project is not None: + proto_run.project = run._settings.project + if run._settings.run_group is not None: + proto_run.run_group = run._settings.run_group + if run._settings.run_job_type is not None: + proto_run.job_type = run._settings.run_job_type + if run._settings.run_id is not None: + proto_run.run_id = run._settings.run_id + if run._settings.run_name is not None: + proto_run.display_name = run._settings.run_name + if run._settings.run_notes is not None: + proto_run.notes = run._settings.run_notes + if run._settings.run_tags is not None: + proto_run.tags.extend(run._settings.run_tags) + if run._start_time is not None: + proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6)) + if run._starting_step is not None: + proto_run.starting_step = run._starting_step + if run._settings.git_remote_url is not None: + proto_run.git.remote_url = run._settings.git_remote_url + if run._settings.git_commit is not None: + proto_run.git.commit = run._settings.git_commit + if run._settings.sweep_id is not None: + proto_run.sweep_id = run._settings.sweep_id + if run._settings.host: + proto_run.host = run._settings.host + if run._settings.resumed: + proto_run.resumed = run._settings.resumed + if run._settings.fork_from: + run_moment = run._settings.fork_from + proto_run.branch_point.run = run_moment.run + proto_run.branch_point.metric = run_moment.metric + proto_run.branch_point.value = run_moment.value + if run._settings.resume_from: + run_moment = run._settings.resume_from + proto_run.branch_point.run = run_moment.run + proto_run.branch_point.metric = run_moment.metric + proto_run.branch_point.value = run_moment.value + if run._forked: + proto_run.forked = run._forked + if run._config is not None: + config_dict = run._config._as_dict() # type: ignore + self._make_config(data=config_dict, obj=proto_run.config) + if run._telemetry_obj: + proto_run.telemetry.MergeFrom(run._telemetry_obj) + if run._start_runtime: + proto_run.runtime = run._start_runtime + return proto_run + + def publish_run(self, run: Run) -> None: + run_record = self._make_run(run) + self._publish_run(run_record) + + @abc.abstractmethod + def _publish_run(self, run: pb.RunRecord) -> None: + raise NotImplementedError + + def publish_cancel(self, cancel_slot: str) -> None: + cancel = pb.CancelRequest(cancel_slot=cancel_slot) + self._publish_cancel(cancel) + + @abc.abstractmethod + def _publish_cancel(self, cancel: pb.CancelRequest) -> None: + raise NotImplementedError + + def publish_config( + self, + data: dict | None = None, + key: tuple[str, ...] | str | None = None, + val: Any | None = None, + ) -> None: + cfg = self._make_config(data=data, key=key, val=val) + + self._publish_config(cfg) + + @abc.abstractmethod + def _publish_config(self, cfg: pb.ConfigRecord) -> None: + raise NotImplementedError + + @abc.abstractmethod + def _publish_metric(self, metric: pb.MetricRecord) -> None: + raise NotImplementedError + + def _make_summary_from_dict(self, summary_dict: dict) -> pb.SummaryRecord: + summary = pb.SummaryRecord() + for k, v in summary_dict.items(): + update = summary.update.add() + update.key = k + update.value_json = json.dumps(v) + return summary + + def _summary_encode( + self, + value: Any, + path_from_root: str, + run: Run, + ) -> dict: + """Normalize, compress, and encode sub-objects for backend storage. + + value: Object to encode. + path_from_root: `str` dot separated string from the top-level summary to the + current `value`. + + Returns: + A new tree of dict's with large objects replaced with dictionaries + with "_type" entries that say which type the original data was. + """ + # Constructs a new `dict` tree in `json_value` that discards and/or + # encodes objects that aren't JSON serializable. + + if isinstance(value, dict): + json_value = {} + for key, value in value.items(): # noqa: B020 + json_value[key] = self._summary_encode( + value, + path_from_root + "." + key, + run=run, + ) + return json_value + else: + friendly_value, converted = json_friendly( + val_to_json(run, path_from_root, value, namespace="summary") + ) + json_value, compressed = maybe_compress_summary( + friendly_value, get_h5_typename(value) + ) + if compressed: + # TODO(jhr): impleement me + pass + # self.write_h5(path_from_root, friendly_value) + + return json_value + + def _make_summary( + self, + summary_record: sr.SummaryRecord, + run: Run, + ) -> pb.SummaryRecord: + pb_summary_record = pb.SummaryRecord() + + for item in summary_record.update: + pb_summary_item = pb_summary_record.update.add() + key_length = len(item.key) + + assert key_length > 0 + + if key_length > 1: + pb_summary_item.nested_key.extend(item.key) + else: + pb_summary_item.key = item.key[0] + + path_from_root = ".".join(item.key) + json_value = self._summary_encode( + item.value, + path_from_root, + run=run, + ) + json_value, _ = json_friendly(json_value) # type: ignore + + pb_summary_item.value_json = json.dumps( + json_value, + cls=WandBJSONEncoderOld, + ) + + for item in summary_record.remove: + pb_summary_item = pb_summary_record.remove.add() + key_length = len(item.key) + + assert key_length > 0 + + if key_length > 1: + pb_summary_item.nested_key.extend(item.key) + else: + pb_summary_item.key = item.key[0] + + return pb_summary_record + + def publish_summary( + self, + run: Run, + summary_record: sr.SummaryRecord, + ) -> None: + pb_summary_record = self._make_summary(summary_record, run=run) + self._publish_summary(pb_summary_record) + + @abc.abstractmethod + def _publish_summary(self, summary: pb.SummaryRecord) -> None: + raise NotImplementedError + + def _make_files(self, files_dict: FilesDict) -> pb.FilesRecord: + files = pb.FilesRecord() + for path, policy in files_dict["files"]: + f = files.files.add() + f.path = path + f.policy = file_policy_to_enum(policy) + return files + + def publish_files(self, files_dict: FilesDict) -> None: + files = self._make_files(files_dict) + self._publish_files(files) + + @abc.abstractmethod + def _publish_files(self, files: pb.FilesRecord) -> None: + raise NotImplementedError + + def publish_python_packages(self, working_set) -> None: + python_packages = pb.PythonPackagesRequest() + for pkg in working_set: + python_packages.package.add(name=pkg.key, version=pkg.version) + self._publish_python_packages(python_packages) + + @abc.abstractmethod + def _publish_python_packages( + self, python_packages: pb.PythonPackagesRequest + ) -> None: + raise NotImplementedError + + def _make_artifact(self, artifact: Artifact) -> pb.ArtifactRecord: + proto_artifact = pb.ArtifactRecord() + proto_artifact.type = artifact.type + proto_artifact.name = artifact.name + proto_artifact.client_id = artifact._client_id + proto_artifact.sequence_client_id = artifact._sequence_client_id + proto_artifact.digest = artifact.digest + if artifact.distributed_id: + proto_artifact.distributed_id = artifact.distributed_id + if artifact.description: + proto_artifact.description = artifact.description + if artifact.metadata: + proto_artifact.metadata = json.dumps(json_friendly_val(artifact.metadata)) + if artifact._base_id: + proto_artifact.base_id = artifact._base_id + + ttl_duration_input = artifact._ttl_duration_seconds_to_gql() + if ttl_duration_input: + proto_artifact.ttl_duration_seconds = ttl_duration_input + proto_artifact.incremental_beta1 = artifact.incremental + self._make_artifact_manifest(artifact.manifest, obj=proto_artifact.manifest) + return proto_artifact + + def _make_artifact_manifest( + self, + artifact_manifest: ArtifactManifest, + obj: pb.ArtifactManifest | None = None, + ) -> pb.ArtifactManifest: + proto_manifest = obj or pb.ArtifactManifest() + proto_manifest.version = artifact_manifest.version() + proto_manifest.storage_policy = artifact_manifest.storage_policy.name() + + # Very large manifests need to be written to file to avoid protobuf size limits. + if len(artifact_manifest) > MANIFEST_FILE_SIZE_THRESHOLD: + path = self._write_artifact_manifest_file(artifact_manifest) + proto_manifest.manifest_file_path = path + return proto_manifest + + # Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now. + # NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go + # The creation logic is in artifacts/_factories.py make_storage_policy + for k, v in artifact_manifest.storage_policy.config().items() or {}.items(): + cfg = proto_manifest.storage_policy_config.add() + cfg.key = k + # TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto? + cfg.value_json = json.dumps(v) + + for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path): + proto_entry = proto_manifest.contents.add() + proto_entry.path = entry.path + proto_entry.digest = entry.digest + if entry.size: + proto_entry.size = entry.size + if entry.birth_artifact_id: + proto_entry.birth_artifact_id = entry.birth_artifact_id + if entry.ref: + proto_entry.ref = entry.ref + if entry.local_path: + proto_entry.local_path = entry.local_path + proto_entry.skip_cache = entry.skip_cache + for k, v in entry.extra.items(): + proto_extra = proto_entry.extra.add() + proto_extra.key = k + proto_extra.value_json = json.dumps(v) + return proto_manifest + + def _write_artifact_manifest_file(self, manifest: ArtifactManifest) -> str: + from wandb.sdk.artifacts.staging import get_staging_dir + + manifest_dir = Path(get_staging_dir()) / "artifact_manifests" + manifest_dir.mkdir(parents=True, exist_ok=True) + # It would be simpler to use `manifest.to_json()`, but that gets very slow for + # large manifests since it encodes the whole thing as a single JSON object. + filename = f"{time.time()}_{token_hex(8)}.manifest_contents.jl.gz" + manifest_file_path = manifest_dir / filename + with gzip.open(manifest_file_path, mode="wt", compresslevel=1) as f: + for entry in manifest.entries.values(): + f.write(f"{json.dumps(entry.to_json())}\n") + return str(manifest_file_path) + + def deliver_link_artifact( + self, + artifact: Artifact, + portfolio_name: str, + aliases: Iterable[str], + entity: str | None = None, + project: str | None = None, + organization: str | None = None, + ) -> MailboxHandle[pb.Result]: + link_artifact = pb.LinkArtifactRequest() + if artifact.is_draft(): + link_artifact.client_id = artifact._client_id + else: + link_artifact.server_id = artifact.id if artifact.id else "" + link_artifact.portfolio_name = portfolio_name + link_artifact.portfolio_entity = entity or "" + link_artifact.portfolio_organization = organization or "" + link_artifact.portfolio_project = project or "" + link_artifact.portfolio_aliases.extend(aliases) + + return self._deliver_link_artifact(link_artifact) + + @abc.abstractmethod + def _deliver_link_artifact( + self, link_artifact: pb.LinkArtifactRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + @staticmethod + def _make_partial_source_str( + source: Any, job_info: dict[str, Any], metadata: dict[str, Any] + ) -> str: + """Construct use_artifact.partial.source_info.source as str.""" + source_type = job_info.get("source_type", "").strip() + if source_type == "artifact": + info_source = job_info.get("source", {}) + source.artifact.artifact = info_source.get("artifact", "") + source.artifact.entrypoint.extend(info_source.get("entrypoint", [])) + source.artifact.notebook = info_source.get("notebook", False) + build_context = info_source.get("build_context") + if build_context: + source.artifact.build_context = build_context + dockerfile = info_source.get("dockerfile") + if dockerfile: + source.artifact.dockerfile = dockerfile + elif source_type == "repo": + source.git.git_info.remote = metadata.get("git", {}).get("remote", "") + source.git.git_info.commit = metadata.get("git", {}).get("commit", "") + source.git.entrypoint.extend(metadata.get("entrypoint", [])) + source.git.notebook = metadata.get("notebook", False) + build_context = metadata.get("build_context") + if build_context: + source.git.build_context = build_context + dockerfile = metadata.get("dockerfile") + if dockerfile: + source.git.dockerfile = dockerfile + elif source_type == "image": + source.image.image = metadata.get("docker", "") + else: + raise ValueError("Invalid source type") + + source_str: str = source.SerializeToString() + return source_str + + def _make_proto_use_artifact( + self, + use_artifact: pb.UseArtifactRecord, + job_name: str, + job_info: dict[str, Any], + metadata: dict[str, Any], + ) -> pb.UseArtifactRecord: + use_artifact.partial.job_name = job_name + use_artifact.partial.source_info._version = job_info.get("_version", "") + use_artifact.partial.source_info.source_type = job_info.get("source_type", "") + use_artifact.partial.source_info.runtime = job_info.get("runtime", "") + + src_str = self._make_partial_source_str( + source=use_artifact.partial.source_info.source, + job_info=job_info, + metadata=metadata, + ) + use_artifact.partial.source_info.source.ParseFromString(src_str) # type: ignore[arg-type] + + return use_artifact + + def publish_use_artifact( + self, + artifact: Artifact, + ) -> None: + assert artifact.id is not None, "Artifact must have an id" + + use_artifact = pb.UseArtifactRecord( + id=artifact.id, + type=artifact.type, + name=artifact.name, + ) + + # TODO(gst): move to internal process + if "_partial" in artifact.metadata: + # Download source info from logged partial job artifact + job_info = {} + try: + path = artifact.get_entry("wandb-job.json").download() + with open(path) as f: + job_info = json.load(f) + + except Exception as e: + logger.warning( + f"Failed to download partial job info from artifact {artifact}, : {e}" + ) + termwarn( + f"Failed to download partial job info from artifact {artifact}, : {e}" + ) + return + + try: + use_artifact = self._make_proto_use_artifact( + use_artifact=use_artifact, + job_name=artifact.name, + job_info=job_info, + metadata=artifact.metadata, + ) + except Exception as e: + logger.warning(f"Failed to construct use artifact proto: {e}") + termwarn(f"Failed to construct use artifact proto: {e}") + return + + self._publish_use_artifact(use_artifact) + + @abc.abstractmethod + def _publish_use_artifact(self, proto_artifact: pb.UseArtifactRecord) -> None: + raise NotImplementedError + + def deliver_artifact( + self, + run: Run, + artifact: Artifact, + aliases: Iterable[str], + tags: Iterable[str] | None = None, + history_step: int | None = None, + is_user_created: bool = False, + use_after_commit: bool = False, + finalize: bool = True, + ) -> MailboxHandle[pb.Result]: + from wandb.sdk.artifacts.staging import get_staging_dir + + proto_run = self._make_run(run) + proto_artifact = self._make_artifact(artifact) + proto_artifact.run_id = proto_run.run_id + proto_artifact.project = proto_run.project + proto_artifact.entity = proto_run.entity + proto_artifact.user_created = is_user_created + proto_artifact.use_after_commit = use_after_commit + proto_artifact.finalize = finalize + + proto_artifact.aliases.extend(aliases or []) + proto_artifact.tags.extend(tags or []) + + log_artifact = pb.LogArtifactRequest() + log_artifact.artifact.CopyFrom(proto_artifact) + if history_step is not None: + log_artifact.history_step = history_step + log_artifact.staging_dir = get_staging_dir() + resp = self._deliver_artifact(log_artifact) + return resp + + @abc.abstractmethod + def _deliver_artifact( + self, + log_artifact: pb.LogArtifactRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_download_artifact( + self, + artifact_id: str, + download_root: str, + allow_missing_references: bool, + skip_cache: bool, + path_prefix: str | None, + ) -> MailboxHandle[pb.Result]: + download_artifact = pb.DownloadArtifactRequest() + download_artifact.artifact_id = artifact_id + download_artifact.download_root = download_root + download_artifact.allow_missing_references = allow_missing_references + download_artifact.skip_cache = skip_cache + download_artifact.path_prefix = path_prefix or "" + resp = self._deliver_download_artifact(download_artifact) + return resp + + @abc.abstractmethod + def _deliver_download_artifact( + self, download_artifact: pb.DownloadArtifactRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def publish_artifact( + self, + run: Run, + artifact: Artifact, + aliases: Iterable[str], + tags: Iterable[str] | None = None, + is_user_created: bool = False, + use_after_commit: bool = False, + finalize: bool = True, + ) -> None: + proto_run = self._make_run(run) + proto_artifact = self._make_artifact(artifact) + proto_artifact.run_id = proto_run.run_id + proto_artifact.project = proto_run.project + proto_artifact.entity = proto_run.entity + proto_artifact.user_created = is_user_created + proto_artifact.use_after_commit = use_after_commit + proto_artifact.finalize = finalize + proto_artifact.aliases.extend(aliases or []) + proto_artifact.tags.extend(tags or []) + self._publish_artifact(proto_artifact) + + @abc.abstractmethod + def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None: + raise NotImplementedError + + def publish_tbdata(self, log_dir: str, save: bool, root_logdir: str = "") -> None: + tbrecord = pb.TBRecord() + tbrecord.log_dir = log_dir + tbrecord.save = save + tbrecord.root_dir = root_logdir + self._publish_tbdata(tbrecord) + + @abc.abstractmethod + def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None: + raise NotImplementedError + + @abc.abstractmethod + def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None: + raise NotImplementedError + + def publish_environment(self, environment: pb.EnvironmentRecord) -> None: + self._publish_environment(environment) + + @abc.abstractmethod + def _publish_environment(self, environment: pb.EnvironmentRecord) -> None: + raise NotImplementedError + + def publish_partial_history( + self, + run: Run, + data: dict, + user_step: int, + step: int | None = None, + flush: bool | None = None, + publish_step: bool = True, + ) -> None: + data = history_dict_to_json(run, data, step=user_step, ignore_copy_err=True) + data.pop("_step", None) + + # add timestamp to the history request, if not already present + # the timestamp might come from the tensorboard log logic + if "_timestamp" not in data: + data["_timestamp"] = time.time() + + partial_history = pb.PartialHistoryRequest() + for k, v in data.items(): + item = partial_history.item.add() + item.key = k + item.value_json = json_dumps_safer_history(v) + + if publish_step and step is not None: + partial_history.step.num = step + if flush is not None: + partial_history.action.flush = flush + self._publish_partial_history(partial_history) + + @abc.abstractmethod + def _publish_partial_history(self, history: pb.PartialHistoryRequest) -> None: + raise NotImplementedError + + def publish_history( + self, + run: Run, + data: dict, + step: int | None = None, + publish_step: bool = True, + ) -> None: + data = history_dict_to_json(run, data, step=step) + history = pb.HistoryRecord() + if publish_step: + assert step is not None + history.step.num = step + data.pop("_step", None) + for k, v in data.items(): + item = history.item.add() + item.key = k + item.value_json = json_dumps_safer_history(v) + self._publish_history(history) + + @abc.abstractmethod + def _publish_history(self, history: pb.HistoryRecord) -> None: + raise NotImplementedError + + def publish_preempting(self) -> None: + preempt_rec = pb.RunPreemptingRecord() + self._publish_preempting(preempt_rec) + + @abc.abstractmethod + def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None: + raise NotImplementedError + + def publish_output( + self, + name: str, + data: str, + *, + nowait: bool = False, + ) -> None: + # from vendor.protobuf import google3.protobuf.timestamp + # ts = timestamp.Timestamp() + # ts.GetCurrentTime() + # now = datetime.now() + if name == "stdout": + otype = pb.OutputRecord.OutputType.STDOUT + elif name == "stderr": + otype = pb.OutputRecord.OutputType.STDERR + else: + # TODO(jhr): throw error? + termwarn("unknown type") + o = pb.OutputRecord(output_type=otype, line=data) + o.timestamp.GetCurrentTime() + self._publish_output(o, nowait=nowait) + + @abc.abstractmethod + def _publish_output(self, outdata: pb.OutputRecord, *, nowait: bool) -> None: + raise NotImplementedError + + def publish_output_raw( + self, + name: str, + data: str, + *, + nowait: bool = False, + ) -> None: + # from vendor.protobuf import google3.protobuf.timestamp + # ts = timestamp.Timestamp() + # ts.GetCurrentTime() + # now = datetime.now() + if name == "stdout": + otype = pb.OutputRawRecord.OutputType.STDOUT + elif name == "stderr": + otype = pb.OutputRawRecord.OutputType.STDERR + else: + # TODO(jhr): throw error? + termwarn("unknown type") + o = pb.OutputRawRecord(output_type=otype, line=data) + o.timestamp.GetCurrentTime() + self._publish_output_raw(o, nowait=nowait) + + @abc.abstractmethod + def _publish_output_raw( + self, + outdata: pb.OutputRawRecord, + *, + nowait: bool, + ) -> None: + raise NotImplementedError + + def publish_pause(self) -> None: + pause = pb.PauseRequest() + self._publish_pause(pause) + + @abc.abstractmethod + def _publish_pause(self, pause: pb.PauseRequest) -> None: + raise NotImplementedError + + def publish_resume(self) -> None: + resume = pb.ResumeRequest() + self._publish_resume(resume) + + @abc.abstractmethod + def _publish_resume(self, resume: pb.ResumeRequest) -> None: + raise NotImplementedError + + def publish_alert( + self, title: str, text: str, level: str, wait_duration: int + ) -> None: + proto_alert = pb.AlertRecord() + proto_alert.title = title + proto_alert.text = text + proto_alert.level = level + proto_alert.wait_duration = wait_duration + self._publish_alert(proto_alert) + + @abc.abstractmethod + def _publish_alert(self, alert: pb.AlertRecord) -> None: + raise NotImplementedError + + def _make_exit(self, exit_code: int | None) -> pb.RunExitRecord: + exit = pb.RunExitRecord() + if exit_code is not None: + exit.exit_code = exit_code + return exit + + def publish_exit(self, exit_code: int | None) -> None: + exit_data = self._make_exit(exit_code) + self._publish_exit(exit_data) + + @abc.abstractmethod + def _publish_exit(self, exit_data: pb.RunExitRecord) -> None: + raise NotImplementedError + + def publish_keepalive(self) -> None: + keepalive = pb.KeepaliveRequest() + self._publish_keepalive(keepalive) + + @abc.abstractmethod + def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None: + raise NotImplementedError + + def publish_job_input( + self, + include_paths: list[list[str]], + exclude_paths: list[list[str]], + input_schema: dict | None, + run_config: bool = False, + file_path: str = "", + ): + """Publishes a request to add inputs to the job. + + If run_config is True, the wandb.config will be added as a job input. + If file_path is provided, the file at file_path will be added as a job + input. + + The paths provided as arguments are sequences of dictionary keys that + specify a path within the wandb.config. If a path is included, the + corresponding field will be treated as a job input. If a path is + excluded, the corresponding field will not be treated as a job input. + + Args: + include_paths: paths within config to include as job inputs. + exclude_paths: paths within config to exclude as job inputs. + input_schema: A JSON Schema describing which attributes will be + editable from the Launch drawer. + run_config: bool indicating whether wandb.config is the input source. + file_path: path to file to include as a job input. + """ + if run_config and file_path: + raise ValueError( + "run_config and file_path are mutually exclusive arguments." + ) + request = pb.JobInputRequest() + include_records = [pb.JobInputPath(path=path) for path in include_paths] + exclude_records = [pb.JobInputPath(path=path) for path in exclude_paths] + request.include_paths.extend(include_records) + request.exclude_paths.extend(exclude_records) + source = pb.JobInputSource( + run_config=pb.JobInputSource.RunConfigSource(), + ) + if run_config: + source.run_config.CopyFrom(pb.JobInputSource.RunConfigSource()) + else: + source.file.CopyFrom( + pb.JobInputSource.ConfigFileSource(path=file_path), + ) + request.input_source.CopyFrom(source) + if input_schema: + request.input_schema = json_dumps_safer(input_schema) + + return self._publish_job_input(request) + + @abc.abstractmethod + def _publish_job_input( + self, request: pb.JobInputRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def publish_probe_system_info(self) -> None: + probe_system_info = pb.ProbeSystemInfoRequest() + return self._publish_probe_system_info(probe_system_info) + + @abc.abstractmethod + def _publish_probe_system_info( + self, probe_system_info: pb.ProbeSystemInfoRequest + ) -> None: + raise NotImplementedError + + def join(self) -> None: + # Drop indicates that the internal process has already been shutdown + if self._drop: + return + + handle = self._deliver_shutdown() + + try: + handle.wait_or(timeout=30) + except TimeoutError: + # This can happen if the server fails to respond due to a bug + # or due to being very busy. + logger.warning("timed out communicating shutdown") + except HandleAbandonedError: + # This can happen if the connection to the server is closed + # before a response is read. + logger.warning("handle abandoned while communicating shutdown") + + @abc.abstractmethod + def _deliver_shutdown(self) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_run(self, run: Run) -> MailboxHandle[pb.Result]: + run_record = self._make_run(run) + return self._deliver_run(run_record) + + def deliver_finish_sync( + self, + ) -> MailboxHandle[pb.Result]: + sync = pb.SyncFinishRequest() + return self._deliver_finish_sync(sync) + + @abc.abstractmethod + def _deliver_finish_sync( + self, sync: pb.SyncFinishRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + @abc.abstractmethod + def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_run_start(self, run: Run) -> MailboxHandle[pb.Result]: + run_start = pb.RunStartRequest(run=self._make_run(run)) + return self._deliver_run_start(run_start) + + @abc.abstractmethod + def _deliver_run_start( + self, run_start: pb.RunStartRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_attach(self, attach_id: str) -> MailboxHandle[pb.Result]: + attach = pb.AttachRequest(attach_id=attach_id) + return self._deliver_attach(attach) + + @abc.abstractmethod + def _deliver_attach( + self, + status: pb.AttachRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_stop_status(self) -> MailboxHandle[pb.Result]: + status = pb.StopStatusRequest() + return self._deliver_stop_status(status) + + @abc.abstractmethod + def _deliver_stop_status( + self, + status: pb.StopStatusRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_network_status(self) -> MailboxHandle[pb.Result]: + status = pb.NetworkStatusRequest() + return self._deliver_network_status(status) + + @abc.abstractmethod + def _deliver_network_status( + self, + status: pb.NetworkStatusRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_internal_messages(self) -> MailboxHandle[pb.Result]: + internal_message = pb.InternalMessagesRequest() + return self._deliver_internal_messages(internal_message) + + @abc.abstractmethod + def _deliver_internal_messages( + self, internal_message: pb.InternalMessagesRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_get_summary(self) -> MailboxHandle[pb.Result]: + get_summary = pb.GetSummaryRequest() + return self._deliver_get_summary(get_summary) + + @abc.abstractmethod + def _deliver_get_summary( + self, + get_summary: pb.GetSummaryRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_get_system_metrics(self) -> MailboxHandle[pb.Result]: + get_system_metrics = pb.GetSystemMetricsRequest() + return self._deliver_get_system_metrics(get_system_metrics) + + @abc.abstractmethod + def _deliver_get_system_metrics( + self, get_summary: pb.GetSystemMetricsRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_exit(self, exit_code: int | None) -> MailboxHandle[pb.Result]: + exit_data = self._make_exit(exit_code) + return self._deliver_exit(exit_data) + + @abc.abstractmethod + def _deliver_exit( + self, + exit_data: pb.RunExitRecord, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_poll_exit(self) -> MailboxHandle[pb.Result]: + poll_exit = pb.PollExitRequest() + return self._deliver_poll_exit(poll_exit) + + @abc.abstractmethod + def _deliver_poll_exit( + self, + poll_exit: pb.PollExitRequest, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_finish_without_exit(self) -> MailboxHandle[pb.Result]: + run_finish_without_exit = pb.RunFinishWithoutExitRequest() + return self._deliver_finish_without_exit(run_finish_without_exit) + + @abc.abstractmethod + def _deliver_finish_without_exit( + self, run_finish_without_exit: pb.RunFinishWithoutExitRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_request_sampled_history(self) -> MailboxHandle[pb.Result]: + sampled_history = pb.SampledHistoryRequest() + return self._deliver_request_sampled_history(sampled_history) + + @abc.abstractmethod + def _deliver_request_sampled_history( + self, sampled_history: pb.SampledHistoryRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + def deliver_request_run_status(self) -> MailboxHandle[pb.Result]: + run_status = pb.RunStatusRequest() + return self._deliver_request_run_status(run_status) + + @abc.abstractmethod + def _deliver_request_run_status( + self, run_status: pb.RunStatusRequest + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_queue.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab20b2db31d3bcd8c207aa208ddf9d233df206b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_queue.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +from multiprocessing.process import BaseProcess +from typing import TYPE_CHECKING + +from typing_extensions import override + +from .interface_shared import InterfaceShared + +if TYPE_CHECKING: + from queue import Queue + + from wandb.proto import wandb_internal_pb2 as pb + from wandb.sdk.mailbox.mailbox_handle import MailboxHandle + + +logger = logging.getLogger("wandb") + + +class InterfaceQueue(InterfaceShared): + """Legacy implementation of InterfaceShared. + + This was used by legacy-service to pass messages back to itself before + the existence of wandb-core. It may be removed once legacy-service is + completely removed (including its use in `wandb sync`). + + Since it was used by the internal service, it does not implement + the "deliver" methods, which are only used in the client. + """ + + def __init__( + self, + record_q: Queue[pb.Record] | None = None, + result_q: Queue[pb.Result] | None = None, + process: BaseProcess | None = None, + ) -> None: + self.record_q = record_q + self.result_q = result_q + self._process = process + super().__init__() + + @override + def _publish(self, record: pb.Record, *, nowait: bool = False) -> None: + if self._process and not self._process.is_alive(): + raise Exception("The wandb backend process has shutdown") + if self.record_q: + self.record_q.put(record) + + @override + async def deliver_async( + self, + record: pb.Record, + ) -> MailboxHandle[pb.Result]: + raise NotImplementedError + + @override + def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]: + raise NotImplementedError diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_shared.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..74a839b1f0c174ef879b325c9837e86232fbd750 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_shared.py @@ -0,0 +1,511 @@ +import abc +import logging +from typing import Any, Optional, cast + +from typing_extensions import override + +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto import wandb_telemetry_pb2 as tpb +from wandb.sdk.mailbox import MailboxHandle +from wandb.util import json_dumps_safer, json_friendly + +from .interface import InterfaceBase + +logger = logging.getLogger("wandb") + + +class InterfaceShared(InterfaceBase, abc.ABC): + """Partially implemented InterfaceBase. + + There is little reason for this to exist separately from InterfaceBase, + which itself is not a pure abstract class and has no other direct + subclasses. Most methods are implemented in this class in terms of the + protected _publish and _deliver methods defined by subclasses. + """ + + def __init__(self) -> None: + super().__init__() + + @abc.abstractmethod + def _publish( + self, + record: pb.Record, + *, + nowait: bool = False, + ) -> None: + """Send a record to the internal service. + + Args: + record: The record to send. This method assigns its stream ID. + nowait: If true, this does not block on socket IO and is safe + to call in W&B's asyncio thread, but it will also not slow + down even if the socket is blocked and allow data to accumulate + in the Python memory. + """ + + @abc.abstractmethod + def _deliver(self, record: pb.Record) -> "MailboxHandle[pb.Result]": + """Send a record to the internal service and return a response handle. + + Args: + record: The record to send. This method assigns its stream ID. + + Returns: + A mailbox handle for waiting for a response. + """ + + @override + def _publish_output( + self, + outdata: pb.OutputRecord, + *, + nowait: bool = False, + ) -> None: + rec = pb.Record() + rec.output.CopyFrom(outdata) + self._publish(rec, nowait=nowait) + + @override + def _publish_output_raw( + self, + outdata: pb.OutputRawRecord, + *, + nowait: bool = False, + ) -> None: + rec = pb.Record() + rec.output_raw.CopyFrom(outdata) + self._publish(rec, nowait=nowait) + + def _publish_cancel(self, cancel: pb.CancelRequest) -> None: + rec = self._make_request(cancel=cancel) + self._publish(rec) + + def _publish_tbdata(self, tbrecord: pb.TBRecord) -> None: + rec = self._make_record(tbrecord=tbrecord) + self._publish(rec) + + def _publish_partial_history( + self, partial_history: pb.PartialHistoryRequest + ) -> None: + rec = self._make_request(partial_history=partial_history) + self._publish(rec) + + def _publish_history(self, history: pb.HistoryRecord) -> None: + rec = self._make_record(history=history) + self._publish(rec) + + def _publish_preempting(self, preempt_rec: pb.RunPreemptingRecord) -> None: + rec = self._make_record(preempting=preempt_rec) + self._publish(rec) + + def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None: + rec = self._make_record(telemetry=telem) + self._publish(rec) + + def _publish_environment(self, environment: pb.EnvironmentRecord) -> None: + rec = self._make_record(environment=environment) + self._publish(rec) + + def _publish_job_input( + self, job_input: pb.JobInputRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(job_input=job_input) + return self._deliver(record) + + def _make_stats(self, stats_dict: dict) -> pb.StatsRecord: + stats = pb.StatsRecord() + stats.stats_type = pb.StatsRecord.StatsType.SYSTEM + stats.timestamp.GetCurrentTime() # todo: fix this, this is wrong :) + for k, v in stats_dict.items(): + item = stats.item.add() + item.key = k + item.value_json = json_dumps_safer(json_friendly(v)[0]) + return stats + + def _make_request( # noqa: C901 + self, + get_summary: Optional[pb.GetSummaryRequest] = None, + pause: Optional[pb.PauseRequest] = None, + resume: Optional[pb.ResumeRequest] = None, + status: Optional[pb.StatusRequest] = None, + stop_status: Optional[pb.StopStatusRequest] = None, + internal_messages: Optional[pb.InternalMessagesRequest] = None, + network_status: Optional[pb.NetworkStatusRequest] = None, + poll_exit: Optional[pb.PollExitRequest] = None, + partial_history: Optional[pb.PartialHistoryRequest] = None, + sampled_history: Optional[pb.SampledHistoryRequest] = None, + run_start: Optional[pb.RunStartRequest] = None, + check_version: Optional[pb.CheckVersionRequest] = None, + log_artifact: Optional[pb.LogArtifactRequest] = None, + download_artifact: Optional[pb.DownloadArtifactRequest] = None, + link_artifact: Optional[pb.LinkArtifactRequest] = None, + defer: Optional[pb.DeferRequest] = None, + attach: Optional[pb.AttachRequest] = None, + server_info: Optional[pb.ServerInfoRequest] = None, + keepalive: Optional[pb.KeepaliveRequest] = None, + run_status: Optional[pb.RunStatusRequest] = None, + sender_mark: Optional[pb.SenderMarkRequest] = None, + sender_read: Optional[pb.SenderReadRequest] = None, + sync_finish: Optional[pb.SyncFinishRequest] = None, + status_report: Optional[pb.StatusReportRequest] = None, + cancel: Optional[pb.CancelRequest] = None, + summary_record: Optional[pb.SummaryRecordRequest] = None, + telemetry_record: Optional[pb.TelemetryRecordRequest] = None, + get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None, + python_packages: Optional[pb.PythonPackagesRequest] = None, + job_input: Optional[pb.JobInputRequest] = None, + run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None, + probe_system_info: Optional[pb.ProbeSystemInfoRequest] = None, + ) -> pb.Record: + request = pb.Request() + if get_summary: + request.get_summary.CopyFrom(get_summary) + elif pause: + request.pause.CopyFrom(pause) + elif resume: + request.resume.CopyFrom(resume) + elif status: + request.status.CopyFrom(status) + elif stop_status: + request.stop_status.CopyFrom(stop_status) + elif internal_messages: + request.internal_messages.CopyFrom(internal_messages) + elif network_status: + request.network_status.CopyFrom(network_status) + elif poll_exit: + request.poll_exit.CopyFrom(poll_exit) + elif partial_history: + request.partial_history.CopyFrom(partial_history) + elif sampled_history: + request.sampled_history.CopyFrom(sampled_history) + elif run_start: + request.run_start.CopyFrom(run_start) + elif check_version: + request.check_version.CopyFrom(check_version) + elif log_artifact: + request.log_artifact.CopyFrom(log_artifact) + elif download_artifact: + request.download_artifact.CopyFrom(download_artifact) + elif link_artifact: + request.link_artifact.CopyFrom(link_artifact) + elif defer: + request.defer.CopyFrom(defer) + elif attach: + request.attach.CopyFrom(attach) + elif server_info: + request.server_info.CopyFrom(server_info) + elif keepalive: + request.keepalive.CopyFrom(keepalive) + elif run_status: + request.run_status.CopyFrom(run_status) + elif sender_mark: + request.sender_mark.CopyFrom(sender_mark) + elif sender_read: + request.sender_read.CopyFrom(sender_read) + elif cancel: + request.cancel.CopyFrom(cancel) + elif status_report: + request.status_report.CopyFrom(status_report) + elif summary_record: + request.summary_record.CopyFrom(summary_record) + elif telemetry_record: + request.telemetry_record.CopyFrom(telemetry_record) + elif get_system_metrics: + request.get_system_metrics.CopyFrom(get_system_metrics) + elif sync_finish: + request.sync_finish.CopyFrom(sync_finish) + elif python_packages: + request.python_packages.CopyFrom(python_packages) + elif job_input: + request.job_input.CopyFrom(job_input) + elif run_finish_without_exit: + request.run_finish_without_exit.CopyFrom(run_finish_without_exit) + elif probe_system_info: + request.probe_system_info.CopyFrom(probe_system_info) + else: + raise Exception("Invalid request") + record = self._make_record(request=request) + # All requests do not get persisted + record.control.local = True + if status_report: + record.control.flow_control = True + return record + + def _make_record( # noqa: C901 + self, + run: Optional[pb.RunRecord] = None, + config: Optional[pb.ConfigRecord] = None, + files: Optional[pb.FilesRecord] = None, + summary: Optional[pb.SummaryRecord] = None, + history: Optional[pb.HistoryRecord] = None, + stats: Optional[pb.StatsRecord] = None, + exit: Optional[pb.RunExitRecord] = None, + artifact: Optional[pb.ArtifactRecord] = None, + tbrecord: Optional[pb.TBRecord] = None, + alert: Optional[pb.AlertRecord] = None, + final: Optional[pb.FinalRecord] = None, + metric: Optional[pb.MetricRecord] = None, + header: Optional[pb.HeaderRecord] = None, + footer: Optional[pb.FooterRecord] = None, + request: Optional[pb.Request] = None, + telemetry: Optional[tpb.TelemetryRecord] = None, + preempting: Optional[pb.RunPreemptingRecord] = None, + use_artifact: Optional[pb.UseArtifactRecord] = None, + output: Optional[pb.OutputRecord] = None, + output_raw: Optional[pb.OutputRawRecord] = None, + environment: Optional[pb.EnvironmentRecord] = None, + ) -> pb.Record: + record = pb.Record() + if run: + record.run.CopyFrom(run) + elif config: + record.config.CopyFrom(config) + elif summary: + record.summary.CopyFrom(summary) + elif history: + record.history.CopyFrom(history) + elif files: + record.files.CopyFrom(files) + elif stats: + record.stats.CopyFrom(stats) + elif exit: + record.exit.CopyFrom(exit) + elif artifact: + record.artifact.CopyFrom(artifact) + elif tbrecord: + record.tbrecord.CopyFrom(tbrecord) + elif alert: + record.alert.CopyFrom(alert) + elif final: + record.final.CopyFrom(final) + elif header: + record.header.CopyFrom(header) + elif footer: + record.footer.CopyFrom(footer) + elif request: + record.request.CopyFrom(request) + elif telemetry: + record.telemetry.CopyFrom(telemetry) + elif metric: + record.metric.CopyFrom(metric) + elif preempting: + record.preempting.CopyFrom(preempting) + elif use_artifact: + record.use_artifact.CopyFrom(use_artifact) + elif output: + record.output.CopyFrom(output) + elif output_raw: + record.output_raw.CopyFrom(output_raw) + elif environment: + record.environment.CopyFrom(environment) + else: + raise Exception("Invalid record") + return record + + def _publish_defer(self, state: "pb.DeferRequest.DeferState.V") -> None: + defer = pb.DeferRequest(state=state) + rec = self._make_request(defer=defer) + rec.control.local = True + self._publish(rec) + + def publish_defer(self, state: int = 0) -> None: + self._publish_defer(cast("pb.DeferRequest.DeferState.V", state)) + + def _publish_header(self, header: pb.HeaderRecord) -> None: + rec = self._make_record(header=header) + self._publish(rec) + + def publish_footer(self) -> None: + footer = pb.FooterRecord() + rec = self._make_record(footer=footer) + self._publish(rec) + + def publish_final(self) -> None: + final = pb.FinalRecord() + rec = self._make_record(final=final) + self._publish(rec) + + def _publish_pause(self, pause: pb.PauseRequest) -> None: + rec = self._make_request(pause=pause) + self._publish(rec) + + def _publish_resume(self, resume: pb.ResumeRequest) -> None: + rec = self._make_request(resume=resume) + self._publish(rec) + + def _publish_run(self, run: pb.RunRecord) -> None: + rec = self._make_record(run=run) + self._publish(rec) + + def _publish_config(self, cfg: pb.ConfigRecord) -> None: + rec = self._make_record(config=cfg) + self._publish(rec) + + def _publish_summary(self, summary: pb.SummaryRecord) -> None: + rec = self._make_record(summary=summary) + self._publish(rec) + + def _publish_metric(self, metric: pb.MetricRecord) -> None: + rec = self._make_record(metric=metric) + self._publish(rec) + + def publish_stats(self, stats_dict: dict) -> None: + stats = self._make_stats(stats_dict) + rec = self._make_record(stats=stats) + self._publish(rec) + + def _publish_python_packages( + self, python_packages: pb.PythonPackagesRequest + ) -> None: + rec = self._make_request(python_packages=python_packages) + self._publish(rec) + + def _publish_files(self, files: pb.FilesRecord) -> None: + rec = self._make_record(files=files) + self._publish(rec) + + def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any: + rec = self._make_record(use_artifact=use_artifact) + self._publish(rec) + + def _publish_probe_system_info( + self, probe_system_info: pb.ProbeSystemInfoRequest + ) -> None: + record = self._make_request(probe_system_info=probe_system_info) + self._publish(record) + + def _deliver_artifact( + self, + log_artifact: pb.LogArtifactRequest, + ) -> MailboxHandle[pb.Result]: + rec = self._make_request(log_artifact=log_artifact) + return self._deliver(rec) + + def _deliver_download_artifact( + self, download_artifact: pb.DownloadArtifactRequest + ) -> MailboxHandle[pb.Result]: + rec = self._make_request(download_artifact=download_artifact) + return self._deliver(rec) + + def _deliver_link_artifact( + self, link_artifact: pb.LinkArtifactRequest + ) -> MailboxHandle[pb.Result]: + rec = self._make_request(link_artifact=link_artifact) + return self._deliver(rec) + + def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None: + rec = self._make_record(artifact=proto_artifact) + self._publish(rec) + + def _publish_alert(self, proto_alert: pb.AlertRecord) -> None: + rec = self._make_record(alert=proto_alert) + self._publish(rec) + + def _deliver_status( + self, + status: pb.StatusRequest, + ) -> MailboxHandle[pb.Result]: + req = self._make_request(status=status) + return self._deliver(req) + + def _publish_exit(self, exit_data: pb.RunExitRecord) -> None: + rec = self._make_record(exit=exit_data) + self._publish(rec) + + def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None: + record = self._make_request(keepalive=keepalive) + self._publish(record) + + def _deliver_shutdown(self) -> MailboxHandle[pb.Result]: + request = pb.Request(shutdown=pb.ShutdownRequest()) + record = self._make_record(request=request) + return self._deliver(record) + + def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]: + record = self._make_record(run=run) + return self._deliver(record) + + def _deliver_finish_sync( + self, + sync_finish: pb.SyncFinishRequest, + ) -> MailboxHandle[pb.Result]: + record = self._make_request(sync_finish=sync_finish) + return self._deliver(record) + + def _deliver_run_start( + self, + run_start: pb.RunStartRequest, + ) -> MailboxHandle[pb.Result]: + record = self._make_request(run_start=run_start) + return self._deliver(record) + + def _deliver_get_summary( + self, + get_summary: pb.GetSummaryRequest, + ) -> MailboxHandle[pb.Result]: + record = self._make_request(get_summary=get_summary) + return self._deliver(record) + + def _deliver_get_system_metrics( + self, get_system_metrics: pb.GetSystemMetricsRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(get_system_metrics=get_system_metrics) + return self._deliver(record) + + def _deliver_exit( + self, + exit_data: pb.RunExitRecord, + ) -> MailboxHandle[pb.Result]: + record = self._make_record(exit=exit_data) + return self._deliver(record) + + def _deliver_poll_exit( + self, + poll_exit: pb.PollExitRequest, + ) -> MailboxHandle[pb.Result]: + record = self._make_request(poll_exit=poll_exit) + return self._deliver(record) + + def _deliver_finish_without_exit( + self, run_finish_without_exit: pb.RunFinishWithoutExitRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(run_finish_without_exit=run_finish_without_exit) + return self._deliver(record) + + def _deliver_stop_status( + self, + stop_status: pb.StopStatusRequest, + ) -> MailboxHandle[pb.Result]: + record = self._make_request(stop_status=stop_status) + return self._deliver(record) + + def _deliver_attach( + self, + attach: pb.AttachRequest, + ) -> MailboxHandle[pb.Result]: + record = self._make_request(attach=attach) + return self._deliver(record) + + def _deliver_network_status( + self, network_status: pb.NetworkStatusRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(network_status=network_status) + return self._deliver(record) + + def _deliver_internal_messages( + self, internal_message: pb.InternalMessagesRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(internal_messages=internal_message) + return self._deliver(record) + + def _deliver_request_sampled_history( + self, sampled_history: pb.SampledHistoryRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(sampled_history=sampled_history) + return self._deliver(record) + + def _deliver_request_run_status( + self, run_status: pb.RunStatusRequest + ) -> MailboxHandle[pb.Result]: + record = self._make_request(run_status=run_status) + return self._deliver(record) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_sock.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_sock.py new file mode 100644 index 0000000000000000000000000000000000000000..79b914dff1fd71ecb538f18300b56fea9b69c95b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/interface_sock.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from typing_extensions import override + +from wandb.proto import wandb_server_pb2 as spb +from wandb.sdk.lib import asyncio_manager + +from .interface_shared import InterfaceShared + +if TYPE_CHECKING: + from wandb.proto import wandb_internal_pb2 as pb + from wandb.sdk.lib.service.service_client import ServiceClient + from wandb.sdk.mailbox import MailboxHandle + + +logger = logging.getLogger("wandb") + + +class InterfaceSock(InterfaceShared): + def __init__( + self, + asyncer: asyncio_manager.AsyncioManager, + client: ServiceClient, + stream_id: str, + ) -> None: + super().__init__() + self._asyncer = asyncer + self._client = client + self._stream_id = stream_id + + def _assign(self, record: pb.Record) -> None: + record._info.stream_id = self._stream_id + + @override + def _publish(self, record: pb.Record, *, nowait: bool = False) -> None: + self._assign(record) + request = spb.ServerRequest() + request.record_publish.CopyFrom(record) + + if nowait: + self._asyncer.run_soon(lambda: self._client.publish(request)) + else: + self._asyncer.run(lambda: self._client.publish(request)) + + @override + def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]: + return self._asyncer.run(lambda: self.deliver_async(record)) + + @override + async def deliver_async(self, record: pb.Record) -> MailboxHandle[pb.Result]: + self._assign(record) + request = spb.ServerRequest() + request.record_publish.CopyFrom(record) + + handle = await self._client.deliver(request) + return handle.map(lambda response: response.result_communicate) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/interface/summary_record.py b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/summary_record.py new file mode 100644 index 0000000000000000000000000000000000000000..2050a39080759ef12c54364aaf76b1dee6ab7e1d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/interface/summary_record.py @@ -0,0 +1,67 @@ +"""Summary Record. + +This module implements a summary record as an intermediate format before being converted +to a protocol buffer. +""" + +import typing as t + + +class SummaryRecord: + """Encodes a diff -- analogous to the SummaryRecord protobuf message.""" + + update: t.List["SummaryItem"] + remove: t.List["SummaryItem"] + + def __init__(self): + self.update = [] + self.remove = [] + + def __str__(self): + s = "SummaryRecord:\n Update:\n " + s += "\n ".join([str(item) for item in self.update]) + s += "\n Remove:\n " + s += "\n ".join([str(item) for item in self.remove]) + s += "\n" + return s + + __repr__ = __str__ + + def _add_next_parent(self, parent_key): + with_next_parent = SummaryRecord() + with_next_parent.update = [ + item._add_next_parent(parent_key) for item in self.update + ] + with_next_parent.remove = [ + item._add_next_parent(parent_key) for item in self.remove + ] + + return with_next_parent + + +class SummaryItem: + """Analogous to the SummaryItem protobuf message.""" + + key: t.Tuple[str] + value: t.Any + + def __init__(self): + self.key = tuple() + self.value = None + + def __str__(self): + return "SummaryItem: key: " + str(self.key) + " value: " + str(self.value) + + __repr__ = __str__ + + def _add_next_parent(self, parent_key): + with_next_parent = SummaryItem() + + key = self.key + if not isinstance(key, tuple): + key = (key,) + + with_next_parent.key = (parent_key,) + self.key + with_next_parent.value = self.value + + return with_next_parent diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__init__.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db119d54533f19d1858d9d381f6fb3036429df0e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/context.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/context.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f321e842016912dbab1907b727fc1d265b8d75fc Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/context.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/datastore.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/datastore.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b563c4bd1dc37c3dce59ff2812e1a54004f32b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/datastore.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/file_pusher.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/file_pusher.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa907eaddd05d246f3aaf2065bdb52ad21d567f5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/file_pusher.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/file_stream.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/file_stream.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00bd9034df4da917d602b733b53a3568c04f3659 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/file_stream.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/handler.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/handler.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c422adb3a46189536e01a4f4821464fb92b6cff Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/handler.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/incremental_table_util.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/incremental_table_util.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f803706f2cac2027a58a54dafba606c1f12c04 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/incremental_table_util.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/job_builder.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/job_builder.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7913f90273a298a31bfb7bc4bfc6521b25384b27 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/job_builder.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/profiler.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/profiler.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7a476b6f51ac8620b59f40d16b0aa9f21758a7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/profiler.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/progress.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/progress.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b546633f12d02cc1f7abd4283f3cea3e07cad8e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/progress.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/run.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/run.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6a14192c183935929c361de10c377d17e98957d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/run.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sample.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sample.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37e915e3acdf103ae175a0461972c684a932ee7c Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sample.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sender.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sender.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0d7d35933da4678d4c4061404591bd4c65ed8b3 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sender.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sender_config.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sender_config.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a67af68424f46b6bab96c0b9c3c6698e49eded6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/sender_config.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/settings_static.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/settings_static.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c675cdaf9931ad8745c8f1d65e0020f5f81a3641 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/settings_static.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/tb_watcher.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/tb_watcher.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a6185bb6f032159aee086812e9589fb46e3de2f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/__pycache__/tb_watcher.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__init__.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..efba046d8a690634a316a9033f74e8320370e61e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__init__.py @@ -0,0 +1,5 @@ +# Generated by ariadne-codegen + +__all__ = ["SERVER_FEATURES_QUERY_GQL", "ServerFeaturesQuery"] +from .operations import SERVER_FEATURES_QUERY_GQL +from .server_features_query import ServerFeaturesQuery diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1507993b1a6ae9e340011687e2565d1fa3380dab Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/operations.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/operations.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35b62329e3dbf7783bab9f97697724eaf0180edc Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/operations.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/server_features_query.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/server_features_query.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbbb32d929ba145dde3b3d79f6c2e483ede47bcb Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/__pycache__/server_features_query.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/enums.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7d61d95669bab416e88c3acc378913704e64e8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/enums.py @@ -0,0 +1,4 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/input_types.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/input_types.py new file mode 100644 index 0000000000000000000000000000000000000000..cc7d61d95669bab416e88c3acc378913704e64e8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/input_types.py @@ -0,0 +1,4 @@ +# Generated by ariadne-codegen +# Source: core/api/graphql/schemas/schema-latest.graphql + +from __future__ import annotations diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/operations.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..6d204236d0154d142ccc4ac22ade6befe7e6e012 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/operations.py @@ -0,0 +1,15 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/utils/ + +__all__ = ["SERVER_FEATURES_QUERY_GQL"] + +SERVER_FEATURES_QUERY_GQL = """ +query ServerFeaturesQuery { + serverInfo { + features { + name + isEnabled + } + } +} +""" diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/server_features_query.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/server_features_query.py new file mode 100644 index 0000000000000000000000000000000000000000..c31ee720f7db95dcb97c8015b0a29db83c6cf07b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/_generated/server_features_query.py @@ -0,0 +1,27 @@ +# Generated by ariadne-codegen +# Source: tools/graphql_codegen/utils/ + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import Field + +from wandb._pydantic import GQLResult + + +class ServerFeaturesQuery(GQLResult): + server_info: Optional[ServerFeaturesQueryServerInfo] = Field(alias="serverInfo") + + +class ServerFeaturesQueryServerInfo(GQLResult): + features: List[Optional[ServerFeaturesQueryServerInfoFeatures]] + + +class ServerFeaturesQueryServerInfoFeatures(GQLResult): + name: str + is_enabled: bool = Field(alias="isEnabled") + + +ServerFeaturesQuery.model_rebuild() +ServerFeaturesQueryServerInfo.model_rebuild() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/context.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/context.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad49ee15a1e15b76c1cfacdbc2158fe8cfa9ad7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/context.py @@ -0,0 +1,89 @@ +"""Context Keeper.""" + +import logging +import threading +from typing import Dict, Optional + +from wandb.proto.wandb_internal_pb2 import Record, Result + +logger = logging.getLogger(__name__) + + +class Context: + _cancel_event: threading.Event + # TODO(debug_context) add debug setting to enable this + # _debug_record: Optional[Record] + + def __init__(self) -> None: + self._cancel_event = threading.Event() + # TODO(debug_context) see above + # self._debug_record = None + + def cancel(self) -> None: + self._cancel_event.set() + + @property + def cancel_event(self) -> threading.Event: + return self._cancel_event + + +def context_id_from_record(record: Record) -> str: + context_id = record.control.mailbox_slot + return context_id + + +def context_id_from_result(result: Result) -> str: + context_id = result.control.mailbox_slot + return context_id + + +class ContextKeeper: + _active_items: Dict[str, Context] + + def __init__(self) -> None: + self._active_items = {} + + def add_from_record(self, record: Record) -> Optional[Context]: + context_id = context_id_from_record(record) + if not context_id: + return None + context_obj = self.add(context_id) + + # TODO(debug_context) see above + # context_obj._debug_record = record + + return context_obj + + def add(self, context_id: str) -> Context: + assert context_id + context_obj = Context() + self._active_items[context_id] = context_obj + return context_obj + + def get(self, context_id: str) -> Optional[Context]: + item = self._active_items.get(context_id) + return item + + def release(self, context_id: str) -> None: + if not context_id: + return + _ = self._active_items.pop(context_id, None) + + def cancel(self, context_id: str) -> bool: + item = self.get(context_id) + if item: + item.cancel() + return True + return False + + # TODO(debug_context) see above + # def _debug_print_orphans(self, print_to_stdout: bool) -> None: + # for context_id, context in self._active_items.items(): + # record = context._debug_record + # record_type = record.WhichOneof("record_type") if record else "unknown" + # message = ( + # f"Context: {context_id} {context.cancel_event.is_set()} {record_type}" + # ) + # logger.warning(message) + # if print_to_stdout: + # print(message) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/datastore.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/datastore.py new file mode 100644 index 0000000000000000000000000000000000000000..200a9f4195a36fb476e384763ca6aa7ba4b37ceb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/datastore.py @@ -0,0 +1,281 @@ +"""leveldb log datastore. + +Format is described at: + https://github.com/google/leveldb/blob/master/doc/log_format.md + +block := record* trailer? +record := + checksum: uint32 // crc32c of type and data[] ; little-endian + length: uint16 // little-endian + type: uint8 // One of FULL, FIRST, MIDDLE, LAST + data: uint8[length] + +header := + ident: char[4] + magic: uint16 + version: uint8 +""" + +import logging +import os +import struct +import zlib +from typing import TYPE_CHECKING, Optional, Tuple + +if TYPE_CHECKING: + from typing import IO, Any + + from wandb.proto.wandb_internal_pb2 import Record + +logger = logging.getLogger(__name__) + +LEVELDBLOG_HEADER_LEN = 7 +LEVELDBLOG_BLOCK_LEN = 32768 +LEVELDBLOG_DATA_LEN = LEVELDBLOG_BLOCK_LEN - LEVELDBLOG_HEADER_LEN + +LEVELDBLOG_FULL = 1 +LEVELDBLOG_FIRST = 2 +LEVELDBLOG_MIDDLE = 3 +LEVELDBLOG_LAST = 4 + +LEVELDBLOG_HEADER_IDENT = ":W&B" +LEVELDBLOG_HEADER_MAGIC = ( + 0xBEE1 # zlib.crc32(bytes("Weights & Biases", 'iso8859-1')) & 0xffff +) +LEVELDBLOG_HEADER_VERSION = 0 + +try: + bytes("", "ascii") + + def strtobytes(x): + """Strtobytes.""" + return bytes(x, "iso8859-1") + +except Exception: + strtobytes = str + + +class DataStore: + _index: int + _flush_offset: int + + def __init__(self) -> None: + self._opened_for_scan = False + self._fp: Optional[IO[Any]] = None + self._index = 0 + self._flush_offset = 0 + self._size_bytes = 0 + + self._crc = [0] * (LEVELDBLOG_LAST + 1) + for x in range(1, LEVELDBLOG_LAST + 1): + self._crc[x] = zlib.crc32(strtobytes(chr(x))) & 0xFFFFFFFF + + def open_for_write(self, fname: str) -> None: + self._fname = fname + logger.info("open: %s", fname) + open_flags = "xb" + self._fp = open(fname, open_flags) + self._write_header() + + def open_for_append(self, fname): + # TODO: implement + self._fname = fname + logger.info("open: %s", fname) + self._fp = open(fname, "wb") + # do something with _index + + def open_for_scan(self, fname): + self._fname = fname + logger.info("open for scan: %s", fname) + self._fp = open(fname, "r+b") + self._index = 0 + self._size_bytes = os.stat(fname).st_size + self._opened_for_scan = True + self._read_header() + + def seek(self, offset: int) -> None: + self._fp.seek(offset) # type: ignore + self._index = offset + + def get_offset(self) -> int: + offset = self._fp.tell() # type: ignore + return offset + + def in_last_block(self): + """Determine if we're in the last block to handle in-progress writes.""" + return self._index > self._size_bytes - LEVELDBLOG_DATA_LEN + + def scan_record(self): + assert self._opened_for_scan, "file not open for scanning" + # TODO(jhr): handle some assertions as file corruption issues + # assume we have enough room to read header, checked by caller? + header = self._fp.read(LEVELDBLOG_HEADER_LEN) + if len(header) == 0: + return None + assert len(header) == LEVELDBLOG_HEADER_LEN, ( + f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}" + ) + fields = struct.unpack(" LEVELDBLOG_DATA_LEN: + self._write_record( + s[data_used : data_used + LEVELDBLOG_DATA_LEN], + LEVELDBLOG_MIDDLE, + ) + data_used += LEVELDBLOG_DATA_LEN + data_left -= LEVELDBLOG_DATA_LEN + + # write last and flush the entire block to disk + self._write_record(s[data_used:], LEVELDBLOG_LAST) + self._fp.flush() + os.fsync(self._fp.fileno()) + self._flush_offset = self._index + + return start_offset, self._index, self._flush_offset + + def ensure_flushed(self, off: int) -> None: + self._fp.flush() # type: ignore + + def write(self, obj: "Record") -> Tuple[int, int, int]: + """Write a protocol buffer. + + Args: + obj: Protocol buffer to write. + + Returns: + (start_offset, end_offset, flush_offset) if successful, + None otherwise + + """ + raw_size = obj.ByteSize() + s = obj.SerializeToString() + assert len(s) == raw_size, "invalid serialization" + ret = self._write_data(s) + return ret + + def close(self) -> None: + if self._fp is not None: + logger.info("close: %s", self._fname) + self._fp.close() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/file_pusher.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/file_pusher.py new file mode 100644 index 0000000000000000000000000000000000000000..72b5b2d8ca41af9d44ccb284c78d600289570f6c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/file_pusher.py @@ -0,0 +1,177 @@ +import concurrent.futures +import logging +import os +import queue +import tempfile +import threading +import time +from typing import TYPE_CHECKING, Optional, Tuple + +import wandb +import wandb.util +from wandb.filesync import stats, step_checksum, step_upload +from wandb.sdk.lib.paths import LogicalPath + +if TYPE_CHECKING: + from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest + from wandb.sdk.artifacts.artifact_saver import SaveFn + from wandb.sdk.internal import file_stream, internal_api + from wandb.sdk.internal.settings_static import SettingsStatic + + +logger = logging.getLogger(__name__) + + +class FilePusher: + """Parallel file upload class. + + This manages uploading multiple files in parallel. It will restart a given file's + upload job if it receives a notification that that file has been modified. The + finish() method will block until all events have been processed and all uploads are + complete. + """ + + MAX_UPLOAD_JOBS = 64 + + def __init__( + self, + api: "internal_api.Api", + file_stream: "file_stream.FileStreamApi", + settings: Optional["SettingsStatic"] = None, + ) -> None: + self._api = api + + # Temporary directory for copies we make of some file types to + # reduce the probability that the file gets changed while we're + # uploading it. + self._tempdir = tempfile.TemporaryDirectory("wandb") + + self._stats = stats.Stats() + + self._incoming_queue: queue.Queue[step_checksum.Event] = queue.Queue() + self._event_queue: queue.Queue[step_upload.Event] = queue.Queue() + + self._step_checksum = step_checksum.StepChecksum( + self._api, + self._tempdir, + self._incoming_queue, + self._event_queue, + self._stats, + ) + self._step_checksum.start() + + self._step_upload = step_upload.StepUpload( + self._api, + self._stats, + self._event_queue, + self.MAX_UPLOAD_JOBS, + file_stream=file_stream, + settings=settings, + ) + self._step_upload.start() + + self._stats_thread_stop = threading.Event() + if os.environ.get("WANDB_DEBUG"): + # debug thread to monitor and report file pusher stats + self._stats_thread = threading.Thread( + target=self._file_pusher_stats, + daemon=True, + name="FPStatsThread", + ) + self._stats_thread.start() + + def _file_pusher_stats(self) -> None: + while not self._stats_thread_stop.is_set(): + logger.info(f"FilePusher stats: {self._stats._stats}") + time.sleep(1) + + def get_status(self) -> Tuple[bool, stats.Summary]: + running = self.is_alive() + summary = self._stats.summary() + return running, summary + + def print_status(self, prefix: bool = True) -> None: + step = 0 + spinner_states = ["-", "\\", "|", "/"] + stop = False + while True: + if not self.is_alive(): + stop = True + summary = self._stats.summary() + line = f" {summary.uploaded_bytes / 1048576.0:.2f}MB of {summary.total_bytes / 1048576.0:.2f}MB uploaded ({summary.deduped_bytes / 1048576.0:.2f}MB deduped)\r" + line = spinner_states[step % 4] + line + step += 1 + wandb.termlog(line, newline=False, prefix=prefix) + if stop: + break + time.sleep(0.25) + dedupe_fraction = ( + summary.deduped_bytes / float(summary.total_bytes) + if summary.total_bytes > 0 + else 0 + ) + if dedupe_fraction > 0.01: + wandb.termlog( + "W&B sync reduced upload amount by %.1f%% " + % (dedupe_fraction * 100), + prefix=prefix, + ) + # clear progress line. + wandb.termlog(" " * 79, prefix=prefix) + + def file_counts_by_category(self) -> stats.FileCountsByCategory: + return self._stats.file_counts_by_category() + + def file_changed(self, save_name: LogicalPath, path: str, copy: bool = True): + """Tell the file pusher that a file's changed and should be uploaded. + + Args: + save_name: string logical location of the file relative to the run + directory. + path: actual string path of the file to upload on the filesystem. + """ + # Tests in linux were failing because wandb-events.jsonl didn't exist + if not os.path.exists(path) or not os.path.isfile(path): + return + if os.path.getsize(path) == 0: + return + + event = step_checksum.RequestUpload(path, save_name, copy) + self._incoming_queue.put(event) + + def store_manifest_files( + self, + manifest: "ArtifactManifest", + artifact_id: str, + save_fn: "SaveFn", + ) -> None: + event = step_checksum.RequestStoreManifestFiles(manifest, artifact_id, save_fn) + self._incoming_queue.put(event) + + def commit_artifact( + self, + artifact_id: str, + *, + finalize: bool = True, + before_commit: step_upload.PreCommitFn, + result_future: "concurrent.futures.Future[None]", + ): + event = step_checksum.RequestCommitArtifact( + artifact_id, finalize, before_commit, result_future + ) + self._incoming_queue.put(event) + + def finish(self, callback: Optional[step_upload.OnRequestFinishFn] = None): + logger.info("shutting down file pusher") + self._incoming_queue.put(step_checksum.RequestFinish(callback)) + self._stats_thread_stop.set() + + def join(self) -> None: + # NOTE: must have called finish before join + logger.info("waiting for file pusher") + while self.is_alive(): + time.sleep(0.5) + self._tempdir.cleanup() + + def is_alive(self) -> bool: + return self._step_checksum.is_alive() or self._step_upload.is_alive() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/file_stream.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/file_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..ec7712bed5c104c2acf30ddaa59a26c124639b90 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/file_stream.py @@ -0,0 +1,687 @@ +import functools +import itertools +import json +import logging +import os +import queue +import random +import sys +import threading +import time +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + Union, +) + +if TYPE_CHECKING: + from typing import TypedDict + + class ProcessedChunk(TypedDict): + offset: int + content: List[str] + + class ProcessedBinaryChunk(TypedDict): + offset: int + content: str + encoding: str + + +import requests + +import wandb +from wandb import util +from wandb.analytics import get_sentry +from wandb.sdk.internal import internal_api + +from ..lib import file_stream_utils + +logger = logging.getLogger(__name__) + + +class Chunk(NamedTuple): + filename: str + data: str + + +class DefaultFilePolicy: + def __init__(self, start_chunk_id: int = 0) -> None: + self._chunk_id = start_chunk_id + self.has_debug_log = False + + def process_chunks( + self, chunks: List[Chunk] + ) -> Union[bool, "ProcessedChunk", "ProcessedBinaryChunk", List["ProcessedChunk"]]: + chunk_id = self._chunk_id + self._chunk_id += len(chunks) + return {"offset": chunk_id, "content": [c.data for c in chunks]} + + # TODO: this is very inefficient, this is meant for temporary debugging and will be removed in future releases + def _debug_log(self, data: Any): + if self.has_debug_log or not os.environ.get("WANDB_DEBUG_FILESTREAM_LOG"): + return + + loaded = json.loads(data) + if not isinstance(loaded, dict): + return + + # get key size and convert to MB + key_sizes = [(k, len(json.dumps(v))) for k, v in loaded.items()] + key_msg = [f"{k}: {v / 1048576:.5f} MB" for k, v in key_sizes] + wandb.termerror(f"Step: {loaded['_step']} | {key_msg}", repeat=False) + self.has_debug_log = True + + +class JsonlFilePolicy(DefaultFilePolicy): + def process_chunks(self, chunks: List[Chunk]) -> "ProcessedChunk": + chunk_id = self._chunk_id + # TODO: chunk_id is getting reset on each request... + self._chunk_id += len(chunks) + chunk_data = [] + for chunk in chunks: + if len(chunk.data) > util.MAX_LINE_BYTES: + msg = f"Metric data exceeds maximum size of {util.to_human_size(util.MAX_LINE_BYTES)} ({util.to_human_size(len(chunk.data))})" + wandb.termerror(msg, repeat=False) + get_sentry().message(msg, repeat=False) + self._debug_log(chunk.data) + else: + chunk_data.append(chunk.data) + + return { + "offset": chunk_id, + "content": chunk_data, + } + + +class SummaryFilePolicy(DefaultFilePolicy): + def process_chunks(self, chunks: List[Chunk]) -> Union[bool, "ProcessedChunk"]: + data = chunks[-1].data + if len(data) > util.MAX_LINE_BYTES: + msg = f"Summary data exceeds maximum size of {util.to_human_size(util.MAX_LINE_BYTES)}. Dropping it." + wandb.termerror(msg, repeat=False) + get_sentry().message(msg, repeat=False) + self._debug_log(data) + return False + return {"offset": 0, "content": [data]} + + +class StreamCRState: + r"""Stream state that tracks carriage returns. + + There are two streams: stdout and stderr. We create two instances for each stream. + An instance holds state about: + found_cr: if a carriage return has been found in this stream. + cr: most recent offset (line number) where we found \r. + We update this offset with every progress bar update. + last_normal: most recent offset without a \r in this stream. + i.e. the most recent "normal" line. + """ + + found_cr: bool + cr: Optional[int] + last_normal: Optional[int] + + def __init__(self) -> None: + self.found_cr = False + self.cr = None + self.last_normal = None + + +class CRDedupeFilePolicy(DefaultFilePolicy): + r"""File stream policy for removing carriage-return erased characters. + + This is what a terminal does. We use it for console output to reduce the amount of + data we need to send over the network (eg. for progress bars), while preserving the + output's appearance in the web app. + + CR stands for "carriage return", for the character \r. It tells the terminal to move + the cursor back to the start of the current line. Progress bars (like tqdm) use \r + repeatedly to overwrite a line with newer updates. This gives the illusion of the + progress bar filling up in real-time. + """ + + def __init__(self, start_chunk_id: int = 0) -> None: + super().__init__(start_chunk_id=start_chunk_id) + self._prev_chunk = None + + self.global_offset = 0 + # cr refers to carriage return \r + self.stderr = StreamCRState() + self.stdout = StreamCRState() + + @staticmethod + def get_consecutive_offsets(console: Dict[int, str]) -> List[List[int]]: + """Compress consecutive line numbers into an interval. + + Args: + console: Dict[int, str] which maps offsets (line numbers) to lines of text. + It represents a mini version of our console dashboard on the UI. + + Returns: + A list of intervals (we compress consecutive line numbers into an interval). + + Example: + >>> console = {2: "", 3: "", 4: "", 5: "", 10: "", 11: "", 20: ""} + >>> get_consecutive_offsets(console) + [(2, 5), (10, 11), (20, 20)] + """ + offsets = sorted(list(console.keys())) + intervals: List = [] + for i, num in enumerate(offsets): + if i == 0: + intervals.append([num, num]) + continue + largest = intervals[-1][1] + if num == largest + 1: + intervals[-1][1] = num + else: + intervals.append([num, num]) + return intervals + + @staticmethod + def split_chunk(chunk: Chunk) -> Tuple[str, str]: + r"""Split chunks. + + Args: + chunk: object with two fields: filename (str) & data (str) + `chunk.data` is a str containing the lines we want. It usually contains \n or \r or both. + `chunk.data` has two possible formats (for the two streams - stdout and stderr): + - "2020-08-25T20:38:36.895321 this is my line of text\nsecond line\n" + - "ERROR 2020-08-25T20:38:36.895321 this is my line of text\nsecond line\nthird\n". + + Here's another example with a carriage return \r. + - "ERROR 2020-08-25T20:38:36.895321 \r progress bar\n" + + Returns: + A 2-tuple of strings. + First str is prefix, either "ERROR {timestamp} " or "{timestamp} ". + Second str is the rest of the string. + + Example: + >>> chunk = Chunk( + ... filename="output.log", + ... data="ERROR 2020-08-25T20:38 this is my line of text\n", + ... ) + >>> split_chunk(chunk) + ("ERROR 2020-08-25T20:38 ", "this is my line of text\n") + """ + prefix = "" + token, rest = chunk.data.split(" ", 1) + if token == "ERROR": + prefix += token + " " + token, rest = rest.split(" ", 1) + prefix += token + " " + return prefix, rest + + def process_chunks(self, chunks: List[Chunk]) -> List["ProcessedChunk"]: + r"""Process chunks. + + Args: + chunks: List of Chunk objects. See description of chunk above in `split_chunk(...)`. + + Returns: + List[Dict]. Each dict in the list contains two keys: an `offset` which holds the line number + and `content` which maps to a list of consecutive lines starting from that offset. + `offset` here means global line number in our console on the UI. + + Example: + >>> chunks = [ + Chunk("output.log", "ERROR 2020-08-25T20:38 this is my line of text\nboom\n"), + Chunk("output.log", "2020-08-25T20:38 this is test\n"), + ] + >>> process_chunks(chunks) + [ + {"offset": 0, "content": [ + "ERROR 2020-08-25T20:38 this is my line of text\n", + "ERROR 2020-08-25T20:38 boom\n", + "2020-08-25T20:38 this is test\n" + ] + } + ] + """ + # Dict[int->str], each offset (line number) mapped to a line. + # Represents a mini-version of our console pane on the UI. + console = {} + sep = os.linesep + + for c in chunks: + prefix, logs_str = self.split_chunk(c) + logs = logs_str.split(sep) + + for line in logs: + stream = self.stderr if prefix.startswith("ERROR ") else self.stdout + if line.startswith("\r"): + # line starting with \r will always overwrite a previous offset. + offset: int = ( + stream.cr + if (stream.found_cr and stream.cr is not None) + else (stream.last_normal or 0) + ) + stream.cr = offset + stream.found_cr = True + console[offset] = prefix + line[1:] + "\n" + + # Usually logs_str = "\r progress bar\n" for progress bar updates. + # If instead logs_str = "\r progress bar\n text\n text\n", + # treat this as the end of a progress bar and reset accordingly. + if ( + logs_str.count(sep) > 1 + and logs_str.replace(sep, "").count("\r") == 1 + ): + stream.found_cr = False + + elif line: + console[self.global_offset] = prefix + line + "\n" + stream.last_normal = self.global_offset + self.global_offset += 1 + + intervals = self.get_consecutive_offsets(console) + ret = [] + for a, b in intervals: + processed_chunk: ProcessedChunk = { + "offset": self._chunk_id + a, + "content": [console[i] for i in range(a, b + 1)], + } + ret.append(processed_chunk) + return ret + + +class FileStreamApi: + """Pushes chunks of files to our streaming endpoint. + + This class is used as a singleton. It has a thread that serializes access to + the streaming endpoint and performs rate-limiting and batching. + + TODO: Differentiate between binary/text encoding. + """ + + class Finish(NamedTuple): + exitcode: int + + class Preempting(NamedTuple): + pass + + class PushSuccess(NamedTuple): + artifact_id: str + save_name: str + + MAX_ITEMS_PER_PUSH = 10000 + + def __init__( + self, + api: "internal_api.Api", + run_id: str, + start_time: float, + timeout: float = 0, + settings: Optional[dict] = None, + ) -> None: + settings = settings or dict() + # NOTE: exc_info is set in thread_except_body context and readable by calling threads + self._exc_info: Optional[ + Union[ + Tuple[Type[BaseException], BaseException, TracebackType], + Tuple[None, None, None], + ] + ] = None + self._settings = settings + self._api = api + self._run_id = run_id + self._start_time = start_time + self._client = requests.Session() + timeout = timeout or 0 + if timeout > 0: + self._client.post = functools.partial(self._client.post, timeout=timeout) # type: ignore[method-assign] + self._client.auth = api.client.transport.session.auth + self._client.headers.update(api.client.transport.headers or {}) + self._client.cookies.update(api.client.transport.cookies or {}) # type: ignore[no-untyped-call] + self._client.proxies.update(api.client.transport.session.proxies or {}) + self._file_policies: Dict[str, DefaultFilePolicy] = {} + self._dropped_chunks: int = 0 + self._queue: queue.Queue = queue.Queue() + self._thread = threading.Thread(target=self._thread_except_body) + # It seems we need to make this a daemon thread to get sync.py's atexit handler to run, which + # cleans this thread up. + self._thread.name = "FileStreamThread" + self._thread.daemon = True + self._init_endpoint() + + def _init_endpoint(self) -> None: + settings = self._api.settings() + settings.update(self._settings) + self._endpoint = "{base}/files/{entity}/{project}/{run}/file_stream".format( + base=settings["base_url"], + entity=settings["entity"], + project=settings["project"], + run=self._run_id, + ) + + def start(self) -> None: + self._init_endpoint() + self._thread.start() + + def set_default_file_policy( + self, filename: str, file_policy: "DefaultFilePolicy" + ) -> None: + """Set an upload policy for a file unless one has already been set.""" + if filename not in self._file_policies: + self._file_policies[filename] = file_policy + + def set_file_policy(self, filename: str, file_policy: "DefaultFilePolicy") -> None: + self._file_policies[filename] = file_policy + + @property + def heartbeat_seconds(self) -> Union[int, float]: + # Defaults to 30 + heartbeat_seconds: Union[int, float] = self._api.dynamic_settings[ + "heartbeat_seconds" + ] + return heartbeat_seconds + + def rate_limit_seconds(self) -> Union[int, float]: + run_time = time.time() - self._start_time + if run_time < 60: + return max(1.0, self.heartbeat_seconds / 15) + elif run_time < 300: + return max(2.5, self.heartbeat_seconds / 3) + else: + return max(5.0, self.heartbeat_seconds) + + def _read_queue(self) -> List: + # called from the push thread (_thread_body), this does an initial read + # that'll block for up to rate_limit_seconds. Then it tries to read + # as much out of the queue as it can. We do this because the http post + # to the server happens within _thread_body, and can take longer than + # our rate limit. So next time we get a chance to read the queue we want + # read all the stuff that queue'd up since last time. + # + # If we have more than MAX_ITEMS_PER_PUSH in the queue then the push thread + # will get behind and data will buffer up in the queue. + return util.read_many_from_queue( + self._queue, self.MAX_ITEMS_PER_PUSH, self.rate_limit_seconds() + ) + + def _thread_body(self) -> None: + posted_data_time = time.time() + posted_anything_time = time.time() + ready_chunks = [] + uploaded: Set[str] = set() + finished: Optional[FileStreamApi.Finish] = None + while finished is None: + items = self._read_queue() + for item in items: + if isinstance(item, self.Finish): + finished = item + elif isinstance(item, self.Preempting): + request_with_retry( + self._client.post, + self._endpoint, + json={ + "complete": False, + "preempting": True, + "dropped": self._dropped_chunks, + "uploaded": list(uploaded), + }, + ) + uploaded = set() + elif isinstance(item, self.PushSuccess): + uploaded.add(item.save_name) + else: + # item is Chunk + ready_chunks.append(item) + + cur_time = time.time() + + if ready_chunks and ( + finished or cur_time - posted_data_time > self.rate_limit_seconds() + ): + posted_data_time = cur_time + posted_anything_time = cur_time + success = self._send(ready_chunks, uploaded=uploaded) + ready_chunks = [] + if success: + uploaded = set() + + # If there aren't ready chunks or uploaded files, we still want to + # send regular heartbeats so the backend doesn't erroneously mark this + # run as crashed. + if cur_time - posted_anything_time > self.heartbeat_seconds: + posted_anything_time = cur_time + + # If we encountered an error trying to publish the + # list of uploaded files, don't reset the `uploaded` + # list. Retry publishing the list on the next attempt. + if not isinstance( + request_with_retry( + self._client.post, + self._endpoint, + json={ + "complete": False, + "failed": False, + "dropped": self._dropped_chunks, + "uploaded": list(uploaded), + }, + ), + Exception, + ): + uploaded = set() + + # post the final close message. (item is self.Finish instance now) + request_with_retry( + self._client.post, + self._endpoint, + json={ + "complete": True, + "exitcode": int(finished.exitcode), + "dropped": self._dropped_chunks, + "uploaded": list(uploaded), + }, + ) + + def _thread_except_body(self) -> None: + # TODO: Consolidate with internal_util.ExceptionThread + try: + self._thread_body() + except Exception: + exc_info = sys.exc_info() + self._exc_info = exc_info + logger.exception("generic exception in filestream thread") + get_sentry().exception(exc_info) + raise + + def _handle_response(self, response: Union[Exception, "requests.Response"]) -> None: + """Log dropped chunks and updates dynamic settings.""" + if isinstance(response, Exception): + wandb.termerror( + "Dropped streaming file chunk (see wandb/debug-internal.log)" + ) + logger.exception(f"dropped chunk {response}") + self._dropped_chunks += 1 + else: + parsed: Optional[dict] = None + try: + parsed = response.json() + except Exception: + pass + if isinstance(parsed, dict): + limits = parsed.get("limits") + if isinstance(limits, dict): + self._api.dynamic_settings.update(limits) + + def _send(self, chunks: List[Chunk], uploaded: Optional[Set[str]] = None) -> bool: + uploaded_list = list(uploaded or []) + # create files dict. dict of pairs where chunks are a list of + # [chunk_id, chunk_data] tuples (as lists since this will be json). + files = {} + # Groupby needs group keys to be consecutive, so sort first. + chunks.sort(key=lambda c: c.filename) + for filename, file_chunks in itertools.groupby(chunks, lambda c: c.filename): + file_chunks_list = list(file_chunks) # groupby returns iterator + # Specific file policies are set by internal/sender.py + self.set_default_file_policy(filename, DefaultFilePolicy()) + files[filename] = self._file_policies[filename].process_chunks( + file_chunks_list + ) + if not files[filename]: + del files[filename] + + for fs in file_stream_utils.split_files(files, max_bytes=util.MAX_LINE_BYTES): + self._handle_response( + request_with_retry( + self._client.post, + self._endpoint, + json={"files": fs, "dropped": self._dropped_chunks}, + retry_callback=self._api.retry_callback, + ) + ) + + if uploaded_list: + if isinstance( + request_with_retry( + self._client.post, + self._endpoint, + json={ + "complete": False, + "failed": False, + "dropped": self._dropped_chunks, + "uploaded": uploaded_list, + }, + ), + Exception, + ): + return False + return True + + def stream_file(self, path: str) -> None: + name = path.split("/")[-1] + with open(path) as f: + self._send([Chunk(name, line) for line in f]) + + def enqueue_preempting(self) -> None: + self._queue.put(self.Preempting()) + + def push(self, filename: str, data: str) -> None: + """Push a chunk of a file to the streaming endpoint. + + Args: + filename: Name of file to append to. + data: Text to append to the file. + """ + self._queue.put(Chunk(filename, data)) + + def push_success(self, artifact_id: str, save_name: str) -> None: + """Notification that a file upload has been successfully completed. + + Args: + artifact_id: ID of artifact + save_name: saved name of the uploaded file + """ + self._queue.put(self.PushSuccess(artifact_id, save_name)) + + def finish(self, exitcode: int) -> None: + """Clean up. + + Anything pushed after finish will be dropped. + + Args: + exitcode: The exitcode of the watched process. + """ + logger.info("file stream finish called") + self._queue.put(self.Finish(exitcode)) + # TODO(jhr): join on a thread which exited with an exception is a noop, clean up this path + self._thread.join() + logger.info("file stream finish is done") + if self._exc_info: + logger.error("FileStream exception", exc_info=self._exc_info) + # re-raising the original exception, will get re-caught in internal.py for the sender thread + if self._exc_info[1] is not None: + raise self._exc_info[1].with_traceback(self._exc_info[2]) + + +MAX_SLEEP_SECONDS = 60 * 5 + + +def request_with_retry( + func: Callable, + *args: Any, + **kwargs: Any, +) -> Union["requests.Response", "requests.RequestException"]: + """Perform a requests http call, retrying with exponential backoff. + + Args: + func: An http-requesting function to call, like requests.post + max_retries: Maximum retries before giving up. + By default, we retry 30 times in ~2 hours before dropping the chunk + *args: passed through to func + **kwargs: passed through to func + """ + max_retries: int = kwargs.pop("max_retries", 30) + retry_callback: Optional[Callable] = kwargs.pop("retry_callback", None) + sleep = 2 + retry_count = 0 + while True: + try: + response: requests.Response = func(*args, **kwargs) + response.raise_for_status() + return response + except ( + requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, + requests.exceptions.Timeout, + ) as e: + if isinstance(e, requests.exceptions.HTTPError): + # Non-retriable HTTP errors. + # + # We retry 500s just to be cautious, and because the back end + # returns them when there are infrastructure issues. If retrying + # some request winds up being problematic, we'll change the + # back end to indicate that it shouldn't be retried. + if e.response is not None and e.response.status_code in { + 400, + 403, + 404, + 409, + }: + return e + + if retry_count == max_retries: + return e + retry_count += 1 + delay = sleep + random.random() * 0.25 * sleep + if isinstance(e, requests.exceptions.HTTPError) and ( + e.response is not None and e.response.status_code == 429 + ): + err_str = ( + f"Filestream rate limit exceeded, retrying in {delay:.1f} seconds. " + ) + if retry_callback: + retry_callback(e.response.status_code, err_str) + logger.info(err_str) + else: + logger.warning( + "requests_with_retry encountered retryable exception: %s. func: %s, args: %s, kwargs: %s", + e, + func, + args, + kwargs, + ) + time.sleep(delay) + sleep *= 2 + if sleep > MAX_SLEEP_SECONDS: + sleep = MAX_SLEEP_SECONDS + except requests.exceptions.RequestException as e: + error_message = "unknown error" + try: + error_message = response.json()["error"] # todo: clean this up + except Exception: + pass + logger.exception(f"requests_with_retry error: {error_message}") + return e diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/handler.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/handler.py new file mode 100644 index 0000000000000000000000000000000000000000..4defa1f2bc76f83bf267f5bc6f3fcd95780e950f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/handler.py @@ -0,0 +1,854 @@ +"""Handle Manager.""" + +import json +import logging +import math +import numbers +import time +from collections import defaultdict +from queue import Queue +from threading import Event +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + cast, +) + +from wandb.errors.links import url_registry +from wandb.proto.wandb_internal_pb2 import ( + HistoryRecord, + InternalMessages, + MetricRecord, + Record, + Result, + RunRecord, + SampledHistoryItem, + SummaryItem, + SummaryRecord, + SummaryRecordRequest, +) + +from ..interface.interface_queue import InterfaceQueue +from ..lib import handler_util, proto_util +from . import context, sample, tb_watcher +from .settings_static import SettingsStatic + +if TYPE_CHECKING: + from wandb.proto.wandb_internal_pb2 import MetricSummary + + +SummaryDict = Dict[str, Any] + +logger = logging.getLogger(__name__) + +# Update (March 5, 2024): Since ~2020/2021, when constructing the summary +# object, we had replaced the artifact path for media types with the latest +# artifact path. The primary purpose of this was to support live updating of +# media objects in the UI (since the default artifact path was fully qualified +# and would not update). However, in March of 2024, a bug was discovered with +# this approach which causes this path to be incorrect in cases where the media +# object is logged to another artifact before being logged to the run. Setting +# this to `False` disables this copy behavior. The impact is that users will +# need to refresh to see updates. Ironically, this updating behavior is not +# currently supported in the UI, so the impact of this change is minimal. +REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False + + +def _dict_nested_set(target: Dict[str, Any], key_list: Sequence[str], v: Any) -> None: + # recurse down the dictionary structure: + + for k in key_list[:-1]: + target.setdefault(k, {}) + new_target = target.get(k) + if TYPE_CHECKING: + new_target = cast(Dict[str, Any], new_target) + target = new_target + # use the last element of the key to write the leaf: + target[key_list[-1]] = v + + +class HandleManager: + _consolidated_summary: SummaryDict + _sampled_history: Dict[str, sample.UniformSampleAccumulator] + _partial_history: Dict[str, Any] + _run_proto: Optional[RunRecord] + _settings: SettingsStatic + _record_q: "Queue[Record]" + _result_q: "Queue[Result]" + _stopped: Event + _writer_q: "Queue[Record]" + _interface: InterfaceQueue + _tb_watcher: Optional[tb_watcher.TBWatcher] + _metric_defines: Dict[str, MetricRecord] + _metric_globs: Dict[str, MetricRecord] + _metric_track: Dict[Tuple[str, ...], float] + _metric_copy: Dict[Tuple[str, ...], Any] + _track_time: Optional[float] + _accumulate_time: float + _run_start_time: Optional[float] + _context_keeper: context.ContextKeeper + + def __init__( + self, + settings: SettingsStatic, + record_q: "Queue[Record]", + result_q: "Queue[Result]", + stopped: Event, + writer_q: "Queue[Record]", + interface: InterfaceQueue, + context_keeper: context.ContextKeeper, + ) -> None: + self._settings = settings + self._record_q = record_q + self._result_q = result_q + self._stopped = stopped + self._writer_q = writer_q + self._interface = interface + self._context_keeper = context_keeper + + self._tb_watcher = None + self._step = 0 + + self._track_time = None + self._accumulate_time = 0 + self._run_start_time = None + + # keep track of summary from key/val updates + self._consolidated_summary = dict() + self._sampled_history = defaultdict(sample.UniformSampleAccumulator) + self._run_proto = None + self._partial_history = dict() + self._metric_defines = defaultdict(MetricRecord) + self._metric_globs = defaultdict(MetricRecord) + self._metric_track = dict() + self._metric_copy = dict() + self._internal_messages = InternalMessages() + + self._dropped_history = False + + def __len__(self) -> int: + return self._record_q.qsize() + + def handle(self, record: Record) -> None: + self._context_keeper.add_from_record(record) + record_type = record.WhichOneof("record_type") + assert record_type + handler_str = "handle_" + record_type + handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore + assert handler, f"unknown handle: {handler_str}" # type: ignore + handler(record) + + def handle_request(self, record: Record) -> None: + request_type = record.request.WhichOneof("request_type") + assert request_type + handler_str = "handle_request_" + request_type + handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore + if request_type != "network_status": + logger.debug(f"handle_request: {request_type}") + assert handler, f"unknown handle: {handler_str}" # type: ignore + handler(record) + + def _dispatch_record(self, record: Record, always_send: bool = False) -> None: + if always_send: + record.control.always_send = True + self._writer_q.put(record) + + def _respond_result(self, result: Result) -> None: + context_id = context.context_id_from_result(result) + self._context_keeper.release(context_id) + self._result_q.put(result) + + def debounce(self) -> None: + pass + + def handle_request_cancel(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_request_defer(self, record: Record) -> None: + defer = record.request.defer + state = defer.state + + logger.info(f"handle defer: {state}") + if state == defer.FLUSH_TB: + if self._tb_watcher: + # shutdown tensorboard workers so we get all metrics flushed + self._tb_watcher.finish() + self._tb_watcher = None + elif state == defer.FLUSH_PARTIAL_HISTORY: + self._flush_partial_history() + elif state == defer.FLUSH_SUM: + self._save_summary(self._consolidated_summary, flush=True) + + # defer is used to drive the sender finish state machine + self._dispatch_record(record, always_send=True) + + def handle_request_python_packages(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_run(self, record: Record) -> None: + if self._settings._offline: + self._run_proto = record.run + result = proto_util._result_from_record(record) + result.run_result.run.CopyFrom(record.run) + self._respond_result(result) + self._dispatch_record(record) + + def handle_stats(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_config(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_output(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_output_raw(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_files(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_request_link_artifact(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_use_artifact(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_artifact(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_alert(self, record: Record) -> None: + self._dispatch_record(record) + + def _save_summary(self, summary_dict: SummaryDict, flush: bool = False) -> None: + summary = SummaryRecord() + for k, v in summary_dict.items(): + update = summary.update.add() + update.key = k + update.value_json = json.dumps(v) + if flush: + record = Record(summary=summary) + self._dispatch_record(record) + elif not self._settings._offline: + # Send this summary update as a request since we aren't persisting every update + summary_record = SummaryRecordRequest(summary=summary) + request_record = self._interface._make_request( + summary_record=summary_record + ) + self._dispatch_record(request_record) + + def _save_history( + self, + history: HistoryRecord, + ) -> None: + for item in history.item: + # TODO(jhr) save nested keys? + k = item.key + v = json.loads(item.value_json) + if isinstance(v, numbers.Real): + self._sampled_history[k].add(v) + + def _update_summary_metrics( + self, + s: "MetricSummary", + kl: List[str], + v: "numbers.Real", + float_v: float, + goal_max: Optional[bool], + ) -> bool: + updated = False + best_key: Optional[Tuple[str, ...]] = None + if s.none: + return False + if s.copy: + # non-key list copy already done in _update_summary + if len(kl) > 1: + _dict_nested_set(self._consolidated_summary, kl, v) + return True + if s.last: + last_key = tuple(kl + ["last"]) + old_last = self._metric_track.get(last_key) + if old_last is None or float_v != old_last: + self._metric_track[last_key] = float_v + _dict_nested_set(self._consolidated_summary, last_key, v) + updated = True + if s.best: + best_key = tuple(kl + ["best"]) + if s.max or best_key and goal_max: + max_key = tuple(kl + ["max"]) + old_max = self._metric_track.get(max_key) + if old_max is None or float_v > old_max: + self._metric_track[max_key] = float_v + if s.max: + _dict_nested_set(self._consolidated_summary, max_key, v) + updated = True + if best_key: + _dict_nested_set(self._consolidated_summary, best_key, v) + updated = True + # defaulting to minimize if goal is not specified + if s.min or best_key and not goal_max: + min_key = tuple(kl + ["min"]) + old_min = self._metric_track.get(min_key) + if old_min is None or float_v < old_min: + self._metric_track[min_key] = float_v + if s.min: + _dict_nested_set(self._consolidated_summary, min_key, v) + updated = True + if best_key: + _dict_nested_set(self._consolidated_summary, best_key, v) + updated = True + if s.mean: + tot_key = tuple(kl + ["tot"]) + num_key = tuple(kl + ["num"]) + avg_key = tuple(kl + ["mean"]) + tot = self._metric_track.get(tot_key, 0.0) + num = self._metric_track.get(num_key, 0) + tot += float_v + num += 1 + self._metric_track[tot_key] = tot + self._metric_track[num_key] = num + _dict_nested_set(self._consolidated_summary, avg_key, tot / num) + updated = True + return updated + + def _update_summary_leaf( + self, + kl: List[str], + v: Any, + d: Optional[MetricRecord] = None, + ) -> bool: + has_summary = d and d.HasField("summary") + if len(kl) == 1: + copy_key = tuple(kl) + old_copy = self._metric_copy.get(copy_key) + if old_copy is None or v != old_copy: + self._metric_copy[copy_key] = v + # Store copy metric if not specified, or copy behavior + if not has_summary or (d and d.summary.copy): + self._consolidated_summary[kl[0]] = v + return True + if not d: + return False + if not has_summary: + return False + if not isinstance(v, numbers.Real): + return False + if math.isnan(v): + return False + float_v = float(v) + goal_max = None + if d.goal: + goal_max = d.goal == d.GOAL_MAXIMIZE + if self._update_summary_metrics( + d.summary, kl=kl, v=v, float_v=float_v, goal_max=goal_max + ): + return True + return False + + def _update_summary_list( + self, + kl: List[str], + v: Any, + d: Optional[MetricRecord] = None, + ) -> bool: + metric_key = ".".join([k.replace(".", "\\.") for k in kl]) + d = self._metric_defines.get(metric_key, d) + # if the dict has _type key, it's a wandb table object + if isinstance(v, dict) and not handler_util.metric_is_wandb_dict(v): + updated = False + for nk, nv in v.items(): + if self._update_summary_list(kl=kl[:] + [nk], v=nv, d=d): + updated = True + return updated + # If the dict is a media object, update the pointer to the latest alias + elif ( + REPLACE_SUMMARY_ART_PATH_WITH_LATEST + and isinstance(v, dict) + and handler_util.metric_is_wandb_dict(v) + ): + if "_latest_artifact_path" in v and "artifact_path" in v: + # TODO: Make non-destructive? + v["artifact_path"] = v["_latest_artifact_path"] + updated = self._update_summary_leaf(kl=kl, v=v, d=d) + return updated + + def _update_summary_media_objects(self, v: Dict[str, Any]) -> Dict[str, Any]: + # For now, non-recursive - just top level + for nk, nv in v.items(): + if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and ( + isinstance(nv, dict) + and handler_util.metric_is_wandb_dict(nv) + and "_latest_artifact_path" in nv + and "artifact_path" in nv + ): + # TODO: Make non-destructive? + nv["artifact_path"] = nv["_latest_artifact_path"] + v[nk] = nv + return v + + def _update_summary(self, history_dict: Dict[str, Any]) -> List[str]: + # keep old behavior fast path if no define metrics have been used + if not self._metric_defines: + history_dict = self._update_summary_media_objects(history_dict) + self._consolidated_summary.update(history_dict) + return list(history_dict.keys()) + updated_keys = [] + for k, v in history_dict.items(): + if self._update_summary_list(kl=[k], v=v): + updated_keys.append(k) + return updated_keys + + def _history_assign_step( + self, + history: HistoryRecord, + history_dict: Dict[str, Any], + ) -> None: + has_step = history.HasField("step") + item = history.item.add() + item.key = "_step" + if has_step: + step = history.step.num + history_dict["_step"] = step + item.value_json = json.dumps(step) + self._step = step + 1 + else: + history_dict["_step"] = self._step + item.value_json = json.dumps(self._step) + self._step += 1 + + def _history_define_metric(self, hkey: str) -> Optional[MetricRecord]: + """Check for hkey match in glob metrics and return the defined metric.""" + # Dont define metric for internal metrics + if hkey.startswith("_"): + return None + for k, mglob in self._metric_globs.items(): + if k.endswith("*"): + if hkey.startswith(k[:-1]): + m = MetricRecord() + m.CopyFrom(mglob) + m.ClearField("glob_name") + m.options.defined = False + m.name = hkey + return m + return None + + def _history_update_leaf( + self, + kl: List[str], + v: Any, + history_dict: Dict[str, Any], + update_history: Dict[str, Any], + ) -> None: + hkey = ".".join([k.replace(".", "\\.") for k in kl]) + m = self._metric_defines.get(hkey) + if not m: + m = self._history_define_metric(hkey) + if not m: + return + mr = Record() + mr.metric.CopyFrom(m) + mr.control.local = True # Dont store this, just send it + self._handle_defined_metric(mr) + + if m.options.step_sync and m.step_metric: + if m.step_metric not in history_dict: + copy_key = tuple([m.step_metric]) + step = self._metric_copy.get(copy_key) + if step is not None: + update_history[m.step_metric] = step + + def _history_update_list( + self, + kl: List[str], + v: Any, + history_dict: Dict[str, Any], + update_history: Dict[str, Any], + ) -> None: + if isinstance(v, dict): + for nk, nv in v.items(): + self._history_update_list( + kl=kl[:] + [nk], + v=nv, + history_dict=history_dict, + update_history=update_history, + ) + return + self._history_update_leaf( + kl=kl, v=v, history_dict=history_dict, update_history=update_history + ) + + def _history_update( + self, + history: HistoryRecord, + history_dict: Dict[str, Any], + ) -> None: + # if syncing an old run, we can skip this logic + if history_dict.get("_step") is None: + self._history_assign_step(history, history_dict) + + update_history: Dict[str, Any] = {} + # Look for metric matches + if self._metric_defines or self._metric_globs: + for hkey, hval in history_dict.items(): + self._history_update_list([hkey], hval, history_dict, update_history) + + if update_history: + history_dict.update(update_history) + for k, v in update_history.items(): + item = history.item.add() + item.key = k + item.value_json = json.dumps(v) + + def handle_history(self, record: Record) -> None: + history_dict = proto_util.dict_from_proto_list(record.history.item) + + # Inject _runtime if it is not present + if history_dict is not None: + if "_runtime" not in history_dict: + self._history_assign_runtime(record.history, history_dict) + + self._history_update(record.history, history_dict) + self._dispatch_record(record) + self._save_history(record.history) + # update summary from history + updated_keys = self._update_summary(history_dict) + if updated_keys: + updated_items = {k: self._consolidated_summary[k] for k in updated_keys} + self._save_summary(updated_items) + + def _flush_partial_history( + self, + step: Optional[int] = None, + ) -> None: + if not self._partial_history: + return + + history = HistoryRecord() + for k, v in self._partial_history.items(): + item = history.item.add() + item.key = k + item.value_json = json.dumps(v) + if step is not None: + history.step.num = step + self.handle_history(Record(history=history)) + self._partial_history = {} + + def handle_request_sender_mark_report(self, record: Record) -> None: + self._dispatch_record(record, always_send=True) + + def handle_request_status_report(self, record: Record) -> None: + self._dispatch_record(record, always_send=True) + + def handle_request_partial_history(self, record: Record) -> None: + partial_history = record.request.partial_history + + flush = None + if partial_history.HasField("action"): + flush = partial_history.action.flush + + step = None + if partial_history.HasField("step"): + step = partial_history.step.num + + history_dict = proto_util.dict_from_proto_list(partial_history.item) + if step is not None: + if step < self._step: + if not self._dropped_history: + message = ( + "Step only supports monotonically increasing values, use define_metric to set a custom x " + f"axis. For details see: {url_registry.url('define-metric')}" + ) + self._internal_messages.warning.append(message) + self._dropped_history = True + message = ( + f"(User provided step: {step} is less than current step: {self._step}. " + f"Dropping entry: {history_dict})." + ) + self._internal_messages.warning.append(message) + return + elif step > self._step: + self._flush_partial_history() + self._step = step + elif flush is None: + flush = True + + self._partial_history.update(history_dict) + + if flush: + self._flush_partial_history(self._step) + + def handle_summary(self, record: Record) -> None: + summary = record.summary + for item in summary.update: + if len(item.nested_key) > 0: + # we use either key or nested_key -- not both + assert item.key == "" + key = tuple(item.nested_key) + else: + # no counter-assertion here, because technically + # summary[""] is valid + key = (item.key,) + + target = self._consolidated_summary + + # recurse down the dictionary structure: + for prop in key[:-1]: + target = target[prop] + + # use the last element of the key to write the leaf: + target[key[-1]] = json.loads(item.value_json) + + for item in summary.remove: + if len(item.nested_key) > 0: + # we use either key or nested_key -- not both + assert item.key == "" + key = tuple(item.nested_key) + else: + # no counter-assertion here, because technically + # summary[""] is valid + key = (item.key,) + + target = self._consolidated_summary + + # recurse down the dictionary structure: + for prop in key[:-1]: + target = target[prop] + + # use the last element of the key to erase the leaf: + del target[key[-1]] + + self._save_summary(self._consolidated_summary) + + def handle_exit(self, record: Record) -> None: + if self._track_time is not None: + self._accumulate_time += time.time() - self._track_time + record.exit.runtime = int(self._accumulate_time) + self._dispatch_record(record, always_send=True) + + def handle_final(self, record: Record) -> None: + self._dispatch_record(record, always_send=True) + + def handle_preempting(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_header(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_footer(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_metadata(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_request_attach(self, record: Record) -> None: + result = proto_util._result_from_record(record) + attach_id = record.request.attach.attach_id + assert attach_id + assert self._run_proto + result.response.attach_response.run.CopyFrom(self._run_proto) + self._respond_result(result) + + def handle_request_log_artifact(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_telemetry(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_request_run_start(self, record: Record) -> None: + run_start = record.request.run_start + assert run_start + assert run_start.run + + self._run_proto = run_start.run + + self._run_start_time = run_start.run.start_time.ToMicroseconds() / 1e6 + + self._track_time = time.time() + if run_start.run.resumed and run_start.run.runtime: + self._accumulate_time = run_start.run.runtime + else: + self._accumulate_time = 0 + + self._tb_watcher = tb_watcher.TBWatcher( + self._settings, interface=self._interface, run_proto=run_start.run + ) + + if run_start.run.resumed or run_start.run.forked: + self._step = run_start.run.starting_step + result = proto_util._result_from_record(record) + self._respond_result(result) + + def handle_request_resume(self, record: Record) -> None: + if self._track_time is not None: + self._accumulate_time += time.time() - self._track_time + self._track_time = time.time() + + def handle_request_pause(self, record: Record) -> None: + if self._track_time is not None: + self._accumulate_time += time.time() - self._track_time + self._track_time = None + + def handle_request_poll_exit(self, record: Record) -> None: + self._dispatch_record(record, always_send=True) + + def handle_request_stop_status(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_request_network_status(self, record: Record) -> None: + self._dispatch_record(record) + + def handle_request_internal_messages(self, record: Record) -> None: + result = proto_util._result_from_record(record) + result.response.internal_messages_response.messages.CopyFrom( + self._internal_messages + ) + self._internal_messages.Clear() + self._respond_result(result) + + def handle_request_status(self, record: Record) -> None: + result = proto_util._result_from_record(record) + self._respond_result(result) + + def handle_request_get_summary(self, record: Record) -> None: + result = proto_util._result_from_record(record) + for key, value in self._consolidated_summary.items(): + item = SummaryItem() + item.key = key + item.value_json = json.dumps(value) + result.response.get_summary_response.item.append(item) + self._respond_result(result) + + def handle_tbrecord(self, record: Record) -> None: + logger.info("handling tbrecord: %s", record) + if self._tb_watcher: + tbrecord = record.tbrecord + self._tb_watcher.add(tbrecord.log_dir, tbrecord.save, tbrecord.root_dir) + self._dispatch_record(record) + + def _handle_defined_metric(self, record: Record) -> None: + metric = record.metric + if metric._control.overwrite: + self._metric_defines[metric.name].CopyFrom(metric) + else: + self._metric_defines[metric.name].MergeFrom(metric) + + # before dispatching, make sure step_metric is defined, if not define it and + # dispatch it locally first + metric = self._metric_defines[metric.name] + if metric.step_metric and metric.step_metric not in self._metric_defines: + m = MetricRecord(name=metric.step_metric) + self._metric_defines[metric.step_metric] = m + mr = Record() + mr.metric.CopyFrom(m) + mr.control.local = True # Don't store this, just send it + self._dispatch_record(mr) + + self._dispatch_record(record) + + def _handle_glob_metric(self, record: Record) -> None: + metric = record.metric + if metric._control.overwrite: + self._metric_globs[metric.glob_name].CopyFrom(metric) + else: + self._metric_globs[metric.glob_name].MergeFrom(metric) + self._dispatch_record(record) + + def handle_metric(self, record: Record) -> None: + """Handle MetricRecord. + + Walkthrough of the life of a MetricRecord: + + Metric defined: + - run.define_metric() parses arguments create wandb_metric.Metric + - build MetricRecord publish to interface + - handler (this function) keeps list of metrics published: + - self._metric_defines: Fully defined metrics + - self._metric_globs: metrics that have a wildcard + - dispatch writer and sender thread + - writer: records are saved to persistent store + - sender: fully defined metrics get mapped into metadata for UI + + History logged: + - handle_history + - check if metric matches _metric_defines + - if not, check if metric matches _metric_globs + - if _metric globs match, generate defined metric and call _handle_metric + + Args: + record (Record): Metric record to process + """ + if record.metric.name: + self._handle_defined_metric(record) + elif record.metric.glob_name: + self._handle_glob_metric(record) + + def handle_request_sampled_history(self, record: Record) -> None: + result = proto_util._result_from_record(record) + for key, sampled in self._sampled_history.items(): + item = SampledHistoryItem() + item.key = key + values: Iterable[Any] = sampled.get() + if all(isinstance(i, numbers.Integral) for i in values): + try: + item.values_int.extend(values) + except ValueError: + # it is safe to ignore these as this is for display information + pass + elif all(isinstance(i, numbers.Real) for i in values): + item.values_float.extend(values) + result.response.sampled_history_response.item.append(item) + self._respond_result(result) + + def handle_request_keepalive(self, record: Record) -> None: + """Handle a keepalive request. + + Keepalive is a noop, we just want to verify transport is alive. + """ + + def handle_request_run_status(self, record: Record) -> None: + self._dispatch_record(record, always_send=True) + + def handle_request_shutdown(self, record: Record) -> None: + # TODO(jhr): should we drain things and stop new requests from coming in? + result = proto_util._result_from_record(record) + self._respond_result(result) + self._stopped.set() + + def handle_request_operations(self, record: Record) -> None: + """No-op. Not implemented for the legacy-service.""" + self._respond_result(proto_util._result_from_record(record)) + + def finish(self) -> None: + logger.info("shutting down handler") + if self._tb_watcher: + self._tb_watcher.finish() + # self._context_keeper._debug_print_orphans() + + def __next__(self) -> Record: + return self._record_q.get(block=True) + + next = __next__ + + def _history_assign_runtime( + self, + history: HistoryRecord, + history_dict: Dict[str, Any], + ) -> None: + # _runtime calculation is meaningless if there is no _timestamp + if "_timestamp" not in history_dict: + return + # if it is offline sync, self._run_start_time is None + # in that case set it to the first tfevent timestamp + if self._run_start_time is None: + self._run_start_time = history_dict["_timestamp"] + history_dict["_runtime"] = history_dict["_timestamp"] - self._run_start_time + item = history.item.add() + item.key = "_runtime" + item.value_json = json.dumps(history_dict[item.key]) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/incremental_table_util.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/incremental_table_util.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d53f7eac0cdc19328f05937427bbd173e0bdfd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/incremental_table_util.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from wandb import Table + from wandb.sdk.artifacts.artifact import Artifact + + from ..wandb_run import Run as LocalRun + +ART_TYPE = "wandb-run-incremental-table" + + +def _get_artifact_name(run: LocalRun, key: str) -> str: + from wandb.sdk.artifacts._internal_artifact import sanitize_artifact_name + + return sanitize_artifact_name(f"run-{run.id}-incr-{key}") + + +def init_artifact(run: LocalRun, sanitized_key: str) -> Artifact: + """Initialize a new artifact for an incremental table. + + Args: + run: The wandb run associated with this artifact + sanitized_key: Sanitized string key to identify the table + + Returns: + A wandb Artifact configured for incremental table storage + """ + from wandb.sdk.artifacts._internal_artifact import InternalArtifact + + artifact = InternalArtifact( + _get_artifact_name(run, sanitized_key), + ART_TYPE, + incremental=True, + ) + return artifact + + +def get_entry_name(incr_table: Table, key: str) -> str: + """Generate a unique entry name for a table increment. + + Args: + run: The wandb run associated with this table + incr_table: The incremental table being updated + key: String key for the table entry + + Returns: + A unique string name for the table entry + """ + epoch = time.time_ns() // 1_000_000 + return f"{incr_table._increment_num}-{epoch}.{key}" diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/internal_api.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/internal_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5e5b9aa56dd907b29ff46bda6915cf299e8aa3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/internal_api.py @@ -0,0 +1,4716 @@ +from __future__ import annotations + +import base64 +import datetime +import functools +import http.client +import json +import logging +import os +import re +import socket +import sys +import threading +from copy import deepcopy +from pathlib import Path +from typing import ( + IO, + TYPE_CHECKING, + Any, + Callable, + Iterable, + Literal, + Mapping, + MutableMapping, + NamedTuple, + Sequence, + TextIO, + Union, + overload, +) + +import click +from wandb_gql import Client, gql +from wandb_gql.client import RetryError +from wandb_graphql.language.ast import Document + +import wandb +from wandb import env, util +from wandb.analytics import get_sentry +from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages +from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError +from wandb.integration.sagemaker import parse_sm_secrets +from wandb.proto.wandb_internal_pb2 import ServerFeature +from wandb.sdk import wandb_setup +from wandb.sdk.internal import settings_static +from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery +from wandb.sdk.lib.gql_request import GraphQLSession +from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64 + +from ..lib import credentials, retry +from ..lib.filenames import DIFF_FNAME, METADATA_FNAME +from . import context +from .progress import Progress + +logger = logging.getLogger(__name__) + +LAUNCH_DEFAULT_PROJECT = "model-registry" + +if TYPE_CHECKING: + from typing import Literal, TypedDict + + import requests + + from .progress import ProgressFn + + class CreateArtifactFileSpecInput(TypedDict, total=False): + """Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql.""" + + artifactID: str + name: str + md5: str + mimetype: str | None + artifactManifestID: str | None + uploadPartsInput: list[dict[str, object]] | None + + class CreateArtifactFilesResponseFile(TypedDict): + id: str + name: str + displayName: str + uploadUrl: str | None + uploadHeaders: Sequence[str] + uploadMultipartUrls: UploadPartsResponse + storagePath: str + artifact: CreateArtifactFilesResponseFileNode + + class CreateArtifactFilesResponseFileNode(TypedDict): + id: str + + class UploadPartsResponse(TypedDict): + uploadUrlParts: list[UploadUrlParts] + uploadID: str + + class UploadUrlParts(TypedDict): + partNumber: int + uploadUrl: str + + class CompleteMultipartUploadArtifactInput(TypedDict): + """Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql.""" + + completeMultipartAction: str + completedParts: dict[int, str] + artifactID: str + storagePath: str + uploadID: str + md5: str + + class CompleteMultipartUploadArtifactResponse(TypedDict): + digest: str + + class DefaultSettings(TypedDict, total=False): + section: str + git_remote: str + ignore_globs: list[str] + base_url: str + root_dir: str | None + api_key: str | None + entity: str | None + organization: str | None + project: str | None + _extra_http_headers: Mapping[str, str] | None + _proxies: Mapping[str, str] | None + + _Response = MutableMapping + SweepState = Literal["RUNNING", "PAUSED", "CANCELED", "FINISHED"] + Number = Union[int, float] + +httpclient_logger = logging.getLogger("http.client") +if os.environ.get("WANDB_DEBUG"): + httpclient_logger.setLevel(logging.DEBUG) + + +def check_httpclient_logger_handler() -> None: + # Only enable http.client logging if WANDB_DEBUG is set + if not os.environ.get("WANDB_DEBUG"): + return + if httpclient_logger.handlers: + return + + # Enable HTTPConnection debug logging to the logging framework + level = logging.DEBUG + + def httpclient_log(*args: Any) -> None: + httpclient_logger.log(level, " ".join(args)) + + # mask the print() built-in in the http.client module to use logging instead + http.client.print = httpclient_log # type: ignore[attr-defined] + # enable debugging + http.client.HTTPConnection.debuglevel = 1 + + root_logger = logging.getLogger("wandb") + if root_logger.handlers: + httpclient_logger.addHandler(root_logger.handlers[0]) + + +class _ThreadLocalData(threading.local): + context: context.Context | None + + def __init__(self) -> None: + self.context = None + + +class _OrgNames(NamedTuple): + entity_name: str + display_name: str + + +def _match_org_with_fetched_org_entities( + organization: str, orgs: Sequence[_OrgNames] +) -> str: + """Match the organization provided in the path with the org entity or org name of the input entity. + + Args: + organization: The organization name to match + orgs: list of tuples containing (org_entity_name, org_display_name) + + Returns: + str: The matched org entity name + + Raises: + ValueError: If no matching organization is found or if multiple orgs exist without a match + """ + for org_names in orgs: + if organization in org_names: + return org_names.entity_name + + if len(orgs) == 1: + raise ValueError( + f"Expecting the organization name or entity name to match {orgs[0].display_name!r} " + f"and cannot be linked/fetched with {organization!r}. " + "Please update the target path with the correct organization name." + ) + + raise ValueError( + "Personal entity belongs to multiple organizations " + f"and cannot be linked/fetched with {organization!r}. " + "Please update the target path with the correct organization name " + "or use a team entity in the entity settings." + ) + + +class Api: + """W&B Internal Api wrapper. + + Note: + Settings are automatically overridden by looking for + a `wandb/settings` file in the current working directory or its parent + directory. If none can be found, we look in the current user's home + directory. + + Args: + default_settings(dict, optional): If you aren't using a settings + file, or you wish to override the section to use in the settings file + Override the settings here. + """ + + HTTP_TIMEOUT = env.get_http_timeout(20) + FILE_PUSHER_TIMEOUT = env.get_file_pusher_timeout() + _global_context: context.Context + _local_data: _ThreadLocalData + + def __init__( + self, + default_settings: ( + wandb.Settings # + | settings_static.SettingsStatic + | DefaultSettings + | None + ) = None, + load_settings: bool = True, + retry_timedelta: datetime.timedelta | None = None, + environ: MutableMapping[str, str] = os.environ, + retry_callback: Callable[[int, str], Any] | None = None, + api_key: str | None = None, + ) -> None: + import requests + + self._environ = environ + self._global_context = context.Context() + self._local_data = _ThreadLocalData() + + default_overrides: dict[str, Any] = ( + dict(default_settings) if default_settings else {} + ) + self.default_settings: DefaultSettings = { + "section": default_overrides.get("section", "default"), + "git_remote": default_overrides.get("git_remote", "origin"), + "ignore_globs": default_overrides.get("ignore_globs", []), + "base_url": default_overrides.get("base_url", "https://api.wandb.ai"), + "root_dir": default_overrides.get("root_dir", None), + "api_key": default_overrides.get("api_key", None), + "entity": default_overrides.get("entity", None), + "organization": default_overrides.get("organization", None), + "project": default_overrides.get("project", None), + "_extra_http_headers": default_overrides.get("_extra_http_headers", None), + "_proxies": default_overrides.get("_proxies", None), + } + + if load_settings: + global_settings = wandb_setup.singleton().settings + if root_dir := self.default_settings["root_dir"]: + global_settings = global_settings.model_copy() + global_settings.root_dir = root_dir + + self._settings = global_settings.read_system_settings().all() + else: + self._settings = {} + + # Mutable settings set by the _file_stream_api + self.dynamic_settings = { + "system_sample_seconds": 2, + "system_samples": 15, + "heartbeat_seconds": 30, + } + + self.retry_timedelta = retry_timedelta or datetime.timedelta(days=7) + self.retry_uploads = 10 + + # todo: remove these hacky hacks after settings refactor is complete + # keeping this code here to limit scope and so that it is easy to remove later + self._extra_http_headers = self.settings("_extra_http_headers") or json.loads( + self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}") + ) + + auth = None + api_key = api_key or self.default_settings.get("api_key") + if api_key: + auth = ("api", api_key) + elif self.access_token is not None: + self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}" + else: + auth = ("api", self.api_key or "") + + proxies = self.settings("_proxies") or json.loads( + self._environ.get("WANDB__PROXIES", "{}") + ) + + self.client = Client( + transport=GraphQLSession( + headers={ + "User-Agent": self.user_agent, + "X-WANDB-USERNAME": env.get_username(env=self._environ), + "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ), + **self._extra_http_headers, + }, + use_json=True, + # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s + # https://bugs.python.org/issue22889 + timeout=self.HTTP_TIMEOUT, + auth=auth, + url=f"{self.settings('base_url')}/graphql", + proxies=proxies, + ) + ) + + self.retry_callback = retry_callback + self._retry_gql = retry.Retry( + self.execute, + retry_timedelta=retry_timedelta, + check_retry_fn=util.no_retry_auth, + retryable_exceptions=(RetryError, requests.RequestException), + retry_callback=retry_callback, + ) + self._current_run_id: str | None = None + self._file_stream_api = None + self._upload_file_session = requests.Session() + if self.FILE_PUSHER_TIMEOUT: + self._upload_file_session.put = functools.partial( # type: ignore + self._upload_file_session.put, + timeout=self.FILE_PUSHER_TIMEOUT, + ) + if proxies: + self._upload_file_session.proxies.update(proxies) + # This Retry class is initialized once for each Api instance, so this + # defaults to retrying 1 million times per process or 7 days + self.upload_file_retry = normalize_exceptions( + retry.retriable(retry_timedelta=retry_timedelta)(self.upload_file) + ) + self.upload_multipart_file_chunk_retry = normalize_exceptions( + retry.retriable(retry_timedelta=retry_timedelta)( + self.upload_multipart_file_chunk + ) + ) + self._client_id_mapping: dict[str, str] = {} + # Large file uploads to azure can optionally use their SDK + self._azure_blob_module = util.get_module("azure.storage.blob") + + self.query_types: list[str] | None = None + self.mutation_types: list[str] | None = None + self.server_info_types: list[str] | None = None + self.server_use_artifact_input_info: list[str] | None = None + self.server_create_artifact_input_info: list[str] | None = None + self.server_artifact_fields_info: list[str] | None = None + self.server_organization_type_fields_info: list[str] | None = None + self.server_supports_enabling_artifact_usage_tracking: bool | None = None + self._max_cli_version: str | None = None + self._server_settings_type: list[str] | None = None + self.fail_run_queue_item_input_info: list[str] | None = None + self.create_launch_agent_input_info: list[str] | None = None + self.server_create_run_queue_supports_drc: bool | None = None + self.server_create_run_queue_supports_priority: bool | None = None + self.server_supports_template_variables: bool | None = None + self.server_push_to_run_queue_supports_priority: bool | None = None + + self._server_features_cache: dict[str, bool] | None = None + + def gql(self, *args: Any, **kwargs: Any) -> Any: + ret = self._retry_gql( + *args, + retry_cancel_event=self.context.cancel_event, + **kwargs, + ) + return ret + + def set_local_context(self, api_context: context.Context | None) -> None: + self._local_data.context = api_context + + def clear_local_context(self) -> None: + self._local_data.context = None + + @property + def context(self) -> context.Context: + return self._local_data.context or self._global_context + + def reauth(self) -> None: + """Ensure the current api key is set in the transport.""" + self.client.transport.session.auth = ("api", self.api_key or "") + + def relocate(self) -> None: + """Ensure the current api points to the right server.""" + self.client.transport.url = "{}/graphql".format(self.settings("base_url")) + + def execute(self, *args: Any, **kwargs: Any) -> _Response: + """Wrapper around execute that logs in cases of failure.""" + import requests + + try: + return self.client.execute(*args, **kwargs) # type: ignore + except requests.exceptions.HTTPError as err: + response = err.response + assert response is not None + logger.exception("Error executing GraphQL.") + for error in parse_backend_error_messages(response): + wandb.termerror(f"Error while calling W&B API: {error} ({response})") + raise + + def validate_api_key(self) -> bool: + """Returns whether the API key stored on initialization is valid.""" + res = self.gql(gql("query { viewer { id } }")) + return res is not None and res["viewer"] is not None + + def set_current_run_id(self, run_id: str) -> None: + self._current_run_id = run_id + + @property + def current_run_id(self) -> str | None: + return self._current_run_id + + @property + def user_agent(self) -> str: + return f"W&B Internal Client {wandb.__version__}" + + @property + def api_key(self) -> str | None: + from wandb.sdk.lib import wbauth + + if ( # + (auth := wbauth.session_credentials(host=self.api_url)) + and isinstance(auth, wbauth.AuthApiKey) + ): + return auth.api_key + + return ( + os.getenv(env.API_KEY) + or wbauth.read_netrc_auth(host=self.api_url) + or parse_sm_secrets().get(env.API_KEY) + or self.default_settings.get("api_key") + ) + + @property + def access_token(self) -> str | None: + """Retrieves an access token for authentication. + + This function attempts to exchange an identity token for a temporary + access token from the server, and save it to the credentials file. + It uses the path to the identity token as defined in the environment + variables. If the environment variable is not set, it returns None. + + Returns: + str | None: The access token if available, otherwise None if + no identity token is supplied. + Raises: + AuthenticationError: If the path to the identity token is not found. + """ + token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE) + if not token_file_str: + return None + + token_file = Path(token_file_str) + if not token_file.exists(): + raise AuthenticationError(f"Identity token file not found: {token_file}") + + base_url = self.settings("base_url") + credentials_file = env.get_credentials_file( + str(credentials.DEFAULT_WANDB_CREDENTIALS_FILE), self._environ + ) + return credentials.access_token(base_url, token_file, credentials_file) + + @property + def api_url(self) -> str: + return self.settings("base_url") # type: ignore + + @property + def app_url(self) -> str: + return wandb.util.app_url(self.api_url) + + @property + def default_entity(self) -> str: + return self.viewer().get("entity") # type: ignore + + @overload + def settings(self, key: None = None) -> dict[str, Any]: ... + + @overload + def settings(self, key: str) -> Any: ... + + def settings(self, key: str | None = None) -> Any: + """The settings overridden from the wandb/settings file. + + Args: + key (str, optional): If provided only this setting is returned + section (str, optional): If provided this section of the setting file is + used, defaults to "default" + + Returns: + A dict with the current settings + + { + "entity": "models", + "base_url": "https://api.wandb.ai", + "project": None, + "organization": "my-org", + } + """ + result: dict[str, Any] = dict(self.default_settings) + result.update(self._settings) + result.update( + { + "entity": env.get_entity( + self._settings.get( + "entity", + result.get("entity"), + ), + env=self._environ, + ), + "organization": env.get_organization( + self._settings.get( + "organization", + result.get("organization"), + ), + env=self._environ, + ), + "project": env.get_project( + self._settings.get( + "project", + result.get("project"), + ), + env=self._environ, + ), + "base_url": env.get_base_url( + self._settings.get( + "base_url", + result.get("base_url"), + ), + env=self._environ, + ), + } + ) + + return result if key is None else result[key] + + def clear_setting(self, key: str) -> None: + self._settings.pop(key, None) + + def set_setting(self, key: str, value: Any) -> None: + self._settings[key] = value + + if key == "entity": + env.set_entity(value, env=self._environ) + elif key == "project": + env.set_project(value, env=self._environ) + elif key == "base_url": + self.relocate() + + def parse_slug( + self, slug: str, project: str | None = None, run: str | None = None + ) -> tuple[str, str]: + """Parse a slug into a project and run. + + Args: + slug (str): The slug to parse + project (str, optional): The project to use, if not provided it will be + inferred from the slug + run (str, optional): The run to use, if not provided it will be inferred + from the slug + + Returns: + A dict with the project and run + """ + if slug and "/" in slug: + parts = slug.split("/") + project = parts[0] + run = parts[1] + else: + project = project or self.settings().get("project") + if project is None: + raise CommError("No default project configured.") + run = run or slug or self.current_run_id or env.get_run(env=self._environ) + assert run, "run must be specified" + return project, run + + @normalize_exceptions + def server_info_introspection(self) -> tuple[list[str], list[str], list[str]]: + query_string = """ + query ProbeServerCapabilities { + QueryType: __type(name: "Query") { + ...fieldData + } + MutationType: __type(name: "Mutation") { + ...fieldData + } + ServerInfoType: __type(name: "ServerInfo") { + ...fieldData + } + } + + fragment fieldData on __Type { + fields { + name + } + } + """ + if ( + self.query_types is None + or self.mutation_types is None + or self.server_info_types is None + ): + query = gql(query_string) + res = self.gql(query) + + self.query_types = [ + field.get("name", "") + for field in res.get("QueryType", {}).get("fields", [{}]) + ] + self.mutation_types = [ + field.get("name", "") + for field in res.get("MutationType", {}).get("fields", [{}]) + ] + self.server_info_types = [ + field.get("name", "") + for field in res.get("ServerInfoType", {}).get("fields", [{}]) + ] + return self.query_types, self.server_info_types, self.mutation_types + + @normalize_exceptions + def server_settings_introspection(self) -> None: + query_string = """ + query ProbeServerSettings { + ServerSettingsType: __type(name: "ServerSettings") { + ...fieldData + } + } + + fragment fieldData on __Type { + fields { + name + } + } + """ + if self._server_settings_type is None: + query = gql(query_string) + res = self.gql(query) + self._server_settings_type = ( + [ + field.get("name", "") + for field in res.get("ServerSettingsType", {}).get("fields", [{}]) + ] + if res + else [] + ) + + def server_use_artifact_input_introspection(self) -> list: + query_string = """ + query ProbeServerUseArtifactInput { + UseArtifactInputInfoType: __type(name: "UseArtifactInput") { + name + inputFields { + name + } + } + } + """ + + if self.server_use_artifact_input_info is None: + query = gql(query_string) + res = self.gql(query) + self.server_use_artifact_input_info = [ + field.get("name", "") + for field in res.get("UseArtifactInputInfoType", {}).get( + "inputFields", [{}] + ) + ] + return self.server_use_artifact_input_info + + @normalize_exceptions + def launch_agent_introspection(self) -> str | None: + query = gql( + """ + query LaunchAgentIntrospection { + LaunchAgentType: __type(name: "LaunchAgent") { + name + } + } + """ + ) + + res = self.gql(query) + return res.get("LaunchAgentType") or None + + @normalize_exceptions + def create_run_queue_introspection(self) -> tuple[bool, bool, bool]: + _, _, mutations = self.server_info_introspection() + query_string = """ + query ProbeCreateRunQueueInput { + CreateRunQueueInputType: __type(name: "CreateRunQueueInput") { + name + inputFields { + name + } + } + } + """ + if ( + self.server_create_run_queue_supports_drc is None + or self.server_create_run_queue_supports_priority is None + ): + query = gql(query_string) + res = self.gql(query) + if res is None: + raise CommError("Could not get CreateRunQueue input from GQL.") + self.server_create_run_queue_supports_drc = "defaultResourceConfigID" in [ + x["name"] + for x in ( + res.get("CreateRunQueueInputType", {}).get("inputFields", [{}]) + ) + ] + self.server_create_run_queue_supports_priority = "prioritizationMode" in [ + x["name"] + for x in ( + res.get("CreateRunQueueInputType", {}).get("inputFields", [{}]) + ) + ] + return ( + "createRunQueue" in mutations, + self.server_create_run_queue_supports_drc, + self.server_create_run_queue_supports_priority, + ) + + @normalize_exceptions + def upsert_run_queue_introspection(self) -> bool: + _, _, mutations = self.server_info_introspection() + return "upsertRunQueue" in mutations + + @normalize_exceptions + def push_to_run_queue_introspection(self) -> tuple[bool, bool]: + query_string = """ + query ProbePushToRunQueueInput { + PushToRunQueueInputType: __type(name: "PushToRunQueueInput") { + name + inputFields { + name + } + } + } + """ + + if ( + self.server_supports_template_variables is None + or self.server_push_to_run_queue_supports_priority is None + ): + query = gql(query_string) + res = self.gql(query) + self.server_supports_template_variables = "templateVariableValues" in [ + x["name"] + for x in ( + res.get("PushToRunQueueInputType", {}).get("inputFields", [{}]) + ) + ] + self.server_push_to_run_queue_supports_priority = "priority" in [ + x["name"] + for x in ( + res.get("PushToRunQueueInputType", {}).get("inputFields", [{}]) + ) + ] + + return ( + self.server_supports_template_variables, + self.server_push_to_run_queue_supports_priority, + ) + + @normalize_exceptions + def create_default_resource_config_introspection(self) -> bool: + _, _, mutations = self.server_info_introspection() + return "createDefaultResourceConfig" in mutations + + @normalize_exceptions + def fail_run_queue_item_introspection(self) -> bool: + _, _, mutations = self.server_info_introspection() + return "failRunQueueItem" in mutations + + @normalize_exceptions + def fail_run_queue_item_fields_introspection(self) -> list: + if self.fail_run_queue_item_input_info: + return self.fail_run_queue_item_input_info + query_string = """ + query ProbeServerFailRunQueueItemInput { + FailRunQueueItemInputInfoType: __type(name:"FailRunQueueItemInput") { + inputFields{ + name + } + } + } + """ + + query = gql(query_string) + res = self.gql(query) + + self.fail_run_queue_item_input_info = [ + field.get("name", "") + for field in res.get("FailRunQueueItemInputInfoType", {}).get( + "inputFields", [{}] + ) + ] + return self.fail_run_queue_item_input_info + + @normalize_exceptions + def fail_run_queue_item( + self, + run_queue_item_id: str, + message: str, + stage: str, + file_paths: list[str] | None = None, + ) -> bool: + if not self.fail_run_queue_item_introspection(): + return False + variable_values: dict[str, str | (list[str] | None)] = { + "runQueueItemId": run_queue_item_id, + } + if "message" in self.fail_run_queue_item_fields_introspection(): + variable_values.update({"message": message, "stage": stage}) + if file_paths is not None: + variable_values["filePaths"] = file_paths + mutation_string = """ + mutation failRunQueueItem($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) { + failRunQueueItem( + input: { + runQueueItemId: $runQueueItemId + message: $message + stage: $stage + filePaths: $filePaths + } + ) { + success + } + } + """ + else: + mutation_string = """ + mutation failRunQueueItem($runQueueItemId: ID!) { + failRunQueueItem( + input: { + runQueueItemId: $runQueueItemId + } + ) { + success + } + } + """ + + mutation = gql(mutation_string) + response = self.gql(mutation, variable_values=variable_values) + result: bool = response["failRunQueueItem"]["success"] + return result + + @normalize_exceptions + def update_run_queue_item_warning_introspection(self) -> bool: + _, _, mutations = self.server_info_introspection() + return "updateRunQueueItemWarning" in mutations + + def _server_features(self) -> dict[str, bool]: + # NOTE: Avoid caching via `@cached_property`, due to undocumented + # locking behavior before Python 3.12. + # See: https://github.com/python/cpython/issues/87634 + query = gql(SERVER_FEATURES_QUERY_GQL) + try: + response = self.gql(query) + except Exception as e: + # Unfortunately we currently have to match on the text of the error message, + # as the `gql` client raises `Exception` rather than a more specific error. + if 'Cannot query field "features" on type "ServerInfo".' in str(e): + self._server_features_cache = {} + else: + raise + else: + info = ServerFeaturesQuery.model_validate(response).server_info + if info and (feats := info.features): + self._server_features_cache = {f.name: f.is_enabled for f in feats if f} + else: + self._server_features_cache = {} + return self._server_features_cache + + def _server_supports(self, feature: int | str) -> bool: + """Return whether the current server supports the given feature. + + This also caches the underlying lookup of server feature flags, + and it maps {feature_name (str) -> is_enabled (bool)}. + + Good to use for features that have a fallback mechanism for older servers. + """ + # If we're given the protobuf enum value, convert to a string name. + # NOTE: We deliberately use names (str) instead of enum values (int) + # as the keys here, since: + # - the server identifies features by their name, rather than (client-side) enum value + # - the defined list of client-side flags may be behind the server-side list of flags + key = ServerFeature.Name(feature) if isinstance(feature, int) else feature + return self._server_features().get(key) or False + + @normalize_exceptions + def update_run_queue_item_warning( + self, + run_queue_item_id: str, + message: str, + stage: str, + file_paths: list[str] | None = None, + ) -> bool: + if not self.update_run_queue_item_warning_introspection(): + return False + mutation = gql( + """ + mutation updateRunQueueItemWarning($runQueueItemId: ID!, $message: String!, $stage: String!, $filePaths: [String!]) { + updateRunQueueItemWarning( + input: { + runQueueItemId: $runQueueItemId + message: $message + stage: $stage + filePaths: $filePaths + } + ) { + success + } + } + """ + ) + response = self.gql( + mutation, + variable_values={ + "runQueueItemId": run_queue_item_id, + "message": message, + "stage": stage, + "filePaths": file_paths, + }, + ) + result: bool = response["updateRunQueueItemWarning"]["success"] + return result + + @normalize_exceptions + def viewer(self) -> dict[str, Any]: + query = gql( + """ + query Viewer{ + viewer { + id + entity + username + flags + teams { + edges { + node { + name + } + } + } + } + } + """ + ) + res = self.gql(query) + return res.get("viewer") or {} + + @normalize_exceptions + def max_cli_version(self) -> str | None: + if self._max_cli_version is not None: + return self._max_cli_version + + query_types, server_info_types, _ = self.server_info_introspection() + cli_version_exists = ( + "serverInfo" in query_types and "cliVersionInfo" in server_info_types + ) + if not cli_version_exists: + return None + + _, server_info = self.viewer_server_info() + self._max_cli_version = server_info.get("cliVersionInfo", {}).get( + "max_cli_version" + ) + return self._max_cli_version + + @normalize_exceptions + def viewer_server_info(self) -> tuple[dict[str, Any], dict[str, Any]]: + local_query = """ + latestLocalVersionInfo { + outOfDate + latestVersionString + versionOnThisInstanceString + } + """ + cli_query = """ + serverInfo { + cliVersionInfo + _LOCAL_QUERY_ + } + """ + query_template = """ + query Viewer{ + viewer { + id + entity + username + email + flags + teams { + edges { + node { + name + } + } + } + } + _CLI_QUERY_ + } + """ + query_types, server_info_types, _ = self.server_info_introspection() + + cli_version_exists = ( + "serverInfo" in query_types and "cliVersionInfo" in server_info_types + ) + + local_version_exists = ( + "serverInfo" in query_types + and "latestLocalVersionInfo" in server_info_types + ) + + cli_query_string = "" if not cli_version_exists else cli_query + local_query_string = "" if not local_version_exists else local_query + + query_string = query_template.replace("_CLI_QUERY_", cli_query_string).replace( + "_LOCAL_QUERY_", local_query_string + ) + query = gql(query_string) + res = self.gql(query) + return res.get("viewer") or {}, res.get("serverInfo") or {} + + @normalize_exceptions + def list_projects(self, entity: str | None = None) -> list[dict[str, str]]: + """List projects in W&B scoped by entity. + + Args: + entity (str, optional): The entity to scope this project to. + + Returns: + [{"id","name","description"}] + """ + query = gql( + """ + query EntityProjects($entity: String) { + models(first: 10, entityName: $entity) { + edges { + node { + id + name + description + } + } + } + } + """ + ) + project_list: list[dict[str, str]] = self._flatten_edges( + self.gql( + query, variable_values={"entity": entity or self.settings("entity")} + )["models"] + ) + return project_list + + @normalize_exceptions + def project(self, project: str, entity: str | None = None) -> _Response: + """Retrieve project. + + Args: + project (str): The project to get details for + entity (str, optional): The entity to scope this project to. + + Returns: + [{"id","name","repo","dockerImage","description"}] + """ + query = gql( + """ + query ProjectDetails($entity: String, $project: String) { + model(name: $project, entityName: $entity) { + id + name + repo + dockerImage + description + } + } + """ + ) + response: _Response = self.gql( + query, variable_values={"entity": entity, "project": project} + )["model"] + return response + + @normalize_exceptions + def sweep( + self, + sweep: str, + specs: str, + project: str | None = None, + entity: str | None = None, + ) -> dict[str, Any]: + """Retrieve sweep. + + Args: + sweep (str): The sweep to get details for + specs (str): history specs + project (str, optional): The project to scope this sweep to. + entity (str, optional): The entity to scope this sweep to. + + Returns: + [{"id","name","repo","dockerImage","description"}] + """ + query = gql( + """ + query SweepWithRuns($entity: String, $project: String, $sweep: String!, $specs: [JSONString!]!) { + project(name: $project, entityName: $entity) { + sweep(sweepName: $sweep) { + id + name + method + state + description + config + createdAt + heartbeatAt + updatedAt + earlyStopJobRunning + bestLoss + controller + scheduler + runs { + edges { + node { + name + state + config + exitcode + heartbeatAt + shouldStop + failed + stopped + running + summaryMetrics + sampledHistory(specs: $specs) + } + } + } + } + } + } + """ + ) + entity = entity or self.settings("entity") + project = project or self.settings("project") + response = self.gql( + query, + variable_values={ + "entity": entity, + "project": project, + "sweep": sweep, + "specs": specs, + }, + ) + if response["project"] is None or response["project"]["sweep"] is None: + raise ValueError(f"Sweep {entity}/{project}/{sweep} not found") + data: dict[str, Any] = response["project"]["sweep"] + if data: + data["runs"] = self._flatten_edges(data["runs"]) + return data + + @normalize_exceptions + def list_runs( + self, project: str, entity: str | None = None + ) -> list[dict[str, str]]: + """List runs in W&B scoped by project. + + Args: + project (str): The project to scope the runs to + entity (str, optional): The entity to scope this project to. Defaults to public models + + Returns: + [{"id","name","description"}] + """ + query = gql( + """ + query ProjectRuns($model: String!, $entity: String) { + model(name: $model, entityName: $entity) { + buckets(first: 10) { + edges { + node { + id + name + displayName + description + } + } + } + } + } + """ + ) + return self._flatten_edges( + self.gql( + query, + variable_values={ + "entity": entity or self.settings("entity"), + "model": project or self.settings("project"), + }, + )["model"]["buckets"] + ) + + @normalize_exceptions + def run_config( + self, project: str, run: str | None = None, entity: str | None = None + ) -> tuple[str, dict[str, Any], str | None, dict[str, Any]]: + """Get the relevant configs for a run. + + Args: + project (str): The project to download, (can include bucket) + run (str, optional): The run to download + entity (str, optional): The entity to scope this project to. + """ + import requests + + check_httpclient_logger_handler() + + query = gql( + """ + query RunConfigs( + $name: String!, + $entity: String, + $run: String!, + $pattern: String!, + $includeConfig: Boolean!, + ) { + model(name: $name, entityName: $entity) { + bucket(name: $run) { + config @include(if: $includeConfig) + commit @include(if: $includeConfig) + files(pattern: $pattern) { + pageInfo { + hasNextPage + endCursor + } + edges { + node { + name + directUrl + } + } + } + } + } + } + """ + ) + + variable_values = { + "name": project, + "run": run, + "entity": entity, + "includeConfig": True, + } + + commit: str = "" + config: dict[str, Any] = {} + patch: str | None = None + metadata: dict[str, Any] = {} + + # If we use the `names` parameter on the `files` node, then the server + # will helpfully give us and 'open' file handle to the files that don't + # exist. This is so that we can upload data to it. However, in this + # case, we just want to download that file and not upload to it, so + # let's instead query for the files that do exist using `pattern` + # (with no wildcards). + # + # Unfortunately we're unable to construct a single pattern that matches + # our 2 files, we would need something like regex for that. + for filename in [DIFF_FNAME, METADATA_FNAME]: + variable_values["pattern"] = filename + response = self.gql(query, variable_values=variable_values) + if response["model"] is None: + raise CommError(f"Run {entity}/{project}/{run} not found") + run_obj: dict = response["model"]["bucket"] + # we only need to fetch this config once + if variable_values["includeConfig"]: + commit = run_obj["commit"] + config = json.loads(run_obj["config"] or "{}") + variable_values["includeConfig"] = False + if run_obj["files"] is not None: + for file_edge in run_obj["files"]["edges"]: + name = file_edge["node"]["name"] + url = file_edge["node"]["directUrl"] + res = requests.get(url) + res.raise_for_status() + if name == METADATA_FNAME: + metadata = res.json() + elif name == DIFF_FNAME: + patch = res.text + + return commit, config, patch, metadata + + @normalize_exceptions + def run_resume_status( + self, entity: str, project_name: str, name: str + ) -> dict[str, Any] | None: + """Check if a run exists and get resume information. + + Args: + entity (str): The entity to scope this project to. + project_name (str): The project to download, (can include bucket) + name (str): The run to download + """ + # Pulling wandbConfig.start_time is required so that we can determine if a run has actually started + query = gql( + """ + query RunResumeStatus($project: String, $entity: String, $name: String!) { + model(name: $project, entityName: $entity) { + id + name + entity { + id + name + } + + bucket(name: $name, missingOk: true) { + id + name + summaryMetrics + displayName + logLineCount + historyLineCount + eventsLineCount + historyTail + eventsTail + config + tags + wandbConfig(keys: ["t"]) + } + } + } + """ + ) + + response = self.gql( + query, + variable_values={ + "entity": entity, + "project": project_name, + "name": name, + }, + ) + + if "model" not in response or "bucket" not in (response["model"] or {}): + return None + + project = response["model"] + self.set_setting("project", project_name) + if "entity" in project: + self.set_setting("entity", project["entity"]["name"]) + + result: dict[str, Any] = project["bucket"] + + return result + + @normalize_exceptions + def check_stop_requested( + self, project_name: str, entity_name: str, run_id: str + ) -> bool: + query = gql( + """ + query RunStoppedStatus($projectName: String, $entityName: String, $runId: String!) { + project(name:$projectName, entityName:$entityName) { + run(name:$runId) { + stopped + } + } + } + """ + ) + + response = self.gql( + query, + variable_values={ + "projectName": project_name, + "entityName": entity_name, + "runId": run_id, + }, + ) + + project = response.get("project", None) + if not project: + return False + run = project.get("run", None) + if not run: + return False + + status: bool = run["stopped"] + return status + + def format_project(self, project: str) -> str: + return re.sub(r"\W+", "-", project.lower()).strip("-_") + + @normalize_exceptions + def upsert_project( + self, + project: str, + id: str | None = None, + description: str | None = None, + entity: str | None = None, + ) -> dict[str, Any]: + """Create a new project. + + Args: + project (str): The project to create + description (str, optional): A description of this project + entity (str, optional): The entity to scope this project to. + """ + mutation = gql( + """ + mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String) { + upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) { + model { + name + description + } + } + } + """ + ) + response = self.gql( + mutation, + variable_values={ + "name": self.format_project(project), + "entity": entity or self.settings("entity"), + "description": description, + "id": id, + }, + ) + result: dict[str, Any] = response["upsertModel"]["model"] + return result + + @normalize_exceptions + def entity_is_team(self, entity: str) -> bool: + query = gql( + """ + query EntityIsTeam($entity: String!) { + entity(name: $entity) { + id + isTeam + } + } + """ + ) + variable_values = { + "entity": entity, + } + + res = self.gql(query, variable_values) + if res.get("entity") is None: + raise Exception( + f"Error fetching entity {entity} " + "check that you have access to this entity" + ) + + is_team: bool = res["entity"]["isTeam"] + return is_team + + @normalize_exceptions + def get_project_run_queues(self, entity: str, project: str) -> list[dict[str, str]]: + query = gql( + """ + query ProjectRunQueues($entity: String!, $projectName: String!){ + project(entityName: $entity, name: $projectName) { + runQueues { + id + name + createdBy + access + } + } + } + """ + ) + variable_values = { + "projectName": project, + "entity": entity, + } + + res = self.gql(query, variable_values) + if res.get("project") is None: + # circular dependency: (LAUNCH_DEFAULT_PROJECT = model-registry) + if project == "model-registry": + msg = ( + f"Error fetching run queues for {entity} " + "check that you have access to this entity and project" + ) + else: + msg = ( + f"Error fetching run queues for {entity}/{project} " + "check that you have access to this entity and project" + ) + + raise Exception(msg) + + project_run_queues: list[dict[str, str]] = res["project"]["runQueues"] + return project_run_queues + + @normalize_exceptions + def create_default_resource_config( + self, + entity: str, + resource: str, + config: str, + template_variables: dict[str, float | int | str] | None, + ) -> dict[str, Any] | None: + if not self.create_default_resource_config_introspection(): + raise Exception() + supports_template_vars, _ = self.push_to_run_queue_introspection() + + mutation_params = """ + $entityName: String!, + $resource: String!, + $config: JSONString! + """ + mutation_inputs = """ + entityName: $entityName, + resource: $resource, + config: $config + """ + + if supports_template_vars: + mutation_params += ", $templateVariables: JSONString" + mutation_inputs += ", templateVariables: $templateVariables" + else: + if template_variables is not None: + raise UnsupportedError( + "server does not support template variables, please update server instance to >=0.46" + ) + + variable_values = { + "entityName": entity, + "resource": resource, + "config": config, + } + if supports_template_vars: + if template_variables is not None: + variable_values["templateVariables"] = json.dumps(template_variables) + else: + variable_values["templateVariables"] = "{}" + + query = gql( + f""" + mutation createDefaultResourceConfig( + {mutation_params} + ) {{ + createDefaultResourceConfig( + input: {{ + {mutation_inputs} + }} + ) {{ + defaultResourceConfigID + success + }} + }} + """ + ) + + result: dict[str, Any] | None = self.gql(query, variable_values)[ + "createDefaultResourceConfig" + ] + return result + + @normalize_exceptions + def create_run_queue( + self, + entity: str, + project: str, + queue_name: str, + access: str, + prioritization_mode: str | None = None, + config_id: str | None = None, + ) -> dict[str, Any] | None: + ( + create_run_queue, + supports_drc, + supports_prioritization, + ) = self.create_run_queue_introspection() + if not create_run_queue: + raise UnsupportedError( + "run queue creation is not supported by this version of " + "wandb server. Consider updating to the latest version." + ) + if not supports_drc and config_id is not None: + raise UnsupportedError( + "default resource configurations are not supported by this version " + "of wandb server. Consider updating to the latest version." + ) + if not supports_prioritization and prioritization_mode is not None: + raise UnsupportedError( + "launch prioritization is not supported by this version of " + "wandb server. Consider updating to the latest version." + ) + + if supports_prioritization: + query = gql( + """ + mutation createRunQueue( + $entity: String!, + $project: String!, + $queueName: String!, + $access: RunQueueAccessType!, + $prioritizationMode: RunQueuePrioritizationMode, + $defaultResourceConfigID: ID, + ) { + createRunQueue( + input: { + entityName: $entity, + projectName: $project, + queueName: $queueName, + access: $access, + prioritizationMode: $prioritizationMode + defaultResourceConfigID: $defaultResourceConfigID + } + ) { + success + queueID + } + } + """ + ) + variable_values = { + "entity": entity, + "project": project, + "queueName": queue_name, + "access": access, + "prioritizationMode": prioritization_mode, + "defaultResourceConfigID": config_id, + } + else: + query = gql( + """ + mutation createRunQueue( + $entity: String!, + $project: String!, + $queueName: String!, + $access: RunQueueAccessType!, + $defaultResourceConfigID: ID, + ) { + createRunQueue( + input: { + entityName: $entity, + projectName: $project, + queueName: $queueName, + access: $access, + defaultResourceConfigID: $defaultResourceConfigID + } + ) { + success + queueID + } + } + """ + ) + variable_values = { + "entity": entity, + "project": project, + "queueName": queue_name, + "access": access, + "defaultResourceConfigID": config_id, + } + + result: dict[str, Any] | None = self.gql(query, variable_values)[ + "createRunQueue" + ] + return result + + @normalize_exceptions + def upsert_run_queue( + self, + queue_name: str, + entity: str, + resource_type: str, + resource_config: dict, + project: str = LAUNCH_DEFAULT_PROJECT, + prioritization_mode: str | None = None, + template_variables: dict | None = None, + external_links: dict | None = None, + ) -> dict[str, Any] | None: + if not self.upsert_run_queue_introspection(): + raise UnsupportedError( + "upserting run queues is not supported by this version of " + "wandb server. Consider updating to the latest version." + ) + query = gql( + """ + mutation upsertRunQueue( + $entityName: String! + $projectName: String! + $queueName: String! + $resourceType: String! + $resourceConfig: JSONString! + $templateVariables: JSONString + $prioritizationMode: RunQueuePrioritizationMode + $externalLinks: JSONString + $clientMutationId: String + ) { + upsertRunQueue( + input: { + entityName: $entityName + projectName: $projectName + queueName: $queueName + resourceType: $resourceType + resourceConfig: $resourceConfig + templateVariables: $templateVariables + prioritizationMode: $prioritizationMode + externalLinks: $externalLinks + clientMutationId: $clientMutationId + } + ) { + success + configSchemaValidationErrors + } + } + """ + ) + variable_values = { + "entityName": entity, + "projectName": project, + "queueName": queue_name, + "resourceType": resource_type, + "resourceConfig": json.dumps(resource_config), + "templateVariables": ( + json.dumps(template_variables) if template_variables else None + ), + "prioritizationMode": prioritization_mode, + "externalLinks": json.dumps(external_links) if external_links else None, + } + result: dict[str, Any] = self.gql(query, variable_values) + return result["upsertRunQueue"] + + @normalize_exceptions + def push_to_run_queue_by_name( + self, + entity: str, + project: str, + queue_name: str, + run_spec: str, + template_variables: dict[str, int | float | str] | None, + priority: int | None = None, + ) -> dict[str, Any] | None: + self.push_to_run_queue_introspection() + """Queryless mutation, should be used before legacy fallback method.""" + + mutation_params = """ + $entityName: String!, + $projectName: String!, + $queueName: String!, + $runSpec: JSONString! + """ + + mutation_input = """ + entityName: $entityName, + projectName: $projectName, + queueName: $queueName, + runSpec: $runSpec + """ + + variables: dict[str, Any] = { + "entityName": entity, + "projectName": project, + "queueName": queue_name, + "runSpec": run_spec, + } + if self.server_push_to_run_queue_supports_priority: + if priority is not None: + variables["priority"] = priority + mutation_params += ", $priority: Int" + mutation_input += ", priority: $priority" + else: + if priority is not None: + raise UnsupportedError( + "server does not support priority, please update server instance to >=0.46" + ) + + if self.server_supports_template_variables: + if template_variables is not None: + variables.update( + {"templateVariableValues": json.dumps(template_variables)} + ) + mutation_params += ", $templateVariableValues: JSONString" + mutation_input += ", templateVariableValues: $templateVariableValues" + else: + if template_variables is not None: + raise UnsupportedError( + "server does not support template variables, please update server instance to >=0.46" + ) + + mutation = gql( + f""" + mutation pushToRunQueueByName( + {mutation_params} + ) {{ + pushToRunQueueByName( + input: {{ + {mutation_input} + }} + ) {{ + runQueueItemId + runSpec + }} + }} + """ + ) + + try: + result: dict[str, Any] | None = self.gql( + mutation, variables, check_retry_fn=util.no_retry_4xx + ).get("pushToRunQueueByName") + if not result: + return None + + if result.get("runSpec"): + run_spec = json.loads(str(result["runSpec"])) + result["runSpec"] = run_spec + + return result + except Exception as e: + if ( + 'Cannot query field "runSpec" on type "PushToRunQueueByNamePayload"' + not in str(e) + ): + return None + + mutation_no_runspec = gql( + """ + mutation pushToRunQueueByName( + $entityName: String!, + $projectName: String!, + $queueName: String!, + $runSpec: JSONString!, + ) { + pushToRunQueueByName( + input: { + entityName: $entityName, + projectName: $projectName, + queueName: $queueName, + runSpec: $runSpec + } + ) { + runQueueItemId + } + } + """ + ) + + try: + result = self.gql( + mutation_no_runspec, variables, check_retry_fn=util.no_retry_4xx + ).get("pushToRunQueueByName") + except Exception: + result = None + + return result + + @normalize_exceptions + def push_to_run_queue( + self, + queue_name: str, + launch_spec: dict[str, str], + template_variables: dict | None, + project_queue: str, + priority: int | None = None, + ) -> dict[str, Any] | None: + self.push_to_run_queue_introspection() + entity = launch_spec.get("queue_entity") or launch_spec["entity"] + run_spec = json.dumps(launch_spec) + + push_result = self.push_to_run_queue_by_name( + entity, project_queue, queue_name, run_spec, template_variables, priority + ) + + if push_result: + return push_result + + if priority is not None: + # Cannot proceed with legacy method if priority is set + return None + + """ Legacy Method """ + queues_found = self.get_project_run_queues(entity, project_queue) + matching_queues = [ + q + for q in queues_found + if q["name"] == queue_name + # ensure user has access to queue + and ( + # TODO: User created queues in the UI have USER access + q["access"] in ["PROJECT", "USER"] + or q["createdBy"] == self.default_entity + ) + ] + if not matching_queues: + # in the case of a missing default queue. create it + if queue_name == "default": + wandb.termlog( + f"No default queue existing for entity: {entity} in project: {project_queue}, creating one." + ) + res = self.create_run_queue( + launch_spec["entity"], + project_queue, + queue_name, + access="PROJECT", + ) + + if res is None or res.get("queueID") is None: + wandb.termerror( + f"Unable to create default queue for entity: {entity} on project: {project_queue}. Run could not be added to a queue" + ) + return None + queue_id = res["queueID"] + + else: + if project_queue == "model-registry": + _msg = f"Unable to push to run queue {queue_name}. Queue not found." + else: + _msg = f"Unable to push to run queue {project_queue}/{queue_name}. Queue not found." + wandb.termwarn(_msg) + return None + elif len(matching_queues) > 1: + wandb.termerror( + f"Unable to push to run queue {queue_name}. More than one queue found with this name." + ) + return None + else: + queue_id = matching_queues[0]["id"] + spec_json = json.dumps(launch_spec) + variables = {"queueID": queue_id, "runSpec": spec_json} + + mutation_params = """ + $queueID: ID!, + $runSpec: JSONString! + """ + mutation_input = """ + queueID: $queueID, + runSpec: $runSpec + """ + if self.server_supports_template_variables: + if template_variables is not None: + mutation_params += ", $templateVariableValues: JSONString" + mutation_input += ", templateVariableValues: $templateVariableValues" + variables.update( + {"templateVariableValues": json.dumps(template_variables)} + ) + else: + if template_variables is not None: + raise UnsupportedError( + "server does not support template variables, please update server instance to >=0.46" + ) + + mutation = gql( + f""" + mutation pushToRunQueue( + {mutation_params} + ) {{ + pushToRunQueue( + input: {{{mutation_input}}} + ) {{ + runQueueItemId + }} + }} + """ + ) + + response = self.gql(mutation, variable_values=variables) + if not response.get("pushToRunQueue"): + raise CommError(f"Error pushing run queue item to queue {queue_name}.") + + result: dict[str, Any] | None = response["pushToRunQueue"] + return result + + @normalize_exceptions + def pop_from_run_queue( + self, + queue_name: str, + entity: str | None = None, + project: str | None = None, + agent_id: str | None = None, + ) -> dict[str, Any] | None: + mutation = gql( + """ + mutation popFromRunQueue($entity: String!, $project: String!, $queueName: String!, $launchAgentId: ID) { + popFromRunQueue(input: { + entityName: $entity, + projectName: $project, + queueName: $queueName, + launchAgentId: $launchAgentId + }) { + runQueueItemId + runSpec + } + } + """ + ) + response = self.gql( + mutation, + variable_values={ + "entity": entity, + "project": project, + "queueName": queue_name, + "launchAgentId": agent_id, + }, + ) + result: dict[str, Any] | None = response["popFromRunQueue"] + return result + + @normalize_exceptions + def ack_run_queue_item(self, item_id: str, run_id: str | None = None) -> bool: + mutation = gql( + """ + mutation ackRunQueueItem($itemId: ID!, $runId: String!) { + ackRunQueueItem(input: { runQueueItemId: $itemId, runName: $runId }) { + success + } + } + """ + ) + response = self.gql( + mutation, variable_values={"itemId": item_id, "runId": str(run_id)} + ) + if not response["ackRunQueueItem"]["success"]: + raise CommError( + "Error acking run queue item. Item may have already been acknowledged by another process" + ) + result: bool = response["ackRunQueueItem"]["success"] + return result + + @normalize_exceptions + def create_launch_agent_fields_introspection(self) -> list: + if self.create_launch_agent_input_info: + return self.create_launch_agent_input_info + query_string = """ + query ProbeServerCreateLaunchAgentInput { + CreateLaunchAgentInputInfoType: __type(name:"CreateLaunchAgentInput") { + inputFields{ + name + } + } + } + """ + + query = gql(query_string) + res = self.gql(query) + + self.create_launch_agent_input_info = [ + field.get("name", "") + for field in res.get("CreateLaunchAgentInputInfoType", {}).get( + "inputFields", [{}] + ) + ] + return self.create_launch_agent_input_info + + @normalize_exceptions + def create_launch_agent( + self, + entity: str, + project: str, + queues: list[str], + agent_config: dict[str, Any], + version: str, + gorilla_agent_support: bool, + ) -> dict: + project_queues = self.get_project_run_queues(entity, project) + if not project_queues: + # create default queue if it doesn't already exist + default = self.create_run_queue( + entity, project, "default", access="PROJECT" + ) + if default is None or default.get("queueID") is None: + raise CommError( + f"Unable to create default queue for {entity}/{project}. No queues for agent to poll" + ) + project_queues = [{"id": default["queueID"], "name": "default"}] + polling_queue_ids = [ + q["id"] for q in project_queues if q["name"] in queues + ] # filter to poll specified queues + if len(polling_queue_ids) != len(queues): + raise CommError( + f"Could not start launch agent: Not all of requested queues ({', '.join(queues)}) found. " + f"Available queues for this project: {','.join([q['name'] for q in project_queues])}" + ) + + if not gorilla_agent_support: + # if gorilla doesn't support launch agents, return a client-generated id + return { + "success": True, + "launchAgentId": None, + } + + hostname = socket.gethostname() + + variable_values = { + "entity": entity, + "project": project, + "queues": polling_queue_ids, + "hostname": hostname, + } + + mutation_params = """ + $entity: String!, + $project: String!, + $queues: [ID!]!, + $hostname: String! + """ + + mutation_input = """ + entityName: $entity, + projectName: $project, + runQueues: $queues, + hostname: $hostname + """ + + if "agentConfig" in self.create_launch_agent_fields_introspection(): + variable_values["agentConfig"] = json.dumps(agent_config) + mutation_params += ", $agentConfig: JSONString" + mutation_input += ", agentConfig: $agentConfig" + if "version" in self.create_launch_agent_fields_introspection(): + variable_values["version"] = version + mutation_params += ", $version: String" + mutation_input += ", version: $version" + + mutation = gql( + f""" + mutation createLaunchAgent( + {mutation_params} + ) {{ + createLaunchAgent( + input: {{ + {mutation_input} + }} + ) {{ + launchAgentId + }} + }} + """ + ) + result: dict = self.gql(mutation, variable_values)["createLaunchAgent"] + return result + + @normalize_exceptions + def update_launch_agent_status( + self, + agent_id: str, + status: str, + gorilla_agent_support: bool, + ) -> dict: + if not gorilla_agent_support: + # if gorilla doesn't support launch agents, this is a no-op + return { + "success": True, + } + + mutation = gql( + """ + mutation updateLaunchAgent($agentId: ID!, $agentStatus: String){ + updateLaunchAgent( + input: { + launchAgentId: $agentId + agentStatus: $agentStatus + } + ) { + success + } + } + """ + ) + variable_values = { + "agentId": agent_id, + "agentStatus": status, + } + result: dict = self.gql(mutation, variable_values)["updateLaunchAgent"] + return result + + @normalize_exceptions + def get_launch_agent(self, agent_id: str, gorilla_agent_support: bool) -> dict: + if not gorilla_agent_support: + return { + "id": None, + "name": "", + "stopPolling": False, + } + query = gql( + """ + query LaunchAgent($agentId: ID!) { + launchAgent(id: $agentId) { + id + name + runQueues + hostname + agentStatus + stopPolling + heartbeatAt + } + } + """ + ) + variable_values = { + "agentId": agent_id, + } + result: dict = self.gql(query, variable_values)["launchAgent"] + return result + + @normalize_exceptions + def upsert_run( + self, + id: str | None = None, + name: str | None = None, + project: str | None = None, + host: str | None = None, + group: str | None = None, + tags: list[str] | None = None, + config: dict | None = None, + description: str | None = None, + entity: str | None = None, + state: str | None = None, + display_name: str | None = None, + notes: str | None = None, + repo: str | None = None, + job_type: str | None = None, + program_path: str | None = None, + commit: str | None = None, + sweep_name: str | None = None, + summary_metrics: str | None = None, + num_retries: int | None = None, + ) -> tuple[dict, bool, list | None]: + """Update a run. + + Args: + id (str, optional): The existing run to update + name (str, optional): The name of the run to create + group (str, optional): Name of the group this run is a part of + project (str, optional): The name of the project + host (str, optional): The name of the host + tags (list, optional): A list of tags to apply to the run + config (dict, optional): The latest config params + description (str, optional): A description of this project + entity (str, optional): The entity to scope this project to. + display_name (str, optional): The display name of this project + notes (str, optional): Notes about this run + repo (str, optional): Url of the program's repository. + state (str, optional): State of the program. + job_type (str, optional): Type of job, e.g 'train'. + program_path (str, optional): Path to the program. + commit (str, optional): The Git SHA to associate the run with + sweep_name (str, optional): The name of the sweep this run is a part of + summary_metrics (str, optional): The JSON summary metrics + num_retries (int, optional): Number of retries + """ + query_string = """ + mutation UpsertBucket( + $id: String, + $name: String, + $project: String, + $entity: String, + $groupName: String, + $description: String, + $displayName: String, + $notes: String, + $commit: String, + $config: JSONString, + $host: String, + $debug: Boolean, + $program: String, + $repo: String, + $jobType: String, + $state: String, + $sweep: String, + $tags: [String!], + $summaryMetrics: JSONString, + ) { + upsertBucket(input: { + id: $id, + name: $name, + groupName: $groupName, + modelName: $project, + entityName: $entity, + description: $description, + displayName: $displayName, + notes: $notes, + config: $config, + commit: $commit, + host: $host, + debug: $debug, + jobProgram: $program, + jobRepo: $repo, + jobType: $jobType, + state: $state, + sweep: $sweep, + tags: $tags, + summaryMetrics: $summaryMetrics, + }) { + bucket { + id + name + displayName + description + config + sweepName + project { + id + name + entity { + id + name + } + } + historyLineCount + } + inserted + _Server_Settings_ + } + } + """ + self.server_settings_introspection() + + server_settings_string = ( + """ + serverSettings { + serverMessages{ + utfText + plainText + htmlText + messageType + messageLevel + } + } + """ + if self._server_settings_type + else "" + ) + + query_string = query_string.replace("_Server_Settings_", server_settings_string) + mutation = gql(query_string) + config_str = json.dumps(config) if config else None + if not description or description.isspace(): + description = None + + kwargs = {} + if num_retries is not None: + kwargs["num_retries"] = num_retries + + variable_values = { + "id": id, + "entity": entity or self.settings("entity"), + "name": name, + "project": project or util.auto_project_name(program_path), + "groupName": group, + "tags": tags, + "description": description, + "config": config_str, + "commit": commit, + "displayName": display_name, + "notes": notes, + "host": None + if self.settings().get("anonymous") in ["allow", "must"] + else host, + "debug": env.is_debug(env=self._environ), + "repo": repo, + "program": program_path, + "jobType": job_type, + "state": state, + "sweep": sweep_name, + "summaryMetrics": summary_metrics, + } + + # retry conflict errors for 2 minutes, default to no_auth_retry + check_retry_fn = util.make_check_retry_fn( + check_fn=util.check_retry_conflict_or_gone, + check_timedelta=datetime.timedelta(minutes=2), + fallback_retry_fn=util.no_retry_auth, + ) + + response = self.gql( + mutation, + variable_values=variable_values, + check_retry_fn=check_retry_fn, + **kwargs, + ) + + run_obj: dict[str, dict[str, dict[str, str]]] = response["upsertBucket"][ + "bucket" + ] + project_obj: dict[str, dict[str, str]] = run_obj.get("project", {}) + if project_obj: + self.set_setting("project", project_obj["name"]) + entity_obj = project_obj.get("entity", {}) + if entity_obj: + self.set_setting("entity", entity_obj["name"]) + + server_messages = None + if self._server_settings_type: + server_messages = ( + response["upsertBucket"] + .get("serverSettings", {}) + .get("serverMessages", []) + ) + + return ( + response["upsertBucket"]["bucket"], + response["upsertBucket"]["inserted"], + server_messages, + ) + + @normalize_exceptions + def rewind_run( + self, + run_name: str, + metric_name: str, + metric_value: float, + program_path: str | None = None, + entity: str | None = None, + project: str | None = None, + num_retries: int | None = None, + ) -> dict: + """Rewinds a run to a previous state. + + Args: + run_name (str): The name of the run to rewind + metric_name (str): The name of the metric to rewind to + metric_value (float): The value of the metric to rewind to + program_path (str, optional): Path to the program + entity (str, optional): The entity to scope this project to + project (str, optional): The name of the project + num_retries (int, optional): Number of retries + + Returns: + A dict with the rewound run + + { + "id": "run_id", + "name": "run_name", + "displayName": "run_display_name", + "description": "run_description", + "config": "stringified_run_config_json", + "sweepName": "run_sweep_name", + "project": { + "id": "project_id", + "name": "project_name", + "entity": { + "id": "entity_id", + "name": "entity_name" + } + }, + "historyLineCount": 100, + } + """ + query_string = """ + mutation RewindRun($runName: String!, $entity: String, $project: String, $metricName: String!, $metricValue: Float!) { + rewindRun(input: {runName: $runName, entityName: $entity, projectName: $project, metricName: $metricName, metricValue: $metricValue}) { + rewoundRun { + id + name + displayName + description + config + sweepName + project { + id + name + entity { + id + name + } + } + historyLineCount + } + } + } + """ + + mutation = gql(query_string) + + kwargs = {} + if num_retries is not None: + kwargs["num_retries"] = num_retries + + variable_values = { + "runName": run_name, + "entity": entity or self.settings("entity"), + "project": project or util.auto_project_name(program_path), + "metricName": metric_name, + "metricValue": metric_value, + } + + # retry conflict errors for 2 minutes, default to no_auth_retry + check_retry_fn = util.make_check_retry_fn( + check_fn=util.check_retry_conflict_or_gone, + check_timedelta=datetime.timedelta(minutes=2), + fallback_retry_fn=util.no_retry_auth, + ) + + response = self.gql( + mutation, + variable_values=variable_values, + check_retry_fn=check_retry_fn, + **kwargs, + ) + + run_obj: dict[str, dict[str, dict[str, str]]] = response.get( + "rewindRun", {} + ).get("rewoundRun", {}) + project_obj: dict[str, dict[str, str]] = run_obj.get("project", {}) + if project_obj: + self.set_setting("project", project_obj["name"]) + entity_obj = project_obj.get("entity", {}) + if entity_obj: + self.set_setting("entity", entity_obj["name"]) + + return run_obj + + @normalize_exceptions + def get_run_info( + self, + entity: str, + project: str, + name: str, + ) -> dict: + query = gql( + """ + query RunInfo($project: String!, $entity: String!, $name: String!) { + project(name: $project, entityName: $entity) { + run(name: $name) { + runInfo { + program + args + os + python + colab + executable + codeSaved + cpuCount + gpuCount + gpu + git { + remote + commit + } + } + } + } + } + """ + ) + variable_values = {"project": project, "entity": entity, "name": name} + res = self.gql(query, variable_values) + if res.get("project") is None: + raise CommError( + f"Error fetching run info for {entity}/{project}/{name}. Check that this project exists and you have access to this entity and project" + ) + elif res["project"].get("run") is None: + raise CommError( + f"Error fetching run info for {entity}/{project}/{name}. Check that this run id exists" + ) + run_info: dict = res["project"]["run"]["runInfo"] + return run_info + + @normalize_exceptions + def get_run_state(self, entity: str, project: str, name: str) -> str: + query = gql( + """ + query RunState( + $project: String!, + $entity: String!, + $name: String!) { + project(name: $project, entityName: $entity) { + run(name: $name) { + state + } + } + } + """ + ) + variable_values = { + "project": project, + "entity": entity, + "name": name, + } + res = self.gql(query, variable_values) + if res.get("project") is None or res["project"].get("run") is None: + raise CommError(f"Error fetching run state for {entity}/{project}/{name}.") + run_state: str = res["project"]["run"]["state"] + return run_state + + @normalize_exceptions + def create_run_files_introspection(self) -> bool: + _, _, mutations = self.server_info_introspection() + return "createRunFiles" in mutations + + @normalize_exceptions + def upload_urls( + self, + project: str, + files: list[str] | dict[str, IO], + run: str | None = None, + entity: str | None = None, + description: str | None = None, + ) -> tuple[str, list[str], dict[str, dict[str, Any]]]: + """Generate temporary resumable upload urls. + + Args: + project (str): The project to download + files (list or dict): The filenames to upload + run (str, optional): The run to upload to + entity (str, optional): The entity to scope this project to. + description (str, optional): description + + Returns: + (run_id, upload_headers, file_info) + run_id: id of run we uploaded files to + upload_headers: A list of headers to use when uploading files. + file_info: A dict of filenames and urls. + { + "run_id": "run_id", + "upload_headers": [""], + "file_info": [ + { "weights.h5": { "uploadUrl": "https://weights.url" } }, + { "model.json": { "uploadUrl": "https://model.json" } } + ] + } + """ + run_name = run or self.current_run_id + assert run_name, "run must be specified" + entity = entity or self.settings("entity") + assert entity, "entity must be specified" + + has_create_run_files_mutation = self.create_run_files_introspection() + if not has_create_run_files_mutation: + return self.legacy_upload_urls(project, files, run, entity, description) + + query = gql( + """ + mutation CreateRunFiles($entity: String!, $project: String!, $run: String!, $files: [String!]!) { + createRunFiles(input: {entityName: $entity, projectName: $project, runName: $run, files: $files}) { + runID + uploadHeaders + files { + name + uploadUrl + } + } + } + """ + ) + + query_result = self.gql( + query, + variable_values={ + "project": project, + "run": run_name, + "entity": entity, + "files": [file for file in files], + }, + ) + + result = query_result["createRunFiles"] + run_id = result["runID"] + if not run_id: + raise CommError( + f"Error uploading files to {entity}/{project}/{run_name}. Check that this project exists and you have access to this entity and project" + ) + file_name_urls = {file["name"]: file for file in result["files"]} + return run_id, result["uploadHeaders"], file_name_urls + + def legacy_upload_urls( + self, + project: str, + files: list[str] | dict[str, IO], + run: str | None = None, + entity: str | None = None, + description: str | None = None, + ) -> tuple[str, list[str], dict[str, dict[str, Any]]]: + """Generate temporary resumable upload urls. + + A new mutation createRunFiles was introduced after 0.15.4. + This function is used to support older versions. + """ + query = gql( + """ + query RunUploadUrls($name: String!, $files: [String]!, $entity: String, $run: String!, $description: String) { + model(name: $name, entityName: $entity) { + bucket(name: $run, desc: $description) { + id + files(names: $files) { + uploadHeaders + edges { + node { + name + url(upload: true) + updatedAt + } + } + } + } + } + } + """ + ) + run_id = run or self.current_run_id + assert run_id, "run must be specified" + entity = entity or self.settings("entity") + query_result = self.gql( + query, + variable_values={ + "name": project, + "run": run_id, + "entity": entity, + "files": [file for file in files], + "description": description, + }, + ) + + run_obj = query_result["model"]["bucket"] + if run_obj: + for file_node in run_obj["files"]["edges"]: + file = file_node["node"] + # we previously used "url" field but now use "uploadUrl" + # replace the "url" field with "uploadUrl for downstream compatibility + if "url" in file and "uploadUrl" not in file: + file["uploadUrl"] = file.pop("url") + + result = { + file["name"]: file for file in self._flatten_edges(run_obj["files"]) + } + return run_obj["id"], run_obj["files"]["uploadHeaders"], result + else: + raise CommError(f"Run does not exist {entity}/{project}/{run_id}.") + + @normalize_exceptions + def download_urls( + self, + project: str, + run: str | None = None, + entity: str | None = None, + ) -> dict[str, dict[str, str]]: + """Generate download urls. + + Args: + project (str): The project to download + run (str): The run to upload to + entity (str, optional): The entity to scope this project to. Defaults to wandb models + + Returns: + A dict of extensions and urls + + { + 'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }, + 'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' } + } + """ + query = gql( + """ + query RunDownloadUrls($name: String!, $entity: String, $run: String!) { + model(name: $name, entityName: $entity) { + bucket(name: $run) { + files { + edges { + node { + name + url + md5 + updatedAt + } + } + } + } + } + } + """ + ) + run = run or self.current_run_id + assert run, "run must be specified" + entity = entity or self.settings("entity") + query_result = self.gql( + query, + variable_values={ + "name": project, + "run": run, + "entity": entity, + }, + ) + if query_result["model"] is None: + raise CommError(f"Run does not exist {entity}/{project}/{run}.") + files = self._flatten_edges(query_result["model"]["bucket"]["files"]) + return {file["name"]: file for file in files if file} + + @normalize_exceptions + def download_url( + self, + project: str, + file_name: str, + run: str | None = None, + entity: str | None = None, + ) -> dict[str, str] | None: + """Generate download urls. + + Args: + project (str): The project to download + file_name (str): The name of the file to download + run (str): The run to upload to + entity (str, optional): The entity to scope this project to. Defaults to wandb models + + Returns: + A dict of extensions and urls + + { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' } + + """ + query = gql( + """ + query RunDownloadUrl($name: String!, $fileName: String!, $entity: String, $run: String!) { + model(name: $name, entityName: $entity) { + bucket(name: $run) { + files(names: [$fileName]) { + edges { + node { + name + url + md5 + updatedAt + } + } + } + } + } + } + """ + ) + run = run or self.current_run_id + assert run, "run must be specified" + query_result = self.gql( + query, + variable_values={ + "name": project, + "run": run, + "fileName": file_name, + "entity": entity or self.settings("entity"), + }, + ) + if query_result["model"]: + files = self._flatten_edges(query_result["model"]["bucket"]["files"]) + return files[0] if len(files) > 0 and files[0].get("updatedAt") else None + else: + return None + + @normalize_exceptions + def download_file(self, url: str) -> tuple[int, requests.Response]: + """Initiate a streaming download. + + Args: + url (str): The url to download + + Returns: + A tuple of the content length and the streaming response + """ + import requests + + check_httpclient_logger_handler() + + http_headers = {} + + auth = None + if self.access_token is not None: + http_headers["Authorization"] = f"Bearer {self.access_token}" + else: + auth = ("api", self.api_key or "") + + response = requests.get( + url, + auth=auth, + headers=http_headers, + stream=True, + ) + response.raise_for_status() + return int(response.headers.get("content-length", 0)), response + + @normalize_exceptions + def download_write_file( + self, + metadata: dict[str, str], + out_dir: str | None = None, + ) -> tuple[str, requests.Response | None]: + """Download a file from a run and write it to wandb/. + + Args: + metadata (obj): The metadata object for the file to download. Comes from Api.download_urls(). + out_dir (str, optional): The directory to write the file to. Defaults to wandb/ + + Returns: + A tuple of the file's local path and the streaming response. The streaming response is None if the file + already existed and was up-to-date. + """ + filename = metadata["name"] + path = os.path.join(out_dir or self.settings("wandb_dir"), filename) + if self.file_current(filename, B64MD5(metadata["md5"])): + return path, None + + size, response = self.download_file(metadata["url"]) + + with util.fsync_open(path, "wb") as file: + for data in response.iter_content(chunk_size=1024): + file.write(data) + + return path, response + + def upload_file_azure( + self, url: str, file: Any, extra_headers: dict[str, str] + ) -> None: + """Upload a file to azure.""" + import requests + from azure.core.exceptions import AzureError # type: ignore + + # Configure the client without retries so our existing logic can handle them + client = self._azure_blob_module.BlobClient.from_blob_url( + url, retry_policy=self._azure_blob_module.LinearRetry(retry_total=0) + ) + try: + if extra_headers.get("Content-MD5") is not None: + md5: bytes | None = base64.b64decode(extra_headers["Content-MD5"]) + else: + md5 = None + content_settings = self._azure_blob_module.ContentSettings( + content_md5=md5, + content_type=extra_headers.get("Content-Type"), + ) + client.upload_blob( + file, + max_concurrency=4, + length=len(file), + overwrite=True, + content_settings=content_settings, + ) + except AzureError as e: + if hasattr(e, "response"): + response = requests.models.Response() + response.status_code = e.response.status_code + response.headers = e.response.headers + raise requests.exceptions.RequestException(e.message, response=response) + else: + raise requests.exceptions.ConnectionError(e.message) + + def upload_multipart_file_chunk( + self, + url: str, + upload_chunk: bytes, + extra_headers: dict[str, str] | None = None, + ) -> requests.Response | None: + """Upload a file chunk to S3 with failure resumption. + + Args: + url: The url to download + upload_chunk: The path to the file you want to upload + extra_headers: A dictionary of extra headers to send with the request + + Returns: + The `requests` library response object + """ + import requests + + check_httpclient_logger_handler() + try: + if env.is_debug(env=self._environ): + logger.debug("upload_file: %s", url) + response = self._upload_file_session.put( + url, data=upload_chunk, headers=extra_headers + ) + if env.is_debug(env=self._environ): + logger.debug("upload_file: %s complete", url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + logger.exception(f"upload_file exception for {url=}") + response_content = e.response.content if e.response is not None else "" + status_code = e.response.status_code if e.response is not None else 0 + # S3 reports retryable request timeouts out-of-band + is_aws_retryable = status_code == 400 and "RequestTimeout" in str( + response_content + ) + # Retry errors from cloud storage or local network issues + if ( + status_code in (308, 408, 409, 429, 500, 502, 503, 504) + or isinstance( + e, + (requests.exceptions.Timeout, requests.exceptions.ConnectionError), + ) + or is_aws_retryable + ): + _e = retry.TransientError(exc=e) + raise _e.with_traceback(sys.exc_info()[2]) + else: + get_sentry().reraise(e) + return response + + def upload_file( + self, + url: str, + file: IO[bytes], + callback: ProgressFn | None = None, + extra_headers: dict[str, str] | None = None, + ) -> requests.Response | None: + """Upload a file to W&B with failure resumption. + + Args: + url: The url to download + file: The path to the file you want to upload + callback: A callback which is passed the number of + bytes uploaded since the last time it was called, used to report progress + extra_headers: A dictionary of extra headers to send with the request + + Returns: + The `requests` library response object + """ + import requests + + check_httpclient_logger_handler() + extra_headers = extra_headers.copy() if extra_headers else {} + response: requests.Response | None = None + progress = Progress(file, callback=callback) + try: + if "x-ms-blob-type" in extra_headers and self._azure_blob_module: + self.upload_file_azure(url, progress, extra_headers) + else: + if "x-ms-blob-type" in extra_headers: + wandb.termwarn( + "Azure uploads over 256MB require the azure SDK, install with pip install wandb[azure]", + repeat=False, + ) + if env.is_debug(env=self._environ): + logger.debug("upload_file: %s", url) + response = self._upload_file_session.put( + url, data=progress, headers=extra_headers + ) + if env.is_debug(env=self._environ): + logger.debug("upload_file: %s complete", url) + response.raise_for_status() + except requests.exceptions.RequestException as e: + logger.exception(f"upload_file exception for {url=}") + response_content = e.response.content if e.response is not None else "" + status_code = e.response.status_code if e.response is not None else 0 + # S3 reports retryable request timeouts out-of-band + is_aws_retryable = ( + "x-amz-meta-md5" in extra_headers + and status_code == 400 + and "RequestTimeout" in str(response_content) + ) + # We need to rewind the file for the next retry (the file passed in is `seek`'ed to 0) + progress.rewind() + # Retry errors from cloud storage or local network issues + if ( + status_code in (308, 408, 409, 429, 500, 502, 503, 504) + or isinstance( + e, + (requests.exceptions.Timeout, requests.exceptions.ConnectionError), + ) + or is_aws_retryable + ): + _e = retry.TransientError(exc=e) + raise _e.with_traceback(sys.exc_info()[2]) + else: + get_sentry().reraise(e) + + return response + + @normalize_exceptions + def register_agent( + self, + host: str, + sweep_id: str | None = None, + project_name: str | None = None, + entity: str | None = None, + ) -> dict: + """Register a new agent. + + Args: + host (str): hostname + sweep_id (str): sweep id + project_name: (str): model that contains sweep + entity: (str): entity that contains sweep + """ + mutation = gql( + """ + mutation CreateAgent( + $host: String! + $projectName: String, + $entityName: String, + $sweep: String! + ) { + createAgent(input: { + host: $host, + projectName: $projectName, + entityName: $entityName, + sweep: $sweep, + }) { + agent { + id + } + } + } + """ + ) + if entity is None: + entity = self.settings("entity") + if project_name is None: + project_name = self.settings("project") + + response = self.gql( + mutation, + variable_values={ + "host": host, + "entityName": entity, + "projectName": project_name, + "sweep": sweep_id, + }, + check_retry_fn=util.no_retry_4xx, + ) + result: dict = response["createAgent"]["agent"] + return result + + def agent_heartbeat( + self, agent_id: str, metrics: dict, run_states: dict + ) -> list[dict[str, Any]]: + """Notify server about agent state, receive commands. + + Args: + agent_id (str): agent_id + metrics (dict): system metrics + run_states (dict): run_id: state mapping + Returns: + list of commands to execute. + """ + mutation = gql( + """ + mutation Heartbeat( + $id: ID!, + $metrics: JSONString, + $runState: JSONString + ) { + agentHeartbeat(input: { + id: $id, + metrics: $metrics, + runState: $runState + }) { + agent { + id + } + commands + } + } + """ + ) + + if agent_id is None: + raise ValueError("Cannot call heartbeat with an unregistered agent.") + + try: + response = self.gql( + mutation, + variable_values={ + "id": agent_id, + "metrics": json.dumps(metrics), + "runState": json.dumps(run_states), + }, + timeout=60, + ) + except Exception: + logger.exception("Error communicating with W&B.") + return [] + else: + result: list[dict[str, Any]] = json.loads( + response["agentHeartbeat"]["commands"] + ) + return result + + @staticmethod + def _validate_config_and_fill_distribution(config: dict) -> dict: + # verify that parameters are well specified. + # TODO(dag): deprecate this in favor of jsonschema validation once + # apiVersion 2 is released and local controller is integrated with + # wandb/client. + + # avoid modifying the original config dict in + # case it is reused outside the calling func + config = deepcopy(config) + + # explicitly cast to dict in case config was passed as a sweepconfig + # sweepconfig does not serialize cleanly to yaml and breaks graphql, + # but it is a subclass of dict, so this conversion is clean + config = dict(config) + + if "parameters" not in config: + # still shows an anaconda warning, but doesn't error + return config + + for parameter_name in config["parameters"]: + parameter = config["parameters"][parameter_name] + if "min" in parameter and "max" in parameter: + if "distribution" not in parameter: + if isinstance(parameter["min"], int) and isinstance( + parameter["max"], int + ): + parameter["distribution"] = "int_uniform" + elif isinstance(parameter["min"], float) and isinstance( + parameter["max"], float + ): + parameter["distribution"] = "uniform" + else: + raise ValueError( + f"Parameter {parameter_name} is ambiguous, please specify bounds as both floats (for a float_" + "uniform distribution) or ints (for an int_uniform distribution)." + ) + return config + + @normalize_exceptions + def upsert_sweep( + self, + config: dict, + controller: str | None = None, + launch_scheduler: str | None = None, + scheduler: str | None = None, + obj_id: str | None = None, + project: str | None = None, + entity: str | None = None, + state: str | None = None, + prior_runs: list[str] | None = None, + display_name: str | None = None, + template_variable_values: dict[str, Any] | None = None, + ) -> tuple[str, list[str]]: + """Upsert a sweep object. + + Args: + config (dict): sweep config (will be converted to yaml) + controller (str): controller to use + launch_scheduler (str): launch scheduler to use + scheduler (str): scheduler to use + obj_id (str): object id + project (str): project to use + entity (str): entity to use + state (str): state + prior_runs (list): IDs of existing runs to add to the sweep + display_name (str): display name for the sweep + template_variable_values (dict): template variable values + """ + import yaml + + project_query = """ + project { + id + name + entity { + id + name + } + } + """ + mutation_str = """ + mutation UpsertSweep( + $id: ID, + $config: String, + $description: String, + $entityName: String, + $projectName: String, + $controller: JSONString, + $scheduler: JSONString, + $state: String, + $priorRunsFilters: JSONString, + $displayName: String, + ) { + upsertSweep(input: { + id: $id, + config: $config, + description: $description, + entityName: $entityName, + projectName: $projectName, + controller: $controller, + scheduler: $scheduler, + state: $state, + priorRunsFilters: $priorRunsFilters, + displayName: $displayName, + }) { + sweep { + name + _PROJECT_QUERY_ + } + configValidationWarnings + } + } + """ + # TODO(jhr): we need protocol versioning to know schema is not supported + # for now we will just try both new and old query + mutation_5 = gql( + mutation_str.replace( + "$controller: JSONString,", + "$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,", + ) + .replace( + "controller: $controller,", + "controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,", + ) + .replace("_PROJECT_QUERY_", project_query) + ) + # launchScheduler was introduced in core v0.14.0 + mutation_4 = gql( + mutation_str.replace( + "$controller: JSONString,", + "$controller: JSONString,$launchScheduler: JSONString,", + ) + .replace( + "controller: $controller,", + "controller: $controller,launchScheduler: $launchScheduler", + ) + .replace("_PROJECT_QUERY_", project_query) + ) + + # mutation 3 maps to backend that can support CLI version of at least 0.10.31 + mutation_3 = gql(mutation_str.replace("_PROJECT_QUERY_", project_query)) + mutation_2 = gql( + mutation_str.replace("_PROJECT_QUERY_", project_query).replace( + "configValidationWarnings", "" + ) + ) + mutation_1 = gql( + mutation_str.replace("_PROJECT_QUERY_", "").replace( + "configValidationWarnings", "" + ) + ) + + # TODO(dag): replace this with a query for protocol versioning + mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1] + + config = self._validate_config_and_fill_distribution(config) + + # Silly, but attr-dicts like Easydicts don't serialize correctly to yaml. + # This sanitizes them with a round trip pass through json to get a regular dict. + class NonOctalStringDumper(yaml.Dumper): + """Prevents strings containing non-octal values like "008" and "009" from being converted to numbers in in the yaml string saved as the sweep config.""" + + def represent_scalar(self, tag, value, style=None): + if ( + tag == "tag:yaml.org,2002:str" + and value.startswith("0") + and len(value) > 1 + ): + return super().represent_scalar(tag, value, style="'") + return super().represent_scalar(tag, value, style) + + config_str = yaml.dump( + json.loads(json.dumps(config)), Dumper=NonOctalStringDumper + ) + filters = None + if prior_runs: + filters = json.dumps({"$or": [{"name": r} for r in prior_runs]}) + + err: Exception | None = None + for mutation in mutations: + try: + variables = { + "id": obj_id, + "config": config_str, + "description": config.get("description"), + "entityName": entity or self.settings("entity"), + "projectName": project or self.settings("project"), + "controller": controller, + "launchScheduler": launch_scheduler, + "templateVariableValues": json.dumps(template_variable_values), + "scheduler": scheduler, + "priorRunsFilters": filters, + "displayName": display_name, + } + if state: + variables["state"] = state + + response = self.gql( + mutation, + variable_values=variables, + check_retry_fn=util.no_retry_4xx, + ) + except UsageError: + raise + except Exception as e: + # graphql schema exception is generic + err = e + continue + err = None + break + if err: + raise err + + sweep: dict[str, dict[str, dict]] = response["upsertSweep"]["sweep"] + project_obj: dict[str, dict] = sweep.get("project", {}) + if project_obj: + self.set_setting("project", project_obj["name"]) + entity_obj: dict = project_obj.get("entity", {}) + if entity_obj: + self.set_setting("entity", entity_obj["name"]) + + warnings = response["upsertSweep"].get("configValidationWarnings", []) + return response["upsertSweep"]["sweep"]["name"], warnings + + @staticmethod + def file_current(fname: str, md5: B64MD5) -> bool: + """Checksum a file and compare the md5 with the known md5.""" + return os.path.isfile(fname) and md5_file_b64(fname) == md5 + + @normalize_exceptions + def pull( + self, project: str, run: str | None = None, entity: str | None = None + ) -> list[requests.Response]: + """Download files from W&B. + + Args: + project (str): The project to download + run (str, optional): The run to upload to + entity (str, optional): The entity to scope this project to. Defaults to wandb models + + Returns: + The `requests` library response object + """ + project, run = self.parse_slug(project, run=run) + urls = self.download_urls(project, run, entity) + responses = [] + for filename in urls: + _, response = self.download_write_file(urls[filename]) + if response: + responses.append(response) + + return responses + + def get_project(self) -> str: + project: str = self.default_settings.get("project") or self.settings("project") + return project + + @normalize_exceptions + def push( + self, + files: list[str] | dict[str, IO], + run: str | None = None, + entity: str | None = None, + project: str | None = None, + description: str | None = None, + force: bool = True, + progress: TextIO | Literal[False] = False, + ) -> list[requests.Response | None]: + """Uploads multiple files to W&B. + + Args: + files (list or dict): The filenames to upload, when dict the values are open files + run (str, optional): The run to upload to + entity (str, optional): The entity to scope this project to. Defaults to wandb models + project (str, optional): The name of the project to upload to. Defaults to the one in settings. + description (str, optional): The description of the changes + force (bool, optional): Whether to prevent push if git has uncommitted changes + progress (callable, or stream): If callable, will be called with (chunk_bytes, + total_bytes) as argument. If TextIO, renders a progress bar to it. + + Returns: + A list of `requests.Response` objects + """ + if project is None: + project = self.get_project() + if project is None: + raise CommError("No project configured.") + if run is None: + run = self.current_run_id + + # TODO(adrian): we use a retriable version of self.upload_file() so + # will never retry self.upload_urls() here. Instead, maybe we should + # make push itself retriable. + _, upload_headers, result = self.upload_urls( + project, + files, + run, + entity, + ) + extra_headers = {} + for upload_header in upload_headers: + key, val = upload_header.split(":", 1) + extra_headers[key] = val + responses = [] + for file_name, file_info in result.items(): + file_url = file_info["uploadUrl"] + + # If the upload URL is relative, fill it in with the base URL, + # since it's a proxied file store like the on-prem VM. + if file_url.startswith("/"): + file_url = f"{self.api_url}{file_url}" + + try: + # To handle Windows paths + # TODO: this doesn't handle absolute paths... + normal_name = os.path.join(*file_name.split("/")) + open_file = ( + files[file_name] + if isinstance(files, dict) + else open(normal_name, "rb") + ) + except OSError: + print(f"{file_name} does not exist") # noqa: T201 + continue + if progress is False: + responses.append( + self.upload_file_retry( + file_info["uploadUrl"], open_file, extra_headers=extra_headers + ) + ) + else: + if callable(progress): + responses.append( # type: ignore + self.upload_file_retry( + file_url, open_file, progress, extra_headers=extra_headers + ) + ) + else: + length = os.fstat(open_file.fileno()).st_size + with click.progressbar( # type: ignore + file=progress, + length=length, + label=f"Uploading file: {file_name}", + fill_char=click.style("&", fg="green"), + ) as bar: + responses.append( + self.upload_file_retry( + file_url, + open_file, + lambda bites, _: bar.update(bites), + extra_headers=extra_headers, + ) + ) + open_file.close() + return responses + + def link_artifact( + self, + client_id: str, + server_id: str, + portfolio_name: str, + entity: str, + project: str, + aliases: Sequence[str], + organization: str, + ) -> dict[str, Any]: + from wandb.sdk.artifacts._validators import is_artifact_registry_project + + template = """ + mutation LinkArtifact( + $artifactPortfolioName: String!, + $entityName: String!, + $projectName: String!, + $aliases: [ArtifactAliasInput!], + ID_TYPE + ) { + linkArtifact(input: { + artifactPortfolioName: $artifactPortfolioName, + entityName: $entityName, + projectName: $projectName, + aliases: $aliases, + ID_VALUE + }) { + versionIndex + } + } + """ + + org_entity = "" + if is_artifact_registry_project(project): + try: + org_entity = self._resolve_org_entity_name( + entity=entity, organization=organization + ) + except ValueError as e: + wandb.termerror(str(e)) + raise + + def replace(a: str, b: str) -> None: + nonlocal template + template = template.replace(a, b) + + if server_id: + replace("ID_TYPE", "$artifactID: ID") + replace("ID_VALUE", "artifactID: $artifactID") + elif client_id: + replace("ID_TYPE", "$clientID: ID") + replace("ID_VALUE", "clientID: $clientID") + + variable_values = { + "clientID": client_id, + "artifactID": server_id, + "artifactPortfolioName": portfolio_name, + "entityName": org_entity or entity, + "projectName": project, + "aliases": [ + {"alias": alias, "artifactCollectionName": portfolio_name} + for alias in aliases + ], + } + + mutation = gql(template) + response = self.gql(mutation, variable_values=variable_values) + link_artifact: dict[str, Any] = response["linkArtifact"] + return link_artifact + + def _resolve_org_entity_name(self, entity: str, organization: str = "") -> str: + # resolveOrgEntityName fetches the portfolio's org entity's name. + # + # The organization parameter may be empty, an org's display name, or an org entity name. + # + # If the server doesn't support fetching the org name of a portfolio, then this returns + # the organization parameter, or an error if it is empty. Otherwise, this returns the + # fetched value after validating that the given organization, if not empty, matches + # either the org's display or entity name. + + if not entity: + raise ValueError("Entity name is required to resolve org entity name.") + + org_fields = self.server_organization_type_introspection() + can_shorthand_org_entity = "orgEntity" in org_fields + if not organization and not can_shorthand_org_entity: + raise ValueError( + "Fetching Registry artifacts without inputting an organization " + "is unavailable for your server version. " + "Please upgrade your server to 0.50.0 or later." + ) + if not can_shorthand_org_entity: + # Server doesn't support fetching org entity to validate, + # assume org entity is correctly inputted + return organization + + orgs_from_entity = self._fetch_orgs_and_org_entities_from_entity(entity) + if organization: + return _match_org_with_fetched_org_entities(organization, orgs_from_entity) + + # If no input organization provided, error if entity belongs to multiple orgs because we + # cannot determine which one to use. + if len(orgs_from_entity) > 1: + raise ValueError( + f"Personal entity {entity!r} belongs to multiple organizations " + "and cannot be used without specifying the organization name. " + "Please specify the organization in the Registry path or use a team entity in the entity settings." + ) + return orgs_from_entity[0].entity_name + + def _fetch_orgs_and_org_entities_from_entity(self, entity: str) -> list[_OrgNames]: + """Fetches organization entity names and display names for a given entity. + + Args: + entity (str): Entity name to lookup. Can be either a personal or team entity. + + Returns: + list[_OrgNames]: list of _OrgNames tuples. (_OrgNames(entity_name, display_name)) + + Raises: + ValueError: If entity is not found, has no organizations, or other validation errors. + """ + query = gql( + """ + query FetchOrgEntityFromEntity($entityName: String!) { + entity(name: $entityName) { + organization { + name + orgEntity { + name + } + } + user { + organizations { + name + orgEntity { + name + } + } + } + } + } + """ + ) + response = self.gql( + query, + variable_values={ + "entityName": entity, + }, + ) + + # Parse organization from response + entity_resp = response["entity"]["organization"] + user_resp = response["entity"]["user"] + # Check for organization under team/org entity type + if entity_resp: + org_name = entity_resp.get("name") + org_entity_name = entity_resp.get("orgEntity") and entity_resp[ + "orgEntity" + ].get("name") + if not org_name or not org_entity_name: + raise ValueError( + f"Unable to find an organization under entity {entity!r}." + ) + return [_OrgNames(entity_name=org_entity_name, display_name=org_name)] + # Check for organization under personal entity type, where a user can belong to multiple orgs + elif user_resp: + orgs = user_resp.get("organizations", []) + org_entities_return = [ + _OrgNames( + entity_name=org["orgEntity"]["name"], display_name=org["name"] + ) + for org in orgs + if org.get("orgEntity") and org.get("name") + ] + if not org_entities_return: + raise ValueError( + f"Unable to resolve an organization associated with personal entity: {entity!r}. " + "This could be because its a personal entity that doesn't belong to any organizations. " + "Please specify the organization in the Registry path or use a team entity in the entity settings." + ) + return org_entities_return + else: + raise ValueError(f"Unable to find an organization under entity {entity!r}.") + + def _construct_use_artifact_query( + self, + artifact_id: str, + entity_name: str | None = None, + project_name: str | None = None, + run_name: str | None = None, + use_as: str | None = None, + artifact_entity_name: str | None = None, + artifact_project_name: str | None = None, + ) -> tuple[Document, dict[str, Any]]: + query_vars = [ + "$entityName: String!", + "$projectName: String!", + "$runName: String!", + "$artifactID: ID!", + ] + query_args = [ + "entityName: $entityName", + "projectName: $projectName", + "runName: $runName", + "artifactID: $artifactID", + ] + + artifact_types = self.server_use_artifact_input_introspection() + if "usedAs" in artifact_types and use_as: + query_vars.append("$usedAs: String") + query_args.append("usedAs: $usedAs") + + entity_name = entity_name or self.settings("entity") + project_name = project_name or self.settings("project") + run_name = run_name or self.current_run_id + + variable_values: dict[str, Any] = { + "entityName": entity_name, + "projectName": project_name, + "runName": run_name, + "artifactID": artifact_id, + "usedAs": use_as, + } + + server_allows_entity_project_information = self._server_supports( + ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION + ) + if server_allows_entity_project_information: + query_vars.extend( + [ + "$artifactEntityName: String", + "$artifactProjectName: String", + ] + ) + query_args.extend( + [ + "artifactEntityName: $artifactEntityName", + "artifactProjectName: $artifactProjectName", + ] + ) + variable_values["artifactEntityName"] = artifact_entity_name + variable_values["artifactProjectName"] = artifact_project_name + + vars_str = ", ".join(query_vars) + args_str = ", ".join(query_args) + + query = gql( + f""" + mutation UseArtifact({vars_str}) {{ + useArtifact(input: {{{args_str}}}) {{ + artifact {{ + id + digest + description + state + createdAt + metadata + }} + }} + }} + """ + ) + return query, variable_values + + def use_artifact( + self, + artifact_id: str, + entity_name: str | None = None, + project_name: str | None = None, + run_name: str | None = None, + artifact_entity_name: str | None = None, + artifact_project_name: str | None = None, + use_as: str | None = None, + ) -> dict[str, Any] | None: + query, variable_values = self._construct_use_artifact_query( + artifact_id, + entity_name, + project_name, + run_name, + use_as, + artifact_entity_name, + artifact_project_name, + ) + response = self.gql(query, variable_values) + + if response["useArtifact"]["artifact"]: + artifact: dict[str, Any] = response["useArtifact"]["artifact"] + return artifact + return None + + # Fetch fields available in backend of Organization type + def server_organization_type_introspection(self) -> list[str]: + query_string = """ + query ProbeServerOrganization { + OrganizationInfoType: __type(name:"Organization") { + fields { + name + } + } + } + """ + + if self.server_organization_type_fields_info is None: + query = gql(query_string) + res = self.gql(query) + input_fields = res.get("OrganizationInfoType", {}).get("fields", [{}]) + self.server_organization_type_fields_info = [ + field["name"] for field in input_fields if "name" in field + ] + + return self.server_organization_type_fields_info + + # Fetch input arguments for the "artifact" endpoint on the "Project" type + def server_project_type_introspection(self) -> bool: + if self.server_supports_enabling_artifact_usage_tracking is not None: + return self.server_supports_enabling_artifact_usage_tracking + + query_string = """ + query ProbeServerProjectInfo { + ProjectInfoType: __type(name:"Project") { + fields { + name + args { + name + } + } + } + } + """ + + query = gql(query_string) + res = self.gql(query) + input_fields = res.get("ProjectInfoType", {}).get("fields", [{}]) + artifact_args: list[dict[str, str]] = next( + ( + field.get("args", []) + for field in input_fields + if field.get("name") == "artifact" + ), + [], + ) + self.server_supports_enabling_artifact_usage_tracking = any( + arg.get("name") == "enableTracking" for arg in artifact_args + ) + + return self.server_supports_enabling_artifact_usage_tracking + + def create_artifact_type( + self, + artifact_type_name: str, + entity_name: str | None = None, + project_name: str | None = None, + description: str | None = None, + ) -> str | None: + mutation = gql( + """ + mutation CreateArtifactType( + $entityName: String!, + $projectName: String!, + $artifactTypeName: String!, + $description: String + ) { + createArtifactType(input: { + entityName: $entityName, + projectName: $projectName, + name: $artifactTypeName, + description: $description + }) { + artifactType { + id + } + } + } + """ + ) + entity_name = entity_name or self.settings("entity") + project_name = project_name or self.settings("project") + response = self.gql( + mutation, + variable_values={ + "entityName": entity_name, + "projectName": project_name, + "artifactTypeName": artifact_type_name, + "description": description, + }, + ) + _id: str | None = response["createArtifactType"]["artifactType"]["id"] + return _id + + def server_artifact_introspection(self) -> list[str]: + query_string = """ + query ProbeServerArtifact { + ArtifactInfoType: __type(name:"Artifact") { + fields { + name + } + } + } + """ + + if self.server_artifact_fields_info is None: + query = gql(query_string) + res = self.gql(query) + input_fields = res.get("ArtifactInfoType", {}).get("fields", [{}]) + self.server_artifact_fields_info = [ + field["name"] for field in input_fields if "name" in field + ] + + return self.server_artifact_fields_info + + def server_create_artifact_introspection(self) -> list[str]: + query_string = """ + query ProbeServerCreateArtifactInput { + CreateArtifactInputInfoType: __type(name:"CreateArtifactInput") { + inputFields{ + name + } + } + } + """ + + if self.server_create_artifact_input_info is None: + query = gql(query_string) + res = self.gql(query) + input_fields = res.get("CreateArtifactInputInfoType", {}).get( + "inputFields", [{}] + ) + self.server_create_artifact_input_info = [ + field["name"] for field in input_fields if "name" in field + ] + + return self.server_create_artifact_input_info + + def _get_create_artifact_mutation( + self, + fields: list, + history_step: int | None, + distributed_id: str | None, + ) -> str: + types = "" + values = "" + + if "historyStep" in fields and history_step not in [0, None]: + types += "$historyStep: Int64!," + values += "historyStep: $historyStep," + + if distributed_id: + types += "$distributedID: String," + values += "distributedID: $distributedID," + + if "clientID" in fields: + types += "$clientID: ID," + values += "clientID: $clientID," + + if "sequenceClientID" in fields: + types += "$sequenceClientID: ID," + values += "sequenceClientID: $sequenceClientID," + + if "enableDigestDeduplication" in fields: + values += "enableDigestDeduplication: true," + + if "ttlDurationSeconds" in fields: + types += "$ttlDurationSeconds: Int64," + values += "ttlDurationSeconds: $ttlDurationSeconds," + + if "tags" in fields: + types += "$tags: [TagInput!]," + values += "tags: $tags," + + query_template = """ + mutation CreateArtifact( + $artifactTypeName: String!, + $artifactCollectionNames: [String!], + $entityName: String!, + $projectName: String!, + $runName: String, + $description: String, + $digest: String!, + $aliases: [ArtifactAliasInput!], + $metadata: JSONString, + _CREATE_ARTIFACT_ADDITIONAL_TYPE_ + ) { + createArtifact(input: { + artifactTypeName: $artifactTypeName, + artifactCollectionNames: $artifactCollectionNames, + entityName: $entityName, + projectName: $projectName, + runName: $runName, + description: $description, + digest: $digest, + digestAlgorithm: MANIFEST_MD5, + aliases: $aliases, + metadata: $metadata, + _CREATE_ARTIFACT_ADDITIONAL_VALUE_ + }) { + artifact { + id + state + artifactSequence { + id + latestArtifact { + id + versionIndex + } + } + } + } + } + """ + + return query_template.replace( + "_CREATE_ARTIFACT_ADDITIONAL_TYPE_", types + ).replace("_CREATE_ARTIFACT_ADDITIONAL_VALUE_", values) + + def create_artifact( + self, + artifact_type_name: str, + artifact_collection_name: str, + digest: str, + client_id: str | None = None, + sequence_client_id: str | None = None, + entity_name: str | None = None, + project_name: str | None = None, + run_name: str | None = None, + description: str | None = None, + metadata: dict | None = None, + ttl_duration_seconds: int | None = None, + aliases: list[dict[str, str]] | None = None, + tags: list[dict[str, str]] | None = None, + distributed_id: str | None = None, + is_user_created: bool | None = False, + history_step: int | None = None, + ) -> tuple[dict, dict]: + fields = self.server_create_artifact_introspection() + artifact_fields = self.server_artifact_introspection() + if ("ttlIsInherited" not in artifact_fields) and ttl_duration_seconds: + wandb.termwarn( + "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL" + ) + # ttlDurationSeconds is only usable if ttlIsInherited is also present + ttl_duration_seconds = None + if ("tags" not in artifact_fields) and tags: + wandb.termwarn( + "Server not compatible with Artifact tags. " + "To use Artifact tags, please upgrade the server to v0.85 or higher." + ) + + query_template = self._get_create_artifact_mutation( + fields, history_step, distributed_id + ) + + entity_name = entity_name or self.settings("entity") + project_name = project_name or self.settings("project") + if not is_user_created: + run_name = run_name or self.current_run_id + + mutation = gql(query_template) + response = self.gql( + mutation, + variable_values={ + "entityName": entity_name, + "projectName": project_name, + "runName": run_name, + "artifactTypeName": artifact_type_name, + "artifactCollectionNames": [artifact_collection_name], + "clientID": client_id, + "sequenceClientID": sequence_client_id, + "digest": digest, + "description": description, + "aliases": list(aliases or []), + "tags": list(tags or []), + "metadata": json.dumps(util.make_safe_for_json(metadata)) + if metadata + else None, + "ttlDurationSeconds": ttl_duration_seconds, + "distributedID": distributed_id, + "historyStep": history_step, + }, + ) + av = response["createArtifact"]["artifact"] + latest = response["createArtifact"]["artifact"]["artifactSequence"].get( + "latestArtifact" + ) + return av, latest + + def commit_artifact(self, artifact_id: str) -> _Response: + mutation = gql( + """ + mutation CommitArtifact( + $artifactID: ID!, + ) { + commitArtifact(input: { + artifactID: $artifactID, + }) { + artifact { + id + digest + } + } + } + """ + ) + + response: _Response = self.gql( + mutation, + variable_values={"artifactID": artifact_id}, + timeout=60, + ) + return response + + def complete_multipart_upload_artifact( + self, + artifact_id: str, + storage_path: str, + completed_parts: list[dict[str, Any]], + upload_id: str | None, + complete_multipart_action: str = "Complete", + ) -> str | None: + mutation = gql( + """ + mutation CompleteMultipartUploadArtifact( + $completeMultipartAction: CompleteMultipartAction!, + $completedParts: [UploadPartsInput!]!, + $artifactID: ID! + $storagePath: String! + $uploadID: String! + ) { + completeMultipartUploadArtifact( + input: { + completeMultipartAction: $completeMultipartAction, + completedParts: $completedParts, + artifactID: $artifactID, + storagePath: $storagePath + uploadID: $uploadID + } + ) { + digest + } + } + """ + ) + response = self.gql( + mutation, + variable_values={ + "completeMultipartAction": complete_multipart_action, + "artifactID": artifact_id, + "storagePath": storage_path, + "completedParts": completed_parts, + "uploadID": upload_id, + }, + ) + digest: str | None = response["completeMultipartUploadArtifact"]["digest"] + return digest + + def create_artifact_manifest( + self, + name: str, + digest: str, + artifact_id: str | None, + base_artifact_id: str | None = None, + entity: str | None = None, + project: str | None = None, + run: str | None = None, + include_upload: bool = True, + type: str = "FULL", + ) -> tuple[str, dict[str, Any]]: + mutation = gql( + """ + mutation CreateArtifactManifest( + $name: String!, + $digest: String!, + $artifactID: ID!, + $baseArtifactID: ID, + $entityName: String!, + $projectName: String!, + $runName: String!, + $includeUpload: Boolean!, + {} + ) {{ + createArtifactManifest(input: {{ + name: $name, + digest: $digest, + artifactID: $artifactID, + baseArtifactID: $baseArtifactID, + entityName: $entityName, + projectName: $projectName, + runName: $runName, + {} + }}) {{ + artifactManifest {{ + id + file {{ + id + name + displayName + uploadUrl @include(if: $includeUpload) + uploadHeaders @include(if: $includeUpload) + }} + }} + }} + }} + """.format( + "$type: ArtifactManifestType = FULL" if type != "FULL" else "", + "type: $type" if type != "FULL" else "", + ) + ) + + entity_name = entity or self.settings("entity") + project_name = project or self.settings("project") + run_name = run or self.current_run_id + + response = self.gql( + mutation, + variable_values={ + "name": name, + "digest": digest, + "artifactID": artifact_id, + "baseArtifactID": base_artifact_id, + "entityName": entity_name, + "projectName": project_name, + "runName": run_name, + "includeUpload": include_upload, + "type": type, + }, + ) + return ( + response["createArtifactManifest"]["artifactManifest"]["id"], + response["createArtifactManifest"]["artifactManifest"]["file"], + ) + + def update_artifact_manifest( + self, + artifact_manifest_id: str, + base_artifact_id: str | None = None, + digest: str | None = None, + include_upload: bool | None = True, + ) -> tuple[str, dict[str, Any]]: + mutation = gql( + """ + mutation UpdateArtifactManifest( + $artifactManifestID: ID!, + $digest: String, + $baseArtifactID: ID, + $includeUpload: Boolean!, + ) { + updateArtifactManifest(input: { + artifactManifestID: $artifactManifestID, + digest: $digest, + baseArtifactID: $baseArtifactID, + }) { + artifactManifest { + id + file { + id + name + displayName + uploadUrl @include(if: $includeUpload) + uploadHeaders @include(if: $includeUpload) + } + } + } + } + """ + ) + + response = self.gql( + mutation, + variable_values={ + "artifactManifestID": artifact_manifest_id, + "digest": digest, + "baseArtifactID": base_artifact_id, + "includeUpload": include_upload, + }, + ) + + return ( + response["updateArtifactManifest"]["artifactManifest"]["id"], + response["updateArtifactManifest"]["artifactManifest"]["file"], + ) + + def update_artifact_metadata( + self, artifact_id: str, metadata: dict[str, Any] + ) -> dict[str, Any]: + """Set the metadata of the given artifact version.""" + mutation = gql( + """ + mutation UpdateArtifact( + $artifactID: ID!, + $metadata: JSONString, + ) { + updateArtifact(input: { + artifactID: $artifactID, + metadata: $metadata, + }) { + artifact { + id + } + } + } + """ + ) + response = self.gql( + mutation, + variable_values={ + "artifactID": artifact_id, + "metadata": json.dumps(metadata), + }, + ) + return response["updateArtifact"]["artifact"] + + def _resolve_client_id( + self, + client_id: str, + ) -> str | None: + if client_id in self._client_id_mapping: + return self._client_id_mapping[client_id] + + query = gql( + """ + query ClientIDMapping($clientID: ID!) { + clientIDMapping(clientID: $clientID) { + serverID + } + } + """ + ) + response = self.gql( + query, + variable_values={ + "clientID": client_id, + }, + ) + server_id = None + if response is not None: + client_id_mapping = response.get("clientIDMapping") + if client_id_mapping is not None: + server_id = client_id_mapping.get("serverID") + if server_id is not None: + self._client_id_mapping[client_id] = server_id + return server_id + + def server_create_artifact_file_spec_input_introspection(self) -> list: + query_string = """ + query ProbeServerCreateArtifactFileSpecInput { + CreateArtifactFileSpecInputInfoType: __type(name:"CreateArtifactFileSpecInput") { + inputFields{ + name + } + } + } + """ + + query = gql(query_string) + res = self.gql(query) + create_artifact_file_spec_input_info = [ + field.get("name", "") + for field in res.get("CreateArtifactFileSpecInputInfoType", {}).get( + "inputFields", [{}] + ) + ] + return create_artifact_file_spec_input_info + + @normalize_exceptions + def create_artifact_files( + self, artifact_files: Iterable[CreateArtifactFileSpecInput] + ) -> Mapping[str, CreateArtifactFilesResponseFile]: + query_template = """ + mutation CreateArtifactFiles( + $storageLayout: ArtifactStorageLayout! + $artifactFiles: [CreateArtifactFileSpecInput!]! + ) { + createArtifactFiles(input: { + artifactFiles: $artifactFiles, + storageLayout: $storageLayout, + }) { + files { + edges { + node { + id + name + displayName + uploadUrl + uploadHeaders + _MULTIPART_UPLOAD_FIELDS_ + artifact { + id + } + } + } + } + } + } + """ + multipart_upload_url_query = """ + storagePath + uploadMultipartUrls { + uploadID + uploadUrlParts { + partNumber + uploadUrl + } + } + """ + + # TODO: we should use constants here from interface/artifacts.py + # but probably don't want the dependency. We're going to remove + # this setting in a future release, so I'm just hard-coding the strings. + storage_layout = "V2" + if env.get_use_v1_artifacts(): + storage_layout = "V1" + + create_artifact_file_spec_input_fields = ( + self.server_create_artifact_file_spec_input_introspection() + ) + if "uploadPartsInput" in create_artifact_file_spec_input_fields: + query_template = query_template.replace( + "_MULTIPART_UPLOAD_FIELDS_", multipart_upload_url_query + ) + else: + query_template = query_template.replace("_MULTIPART_UPLOAD_FIELDS_", "") + + mutation = gql(query_template) + response = self.gql( + mutation, + variable_values={ + "storageLayout": storage_layout, + "artifactFiles": [af for af in artifact_files], + }, + ) + + result = {} + for edge in response["createArtifactFiles"]["files"]["edges"]: + node = edge["node"] + result[node["displayName"]] = node + return result + + @normalize_exceptions + def notify_scriptable_run_alert( + self, + title: str, + text: str, + level: str | None = None, + wait_duration: Number | None = None, + ) -> bool: + mutation = gql( + """ + mutation NotifyScriptableRunAlert( + $entityName: String!, + $projectName: String!, + $runName: String!, + $title: String!, + $text: String!, + $severity: AlertSeverity = INFO, + $waitDuration: Duration + ) { + notifyScriptableRunAlert(input: { + entityName: $entityName, + projectName: $projectName, + runName: $runName, + title: $title, + text: $text, + severity: $severity, + waitDuration: $waitDuration + }) { + success + } + } + """ + ) + + response = self.gql( + mutation, + variable_values={ + "entityName": self.settings("entity"), + "projectName": self.settings("project"), + "runName": self.current_run_id, + "title": title, + "text": text, + "severity": level, + "waitDuration": wait_duration, + }, + ) + success: bool = response["notifyScriptableRunAlert"]["success"] + return success + + def get_sweep_state( + self, sweep: str, entity: str | None = None, project: str | None = None + ) -> SweepState: + state: SweepState = self.sweep( + sweep=sweep, entity=entity, project=project, specs="{}" + )["state"] + return state + + def set_sweep_state( + self, + sweep: str, + state: SweepState, + entity: str | None = None, + project: str | None = None, + ) -> None: + assert state in ("RUNNING", "PAUSED", "CANCELED", "FINISHED") + s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}") + curr_state = s["state"].upper() + if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"): + raise Exception(f"Cannot pause {curr_state.lower()} sweep.") + elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"): + raise Exception(f"Sweep already {curr_state.lower()}.") + sweep_id = s["id"] + mutation = gql( + """ + mutation UpsertSweep( + $id: ID, + $state: String, + $entityName: String, + $projectName: String + ) { + upsertSweep(input: { + id: $id, + state: $state, + entityName: $entityName, + projectName: $projectName + }){ + sweep { + name + } + } + } + """ + ) + self.gql( + mutation, + variable_values={ + "id": sweep_id, + "state": state, + "entityName": entity or self.settings("entity"), + "projectName": project or self.settings("project"), + }, + ) + + def stop_sweep( + self, + sweep: str, + entity: str | None = None, + project: str | None = None, + ) -> None: + """Finish the sweep to stop running new runs and let currently running runs finish.""" + self.set_sweep_state( + sweep=sweep, state="FINISHED", entity=entity, project=project + ) + + def cancel_sweep( + self, + sweep: str, + entity: str | None = None, + project: str | None = None, + ) -> None: + """Cancel the sweep to kill all running runs and stop running new runs.""" + self.set_sweep_state( + sweep=sweep, state="CANCELED", entity=entity, project=project + ) + + def pause_sweep( + self, + sweep: str, + entity: str | None = None, + project: str | None = None, + ) -> None: + """Pause the sweep to temporarily stop running new runs.""" + self.set_sweep_state( + sweep=sweep, state="PAUSED", entity=entity, project=project + ) + + def resume_sweep( + self, + sweep: str, + entity: str | None = None, + project: str | None = None, + ) -> None: + """Resume the sweep to continue running new runs.""" + self.set_sweep_state( + sweep=sweep, state="RUNNING", entity=entity, project=project + ) + + def _status_request(self, url: str, length: int) -> requests.Response: + """Ask google how much we've uploaded.""" + import requests + + check_httpclient_logger_handler() + return requests.put( + url=url, + headers={"Content-Length": "0", "Content-Range": f"bytes */{length}"}, + ) + + def _flatten_edges(self, response: _Response) -> list[dict]: + """Return an array from the nested graphql relay structure.""" + return [node["node"] for node in response["edges"]] + + @normalize_exceptions + def stop_run( + self, + run_id: str, + ) -> bool: + mutation = gql( + """ + mutation stopRun($id: ID!) { + stopRun(input: { + id: $id + }) { + clientMutationId + success + } + } + """ + ) + + response = self.gql( + mutation, + variable_values={ + "id": run_id, + }, + ) + + success: bool = response["stopRun"].get("success") + + return success + + @normalize_exceptions + def create_custom_chart( + self, + entity: str, + name: str, + display_name: str, + spec_type: str, + access: str, + spec: str | Mapping[str, Any], + ) -> dict[str, Any] | None: + if not isinstance(spec, str): + spec = json.dumps(spec) + + mutation = gql( + """ + mutation CreateCustomChart( + $entity: String! + $name: String! + $displayName: String! + $type: String! + $access: String! + $spec: JSONString! + ) { + createCustomChart( + input: { + entity: $entity + name: $name + displayName: $displayName + type: $type + access: $access + spec: $spec + } + ) { + chart { id } + } + } + """ + ) + + variable_values = { + "entity": entity, + "name": name, + "displayName": display_name, + "type": spec_type, + "access": access, + "spec": spec, + } + + result: dict[str, Any] | None = self.gql(mutation, variable_values)[ + "createCustomChart" + ] + return result diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/job_builder.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/job_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..778e27c39d9372ea81b50592d584fe3b7a6b4579 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/job_builder.py @@ -0,0 +1,656 @@ +"""job builder.""" + +import json +import logging +import os +import re +import sys +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + TypedDict, + Union, +) + +import wandb +from wandb.sdk.artifacts._internal_artifact import InternalArtifact +from wandb.sdk.artifacts.artifact import Artifact +from wandb.sdk.data_types._dtypes import TypeRegistry +from wandb.sdk.internal.internal_api import Api +from wandb.sdk.lib.filenames import DIFF_FNAME, METADATA_FNAME, REQUIREMENTS_FNAME +from wandb.util import make_artifact_name_safe + +from .settings_static import SettingsStatic + +_logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from wandb.proto.wandb_internal_pb2 import ArtifactRecord + +FROZEN_REQUIREMENTS_FNAME = "requirements.frozen.txt" +JOB_FNAME = "wandb-job.json" +JOB_ARTIFACT_TYPE = "job" + +LOG_LEVEL = Literal["log", "warn", "error"] + + +class Version: + def __init__(self, major: int, minor: int, patch: int): + self._major = major + self._minor = minor + self._patch = patch + + def __repr__(self) -> str: + return f"{self._major}.{self._minor}.{self._patch}" + + def __lt__(self, other: "Version") -> bool: + if self._major < other._major: + return True + elif self._major == other._major: + if self._minor < other._minor: + return True + elif self._minor == other._minor: + if self._patch < other._patch: + return True + return False + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Version): + return NotImplemented + return ( + self._major == other._major + and self._minor == other._minor + and self._patch == other._patch + ) + + +# Minimum supported wandb version for keys in the source dict of wandb-job.json +SOURCE_KEYS_MIN_SUPPORTED_VERSION = { + "dockerfile": Version(0, 17, 0), + "build_context": Version(0, 17, 0), +} + + +class GitInfo(TypedDict): + remote: str + commit: str + + +class GitSourceDict(TypedDict): + git: GitInfo + entrypoint: List[str] + notebook: bool + build_context: Optional[str] + dockerfile: Optional[str] + + +class ArtifactSourceDict(TypedDict): + artifact: str + entrypoint: List[str] + notebook: bool + build_context: Optional[str] + dockerfile: Optional[str] + + +class ImageSourceDict(TypedDict): + image: str + + +class JobSourceDict(TypedDict, total=False): + _version: str + source_type: str + source: Union[GitSourceDict, ArtifactSourceDict, ImageSourceDict] + input_types: Dict[str, Any] + output_types: Dict[str, Any] + runtime: Optional[str] + services: Dict[str, str] + + +class ArtifactInfoForJob(TypedDict): + id: str + name: str + + +def get_min_supported_for_source_dict( + source: Union[GitSourceDict, ArtifactSourceDict, ImageSourceDict], +) -> Optional[Version]: + """Get the minimum supported wandb version the source dict of wandb-job.json.""" + min_seen = None + for key in source: + new_ver = SOURCE_KEYS_MIN_SUPPORTED_VERSION.get(key) + if new_ver: + if min_seen is None or new_ver < min_seen: + min_seen = new_ver + return min_seen + + +class JobBuilder: + _settings: SettingsStatic + _files_dir: str + _metadatafile_path: Optional[str] + _requirements_path: Optional[str] + _config: Optional[Dict[str, Any]] + _summary: Optional[Dict[str, Any]] + _logged_code_artifact: Optional[ArtifactInfoForJob] + _disable: bool + _partial_source_id: Optional[str] # Partial job source artifact id. + _aliases: List[str] + _job_seq_id: Optional[str] + _job_version_alias: Optional[str] + _is_notebook_run: bool + _verbose: bool + _services: Dict[str, str] + + def __init__( + self, + settings: SettingsStatic, + verbose: bool = False, + *, + files_dir: str, + ): + """Instantiate a JobBuilder. + + Args: + settings: Parameters for the job builder. + In a run, this is the run's settings. + Otherwise, this is a set of undocumented parameters, + all of which should be made explicit like files_dir. + files_dir: The directory where to write files. + In a run, this should be the run's files directory. + """ + self._settings = settings + self._files_dir = files_dir + + self._metadatafile_path = None + self._requirements_path = None + self._config = None + self._summary = None + self._logged_code_artifact = None + self._job_seq_id = None + self._job_version_alias = None + self._disable = settings.disable_job_creation or settings.x_disable_machine_info + self._partial_source_id = None + self._aliases = [] + self._source_type: Optional[Literal["repo", "artifact", "image"]] = ( + settings.job_source # type: ignore[assignment] + ) + self._is_notebook_run = self._get_is_notebook_run() + self._verbose = verbose + self._partial = False + self._services = {} + + def set_config(self, config: Dict[str, Any]) -> None: + self._config = config + + def set_summary(self, summary: Dict[str, Any]) -> None: + self._summary = summary + + @property + def disable(self) -> bool: + return self._disable + + @disable.setter + def disable(self, val: bool) -> None: + self._disable = val + + @property + def input_types(self) -> Dict[str, Any]: + return TypeRegistry.type_of(self._config).to_json() + + @property + def output_types(self) -> Dict[str, Any]: + return TypeRegistry.type_of(self._summary).to_json() + + def set_partial_source_id(self, source_id: str) -> None: + self._partial_source_id = source_id + + def _handle_server_artifact( + self, res: Optional[Dict], artifact: "ArtifactRecord" + ) -> None: + if artifact.type == "job" and res is not None: + try: + if res["artifactSequence"]["latestArtifact"] is None: + self._job_version_alias = "v0" + elif res["artifactSequence"]["latestArtifact"]["id"] == res["id"]: + self._job_version_alias = ( + f"v{res['artifactSequence']['latestArtifact']['versionIndex']}" + ) + else: + self._job_version_alias = f"v{res['artifactSequence']['latestArtifact']['versionIndex'] + 1}" + self._job_seq_id = res["artifactSequence"]["id"] + except KeyError as e: + _logger.info(f"Malformed response from ArtifactSaver.save {e}") + if artifact.type == "code" and res is not None: + self._logged_code_artifact = ArtifactInfoForJob( + { + "id": res["id"], + "name": artifact.name, + } + ) + + def _build_repo_job_source( + self, + program_relpath: str, + metadata: Dict[str, Any], + ) -> Tuple[Optional[GitSourceDict], Optional[str]]: + git_info: Dict[str, str] = metadata.get("git", {}) + remote = git_info.get("remote") + commit = git_info.get("commit") + root = metadata.get("root") + assert remote is not None + assert commit is not None + if self._is_notebook_run: + if not os.path.exists( + os.path.join(os.getcwd(), os.path.basename(program_relpath)) + ): + return None, None + + if root is None or self._settings.x_jupyter_root is None: + _logger.info("target path does not exist, exiting") + return None, None + assert self._settings.x_jupyter_root is not None + # git notebooks set the root to the git root, + # jupyter_root contains the path where the jupyter notebook was started + # program_relpath contains the path from jupyter_root to the file + # full program path here is actually the relpath from the program to the git root + full_program_path = os.path.join( + os.path.relpath(str(self._settings.x_jupyter_root), root), + program_relpath, + ) + full_program_path = os.path.normpath(full_program_path) + # if the notebook server is started above the git repo need to clear all the ..s + if full_program_path.startswith(".."): + split_path = full_program_path.split("/") + count_dots = 0 + for p in split_path: + if p == "..": + count_dots += 1 + full_program_path = "/".join(split_path[2 * count_dots :]) + else: + full_program_path = program_relpath + + entrypoint = self._get_entrypoint(full_program_path, metadata) + # TODO: update executable to a method that supports pex + source: GitSourceDict = { + "git": {"remote": remote, "commit": commit}, + "entrypoint": entrypoint, + "notebook": self._is_notebook_run, + "build_context": metadata.get("build_context"), + "dockerfile": metadata.get("dockerfile"), + } + name = self._make_job_name(f"{remote}_{program_relpath}") + + return source, name + + def _log_if_verbose(self, message: str, level: LOG_LEVEL) -> None: + log_func: Optional[Union[Callable[[Any], None], Callable[[Any], None]]] = None + if level == "log": + _logger.info(message) + log_func = wandb.termlog + elif level == "warn": + _logger.warning(message) + log_func = wandb.termwarn + elif level == "error": + _logger.error(message) + log_func = wandb.termerror + + if self._verbose and log_func is not None: + log_func(message) + + def _build_artifact_job_source( + self, + program_relpath: str, + metadata: Dict[str, Any], + ) -> Tuple[Optional[ArtifactSourceDict], Optional[str]]: + assert isinstance(self._logged_code_artifact, dict) + # TODO: should we just always exit early if the path doesn't exist? + if self._is_notebook_run and not self._is_colab_run(): + full_program_relpath = os.path.relpath(program_relpath, os.getcwd()) + # if the resolved path doesn't exist, then we shouldn't make a job because it will fail + if not os.path.exists(full_program_relpath): + # when users call log code in a notebook the code artifact starts + # at the directory the notebook is in instead of the jupyter core + if not os.path.exists(os.path.basename(program_relpath)): + _logger.info("target path does not exist, exiting") + self._log_if_verbose( + "No program path found when generating artifact job source for a non-colab notebook run. See https://docs.wandb.ai/guides/launch/create-job", + "warn", + ) + return None, None + full_program_relpath = os.path.basename(program_relpath) + else: + full_program_relpath = program_relpath + + entrypoint = self._get_entrypoint(full_program_relpath, metadata) + # TODO: update executable to a method that supports pex + source: ArtifactSourceDict = { + "entrypoint": entrypoint, + "notebook": self._is_notebook_run, + "artifact": f"wandb-artifact://_id/{self._logged_code_artifact['id']}", + "build_context": metadata.get("build_context"), + "dockerfile": metadata.get("dockerfile"), + } + artifact_basename, *_ = self._logged_code_artifact["name"].split(":") + name = self._make_job_name(artifact_basename) + + return source, name + + def _build_image_job_source( + self, metadata: Dict[str, Any] + ) -> Tuple[ImageSourceDict, str]: + image_name = metadata.get("docker") + assert isinstance(image_name, str) + + raw_image_name = image_name + if ":" in image_name: + tag = image_name.split(":")[-1] + + # if tag looks properly formatted, assume its a tag + # regex: alphanumeric and "_" "-" "." + if re.fullmatch(r"([a-zA-Z0-9_\-\.]+)", tag): + raw_image_name = raw_image_name.replace(f":{tag}", "") + self._aliases += [tag] + + source: ImageSourceDict = { + "image": image_name, + } + name = self._make_job_name(raw_image_name) + + return source, name + + def _make_job_name(self, input_str: str) -> str: + """Use job name from settings if provided, else use programmatic name.""" + if self._settings.job_name: + return self._settings.job_name + + return make_artifact_name_safe(f"job-{input_str}") + + def _get_entrypoint( + self, + program_relpath: str, + metadata: Dict[str, Any], + ) -> List[str]: + # if building a partial job from CLI, overwrite entrypoint and notebook + # should already be in metadata from create_job + if self._partial: + if metadata.get("entrypoint"): + entrypoint: List[str] = metadata["entrypoint"] + return entrypoint + # job is being built from a run + entrypoint = [os.path.basename(sys.executable), program_relpath] + + return entrypoint + + def _get_is_notebook_run(self) -> bool: + return hasattr(self._settings, "_jupyter") and bool(self._settings._jupyter) + + def _is_colab_run(self) -> bool: + return hasattr(self._settings, "_colab") and bool(self._settings._colab) + + def _build_job_source( + self, + source_type: str, + program_relpath: Optional[str], + metadata: Dict[str, Any], + ) -> Tuple[ + Union[GitSourceDict, ArtifactSourceDict, ImageSourceDict, None], + Optional[str], + ]: + """Construct a job source dict and name from the current run. + + Args: + source_type (str): The type of source to build the job from. One of + "repo", "artifact", or "image". + """ + source: Union[ + GitSourceDict, + ArtifactSourceDict, + ImageSourceDict, + None, + ] = None + + if source_type == "repo": + source, name = self._build_repo_job_source( + program_relpath or "", + metadata, + ) + elif source_type == "artifact": + source, name = self._build_artifact_job_source( + program_relpath or "", + metadata, + ) + elif source_type == "image" and self._has_image_job_ingredients(metadata): + source, name = self._build_image_job_source(metadata) + else: + source = None + + if source is None: + if source_type: + self._log_if_verbose( + f"Source type is set to '{source_type}' but some required information is missing " + "from the environment. A job will not be created from this run. See " + "https://docs.wandb.ai/guides/launch/create-job", + "warn", + ) + return None, None + + return source, name + + def build( + self, + api: Api, + build_context: Optional[str] = None, + dockerfile: Optional[str] = None, + base_image: Optional[str] = None, + ) -> Optional[Artifact]: + """Build a job artifact from the current run. + + Args: + api (Api): The API object to use to create the job artifact. + build_context (Optional[str]): Path within the job source code to + the image build context. Saved as part of the job for future + builds. + dockerfile (Optional[str]): Path within the build context the + Dockerfile. Saved as part of the job for future builds. + base_image (Optional[str]): The base image used to run the job code. + + Returns: + Optional[Artifact]: The job artifact if it was successfully built, + otherwise None. + """ + _logger.info("Attempting to build job artifact") + + # If a partial job was used, write the input/output types to the metadata + # rather than building a new job version. + if self._partial_source_id is not None: + new_metadata = { + "input_types": {"@wandb.config": self.input_types}, + "output_types": self.output_types, + } + api.update_artifact_metadata( + self._partial_source_id, + new_metadata, + ) + return None + + if not os.path.exists(os.path.join(self._files_dir, REQUIREMENTS_FNAME)): + self._log_if_verbose( + "No requirements.txt found, not creating job artifact. See https://docs.wandb.ai/guides/launch/create-job", + "warn", + ) + return None + metadata = self._handle_metadata_file() + if metadata is None: + self._log_if_verbose( + f"Ensure read and write access to run files dir: {self._files_dir}, control this via the WANDB_DIR env var. See https://docs.wandb.ai/guides/track/environment-variables", + "warn", + ) + return None + + runtime: Optional[str] = metadata.get("python") + # can't build a job without a python version + if runtime is None: + self._log_if_verbose( + "No python version found in metadata, not creating job artifact. " + "See https://docs.wandb.ai/guides/launch/create-job", + "warn", + ) + return None + + input_types = TypeRegistry.type_of(self._config).to_json() + output_types = TypeRegistry.type_of(self._summary).to_json() + + name: Optional[str] = None + source_info: Optional[JobSourceDict] = None + + # configure job from environment + source_type = self._get_source_type(metadata) + if not source_type: + # if source_type is None, then we don't have enough information to build a job + # if the user intended to create a job, warn. + if ( + self._settings.job_name + or self._settings.job_source + or self._source_type + ): + self._log_if_verbose( + "No source type found, not creating job artifact", "warn" + ) + return None + + program_relpath = self._get_program_relpath(source_type, metadata) + if not self._partial and source_type != "image" and not program_relpath: + self._log_if_verbose( + "No program path found, not creating job artifact. " + "See https://docs.wandb.ai/guides/launch/create-job", + "warn", + ) + return None + + source, name = self._build_job_source( + source_type, + program_relpath, + metadata, + ) + if source is None: + return None + + if build_context: + source["build_context"] = build_context # type: ignore[typeddict-item] + if dockerfile: + source["dockerfile"] = dockerfile # type: ignore[typeddict-item] + if base_image: + source["base_image"] = base_image # type: ignore[typeddict-item] + + # Pop any keys that are initialized to None. The current TypedDict + # system for source dicts requires all keys to be present, but we + # don't want to include keys that are None in the final dict. + for key in list(source.keys()): + if source[key] is None: # type: ignore[literal-required] + source.pop(key) # type: ignore[literal-require,misc] + + source_info = { + "_version": str(get_min_supported_for_source_dict(source) or "v0"), + "source_type": source_type, + "source": source, + "input_types": input_types, + "output_types": output_types, + "runtime": runtime, + } + + if self._services: + source_info["services"] = self._services + + assert source_info is not None + assert name is not None + + artifact = InternalArtifact(name, JOB_ARTIFACT_TYPE) + + _logger.info("adding wandb-job metadata file") + with artifact.new_file("wandb-job.json") as f: + f.write(json.dumps(source_info, indent=4)) + + artifact.add_file( + os.path.join(self._files_dir, REQUIREMENTS_FNAME), + name=FROZEN_REQUIREMENTS_FNAME, + ) + + if source_type == "repo": + # add diff + if os.path.exists(os.path.join(self._files_dir, DIFF_FNAME)): + artifact.add_file( + os.path.join(self._files_dir, DIFF_FNAME), + name=DIFF_FNAME, + ) + + return artifact + + def _get_source_type(self, metadata: Dict[str, Any]) -> Optional[str]: + if self._source_type: + return self._source_type + + if self._has_git_job_ingredients(metadata): + _logger.info("is repo sourced job") + return "repo" + + if self._has_artifact_job_ingredients(): + _logger.info("is artifact sourced job") + return "artifact" + + if self._has_image_job_ingredients(metadata): + _logger.info("is image sourced job") + return "image" + + _logger.info("no source found") + return None + + def _get_program_relpath( + self, source_type: str, metadata: Dict[str, Any] + ) -> Optional[str]: + if self._is_notebook_run: + _logger.info("run is notebook based run") + program = metadata.get("program") + + if not program: + self._log_if_verbose( + "Notebook 'program' path not found in metadata. See https://docs.wandb.ai/guides/launch/create-job", + "warn", + ) + + return program + + if source_type == "artifact" or self._settings.job_source == "artifact": + # if the job is set to be an artifact, use relpath guaranteed + # to be correct. 'codePath' uses the root path when in git repo + # fallback to codePath if strictly local relpath not present + return metadata.get("codePathLocal") or metadata.get("codePath") + + return metadata.get("codePath") + + def _handle_metadata_file( + self, + ) -> Optional[Dict]: + if os.path.exists(os.path.join(self._files_dir, METADATA_FNAME)): + with open(os.path.join(self._files_dir, METADATA_FNAME)) as f: + metadata: Dict = json.load(f) + return metadata + + return None + + def _has_git_job_ingredients(self, metadata: Dict[str, Any]) -> bool: + git_info: Dict[str, str] = metadata.get("git", {}) + if self._is_notebook_run and metadata.get("root") is None: + return False + return git_info.get("remote") is not None and git_info.get("commit") is not None + + def _has_artifact_job_ingredients(self) -> bool: + return self._logged_code_artifact is not None + + def _has_image_job_ingredients(self, metadata: Dict[str, Any]) -> bool: + return metadata.get("docker") is not None diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/profiler.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..9973e42984580a3fca6dc6706728a58ed22d4277 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/profiler.py @@ -0,0 +1,79 @@ +"""Integration with pytorch profiler.""" + +import os + +import wandb +from wandb.errors import Error, UsageError +from wandb.sdk.lib import telemetry + +PYTORCH_MODULE = "torch" +PYTORCH_PROFILER_MODULE = "torch.profiler" + + +def torch_trace_handler(): + """Create a trace handler for traces generated by the profiler. + + Provide as an argument to `torch.profiler.profile`: + ```python + torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler()) + ``` + + Calling this function ensures that profiler charts & tables can be viewed in + your run dashboard on wandb.ai. + + Please note that `wandb.init()` must be called before this function is + invoked, and the reinit setting must not be set to "create_new". The PyTorch + (torch) version must also be at least 1.9, in order to ensure stability of + their Profiler API. + + Args: + None + + Returns: + None + + Raises: + UsageError if wandb.init() hasn't been called before profiling. + Error if torch version is less than 1.9.0. + + Examples: + ```python + run = wandb.init() + run.config.id = "profile_code" + + with torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), + on_trace_ready=wandb.profiler.torch_trace_handler(), + record_shapes=True, + with_stack=True, + ) as prof: + for i, batch in enumerate(dataloader): + if step >= 5: + break + train(batch) + prof.step() + ``` + """ + from packaging.version import parse + + torch = wandb.util.get_module(PYTORCH_MODULE, required=True) + torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True) + + if parse(torch.__version__) < parse("1.9.0"): + raise Error( + f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\ + \nVersion of torch currently installed: {torch.__version__}" + ) + + try: + logdir = os.path.join(wandb.run.dir, "pytorch_traces") # type: ignore + os.mkdir(logdir) + except AttributeError: + raise UsageError( + "Please call `wandb.init()` before `wandb.profiler.torch_trace_handler()`" + ) from None + + with telemetry.context() as tel: + tel.feature.torch_profiler_trace = True + + return torch_profiler.tensorboard_trace_handler(logdir) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/progress.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/progress.py new file mode 100644 index 0000000000000000000000000000000000000000..ef145aefe2510f86010ee5377bd0a7d9ecf715b2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/progress.py @@ -0,0 +1,77 @@ +"""progress.""" + +import os +from typing import IO, TYPE_CHECKING, Optional + +from wandb.errors import CommError + +if TYPE_CHECKING: + from typing import Protocol + + class ProgressFn(Protocol): + def __call__(self, new_bytes: int, total_bytes: int) -> None: + pass + + +class Progress: + """A helper class for displaying progress.""" + + ITER_BYTES = 1024 * 1024 + + def __init__( + self, file: IO[bytes], callback: Optional["ProgressFn"] = None + ) -> None: + self.file = file + if callback is None: + + def callback_(new_bytes: int, total_bytes: int) -> None: + pass + + callback = callback_ + + self.callback: ProgressFn = callback + self.bytes_read = 0 + self.len = os.fstat(file.fileno()).st_size + + def read(self, size=-1): + """Read bytes and call the callback.""" + bites = self.file.read(size) + self.bytes_read += len(bites) + if not bites and self.bytes_read < self.len: + # Files shrinking during uploads causes request timeouts. Maybe + # we could avoid those by updating the self.len in real-time, but + # files getting truncated while uploading seems like something + # that shouldn't really be happening anyway. + raise CommError( + f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded." + ) + # Growing files are also likely to be bad, but our code didn't break + # on those in the past, so it's riskier to make that an error now. + self.callback(len(bites), self.bytes_read) + return bites + + def rewind(self) -> None: + self.callback(-self.bytes_read, 0) + self.bytes_read = 0 + self.file.seek(0) + + def __getattr__(self, name): + """Fallback to the file object for attrs not defined here.""" + if hasattr(self.file, name): + return getattr(self.file, name) + else: + raise AttributeError + + def __iter__(self): + return self + + def __next__(self): + bites = self.read(self.ITER_BYTES) + if len(bites) == 0: + raise StopIteration + return bites + + def __len__(self): + return self.len + + next = __next__ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/run.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/run.py new file mode 100644 index 0000000000000000000000000000000000000000..885729a8b0c27eb91f9b0ff006576186be491235 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/run.py @@ -0,0 +1,27 @@ +# +"""InternalRun - Internal-only run object. + +Semi-stubbed run for internal process use. + +""" + +from typing_extensions import override + +from wandb.sdk import wandb_run + + +class InternalRun(wandb_run.Run): + def __init__(self, run_obj, settings, datatypes_cb): + super().__init__(settings=settings) + self._run_obj = run_obj + self._datatypes_cb = datatypes_cb + + @override + def _set_backend(self, backend): + # This type of run object can't have a backend + # or do any writes. + pass + + @override + def _publish_file(self, fname: str) -> None: + self._datatypes_cb(fname) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sample.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..7bffe7562e4b28614bad3e4c1cbf5be5423532fe --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sample.py @@ -0,0 +1,70 @@ +"""sample.""" + +import math + + +class UniformSampleAccumulator: + def __init__(self, min_samples=None): + self._samples = min_samples or 64 + # force power of 2 samples + self._samples = 2 ** int(math.ceil(math.log(self._samples, 2))) + # target oversample by factor of 2 + self._samples2 = self._samples * 2 + # max size of each buffer + self._max = self._samples2 // 2 + self._shift = 0 + self._mask = (1 << self._shift) - 1 + self._buckets = int(math.log(self._samples2, 2)) + self._buckets_bits = int(math.log(self._buckets, 2)) + self._buckets_mask = (1 << self._buckets_bits + 1) - 1 + self._buckets_index = 0 + self._bucket = [] + self._index = [0] * self._buckets + self._count = 0 + self._log2 = [0] + + # pre-allocate buckets + for _ in range(self._buckets): + self._bucket.append([0] * self._max) + # compute integer log2 + self._log2 += [int(math.log(i, 2)) for i in range(1, 2**self._buckets + 1)] + + def _show(self): + print("=" * 20) # noqa: T201 + for b in range(self._buckets): + b = (b + self._buckets_index) % self._buckets + vals = [self._bucket[b][i] for i in range(self._index[b])] + print(f"{b}: {vals}") # noqa: T201 + + def add(self, val): + self._count += 1 + cnt = self._count + if cnt & self._mask: + return + b = cnt >> self._shift + b = self._log2[b] # b = int(math.log(b, 2)) + if b >= self._buckets: + self._index[self._buckets_index] = 0 + self._buckets_index = (self._buckets_index + 1) % self._buckets + self._shift += 1 + self._mask = (self._mask << 1) | 1 + b += self._buckets - 1 + b = (b + self._buckets_index) % self._buckets + self._bucket[b][self._index[b]] = val + self._index[b] += 1 + + def get(self): + full = [] + sampled = [] + # self._show() + for b in range(self._buckets): + max_num = 2**b + b = (b + self._buckets_index) % self._buckets + modb = self._index[b] // max_num + for i in range(self._index[b]): + if not modb or i % modb == 0: + sampled.append(self._bucket[b][i]) + full.append(self._bucket[b][i]) + if len(sampled) < self._samples: + return tuple(full) + return tuple(sampled) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sender.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sender.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a7e786ece80e77030372e3205c5ac278ebf92 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sender.py @@ -0,0 +1,1696 @@ +"""sender.""" + +import contextlib +import glob +import gzip +import json +import logging +import os +import queue +import threading +import time +import traceback +from collections import defaultdict +from datetime import datetime +from queue import Queue +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) + +import requests + +import wandb +from wandb import util +from wandb.analytics import get_sentry +from wandb.errors import CommError, UsageError +from wandb.errors.util import ProtobufErrorHandler +from wandb.filesync.dir_watcher import DirWatcher +from wandb.proto import wandb_internal_pb2 +from wandb.sdk.artifacts.artifact_saver import ArtifactSaver +from wandb.sdk.interface import interface +from wandb.sdk.interface.interface_queue import InterfaceQueue +from wandb.sdk.internal import ( + context, + datastore, + file_stream, + internal_api, + sender_config, +) +from wandb.sdk.internal.file_pusher import FilePusher +from wandb.sdk.internal.job_builder import JobBuilder +from wandb.sdk.internal.settings_static import SettingsStatic +from wandb.sdk.lib import ( + config_util, + filenames, + filesystem, + proto_util, + redirect, + retry, + telemetry, +) +from wandb.sdk.lib.proto_util import message_to_dict + +if TYPE_CHECKING: + from wandb.proto.wandb_internal_pb2 import ( + ArtifactManifest, + ArtifactManifestEntry, + ArtifactRecord, + EnvironmentRecord, + HttpResponse, + LocalInfo, + Record, + Result, + RunExitResult, + RunRecord, + SummaryRecord, + ) + + StreamLiterals = Literal["stdout", "stderr"] + + +logger = logging.getLogger(__name__) + + +_OUTPUT_MIN_CALLBACK_INTERVAL = 2 # seconds + + +def _framework_priority() -> Generator[Tuple[str, str], None, None]: + yield from [ + ("lightgbm", "lightgbm"), + ("catboost", "catboost"), + ("xgboost", "xgboost"), + ("transformers_huggingface", "huggingface"), # backwards compatibility + ("transformers", "huggingface"), + ("pytorch_ignite", "ignite"), # backwards compatibility + ("ignite", "ignite"), + ("pytorch_lightning", "lightning"), + ("fastai", "fastai"), + ("torch", "torch"), + ("keras", "keras"), + ("tensorflow", "tensorflow"), + ("sklearn", "sklearn"), + ] + + +def _manifest_json_from_proto(manifest: "ArtifactManifest") -> Dict: + if manifest.version == 1: + if manifest.manifest_file_path: + contents = {} + with gzip.open(manifest.manifest_file_path, "rt") as f: + for line in f: + entry_json = json.loads(line) + path = entry_json.pop("path") + contents[path] = entry_json + else: + contents = { + content.path: _manifest_entry_from_proto(content) + for content in manifest.contents + } + else: + raise ValueError(f"unknown artifact manifest version: {manifest.version}") + + return { + "version": manifest.version, + "storagePolicy": manifest.storage_policy, + "storagePolicyConfig": { + config.key: json.loads(config.value_json) + for config in manifest.storage_policy_config + }, + "contents": contents, + } + + +def _manifest_entry_from_proto(entry: "ArtifactManifestEntry") -> Dict: + birth_artifact_id = entry.birth_artifact_id if entry.birth_artifact_id else None + return { + "digest": entry.digest, + "birthArtifactID": birth_artifact_id, + "ref": entry.ref if entry.ref else None, + "size": entry.size if entry.size is not None else None, + "local_path": entry.local_path if entry.local_path else None, + "skip_cache": entry.skip_cache, + "extra": {extra.key: json.loads(extra.value_json) for extra in entry.extra}, + } + + +class ResumeState: + resumed: bool + step: int + history: int + events: int + output: int + runtime: float + wandb_runtime: Optional[int] + summary: Optional[Dict[str, Any]] + config: Optional[Dict[str, Any]] + tags: Optional[List[str]] + + def __init__(self) -> None: + self.resumed = False + self.step = 0 + self.history = 0 + self.events = 0 + self.output = 0 + self.runtime = 0 + # wandb_runtime is the canonical runtime (stored in summary._wandb.runtime) + self.wandb_runtime = None + self.summary = None + self.config = None + self.tags = None + + def __str__(self) -> str: + obj = ",".join(map(lambda it: f"{it[0]}={it[1]}", vars(self).items())) + return f"ResumeState({obj})" + + +class _OutputRawStream: + _stopped: threading.Event + _queue: queue.Queue + _emulator: redirect.TerminalEmulator + _writer_thr: threading.Thread + _reader_thr: threading.Thread + + def __init__(self, stream: str, sm: "SendManager"): + self._stopped = threading.Event() + self._queue = queue.Queue() + self._emulator = redirect.TerminalEmulator() + self._writer_thr = threading.Thread( + target=sm._output_raw_writer_thread, + kwargs=dict(stream=stream), + daemon=True, + name=f"OutRawWr-{stream}", + ) + self._reader_thr = threading.Thread( + target=sm._output_raw_reader_thread, + kwargs=dict(stream=stream), + daemon=True, + name=f"OutRawRd-{stream}", + ) + + def start(self) -> None: + self._writer_thr.start() + self._reader_thr.start() + + +class SendManager: + UPDATE_CONFIG_TIME: int = 30 + UPDATE_STATUS_TIME: int = 5 + + _settings: SettingsStatic + _record_q: "Queue[Record]" + _result_q: "Queue[Result]" + _interface: InterfaceQueue + _api_settings: Dict[str, str] + _partial_output: Dict[str, str] + _context_keeper: context.ContextKeeper + + _telemetry_obj: telemetry.TelemetryRecord + _environment_obj: "EnvironmentRecord" + _fs: Optional["file_stream.FileStreamApi"] + _run: Optional["RunRecord"] + _entity: Optional[str] + _project: Optional[str] + _dir_watcher: Optional["DirWatcher"] + _pusher: Optional["FilePusher"] + _record_exit: Optional["Record"] + _exit_result: Optional["RunExitResult"] + _resume_state: ResumeState + _rewind_response: Optional[Dict[str, Any]] + _cached_server_info: Dict[str, Any] + _cached_viewer: Dict[str, Any] + _server_messages: List[Dict[str, Any]] + _ds: Optional[datastore.DataStore] + _output_raw_streams: Dict["StreamLiterals", _OutputRawStream] + _output_raw_file: Optional[filesystem.CRDedupedFile] + _send_record_num: int + _send_end_offset: int + _debounce_config_time: float + _debounce_status_time: float + + def __init__( + self, + settings: SettingsStatic, + record_q: "Queue[Record]", + result_q: "Queue[Result]", + interface: InterfaceQueue, + context_keeper: context.ContextKeeper, + ) -> None: + self._settings = settings + self._record_q = record_q + self._result_q = result_q + self._interface = interface + self._context_keeper = context_keeper + + self._ds = None + self._send_record_num = 0 + self._send_end_offset = 0 + + self._fs = None + self._pusher = None + self._dir_watcher = None + + # State updated by login + self._entity = None + self._flags = None + + # State updated by wandb.init + self._run = None + self._project = None + + # keep track of config from key/val updates + self._consolidated_config = sender_config.ConfigState() + + self._start_time: int = 0 + self._telemetry_obj = telemetry.TelemetryRecord() + self._environment_obj = wandb_internal_pb2.EnvironmentRecord() + self._config_metric_pbdict_list: List[Dict[int, Any]] = [] + self._metadata_summary: Dict[str, Any] = defaultdict() + self._cached_summary: Dict[str, Any] = dict() + self._config_metric_index_dict: Dict[str, int] = {} + self._config_metric_dict: Dict[str, wandb_internal_pb2.MetricRecord] = {} + self._consolidated_summary: Dict[str, Any] = dict() + + self._cached_server_info = dict() + self._cached_viewer = dict() + self._server_messages = [] + + # State updated by resuming + self._resume_state = ResumeState() + self._rewind_response = None + + # State added when run_exit is initiated and complete + self._record_exit = None + self._exit_result = None + + self._api = internal_api.Api( + default_settings=settings, retry_callback=self.retry_callback + ) + self._api_settings = dict() + + # queue filled by retry_callback + self._retry_q: Queue[HttpResponse] = queue.Queue() + + # do we need to debounce? + self._config_needs_debounce: bool = False + + # TODO(jhr): do something better, why do we need to send full lines? + self._partial_output = dict() + + self._exit_code = 0 + + # internal vars for handing raw console output + self._output_raw_streams = dict() + self._output_raw_file = None + + # job builder + self._job_builder = JobBuilder( + settings, + files_dir=settings.files_dir, + ) + + time_now = time.monotonic() + self._debounce_config_time = time_now + self._debounce_status_time = time_now + + @classmethod + def setup( + cls, + root_dir: str, + resume: Union[None, bool, str], + ) -> "SendManager": + """Set up a standalone SendManager. + + Exclusively used in `sync.py`. + """ + files_dir = os.path.join(root_dir, "files") + settings = wandb.Settings( + x_files_dir=files_dir, + root_dir=root_dir, + # _start_time=0, + resume=resume, + # ignore_globs=(), + x_sync=True, + disable_job_creation=False, + x_file_stream_timeout_seconds=0, + ) + record_q: Queue[Record] = queue.Queue() + result_q: Queue[Result] = queue.Queue() + publish_interface = InterfaceQueue(record_q=record_q) + context_keeper = context.ContextKeeper() + return SendManager( + settings=SettingsStatic(dict(settings)), + record_q=record_q, + result_q=result_q, + interface=publish_interface, + context_keeper=context_keeper, + ) + + def __len__(self) -> int: + return self._record_q.qsize() + + def __enter__(self) -> "SendManager": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[traceback.TracebackException], + ) -> Literal[False]: + while self: + data = next(self) + self.send(data) + self.finish() + return False + + def retry_callback(self, status: int, response_text: str) -> None: + response = wandb_internal_pb2.HttpResponse() + response.http_status_code = status + response.http_response_text = response_text + self._retry_q.put(response) + + def send(self, record: "Record") -> None: + self._update_record_num(record.num) + self._update_end_offset(record.control.end_offset) + + record_type = record.WhichOneof("record_type") + assert record_type + handler_str = "send_" + record_type + send_handler = getattr(self, handler_str, None) + # Don't log output to reduce log noise + if record_type not in {"output", "request", "output_raw"}: + logger.debug(f"send: {record_type}") + assert send_handler, f"unknown send handler: {handler_str}" + + context_id = context.context_id_from_record(record) + api_context = self._context_keeper.get(context_id) + try: + self._api.set_local_context(api_context) + send_handler(record) + except retry.RetryCancelledError: + logger.debug(f"Record cancelled: {record_type}") + self._context_keeper.release(context_id) + finally: + self._api.clear_local_context() + + def send_preempting(self, _: "Record") -> None: + if self._fs: + self._fs.enqueue_preempting() + + def send_request_sender_mark(self, _: "Record") -> None: + self._maybe_report_status(always=True) + + def send_request(self, record: "Record") -> None: + request_type = record.request.WhichOneof("request_type") + assert request_type + handler_str = "send_request_" + request_type + send_handler = getattr(self, handler_str, None) + if request_type != "network_status": + logger.debug(f"send_request: {request_type}") + assert send_handler, f"unknown handle: {handler_str}" + send_handler(record) + + def _respond_result(self, result: "Result") -> None: + context_id = context.context_id_from_result(result) + self._context_keeper.release(context_id) + self._result_q.put(result) + + def _flatten(self, dictionary: Dict) -> None: + if isinstance(dictionary, dict): + for k, v in list(dictionary.items()): + if isinstance(v, dict): + self._flatten(v) + dictionary.pop(k) + for k2, v2 in v.items(): + dictionary[k + "." + k2] = v2 + + def _update_record_num(self, record_num: int) -> None: + if not record_num: + return + # Currently how we handle offline mode and syncing is not + # compatible with this assertion due to how the exit record + # is (mis)handled: + # - using "always_send" in offline mode to trigger defer + # state machine + # - skipping the exit record in `wandb sync` mode so that + # it is always executed as the last record + if not self._settings._offline and not self._settings.x_sync: + assert record_num == self._send_record_num + 1 + self._send_record_num = record_num + + def _update_end_offset(self, end_offset: int) -> None: + if not end_offset: + return + self._send_end_offset = end_offset + + def send_request_sender_read(self, record: "Record") -> None: + if self._ds is None: + self._ds = datastore.DataStore() + self._ds.open_for_scan(self._settings.sync_file) + + # TODO(cancel_paused): implement cancel_set logic + # The idea is that there is an active request to cancel a + # message that is being read from the transaction log below + + start_offset = record.request.sender_read.start_offset + final_offset = record.request.sender_read.final_offset + self._ds.seek(start_offset) + + current_end_offset = 0 + while current_end_offset < final_offset: + data = self._ds.scan_data() + assert data + current_end_offset = self._ds.get_offset() + + send_record = wandb_internal_pb2.Record() + send_record.ParseFromString(data) + self._update_end_offset(current_end_offset) + self.send(send_record) + + # make sure we perform deferred operations + self.debounce() + + # make sure that we always update writer for every sended read request + self._maybe_report_status(always=True) + + def send_request_stop_status(self, record: "Record") -> None: + result = proto_util._result_from_record(record) + status_resp = result.response.stop_status_response + status_resp.run_should_stop = False + if self._entity and self._project and self._run and self._run.run_id: + try: + status_resp.run_should_stop = self._api.check_stop_requested( + self._project, self._entity, self._run.run_id + ) + except Exception as e: + logger.warning("Failed to check stop requested status: %s", e) + self._respond_result(result) + + def _maybe_update_config(self, always: bool = False) -> None: + time_now = time.monotonic() + if ( + not always + and time_now < self._debounce_config_time + self.UPDATE_CONFIG_TIME + ): + return + if self._config_needs_debounce: + self._debounce_config() + self._debounce_config_time = time_now + + def _maybe_report_status(self, always: bool = False) -> None: + time_now = time.monotonic() + if ( + not always + and time_now < self._debounce_status_time + self.UPDATE_STATUS_TIME + ): + return + self._debounce_status_time = time_now + + status_report = wandb_internal_pb2.StatusReportRequest( + record_num=self._send_record_num, + sent_offset=self._send_end_offset, + ) + status_time = time.time() + status_report.sync_time.FromMicroseconds(int(status_time * 1e6)) + record = self._interface._make_request(status_report=status_report) + self._interface._publish(record) + + def debounce(self, final: bool = False) -> None: + self._maybe_report_status(always=final) + self._maybe_update_config(always=final) + + def _debounce_config(self) -> None: + config_value_dict = self._config_backend_dict() + # TODO(jhr): check result of upsert_run? + if self._run: + self._api.upsert_run( + name=self._run.run_id, + config=config_value_dict, + **self._api_settings, # type: ignore + ) + self._config_save(config_value_dict) + self._config_needs_debounce = False + + def send_request_network_status(self, record: "Record") -> None: + result = proto_util._result_from_record(record) + status_resp = result.response.network_status_response + while True: + try: + status_resp.network_responses.append(self._retry_q.get_nowait()) + except queue.Empty: + break + except Exception as e: + logger.warning(f"Error emptying retry queue: {e}") + self._respond_result(result) + + def send_exit(self, record: "Record") -> None: + # track where the exit came from + self._record_exit = record + + run_exit = record.exit + self._exit_code = run_exit.exit_code + logger.info("handling exit code: %s", run_exit.exit_code) + runtime = run_exit.runtime + logger.info("handling runtime: %s", run_exit.runtime) + self._metadata_summary["runtime"] = runtime + self._update_summary() + + # We need to give the request queue a chance to empty between states + # so use handle_request_defer as a state machine. + logger.info("send defer") + self._interface.publish_defer() + + def send_final(self, record: "Record") -> None: + pass + + def _flush_run(self) -> None: + pass + + def send_request_status_report(self, record: "Record") -> None: + # todo? this is just a noop to please wandb sync + pass + + def send_request_defer(self, record: "Record") -> None: # noqa: C901 + defer = record.request.defer + state = defer.state + logger.info(f"handle sender defer: {state}") + + def transition_state() -> None: + state = defer.state + 1 + logger.info(f"send defer: {state}") + self._interface.publish_defer(state) + + done = False + if state == defer.BEGIN: + transition_state() + elif state == defer.FLUSH_RUN: + self._flush_run() + transition_state() + elif state == defer.FLUSH_STATS: + # NOTE: this is handled in handler.py:handle_request_defer() + transition_state() + elif state == defer.FLUSH_PARTIAL_HISTORY: + # NOTE: this is handled in handler.py:handle_request_defer() + transition_state() + elif state == defer.FLUSH_TB: + # NOTE: this is handled in handler.py:handle_request_defer() + transition_state() + elif state == defer.FLUSH_SUM: + # NOTE: this is handled in handler.py:handle_request_defer() + transition_state() + elif state == defer.FLUSH_DEBOUNCER: + self.debounce(final=True) + transition_state() + elif state == defer.FLUSH_OUTPUT: + self._output_raw_finish() + transition_state() + elif state == defer.FLUSH_JOB: + self._flush_job() + transition_state() + elif state == defer.FLUSH_DIR: + if self._dir_watcher: + self._dir_watcher.finish() + self._dir_watcher = None + transition_state() + elif state == defer.FLUSH_FP: + if self._pusher: + # FilePusher generates some events for FileStreamApi, so we + # need to wait for pusher to finish before going to the next + # state to ensure that filestream gets all the events that we + # want before telling it to finish up + self._pusher.finish(transition_state) + else: + transition_state() + elif state == defer.JOIN_FP: + if self._pusher: + self._pusher.join() + transition_state() + elif state == defer.FLUSH_FS: + if self._fs: + # TODO(jhr): now is a good time to output pending output lines + self._fs.finish(self._exit_code) + self._fs = None + transition_state() + elif state == defer.FLUSH_FINAL: + self._interface.publish_final() + self._interface.publish_footer() + transition_state() + elif state == defer.END: + done = True + else: + raise AssertionError("unknown state") + + if not done: + return + + exit_result = wandb_internal_pb2.RunExitResult() + + # mark exit done in case we are polling on exit + self._exit_result = exit_result + + # Report response to mailbox + if self._record_exit and self._record_exit.control.mailbox_slot: + result = proto_util._result_from_record(self._record_exit) + result.exit_result.CopyFrom(exit_result) + self._respond_result(result) + + def send_request_poll_exit(self, record: "Record") -> None: + if not record.control.req_resp and not record.control.mailbox_slot: + return + + result = proto_util._result_from_record(record) + + if self._pusher: + _alive, status = self._pusher.get_status() + file_counts = self._pusher.file_counts_by_category() + resp = result.response.poll_exit_response + resp.pusher_stats.uploaded_bytes = status.uploaded_bytes + resp.pusher_stats.total_bytes = status.total_bytes + resp.pusher_stats.deduped_bytes = status.deduped_bytes + resp.file_counts.wandb_count = file_counts.wandb + resp.file_counts.media_count = file_counts.media + resp.file_counts.artifact_count = file_counts.artifact + resp.file_counts.other_count = file_counts.other + + if self._exit_result: + result.response.poll_exit_response.done = True + result.response.poll_exit_response.exit_result.CopyFrom(self._exit_result) + + self._respond_result(result) + + def _setup_resume( + self, run: "RunRecord" + ) -> Optional["wandb_internal_pb2.ErrorInfo"]: + """Queries the backend for a run; fail if the settings are incompatible.""" + if not self._settings.resume: + return None + + # TODO: This causes a race, we need to make the upsert atomically + # only create or update depending on the resume config + # we use the runs entity if set, otherwise fallback to users entity + # todo: ensure entity is not None as self._entity is Optional[str] + entity = run.entity or self._entity + logger.info( + "checking resume status for %s/%s/%s", entity, run.project, run.run_id + ) + resume_status = self._api.run_resume_status( + entity=entity, # type: ignore + project_name=run.project, + name=run.run_id, + ) + # No resume status = run does not exist; No t key in wandbConfig = run exists but hasn't been inited + if not resume_status or '"t":' not in resume_status.get("wandbConfig", ""): + if self._settings.resume == "must": + error = wandb_internal_pb2.ErrorInfo() + error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE + error.message = ( + "You provided an invalid value for the `resume` argument." + f" The value 'must' is not a valid option for resuming a run ({run.run_id}) that has not been initialized." + " Please check your inputs and try again with a valid run ID." + " If you are trying to start a new run, please omit the `resume` argument or use `resume='allow'`." + ) + return error + return None + + # + # handle cases where we have resume_status + # + if self._settings.resume == "never": + error = wandb_internal_pb2.ErrorInfo() + error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE + error.message = ( + "You provided an invalid value for the `resume` argument." + f" The value 'never' is not a valid option for resuming a run ({run.run_id}) that already exists." + " Please check your inputs and try again with a valid value for the `resume` argument." + ) + return error + + history = {} + events = {} + config = {} + summary = {} + try: + events_rt = 0 + history_rt = 0 + history = json.loads(resume_status["historyTail"]) + if history: + history = json.loads(history[-1]) + history_rt = history.get("_runtime", 0) + events = json.loads(resume_status["eventsTail"]) + if events: + events = json.loads(events[-1]) + events_rt = events.get("_runtime", 0) + config = json.loads(resume_status["config"] or "{}") + summary = json.loads(resume_status["summaryMetrics"] or "{}") + new_runtime = summary.get("_wandb", {}).get("runtime", None) + if new_runtime is not None: + self._resume_state.wandb_runtime = new_runtime + tags = resume_status.get("tags") or [] + + except (IndexError, ValueError): + logger.exception("unable to load resume tails") + if self._settings.resume == "must": + error = wandb_internal_pb2.ErrorInfo() + error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE + error.message = f"resume='must' but could not resume ({run.run_id}) " + return error + + # TODO: Do we need to restore config / summary? + # System metrics runtime is usually greater than history + self._resume_state.runtime = max(events_rt, history_rt) + last_step = history.get("_step", 0) + history_line_count = resume_status["historyLineCount"] + self._resume_state.step = last_step + 1 if history_line_count > 0 else last_step + self._resume_state.history = history_line_count + self._resume_state.events = resume_status["eventsLineCount"] + self._resume_state.output = resume_status["logLineCount"] + self._resume_state.config = config + self._resume_state.summary = summary + self._resume_state.tags = tags + self._resume_state.resumed = True + logger.info(f"configured resuming with: {self._resume_state}") + return None + + def _telemetry_get_framework(self) -> str: + """Get telemetry data for internal config structure.""" + # detect framework by checking what is loaded + imports: telemetry.TelemetryImports + if self._telemetry_obj.HasField("imports_finish"): + imports = self._telemetry_obj.imports_finish + elif self._telemetry_obj.HasField("imports_init"): + imports = self._telemetry_obj.imports_init + else: + return "" + framework = next( + (n for f, n in _framework_priority() if getattr(imports, f, False)), "" + ) + return framework + + def _config_backend_dict(self) -> sender_config.BackendConfigDict: + config = self._consolidated_config or sender_config.ConfigState() + return config.to_backend_dict( + telemetry_record=self._telemetry_obj, + framework=self._telemetry_get_framework(), + start_time_millis=self._start_time, + metric_pbdicts=self._config_metric_pbdict_list, + environment_record=self._environment_obj, + ) + + def _config_save( + self, + config_value_dict: sender_config.BackendConfigDict, + ) -> None: + config_path = os.path.join(self._settings.files_dir, "config.yaml") + config_util.save_config_file_from_dict(config_path, config_value_dict) + + def _sync_spell(self) -> None: + """Sync this run with spell.""" + if not self._run: + return + try: + env = os.environ + self._interface.publish_config( + key=("_wandb", "spell_url"), val=env.get("SPELL_RUN_URL") + ) + url = f"{self._api.app_url}/{self._run.entity}/{self._run.project}/runs/{self._run.run_id}" + requests.put( + env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url", + json={"access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url}, + timeout=2, + ) + except requests.RequestException: + pass + # TODO: do something if sync spell is not successful? + + def _setup_fork(self, server_run: dict): + assert self._run + assert self._run.branch_point + first_step = int(self._run.branch_point.value) + 1 + self._resume_state.step = first_step + self._resume_state.history = server_run.get("historyLineCount", 0) + self._run.forked = True + self._run.starting_step = first_step + + def _load_rewind_state(self, run: "RunRecord"): + assert run.branch_point + self._rewind_response = self._api.rewind_run( + run_name=run.run_id, + entity=run.entity or None, + project=run.project or None, + metric_name=run.branch_point.metric, + metric_value=run.branch_point.value, + program_path=self._settings.program or None, + ) + self._resume_state.history = self._rewind_response.get("historyLineCount", 0) + self._resume_state.config = json.loads( + self._rewind_response.get("config", "{}") + ) + + def _install_rewind_state(self): + assert self._run + assert self._run.branch_point + assert self._rewind_response + + first_step = int(self._run.branch_point.value) + 1 + self._resume_state.step = first_step + + # We set the fork flag here because rewind uses the forking + # infrastructure under the hood. Setting `forked` here + # ensures that run._step is properly set in the user process. + self._run.forked = True + self._run.starting_step = first_step + + def _handle_error( + self, + record: "Record", + error: "wandb_internal_pb2.ErrorInfo", + run: "RunRecord", + ) -> None: + if record.control.req_resp or record.control.mailbox_slot: + result = proto_util._result_from_record(record) + result.run_result.run.CopyFrom(run) + result.run_result.error.CopyFrom(error) + self._respond_result(result) + else: + logger.error("Got error in async mode: %s", error.message) + + def send_run(self, record: "Record", file_dir: Optional[str] = None) -> None: + run = record.run + error = None + is_wandb_init = self._run is None + + # save start time of a run + self._start_time = int(run.start_time.ToMicroseconds() // 1e6) + + # update telemetry + if run.telemetry: + self._telemetry_obj.MergeFrom(run.telemetry) + if self._settings.x_sync: + self._telemetry_obj.feature.sync = True + + # build config dict + config_value_dict: Optional[sender_config.BackendConfigDict] = None + if run.config: + self._consolidated_config.update_from_proto(run.config) + config_value_dict = self._config_backend_dict() + self._config_save(config_value_dict) + + do_rewind = run.branch_point.run == run.run_id + do_fork = not do_rewind and run.branch_point.run != "" + do_resume = bool(self._settings.resume) + + num_resume_options_set = sum([do_fork, do_rewind, do_resume]) + if num_resume_options_set > 1: + error = wandb_internal_pb2.ErrorInfo() + error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE + error.message = ( + "Multiple resume options specified. " + "Please specify only one of `fork_from`, `resume`, or `resume_from`." + ) + self._handle_error(record, error, run) + + if is_wandb_init: + # Ensure we have a project to query for status + if run.project == "": + run.project = util.auto_project_name(self._settings.program) + # Only check resume status on `wandb.init` + + if do_resume: + error = self._setup_resume(run) + + elif do_rewind: + error = self._load_rewind_state(run) + + if error is not None: + self._handle_error(record, error, run) + return + + # Save the resumed config + if self._resume_state.config is not None: + self._consolidated_config.merge_resumed_config( + config_util.dict_strip_value_dict(self._resume_state.config) + ) + + config_value_dict = self._config_backend_dict() + self._config_save(config_value_dict) + + # handle empty config + # TODO(jhr): consolidate the 4 ways config is built: + # (passed config, empty config, resume config, send_config) + if not config_value_dict: + config_value_dict = self._config_backend_dict() + self._config_save(config_value_dict) + + try: + server_run = self._init_run(run, config_value_dict) + except (CommError, UsageError) as e: + logger.error(e, exc_info=True) + error = ProtobufErrorHandler.from_exception(e) + self._handle_error(record, error, run) + return + + assert self._run # self._run is configured in _init_run() + + if do_fork: + error = self._setup_fork(server_run) + + if error is not None: + self._handle_error(record, error, run) + return + + if record.control.req_resp or record.control.mailbox_slot: + result = proto_util._result_from_record(record) + # TODO: we could do self._interface.publish_defer(resp) to notify + # the handler not to actually perform server updates for this uuid + # because the user process will send a summary update when we resume + result.run_result.run.CopyFrom(self._run) + self._respond_result(result) + + # Only spin up our threads on the first run message + if is_wandb_init: + self._start_run_threads(file_dir) + else: + logger.info("updated run: %s", self._run.run_id) + + def _update_resume_state(self, is_rewinding: bool, inserted: bool): + assert self._run + if self._resume_state.resumed: + self._run.resumed = True + if self._resume_state.wandb_runtime is not None: + self._run.runtime = self._resume_state.wandb_runtime + elif is_rewinding: + # because is_rewinding is mutually exclusive with self._resume_state.resumed, + # this block will always execute if is_rewinding is set + self._install_rewind_state() + else: + # If the user is not resuming, and we didn't insert on upsert_run then + # it is likely that we are overwriting the run which we might want to + # prevent in the future. This could be a false signal since an upsert_run + # message which gets retried in the network could also show up as not + # inserted. + if not inserted: + # no need to flush this, it will get updated eventually + self._telemetry_obj.feature.maybe_run_overwrite = True + + def _init_run( + self, + run: "RunRecord", + config_dict: Optional[sender_config.BackendConfigDict], + ) -> dict: + # We subtract the previous runs runtime when resuming + start_time = ( + run.start_time.ToMicroseconds() / 1e6 + ) - self._resume_state.runtime + # TODO: we don't check inserted currently, ultimately we should make + # the upsert know the resume state and fail transactionally + + if self._resume_state and self._resume_state.tags and not run.tags: + run.tags.extend(self._resume_state.tags) + + is_rewinding = bool(self._settings.resume_from) + if is_rewinding: + assert self._rewind_response + server_run = self._rewind_response + server_messages = None + inserted = True + else: + server_run, inserted, server_messages = self._api.upsert_run( + name=run.run_id, + entity=run.entity or None, + project=run.project or None, + group=run.run_group or None, + job_type=run.job_type or None, + display_name=run.display_name or None, + notes=run.notes or None, + tags=run.tags[:] or None, + config=config_dict or None, + sweep_name=run.sweep_id or None, + host=run.host or None, + program_path=self._settings.program or None, + repo=run.git.remote_url or None, + commit=run.git.commit or None, + ) + + # TODO: we don't want to create jobs in sweeps, since the + # executable doesn't appear to be consistent + if run.sweep_id: + self._job_builder.disable = True + + self._server_messages = server_messages or [] + self._run = run + + if self._resume_state.resumed and is_rewinding: + # this should not ever be possible to hit, since we check for + # resumption above and raise an error if resumption is specified + # twice. + raise ValueError( + "Cannot attempt to rewind and resume a run - only one of " + "`resume` or `resume_from` can be specified." + ) + + self._update_resume_state(is_rewinding, inserted) + self._run.starting_step = self._resume_state.step + self._run.start_time.FromMicroseconds(int(start_time * 1e6)) + self._run.config.CopyFrom(self._interface._make_config(config_dict)) + if self._resume_state.summary is not None: + self._run.summary.CopyFrom( + self._interface._make_summary_from_dict(self._resume_state.summary) + ) + storage_id = server_run.get("id") + if storage_id: + self._run.storage_id = storage_id + id = server_run.get("name") + if id: + self._api.set_current_run_id(id) + display_name = server_run.get("displayName") + if display_name: + self._run.display_name = display_name + project = server_run.get("project") + # TODO: remove self._api.set_settings, and make self._project a property? + if project: + project_name = project.get("name") + if project_name: + self._run.project = project_name + self._project = project_name + self._api_settings["project"] = project_name + self._api.set_setting("project", project_name) + entity = project.get("entity") + if entity: + entity_name = entity.get("name") + if entity_name: + self._run.entity = entity_name + self._entity = entity_name + self._api_settings["entity"] = entity_name + self._api.set_setting("entity", entity_name) + sweep_id = server_run.get("sweepName") + if sweep_id: + self._run.sweep_id = sweep_id + if os.getenv("SPELL_RUN_URL"): + self._sync_spell() + return server_run + + def _start_run_threads(self, file_dir: Optional[str] = None) -> None: + assert self._run # self._run is configured by caller + self._fs = file_stream.FileStreamApi( + self._api, + self._run.run_id, + self._run.start_time.ToMicroseconds() / 1e6, + timeout=self._settings.x_file_stream_timeout_seconds or 0, + settings=self._api_settings, + ) + # Ensure the streaming polices have the proper offsets + self._fs.set_file_policy("wandb-summary.json", file_stream.SummaryFilePolicy()) + self._fs.set_file_policy( + "wandb-history.jsonl", + file_stream.JsonlFilePolicy(start_chunk_id=self._resume_state.history), + ) + self._fs.set_file_policy( + "wandb-events.jsonl", + file_stream.JsonlFilePolicy(start_chunk_id=self._resume_state.events), + ) + self._fs.set_file_policy( + "output.log", + file_stream.CRDedupeFilePolicy(start_chunk_id=self._resume_state.output), + ) + + # hack to merge run_settings and self._settings object together + # so that fields like entity or project are available to be attached to Sentry events. + run_settings = message_to_dict(self._run) + _settings = dict(self._settings) + _settings.update(run_settings) + get_sentry().configure_scope(tags=_settings, process_context="internal") + + self._fs.start() + self._pusher = FilePusher(self._api, self._fs, settings=self._settings) + self._dir_watcher = DirWatcher(self._settings, self._pusher, file_dir) + logger.info( + "run started: %s with start time %s", + self._run.run_id, + self._run.start_time.ToMicroseconds() / 1e6, + ) + + def _save_history(self, history_dict: Dict[str, Any]) -> None: + if self._fs: + self._fs.push(filenames.HISTORY_FNAME, json.dumps(history_dict)) + + def send_history(self, record: "Record") -> None: + history = record.history + history_dict = proto_util.dict_from_proto_list(history.item) + self._save_history(history_dict) + + def _update_summary_record(self, summary: "SummaryRecord") -> None: + summary_dict = proto_util.dict_from_proto_list(summary.update) + self._cached_summary = summary_dict + self._update_summary() + + def send_summary(self, record: "Record") -> None: + self._update_summary_record(record.summary) + + def send_request_summary_record(self, record: "Record") -> None: + self._update_summary_record(record.request.summary_record.summary) + + def _update_summary(self) -> None: + summary_dict = self._cached_summary.copy() + summary_dict.pop("_wandb", None) + if self._metadata_summary: + summary_dict["_wandb"] = self._metadata_summary + # merge with consolidated summary + self._consolidated_summary.update(summary_dict) + json_summary = json.dumps(self._consolidated_summary) + if self._fs: + self._fs.push(filenames.SUMMARY_FNAME, json_summary) + # TODO(jhr): we should only write this at the end of the script + summary_path = os.path.join(self._settings.files_dir, filenames.SUMMARY_FNAME) + with open(summary_path, "w") as f: + f.write(json_summary) + self._save_file(filesystem.GlobStr(filenames.SUMMARY_FNAME)) + + def send_stats(self, record: "Record") -> None: + stats = record.stats + if stats.stats_type != wandb_internal_pb2.StatsRecord.StatsType.SYSTEM: + return + if not self._fs: + return + if not self._run: + return + now_us = stats.timestamp.ToMicroseconds() + start_us = self._run.start_time.ToMicroseconds() + d = dict() + for item in stats.item: + try: + d[item.key] = json.loads(item.value_json) + except json.JSONDecodeError: + logger.exception("error decoding stats json: %s", item.value_json) + row: Dict[str, Any] = dict(system=d) + self._flatten(row) + row["_wandb"] = True + row["_timestamp"] = now_us / 1e6 + row["_runtime"] = (now_us - start_us) / 1e6 + self._fs.push(filenames.EVENTS_FNAME, json.dumps(row)) + # TODO(jhr): check fs.push results? + + def _output_raw_finish(self) -> None: + for stream, output_raw in self._output_raw_streams.items(): + output_raw._stopped.set() + + # shut down threads + output_raw._writer_thr.join(timeout=5) + if output_raw._writer_thr.is_alive(): + logger.info("processing output...") + output_raw._writer_thr.join() + output_raw._reader_thr.join() + + # flush output buffers and files + self._output_raw_flush(stream) + self._output_raw_streams = {} + if self._output_raw_file: + self._output_raw_file.close() + self._output_raw_file = None + + def _output_raw_writer_thread(self, stream: "StreamLiterals") -> None: + while True: + output_raw = self._output_raw_streams[stream] + if output_raw._queue.empty(): + if output_raw._stopped.is_set(): + return + time.sleep(0.5) + continue + data = [] + while not output_raw._queue.empty(): + data.append(output_raw._queue.get()) + if output_raw._stopped.is_set() and sum(map(len, data)) > 100000: + logger.warning("Terminal output too large. Logging without processing.") + self._output_raw_flush(stream) + for line in data: + self._output_raw_flush(stream, line) + # TODO: lets mark that this happened in telemetry + return + try: + output_raw._emulator.write("".join(data)) + except Exception as e: + logger.warning(f"problem writing to output_raw emulator: {e}") + + def _output_raw_reader_thread(self, stream: "StreamLiterals") -> None: + output_raw = self._output_raw_streams[stream] + while not (output_raw._stopped.is_set() and output_raw._queue.empty()): + self._output_raw_flush(stream) + time.sleep(_OUTPUT_MIN_CALLBACK_INTERVAL) + + def _output_raw_flush( + self, stream: "StreamLiterals", data: Optional[str] = None + ) -> None: + if data is None: + output_raw = self._output_raw_streams[stream] + try: + data = output_raw._emulator.read() + except Exception as e: + logger.warning(f"problem reading from output_raw emulator: {e}") + if data: + self._send_output_line(stream, data) + if self._output_raw_file: + self._output_raw_file.write(data.encode("utf-8")) + + def send_request_python_packages(self, record: "Record") -> None: + import os + + from wandb.sdk.lib.filenames import REQUIREMENTS_FNAME + + installed_packages_list = sorted( + f"{r.name}=={r.version}" for r in record.request.python_packages.package + ) + with open(os.path.join(self._settings.files_dir, REQUIREMENTS_FNAME), "w") as f: + f.write("\n".join(installed_packages_list)) + + def send_output(self, record: "Record") -> None: + if not self._fs: + return + out = record.output + stream: StreamLiterals = "stdout" + if out.output_type == wandb_internal_pb2.OutputRecord.OutputType.STDERR: + stream = "stderr" + line = out.line + self._send_output_line(stream, line) + + def send_output_raw(self, record: "Record") -> None: + if not self._fs: + return + out = record.output_raw + stream: StreamLiterals = "stdout" + if out.output_type == wandb_internal_pb2.OutputRawRecord.OutputType.STDERR: + stream = "stderr" + line = out.line + + output_raw = self._output_raw_streams.get(stream) + if not output_raw: + output_raw = _OutputRawStream(stream=stream, sm=self) + self._output_raw_streams[stream] = output_raw + + # open the console output file shared between both streams + if not self._output_raw_file: + output_log_path = os.path.join( + self._settings.files_dir, filenames.OUTPUT_FNAME + ) + output_raw_file = None + try: + output_raw_file = filesystem.CRDedupedFile( + open(output_log_path, "wb") + ) + except OSError as e: + logger.warning(f"could not open output_raw_file: {e}") + if output_raw_file: + self._output_raw_file = output_raw_file + output_raw.start() + + output_raw._queue.put(line) + + def _send_output_line(self, stream: "StreamLiterals", line: str) -> None: + """Combined writer for raw and non raw output lines. + + This is combined because they are both post emulator. + """ + prepend = "" + if stream == "stderr": + prepend = "ERROR " + if not line.endswith("\n"): + self._partial_output.setdefault(stream, "") + if line.startswith("\r"): + # TODO: maybe we shouldn't just drop this, what if there was some \ns in the partial + # that should probably be the check instead of not line.endswith(\n") + # logger.info(f"Dropping data {self._partial_output[stream]}") + self._partial_output[stream] = "" + self._partial_output[stream] += line + # TODO(jhr): how do we make sure this gets flushed? + # we might need this for other stuff like telemetry + else: + # TODO(jhr): use time from timestamp proto + # TODO(jhr): do we need to make sure we write full lines? + # seems to be some issues with line breaks + cur_time = time.time() + timestamp = datetime.utcfromtimestamp(cur_time).isoformat() + " " + prev_str = self._partial_output.get(stream, "") + line = f"{prepend}{timestamp}{prev_str}{line}" + if self._fs: + self._fs.push(filenames.OUTPUT_FNAME, line) + self._partial_output[stream] = "" + + def _update_config(self) -> None: + self._config_needs_debounce = True + + def send_config(self, record: "Record") -> None: + self._consolidated_config.update_from_proto(record.config) + self._update_config() + + def send_metric(self, record: "Record") -> None: + metric = record.metric + if metric.glob_name: + logger.warning("Seen metric with glob (shouldn't happen)") + return + + # merge or overwrite + old_metric = self._config_metric_dict.get( + metric.name, wandb_internal_pb2.MetricRecord() + ) + if metric._control.overwrite: + old_metric.CopyFrom(metric) + else: + old_metric.MergeFrom(metric) + self._config_metric_dict[metric.name] = old_metric + metric = old_metric + + # convert step_metric to index + if metric.step_metric: + find_step_idx = self._config_metric_index_dict.get(metric.step_metric) + if find_step_idx is not None: + # make a copy of this metric as we will be modifying it + rec = wandb_internal_pb2.Record() + rec.metric.CopyFrom(metric) + metric = rec.metric + + metric.ClearField("step_metric") + metric.step_metric_index = find_step_idx + 1 + + md: Dict[int, Any] = proto_util.proto_encode_to_dict(metric) + find_idx = self._config_metric_index_dict.get(metric.name) + if find_idx is not None: + self._config_metric_pbdict_list[find_idx] = md + else: + next_idx = len(self._config_metric_pbdict_list) + self._config_metric_pbdict_list.append(md) + self._config_metric_index_dict[metric.name] = next_idx + self._debounce_config() + + def _update_telemetry_record(self, telemetry: telemetry.TelemetryRecord) -> None: + self._telemetry_obj.MergeFrom(telemetry) + self._debounce_config() + + def send_telemetry(self, record: "Record") -> None: + self._update_telemetry_record(record.telemetry) + + def send_request_telemetry_record(self, record: "Record") -> None: + self._update_telemetry_record(record.request.telemetry_record.telemetry) + + def _save_file( + self, fname: filesystem.GlobStr, policy: "filesystem.PolicyName" = "end" + ) -> None: + logger.info("saving file %s with policy %s", fname, policy) + if self._dir_watcher: + self._dir_watcher.update_policy(fname, policy) + + def send_files(self, record: "Record") -> None: + files = record.files + for k in files.files: + # TODO(jhr): fix paths with directories + self._save_file( + filesystem.GlobStr(glob.escape(k.path)), + interface.file_enum_to_policy(k.policy), + ) + + def send_header(self, record: "Record") -> None: + pass + + def send_footer(self, record: "Record") -> None: + pass + + def send_tbrecord(self, record: "Record") -> None: + # tbrecord watching threads are handled by handler.py + pass + + def _update_environment_record(self, environment: "EnvironmentRecord") -> None: + self._environment_obj.MergeFrom(environment) + self._debounce_config() + + def send_environment(self, record: "Record") -> None: + """Inject environment info into config and upload as a JSON file.""" + self._update_environment_record(record.environment) + + environment_json = json.dumps(proto_util.message_to_dict(self._environment_obj)) + + with open( + os.path.join(self._settings.files_dir, filenames.METADATA_FNAME), "w" + ) as f: + f.write(environment_json) + + self._save_file(filesystem.GlobStr(filenames.METADATA_FNAME), policy="now") + + def send_request_link_artifact(self, record: "Record") -> None: + if not (record.control.req_resp or record.control.mailbox_slot): + raise ValueError( + f"Expected either `req_resp` or `mailbox_slot`, got: {record.control!r}" + ) + result = proto_util._result_from_record(record) + link = record.request.link_artifact + client_id = link.client_id + server_id = link.server_id + portfolio_name = link.portfolio_name + entity = link.portfolio_entity + project = link.portfolio_project + aliases = link.portfolio_aliases + organization = link.portfolio_organization + logger.debug( + f"link_artifact params - client_id={client_id}, server_id={server_id}, " + f"portfolio_name={portfolio_name}, entity={entity}, project={project}, " + f"organization={organization}" + ) + if (client_id or server_id) and portfolio_name and entity and project: + try: + response = self._api.link_artifact( + client_id, + server_id, + portfolio_name, + entity, + project, + aliases, + organization, + ) + result.response.link_artifact_response.version_index = response[ + "versionIndex" + ] + except Exception as e: + org_or_entity = organization or entity + result.response.link_artifact_response.error_message = ( + f"error linking artifact to " + f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}' + ) + logger.warning("Failed to link artifact to portfolio: %s", e) + self._respond_result(result) + + def send_use_artifact(self, record: "Record") -> None: + """Pretend to send a used artifact. + + This function doesn't actually send anything, it is just used internally. + """ + use = record.use_artifact + + if use.type == "job" and not use.partial.job_name: + self._job_builder.disable = True + elif use.partial.job_name: + # job is partial, let job builder rebuild job, set job source dict + self._job_builder.set_partial_source_id(use.id) + + def send_request_log_artifact(self, record: "Record") -> None: + result = proto_util._result_from_record(record) + artifact = record.request.log_artifact.artifact + history_step = record.request.log_artifact.history_step + + try: + res = self._send_artifact(artifact, history_step) + assert res, "Unable to send artifact" + result.response.log_artifact_response.artifact_id = res["id"] + logger.info(f"logged artifact {artifact.name} - {res}") + except Exception as e: + result.response.log_artifact_response.error_message = ( + f'error logging artifact "{artifact.type}/{artifact.name}": {e}' + ) + + self._respond_result(result) + + def send_artifact(self, record: "Record") -> None: + artifact = record.artifact + try: + res = self._send_artifact(artifact) + logger.info(f"sent artifact {artifact.name} - {res}") + except Exception: + logger.exception( + f'send_artifact: failed for artifact "{artifact.type}/{artifact.name}"' + ) + + def _send_artifact( + self, artifact: "ArtifactRecord", history_step: Optional[int] = None + ) -> Optional[Dict]: + from packaging.version import parse + + assert self._pusher + saver = ArtifactSaver( + api=self._api, + digest=artifact.digest, + manifest_json=_manifest_json_from_proto(artifact.manifest), + file_pusher=self._pusher, + is_user_created=artifact.user_created, + ) + + if artifact.distributed_id: + max_cli_version = self._max_cli_version() + if max_cli_version is None or parse(max_cli_version) < parse("0.10.16"): + logger.warning( + "This W&B Server doesn't support distributed artifacts, " + "have your administrator install wandb/local >= 0.9.37" + ) + return None + + metadata = json.loads(artifact.metadata) if artifact.metadata else None + res = saver.save( + entity=artifact.entity, + project=artifact.project, + type=artifact.type, + name=artifact.name, + client_id=artifact.client_id, + sequence_client_id=artifact.sequence_client_id, + metadata=metadata, + ttl_duration_seconds=artifact.ttl_duration_seconds or None, + description=artifact.description or None, + aliases=artifact.aliases, + tags=artifact.tags, + use_after_commit=artifact.use_after_commit, + distributed_id=artifact.distributed_id, + finalize=artifact.finalize, + incremental=artifact.incremental_beta1, + history_step=history_step, + base_id=artifact.base_id or None, + ) + + self._job_builder._handle_server_artifact(res, artifact) + + if artifact.manifest.manifest_file_path: + with contextlib.suppress(FileNotFoundError): + os.remove(artifact.manifest.manifest_file_path) + return res + + def send_alert(self, record: "Record") -> None: + from packaging.version import parse + + alert = record.alert + max_cli_version = self._max_cli_version() + if max_cli_version is None or parse(max_cli_version) < parse("0.10.9"): + logger.warning( + "This W&B server doesn't support alerts, " + "have your administrator install wandb/local >= 0.9.31" + ) + else: + try: + self._api.notify_scriptable_run_alert( + title=alert.title, + text=alert.text, + level=alert.level, + wait_duration=alert.wait_duration, + ) + except Exception: + logger.exception(f"send_alert: failed for alert {alert.title!r}") + + def finish(self) -> None: + logger.info("shutting down sender") + # if self._tb_watcher: + # self._tb_watcher.finish() + self._output_raw_finish() + if self._dir_watcher: + self._dir_watcher.finish() + self._dir_watcher = None + if self._pusher: + self._pusher.finish() + self._pusher.join() + self._pusher = None + if self._fs: + self._fs.finish(self._exit_code) + self._fs = None + get_sentry().end_session() + + def _max_cli_version(self) -> Optional[str]: + server_info = self.get_server_info() + max_cli_version = server_info.get("cliVersionInfo", {}).get( + "max_cli_version", None + ) + if not isinstance(max_cli_version, str): + return None + return max_cli_version + + def get_viewer_server_info(self) -> None: + if self._cached_server_info and self._cached_viewer: + return + self._cached_viewer, self._cached_server_info = self._api.viewer_server_info() + + def get_viewer_info(self) -> Dict[str, Any]: + if not self._cached_viewer: + self.get_viewer_server_info() + return self._cached_viewer + + def get_server_info(self) -> Dict[str, Any]: + if not self._cached_server_info: + self.get_viewer_server_info() + return self._cached_server_info + + def get_local_info(self) -> "LocalInfo": + """Queries the server to get the local version information. + + First, we perform an introspection, if it returns empty we deduce that the + docker image is out-of-date. Otherwise, we use the returned values to deduce the + state of the local server. + """ + local_info = wandb_internal_pb2.LocalInfo() + if self._settings._offline: + local_info.out_of_date = False + return local_info + + latest_local_version = "latest" + + # Assuming the query is successful if the result is empty it indicates that + # the backend is out of date since it doesn't have the desired field + server_info = self.get_server_info() + latest_local_version_info = server_info.get("latestLocalVersionInfo", {}) + if latest_local_version_info is None: + local_info.out_of_date = False + else: + local_info.out_of_date = latest_local_version_info.get("outOfDate", True) + local_info.version = latest_local_version_info.get( + "latestVersionString", latest_local_version + ) + return local_info + + def _flush_job(self) -> None: + if self._job_builder.disable or self._settings._offline: + return + self._job_builder.set_config(self._consolidated_config.non_internal_config()) + summary_dict = self._cached_summary.copy() + summary_dict.pop("_wandb", None) + self._job_builder.set_summary(summary_dict) + + artifact = self._job_builder.build(api=self._api) + if artifact is not None and self._run is not None: + proto_artifact = self._interface._make_artifact(artifact) + proto_artifact.run_id = self._run.run_id + proto_artifact.project = self._run.project + proto_artifact.entity = self._run.entity + # TODO: this should be removed when the latest tag is handled + # by the backend (WB-12116) + proto_artifact.aliases.append("latest") + # add docker image tag + for alias in self._job_builder._aliases: + proto_artifact.aliases.append(alias) + + proto_artifact.user_created = True + proto_artifact.use_after_commit = True + proto_artifact.finalize = True + + self._interface._publish_artifact(proto_artifact) + + def __next__(self) -> "Record": + return self._record_q.get(block=True) + + next = __next__ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sender_config.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sender_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc478958173637fefba1c620977af1d52faad0d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/sender_config.py @@ -0,0 +1,203 @@ +import json +from typing import Any, Dict, NewType, Optional, Sequence + +from wandb.proto import wandb_internal_pb2 +from wandb.sdk.lib import proto_util, telemetry + +BackendConfigDict = NewType("BackendConfigDict", Dict[str, Any]) +"""Run config dictionary in the format used by the backend.""" + +_WANDB_INTERNAL_KEY = "_wandb" + + +class ConfigState: + """The configuration of a run.""" + + def __init__(self, tree: Optional[Dict[str, Any]] = None) -> None: + self._tree: Dict[str, Any] = tree or {} + """A tree with string-valued nodes and JSON leaves. + + Leaves are Python objects that are valid JSON values: + + * Primitives like strings and numbers + * Dictionaries from strings to JSON objects + * Lists of JSON objects + """ + + def non_internal_config(self) -> Dict[str, Any]: + """Returns the config settings minus "_wandb".""" + return {k: v for k, v in self._tree.items() if k != _WANDB_INTERNAL_KEY} + + def update_from_proto( + self, + config_record: wandb_internal_pb2.ConfigRecord, + ) -> None: + """Applies update and remove commands.""" + for config_item in config_record.update: + self._update_at_path( + _key_path(config_item), + json.loads(config_item.value_json), + ) + + for config_item in config_record.remove: + self._delete_at_path(_key_path(config_item)) + + def merge_resumed_config(self, old_config_tree: Dict[str, Any]) -> None: + """Merges the config from a run that's being resumed.""" + # Add any top-level keys that aren't already set. + self._add_unset_keys_from_subtree(old_config_tree, []) + + # When resuming a run, we want to ensure the some of the old configs keys + # are maintained. So we have this logic here to add back + # any keys that were in the old config but not in the new config + for key in ["viz", "visualize", "mask/class_labels"]: + self._add_unset_keys_from_subtree( + old_config_tree, + [_WANDB_INTERNAL_KEY, key], + ) + + def _add_unset_keys_from_subtree( + self, + old_config_tree: Dict[str, Any], + path: Sequence[str], + ) -> None: + """Uses the given subtree for keys that aren't already set.""" + old_subtree = _subtree(old_config_tree, path, create=False) + if not old_subtree: + return + + new_subtree = _subtree(self._tree, path, create=True) + assert new_subtree is not None + + for key, value in old_subtree.items(): + if key not in new_subtree: + new_subtree[key] = value + + def to_backend_dict( + self, + telemetry_record: telemetry.TelemetryRecord, + framework: Optional[str], + start_time_millis: int, + metric_pbdicts: Sequence[Dict[int, Any]], + environment_record: wandb_internal_pb2.EnvironmentRecord, + ) -> BackendConfigDict: + """Returns a dictionary representation expected by the backend. + + The backend expects the configuration in a specific format, and the + config is also used to store additional metadata about the run. + + Args: + telemetry_record: Telemetry information to insert. + framework: The detected framework used in the run (e.g. TensorFlow). + start_time_millis: The run's start time in Unix milliseconds. + metric_pbdicts: List of dict representations of metric protobuffers. + """ + backend_dict = self._tree.copy() + wandb_internal = backend_dict.setdefault(_WANDB_INTERNAL_KEY, {}) + + ################################################### + # Telemetry information + ################################################### + py_version = telemetry_record.python_version + if py_version: + wandb_internal["python_version"] = py_version + + cli_version = telemetry_record.cli_version + if cli_version: + wandb_internal["cli_version"] = cli_version + + if framework: + wandb_internal["framework"] = framework + + huggingface_version = telemetry_record.huggingface_version + if huggingface_version: + wandb_internal["huggingface_version"] = huggingface_version + + wandb_internal["is_jupyter_run"] = telemetry_record.env.jupyter + wandb_internal["is_kaggle_kernel"] = telemetry_record.env.kaggle + wandb_internal["start_time"] = start_time_millis + + # The full telemetry record. + wandb_internal["t"] = proto_util.proto_encode_to_dict(telemetry_record) + + ################################################### + # Metrics + ################################################### + if metric_pbdicts: + wandb_internal["m"] = metric_pbdicts + + ################################################### + # Environment + ################################################### + writer_id = environment_record.writer_id + if writer_id: + environment_dict = proto_util.message_to_dict(environment_record) + wandb_internal["e"] = {writer_id: environment_dict} + + return BackendConfigDict( + { + key: { + # Configurations can be stored in a hand-written YAML file, + # and users can add descriptions to their hyperparameters + # there. However, we don't support a way to set descriptions + # via code, so this is always None. + "desc": None, + "value": value, + } + for key, value in self._tree.items() + } + ) + + def _update_at_path( + self, + key_path: Sequence[str], + value: Any, + ) -> None: + """Sets the value at the path in the config tree.""" + subtree = _subtree(self._tree, key_path[:-1], create=True) + assert subtree is not None + + subtree[key_path[-1]] = value + + def _delete_at_path( + self, + key_path: Sequence[str], + ) -> None: + """Removes the subtree at the path in the config tree.""" + subtree = _subtree(self._tree, key_path[:-1], create=False) + if subtree: + del subtree[key_path[-1]] + + +def _key_path(config_item: wandb_internal_pb2.ConfigItem) -> Sequence[str]: + """Returns the key path referenced by the config item.""" + if config_item.nested_key: + return config_item.nested_key + elif config_item.key: + return [config_item.key] + else: + raise AssertionError( + "Invalid ConfigItem: either key or nested_key must be set", + ) + + +def _subtree( + tree: Dict[str, Any], + key_path: Sequence[str], + *, + create: bool = False, +) -> Optional[Dict[str, Any]]: + """Returns a subtree at the given path.""" + for key in key_path: + subtree = tree.get(key) + + if not subtree: + if create: + subtree = {} + tree[key] = subtree + else: + return None + + tree = subtree + + return tree diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/settings_static.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/settings_static.py new file mode 100644 index 0000000000000000000000000000000000000000..009f6db1d552170b8740112965b55101f8b1109d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/settings_static.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import Any, Iterable + +from wandb.sdk.wandb_settings import Settings + + +class SettingsStatic(Settings): + """A readonly object that wraps a protobuf Settings message. + + Implements the mapping protocol, so you can access settings as + attributes or items. + """ + + def __init__(self, data: dict[str, Any]) -> None: + super().__init__(**data) + + def __setattr__(self, name: str, value: object) -> None: + raise AttributeError("Error: SettingsStatic is a readonly object") + + def __setitem__(self, key: str, val: object) -> None: + raise AttributeError("Error: SettingsStatic is a readonly object") + + def keys(self) -> Iterable[str]: + return self.__dict__.keys() + + def __getitem__(self, key: str) -> Any: + return self.__dict__[key] + + def __getattr__(self, name: str) -> Any: + try: + return self.__dict__[name] + except KeyError: + raise AttributeError(f"SettingsStatic has no attribute {name}") + + def __str__(self) -> str: + return str(self.__dict__) + + def __contains__(self, key: str) -> bool: + return key in self.__dict__ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/internal/tb_watcher.py b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/tb_watcher.py new file mode 100644 index 0000000000000000000000000000000000000000..4fe7cfca8177cad1352b0e2a7b25a827d533fa51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/internal/tb_watcher.py @@ -0,0 +1,520 @@ +"""tensorboard watcher.""" + +import glob +import logging +import os +import queue +import socket +import sys +import threading +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import wandb +from wandb import util +from wandb.plot import CustomChart +from wandb.sdk.lib import filesystem + +from . import run as internal_run + +if TYPE_CHECKING: + from queue import PriorityQueue + + from tensorboard.backend.event_processing.event_file_loader import EventFileLoader + from tensorboard.compat.proto.event_pb2 import ProtoEvent + + from wandb.proto.wandb_internal_pb2 import RunRecord + from wandb.sdk.lib.filesystem import FilesDict + + from ..interface.interface_queue import InterfaceQueue + from .settings_static import SettingsStatic + + HistoryDict = Dict[str, Any] + +# Give some time for tensorboard data to be flushed +SHUTDOWN_DELAY = 5 +ERROR_DELAY = 5 +REMOTE_FILE_TOKEN = "://" +logger = logging.getLogger(__name__) + + +def _link_and_save_file( + path: str, base_path: str, interface: "InterfaceQueue", settings: "SettingsStatic" +) -> None: + # TODO(jhr): should this logic be merged with Run.save() + files_dir = settings.files_dir + file_name = os.path.relpath(path, base_path) + abs_path = os.path.abspath(path) + wandb_path = os.path.join(files_dir, file_name) + filesystem.mkdir_exists_ok(os.path.dirname(wandb_path)) + # We overwrite existing symlinks because namespaces can change in Tensorboard + if os.path.islink(wandb_path) and abs_path != os.readlink(wandb_path): + os.remove(wandb_path) + os.symlink(abs_path, wandb_path) + elif not os.path.exists(wandb_path): + os.symlink(abs_path, wandb_path) + # TODO(jhr): need to figure out policy, live/throttled? + interface.publish_files( + dict(files=[(filesystem.GlobStr(glob.escape(file_name)), "live")]) + ) + + +def is_tfevents_file_created_by( + path: str, hostname: Optional[str], start_time: Optional[float] +) -> bool: + """Check if a path is a tfevents file. + + Optionally checks that it was created by [hostname] after [start_time]. + + tensorboard tfevents filename format: + https://github.com/tensorflow/tensorboard/blob/f3f26b46981da5bd46a5bb93fcf02d9eb7608bc1/tensorboard/summary/writer/event_file_writer.py#L81 + tensorflow tfevents filename format: + https://github.com/tensorflow/tensorflow/blob/8f597046dc30c14b5413813d02c0e0aed399c177/tensorflow/core/util/events_writer.cc#L68 + """ + if not path: + raise ValueError("Path must be a nonempty string") + basename = os.path.basename(path) + if basename.endswith((".profile-empty", ".sagemaker-uploaded")): + return False + fname_components = basename.split(".") + try: + tfevents_idx = fname_components.index("tfevents") + except ValueError: + return False + # check the hostname, which may have dots + if hostname is not None: + for i, part in enumerate(hostname.split(".")): + try: + fname_component_part = fname_components[tfevents_idx + 2 + i] + except IndexError: + return False + if part != fname_component_part: + return False + if start_time is not None: + try: + created_time = int(fname_components[tfevents_idx + 1]) + except (ValueError, IndexError): + return False + # Ensure that the file is newer then our start time, and that it was + # created from the same hostname. + # TODO: we should also check the PID (also contained in the tfevents + # filename). Can we assume that our parent pid is the user process + # that wrote these files? + if created_time < int(start_time): + return False + return True + + +class TBWatcher: + _logdirs: "Dict[str, TBDirWatcher]" + _watcher_queue: "PriorityQueue" + + def __init__( + self, + settings: "SettingsStatic", + run_proto: "RunRecord", + interface: "InterfaceQueue", + force: bool = False, + ) -> None: + self._logdirs = {} + self._consumer: Optional[TBEventConsumer] = None + self._settings = settings + self._interface = interface + self._run_proto = run_proto + self._force = force + # TODO(jhr): do we need locking in this queue? + self._watcher_queue = queue.PriorityQueue() + wandb.tensorboard.reset_state() # type: ignore + + def _calculate_namespace(self, logdir: str, rootdir: str) -> Optional[str]: + namespace: Optional[str] + dirs = list(self._logdirs) + [logdir] + + if os.path.isfile(logdir): + filename = os.path.basename(logdir) + else: + filename = "" + + if rootdir == "": + rootdir = util.to_forward_slash_path( + os.path.dirname(os.path.commonprefix(dirs)) + ) + # Tensorboard loads all tfevents files in a directory and prepends + # their values with the path. Passing namespace to log allows us + # to nest the values in wandb + # Note that we strip '/' instead of os.sep, because elsewhere we've + # converted paths to forward slash. + namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/") + + # TODO: revisit this heuristic, it exists because we don't know the + # root log directory until more than one tfevents file is written to + if len(dirs) == 1 and namespace not in ["train", "validation"]: + namespace = None + else: + namespace = logdir.replace(filename, "").replace(rootdir, "").strip("/") + + return namespace + + def add(self, logdir: str, save: bool, root_dir: str) -> None: + logdir = util.to_forward_slash_path(logdir) + root_dir = util.to_forward_slash_path(root_dir) + if logdir in self._logdirs: + return + namespace = self._calculate_namespace(logdir, root_dir) + # TODO(jhr): implement the deferred tbdirwatcher to find namespace + + if not self._consumer: + self._consumer = TBEventConsumer( + self, self._watcher_queue, self._run_proto, self._settings + ) + self._consumer.start() + + tbdir_watcher = TBDirWatcher( + self, logdir, save, namespace, self._watcher_queue, self._force + ) + self._logdirs[logdir] = tbdir_watcher + tbdir_watcher.start() + + def finish(self) -> None: + for tbdirwatcher in self._logdirs.values(): + tbdirwatcher.shutdown() + for tbdirwatcher in self._logdirs.values(): + tbdirwatcher.finish() + if self._consumer: + self._consumer.finish() + + +class TBDirWatcher: + def __init__( + self, + tbwatcher: "TBWatcher", + logdir: str, + save: bool, + namespace: Optional[str], + queue: "PriorityQueue", + force: bool = False, + ) -> None: + self.directory_watcher = util.get_module( + "tensorboard.backend.event_processing.directory_watcher", + required="Please install tensorboard package", + ) + # self.event_file_loader = util.get_module( + # "tensorboard.backend.event_processing.event_file_loader", + # required="Please install tensorboard package", + # ) + self.tf_compat = util.get_module( + "tensorboard.compat", required="Please install tensorboard package" + ) + self._tbwatcher = tbwatcher + self._generator = self.directory_watcher.DirectoryWatcher( + logdir, self._loader(save, namespace), self._is_our_tfevents_file + ) + self._thread = threading.Thread(target=self._thread_except_body) + self._first_event_timestamp = None + self._shutdown = threading.Event() + self._queue = queue + self._file_version = None + self._namespace = namespace + self._logdir = logdir + self._hostname = socket.gethostname() + self._force = force + self._process_events_lock = threading.Lock() + + def start(self) -> None: + self._thread.start() + + def _is_our_tfevents_file(self, path: str) -> bool: + """Check if a path has been modified since launch and contains tfevents.""" + if not path: + raise ValueError("Path must be a nonempty string") + path = self.tf_compat.tf.compat.as_str_any(path) + if self._force: + return is_tfevents_file_created_by(path, None, None) + else: + return is_tfevents_file_created_by( + path, self._hostname, self._tbwatcher._settings.x_start_time + ) + + def _loader( + self, save: bool = True, namespace: Optional[str] = None + ) -> "EventFileLoader": + """Incredibly hacky class generator to optionally save / prefix tfevent files.""" + _loader_interface = self._tbwatcher._interface + _loader_settings = self._tbwatcher._settings + try: + from tensorboard.backend.event_processing import event_file_loader + except ImportError: + raise Exception("Please install tensorboard package") + + class EventFileLoader(event_file_loader.EventFileLoader): + def __init__(self, file_path: str) -> None: + super().__init__(file_path) + if save: + if REMOTE_FILE_TOKEN in file_path: + logger.warning( + "Not persisting remote tfevent file: %s", file_path + ) + else: + # TODO: save plugins? + logdir = os.path.dirname(file_path) + parts = list(os.path.split(logdir)) + if namespace and parts[-1] == namespace: + parts.pop() + logdir = os.path.join(*parts) + _link_and_save_file( + path=file_path, + base_path=logdir, + interface=_loader_interface, + settings=_loader_settings, + ) + + return EventFileLoader + + def _process_events(self, shutdown_call: bool = False) -> None: + try: + with self._process_events_lock: + for event in self._generator.Load(): + self.process_event(event) + except ( + self.directory_watcher.DirectoryDeletedError, + StopIteration, + RuntimeError, + OSError, + ) as e: + # When listing s3 the directory may not yet exist, or could be empty + logger.debug("Encountered tensorboard directory watcher error: %s", e) + if not self._shutdown.is_set() and not shutdown_call: + time.sleep(ERROR_DELAY) + + def _thread_except_body(self) -> None: + try: + self._thread_body() + except Exception: + logger.exception("generic exception in TBDirWatcher thread") + raise + + def _thread_body(self) -> None: + """Check for new events every second.""" + shutdown_time: Optional[float] = None + while True: + self._process_events() + if self._shutdown.is_set(): + now = time.time() + if not shutdown_time: + shutdown_time = now + SHUTDOWN_DELAY + elif now > shutdown_time: + break + time.sleep(1) + + def process_event(self, event: "ProtoEvent") -> None: + # print("\nEVENT:::", self._logdir, self._namespace, event, "\n") + if self._first_event_timestamp is None: + self._first_event_timestamp = event.wall_time + + if event.HasField("file_version"): + self._file_version = event.file_version + + if event.HasField("summary"): + self._queue.put(Event(event, self._namespace)) + + def shutdown(self) -> None: + self._process_events(shutdown_call=True) + self._shutdown.set() + + def finish(self) -> None: + self.shutdown() + self._thread.join() + + +class Event: + """An event wrapper to enable priority queueing.""" + + def __init__(self, event: "ProtoEvent", namespace: Optional[str]): + self.event = event + self.namespace = namespace + self.created_at = time.time() + + def __lt__(self, other: "Event") -> bool: + if self.event.wall_time < other.event.wall_time: + return True + return False + + +class TBEventConsumer: + """Consume tfevents from a priority queue. + + There should always only be one of these per run_manager. We wait for 10 seconds of + queued events to reduce the chance of multiple tfevent files triggering out of order + steps. + """ + + def __init__( + self, + tbwatcher: TBWatcher, + queue: "PriorityQueue", + run_proto: "RunRecord", + settings: "SettingsStatic", + delay: int = 10, + ) -> None: + self._tbwatcher = tbwatcher + self._queue = queue + self._thread = threading.Thread(target=self._thread_except_body) + self._shutdown = threading.Event() + self.tb_history = TBHistory() + self._delay = delay + + # This is a bit of a hack to get file saving to work as it does in the user + # process. Since we don't have a real run object, we have to define the + # datatypes callback ourselves. + def datatypes_cb(fname: filesystem.GlobStr) -> None: + files: FilesDict = dict(files=[(fname, "now")]) + self._tbwatcher._interface.publish_files(files) + + # this is only used for logging artifacts + self._internal_run = internal_run.InternalRun(run_proto, settings, datatypes_cb) + self._internal_run._set_internal_run_interface(self._tbwatcher._interface) + + def start(self) -> None: + self._start_time = time.time() + self._thread.start() + + def finish(self) -> None: + self._delay = 0 + self._shutdown.set() + self._thread.join() + while not self._queue.empty(): + event = self._queue.get(True, 1) + if event: + self._handle_event(event, history=self.tb_history) + items = self.tb_history._get_and_reset() + for item in items: + self._save_row( + item, + ) + + def _thread_except_body(self) -> None: + try: + self._thread_body() + except Exception: + logger.exception("generic exception in TBEventConsumer thread") + raise + + def _thread_body(self) -> None: + while True: + try: + event = self._queue.get(True, 1) + # Wait self._delay seconds from consumer start before logging events + if ( + time.time() < self._start_time + self._delay + and not self._shutdown.is_set() + ): + self._queue.put(event) + time.sleep(0.1) + continue + except queue.Empty: + event = None + if self._shutdown.is_set(): + break + if event: + self._handle_event(event, history=self.tb_history) + items = self.tb_history._get_and_reset() + for item in items: + self._save_row( + item, + ) + # flush uncommitted data + self.tb_history._flush() + items = self.tb_history._get_and_reset() + for item in items: + self._save_row(item) + + def _handle_event( + self, event: "ProtoEvent", history: Optional["TBHistory"] = None + ) -> None: + wandb.tensorboard._log( # type: ignore + event.event, + step=event.event.step, + namespace=event.namespace, + history=history, + ) + + def _save_row(self, row: "HistoryDict") -> None: + chart_keys = set() + for k, v in row.items(): + if isinstance(v, CustomChart): + chart_keys.add(k) + v.set_key(k) + self._tbwatcher._interface.publish_config( + key=v.spec.config_key, + val=v.spec.config_value, + ) + + for k in chart_keys: + chart = row.pop(k) + if isinstance(chart, CustomChart): + row[chart.spec.table_key] = chart.table + + self._tbwatcher._interface.publish_history( + self._internal_run, + row, + publish_step=False, + ) + + +class TBHistory: + _data: "HistoryDict" + _added: "List[HistoryDict]" + + def __init__(self) -> None: + self._step = 0 + self._step_size = 0 + self._data = dict() + self._added = [] + + def _flush(self) -> None: + if not self._data: + return + # A single tensorboard step may have too much data + # we just drop the largest keys in the step if it does. + # TODO: we could flush the data across multiple steps + if self._step_size > util.MAX_LINE_BYTES: + metrics = [(k, sys.getsizeof(v)) for k, v in self._data.items()] + metrics.sort(key=lambda t: t[1], reverse=True) + bad = 0 + dropped_keys = [] + for k, v in metrics: + # TODO: (cvp) Added a buffer of 100KiB, this feels rather brittle. + if self._step_size - bad < util.MAX_LINE_BYTES - 100000: + break + else: + bad += v + dropped_keys.append(k) + del self._data[k] + wandb.termwarn( + f"Step {self._step} exceeds max data limit, dropping {len(dropped_keys)} of the largest keys:" + ) + print("\t" + ("\n\t".join(dropped_keys))) # noqa: T201 + self._data["_step"] = self._step + self._added.append(self._data) + self._step += 1 + self._step_size = 0 + + def add(self, d: "HistoryDict") -> None: + self._flush() + self._data = dict() + self._data.update(self._track_history_dict(d)) + + def _track_history_dict(self, d: "HistoryDict") -> "HistoryDict": + e = {} + for k in d.keys(): + e[k] = d[k] + self._step_size += sys.getsizeof(e[k]) + return e + + def _row_update(self, d: "HistoryDict") -> None: + self._data.update(self._track_history_dict(d)) + + def _get_and_reset(self) -> "List[HistoryDict]": + added = self._added[:] + self._added = [] + return added diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/launch/loader.py b/.venv/lib/python3.13/site-packages/wandb/sdk/launch/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..d8015a25d233688ab2e2e1634d0bed621121235d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/launch/loader.py @@ -0,0 +1,249 @@ +"""Utilities for the agent.""" + +from typing import Any, Dict, Optional + +import wandb +from wandb.apis.internal import Api +from wandb.docker import is_docker_installed +from wandb.sdk.launch.errors import LaunchError + +from .builder.abstract import AbstractBuilder +from .environment.abstract import AbstractEnvironment +from .registry.abstract import AbstractRegistry +from .runner.abstract import AbstractRunner + +WANDB_RUNNERS = { + "local-container", + "local-process", + "kubernetes", + "vertex", + "sagemaker", +} + + +def environment_from_config(config: Optional[Dict[str, Any]]) -> AbstractEnvironment: + """Create an environment from a config. + + This helper function is used to create an environment from a config. The + config should have a "type" key that specifies the type of environment to + create. The remaining keys are passed to the environment's from_config + method. If the config is None or empty, a LocalEnvironment is returned. + + Arguments: + config (Dict[str, Any]): The config. + + Returns: + Environment: The environment constructed. + """ + if not config: + from .environment.local_environment import LocalEnvironment + + return LocalEnvironment() # This is the default, dummy environment. + env_type = config.get("type") + if not env_type: + raise LaunchError( + "Could not create environment from config. Environment type not specified!" + ) + if env_type == "local": + from .environment.local_environment import LocalEnvironment + + return LocalEnvironment.from_config(config) + if env_type == "aws": + from .environment.aws_environment import AwsEnvironment + + return AwsEnvironment.from_config(config) + if env_type == "gcp": + from .environment.gcp_environment import GcpEnvironment + + return GcpEnvironment.from_config(config) + if env_type == "azure": + from .environment.azure_environment import AzureEnvironment + + return AzureEnvironment.from_config(config) + raise LaunchError( + f"Could not create environment from config. Invalid type: {env_type}" + ) + + +def registry_from_config( + config: Optional[Dict[str, Any]], environment: AbstractEnvironment +) -> AbstractRegistry: + """Create a registry from a config. + + This helper function is used to create a registry from a config. The + config should have a "type" key that specifies the type of registry to + create. The remaining keys are passed to the registry's from_config + method. If the config is None or empty, a LocalRegistry is returned. + + Arguments: + config (Dict[str, Any]): The registry config. + environment (Environment): The environment of the registry. + + Returns: + The registry if config is not None, otherwise None. + + Raises: + LaunchError: If the registry is not configured correctly. + """ + if not config: + from .registry.local_registry import LocalRegistry + + return LocalRegistry() # This is the default, dummy registry. + + wandb.termwarn( + "The `registry` block of the launch agent config is being deprecated. " + "Please specify an image repository URI under the `builder.destination` " + "key of your launch agent config. See " + "https://docs.wandb.ai/guides/launch/setup-agent-advanced#agent-configuration " + "for more information." + ) + + registry_type = config.get("type") + if registry_type is None or registry_type == "local": + from .registry.local_registry import LocalRegistry + + return LocalRegistry() # This is the default, dummy registry. + if registry_type == "ecr": + from .registry.elastic_container_registry import ElasticContainerRegistry + + return ElasticContainerRegistry.from_config(config) + if registry_type == "gcr": + from .registry.google_artifact_registry import GoogleArtifactRegistry + + return GoogleArtifactRegistry.from_config(config) + if registry_type == "acr": + from .registry.azure_container_registry import AzureContainerRegistry + + return AzureContainerRegistry.from_config(config) + raise LaunchError( + f"Could not create registry from config. Invalid registry type: {registry_type}" + ) + + +def builder_from_config( + config: Optional[Dict[str, Any]], + environment: AbstractEnvironment, + registry: AbstractRegistry, +) -> AbstractBuilder: + """Create a builder from a config. + + This helper function is used to create a builder from a config. The + config should have a "type" key that specifies the type of builder to import + and create. The remaining keys are passed to the builder's from_config + method. If the config is None or empty, a default builder is returned. + + The default builder will be a DockerBuilder if we find a working docker cli + on the system, otherwise it will be a NoOpBuilder. + + Arguments: + config (Dict[str, Any]): The builder config. + registry (Registry): The registry of the builder. + + Returns: + The builder. + + Raises: + LaunchError: If the builder is not configured correctly. + """ + if not config: + if is_docker_installed(): + from .builder.docker_builder import DockerBuilder + + return DockerBuilder.from_config( + {}, environment, registry + ) # This is the default builder. + + from .builder.noop import NoOpBuilder + + return NoOpBuilder.from_config({}, environment, registry) + + builder_type = config.get("type") + if builder_type is None: + raise LaunchError( + "Could not create builder from config. Builder type not specified" + ) + if builder_type == "docker": + from .builder.docker_builder import DockerBuilder + + return DockerBuilder.from_config(config, environment, registry) + if builder_type == "kaniko": + from .builder.kaniko_builder import KanikoBuilder + + return KanikoBuilder.from_config(config, environment, registry) + if builder_type == "noop": + from .builder.noop import NoOpBuilder + + return NoOpBuilder.from_config(config, environment, registry) + raise LaunchError( + f"Could not create builder from config. Invalid builder type: {builder_type}" + ) + + +def runner_from_config( + runner_name: str, + api: Api, + runner_config: Dict[str, Any], + environment: AbstractEnvironment, + registry: AbstractRegistry, +) -> AbstractRunner: + """Create a runner from a config. + + This helper function is used to create a runner from a config. The + config should have a "type" key that specifies the type of runner to import + and create. The remaining keys are passed to the runner's from_config + method. If the config is None or empty, a LocalContainerRunner is returned. + + Arguments: + runner_name (str): The name of the backend. + api (Api): The API. + runner_config (Dict[str, Any]): The backend config. + + Returns: + The runner. + + Raises: + LaunchError: If the runner is not configured correctly. + """ + if not runner_name or runner_name in ["local-container", "local"]: + from .runner.local_container import LocalContainerRunner + + return LocalContainerRunner(api, runner_config, environment, registry) + if runner_name == "local-process": + from .runner.local_process import LocalProcessRunner + + return LocalProcessRunner(api, runner_config) + if runner_name == "sagemaker": + from .environment.aws_environment import AwsEnvironment + + if not isinstance(environment, AwsEnvironment): + try: + environment = AwsEnvironment.from_default() + except LaunchError as e: + raise LaunchError( + "Could not create Sagemaker runner. " + "Environment must be an instance of AwsEnvironment." + ) from e + from .runner.sagemaker_runner import SageMakerRunner + + return SageMakerRunner(api, runner_config, environment, registry) + if runner_name in ["vertex", "gcp-vertex"]: + from .environment.gcp_environment import GcpEnvironment + + if not isinstance(environment, GcpEnvironment): + try: + environment = GcpEnvironment.from_default() + except LaunchError as e: + raise LaunchError( + "Could not create Vertex runner. " + "Environment must be an instance of GcpEnvironment." + ) from e + from .runner.vertex_runner import VertexRunner + + return VertexRunner(api, runner_config, environment, registry) + if runner_name == "kubernetes": + from .runner.kubernetes_runner import KubernetesRunner + + return KubernetesRunner(api, runner_config, environment, registry) + raise LaunchError( + f"Could not create runner from config. Invalid runner name: {runner_name}" + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/launch/wandb_reference.py b/.venv/lib/python3.13/site-packages/wandb/sdk/launch/wandb_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..5de34c04bf3aa77068da19cb6b4af8f1e163709d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/launch/wandb_reference.py @@ -0,0 +1,138 @@ +"""Support for parsing W&B URLs (which might be user provided) into constituent parts.""" + +from dataclasses import dataclass +from enum import IntEnum +from typing import Optional +from urllib.parse import urlparse + +PREFIX_HTTP = "http://" +PREFIX_HTTPS = "https://" + + +class ReferenceType(IntEnum): + RUN = 1 + JOB = 2 + + +# Ideally we would not overload the URL paths as we do. +# TODO: Not sure these are exhaustive, and even if so more special paths might get added. +# Would be good to have restrictions that we could check. +RESERVED_NON_ENTITIES = ( + "create-team", + "fully-connected", + "registry", + "settings", + "subscriptions", +) +RESERVED_NON_PROJECTS = ( + "likes", + "projects", +) +RESERVED_JOB_PATHS = ("_view",) + + +@dataclass +class WandbReference: + # TODO: This will include port, should we separate that out? + host: Optional[str] = None + + entity: Optional[str] = None + project: Optional[str] = None + + # Set when we don't know how to parse yet + path: Optional[str] = None + + # Reference type will determine what other fields are set + ref_type: Optional[ReferenceType] = None + + run_id: Optional[str] = None + + job_name: Optional[str] = None + job_alias: str = "latest" # In addition to an alias can be a version specifier + + def is_bare(self) -> bool: + return self.host is None + + def is_job(self) -> bool: + return self.ref_type == ReferenceType.JOB + + def is_run(self) -> bool: + return self.ref_type == ReferenceType.RUN + + def is_job_or_run(self) -> bool: + return self.is_job() or self.is_run() + + def job_reference(self) -> str: + assert self.is_job() + return f"{self.job_name}:{self.job_alias}" + + def job_reference_scoped(self) -> str: + assert self.entity + assert self.project + unscoped = self.job_reference() + return f"{self.entity}/{self.project}/{unscoped}" + + def url_host(self) -> str: + return f"{PREFIX_HTTPS}{self.host}" if self.host else "" + + def url_entity(self) -> str: + assert self.entity + return f"{self.url_host()}/{self.entity}" + + def url_project(self) -> str: + assert self.project + return f"{self.url_entity()}/{self.project}" + + @staticmethod + def parse(uri: str) -> Optional["WandbReference"]: + """Attempt to parse a string as a W&B URL.""" + # TODO: Error if HTTP and host is not localhost? + if ( + not uri.startswith("/") + and not uri.startswith(PREFIX_HTTP) + and not uri.startswith(PREFIX_HTTPS) + ): + return None + + ref = WandbReference() + + # This takes care of things like query and fragment + parsed = urlparse(uri) + if parsed.netloc: + ref.host = parsed.netloc + + if not parsed.path.startswith("/"): + return ref + + ref.path = parsed.path[1:] + parts = ref.path.split("/") + if len(parts) > 0: + if parts[0] not in RESERVED_NON_ENTITIES: + ref.path = None + ref.entity = parts[0] + if len(parts) > 1: + if parts[1] not in RESERVED_NON_PROJECTS: + ref.project = parts[1] + if len(parts) > 3 and parts[2] == "runs": + ref.ref_type = ReferenceType.RUN + ref.run_id = parts[3] + elif ( + len(parts) > 4 + and parts[2] == "artifacts" + and parts[3] == "job" + ): + ref.ref_type = ReferenceType.JOB + ref.job_name = parts[4] + if len(parts) > 5 and parts[5] not in RESERVED_JOB_PATHS: + ref.job_alias = parts[5] + # TODO: Right now we are not tracking selection as part of URL state in the Jobs tab. + # If that changes we'll want to update this. + + return ref + + @staticmethod + def is_uri_job_or_run(uri: str) -> bool: + ref = WandbReference.parse(uri) + if ref and ref.is_job_or_run(): + return True + return False diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/mailbox.py b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/mailbox.py new file mode 100644 index 0000000000000000000000000000000000000000..a455ad57797ce4ee377f56f774d48448e5ec69be --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/mailbox.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import logging +import secrets +import string +import threading +from typing import Awaitable, Callable + +from wandb.proto import wandb_internal_pb2 as pb +from wandb.proto import wandb_server_pb2 as spb +from wandb.sdk.lib import asyncio_manager + +from .mailbox_handle import MailboxHandle +from .response_handle import MailboxResponseHandle + +_logger = logging.getLogger(__name__) + + +class MailboxClosedError(Exception): + """The mailbox has been closed and cannot be used.""" + + +class Mailbox: + """Matches service responses to requests. + + The mailbox can set an address on a server request and create a handle for + waiting for a response to that record. Responses are delivered by calling + `deliver()`. The `close()` method abandons all handles in case the + service process becomes unreachable. + """ + + def __init__( + self, + asyncer: asyncio_manager.AsyncioManager, + cancel: Callable[[str], Awaitable[None]], + ) -> None: + """Create a mailbox. + + Args: + asyncer: Asyncio runner for scheduling async operations. + cancel: A callback that can be used to cancel a request by ID. + """ + self._asyncer = asyncer + self._cancel = cancel + + self._handles: dict[str, MailboxResponseHandle] = {} + self._handles_lock = threading.Lock() + self._closed = False + + def require_response( + self, + request: spb.ServerRequest | pb.Record, + ) -> MailboxHandle[spb.ServerResponse]: + """Set a response address on a request. + + Args: + request: The request on which to set a request ID or mailbox slot. + This is mutated. An address must not already be set. + + Returns: + A handle for waiting for the response to the request. + + Raises: + MailboxClosedError: If the mailbox has been closed, in which case + no new responses are expected to be delivered and new handles + cannot be created. + """ + if isinstance(request, spb.ServerRequest): + if (address := request.request_id) or ( + address := request.record_publish.control.mailbox_slot + ): + raise ValueError(f"Request already has an address ({address})") + + address = self._new_address() + request.request_id = address + if request.HasField("record_publish"): + request.record_publish.control.mailbox_slot = address + if request.HasField("record_communicate"): + request.record_communicate.control.mailbox_slot = address + else: + if address := request.control.mailbox_slot: + raise ValueError(f"Request already has an address ({address})") + + address = self._new_address() + request.control.mailbox_slot = address + + with self._handles_lock: + if self._closed: + raise MailboxClosedError() + + handle = MailboxResponseHandle( + address, + asyncer=self._asyncer, + cancel=self._cancel, + ) + self._handles[address] = handle + + return handle + + def _new_address(self) -> str: + """Returns an unused address for a request. + + Assumes `_handles_lock` is held. + """ + + def generate(): + return "".join( + secrets.choice(string.ascii_lowercase + string.digits) + for _ in range(12) + ) + + address = generate() + + # Being extra cautious. This loop will almost never be entered. + while address in self._handles: + address = generate() + + return address + + async def deliver(self, response: spb.ServerResponse) -> None: + """Deliver a response from the service. + + If the response address is invalid, this does nothing. + It is a no-op if the mailbox has been closed. + """ + address = response.request_id + if not address: + kind: str | None = response.WhichOneof("server_response_type") + if kind == "result_communicate": + result_type = response.result_communicate.WhichOneof("result_type") + kind = f"result_communicate.{result_type}" + + _logger.error(f"Received response with no mailbox slot: {kind}") + return + + with self._handles_lock: + # NOTE: If the mailbox is closed, this returns None because + # we clear the dict. + handle = self._handles.pop(address, None) + + # It is not an error if there is no handle for the address: + # handles can be abandoned if the result is no longer needed. + if handle: + await handle.deliver(response) + + def close(self) -> None: + """Indicate no further responses will be delivered. + + Abandons all handles. + """ + with self._handles_lock: + self._closed = True + + _logger.info( + f"Closing mailbox, abandoning {len(self._handles)} handles.", + ) + + for handle in self._handles.values(): + handle.abandon() + self._handles.clear() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/mailbox_handle.py b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/mailbox_handle.py new file mode 100644 index 0000000000000000000000000000000000000000..a5744643291b8816a38a291b475132f22883110d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/mailbox_handle.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import abc +from typing import Callable, Generic + +from typing_extensions import TypeVar, override + +from wandb.sdk.lib import asyncio_manager + +_T = TypeVar("_T") +_S = TypeVar("_S") + + +class HandleAbandonedError(Exception): + """The handle has no response and has been abandoned.""" + + +class MailboxHandle(abc.ABC, Generic[_T]): + """A handle for waiting on a response to a request.""" + + def __init__(self, asyncer: asyncio_manager.AsyncioManager) -> None: + self._asyncer = asyncer + + @property + def asyncer(self) -> asyncio_manager.AsyncioManager: + """The asyncio thread to which the handle belongs. + + The handle's async methods must be run using this object. + """ + return self._asyncer + + def map(self, fn: Callable[[_T], _S]) -> MailboxHandle[_S]: + """Returns a transformed handle. + + Methods on the returned handle call methods on this handle, but the + response type is derived using the given function. + + Args: + fn: A function to apply to this handle's result to get the new + handle's result. The function should be pure and fast. + """ + return _MailboxMappedHandle(self, fn) + + @abc.abstractmethod + def cancel(self) -> None: + """Cancel the handle, requesting any associated work to not complete. + + Any calls to `wait_or` or `wait_async` will raise `HandleAbandonedError` + if they aren't resolved within a short time. + + Cancellation is best-effort. Most exceptions are logged and suppressed. + """ + + @abc.abstractmethod + def wait_or(self, *, timeout: float | None) -> _T: + """Wait for a response or a timeout. + + It is an error to call this from an async function. + On error, including KeyboardInterrupt or a timeout, + the handle cancels itself. + + Args: + timeout: A finite number of seconds or None to never time out. + If less than or equal to zero, times out immediately unless + the response is available. + + Returns: + The response if it arrives before the timeout or has already arrived. + + Raises: + TimeoutError: If the timeout is reached. + HandleAbandonedError: If the handle becomes abandoned. + """ + + @abc.abstractmethod + async def wait_async(self, *, timeout: float | None) -> _T: + """Wait for a response or timeout. + + This must run in an `asyncio` event loop. + On error, including asyncio cancellation, KeyboardInterrupt or + a timeout, the handle cancels itself. + + Args: + timeout: A finite number of seconds or None to never time out. + + Returns: + The response if it arrives before the timeout or has already arrived. + + Raises: + TimeoutError: If the timeout is reached. + HandleAbandonedError: If the handle becomes abandoned. + """ + + +class _MailboxMappedHandle(Generic[_S], MailboxHandle[_S]): + """A mailbox handle whose result is derived from another handle.""" + + def __init__( + self, + handle: MailboxHandle[_T], + fn: Callable[[_T], _S], + ) -> None: + super().__init__(handle.asyncer) + self._handle = handle + self._fn = fn + + @override + def cancel(self) -> None: + self._handle.cancel() + + @override + def wait_or(self, *, timeout: float | None) -> _S: + return self._fn(self._handle.wait_or(timeout=timeout)) + + @override + async def wait_async(self, *, timeout: float | None) -> _S: + response = await self._handle.wait_async(timeout=timeout) + return self._fn(response) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/response_handle.py b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/response_handle.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc1f56495e5f39fcbbb45be73f01364aa506e3e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/response_handle.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import asyncio +import logging +import math +from typing import Awaitable, Callable + +from typing_extensions import override + +from wandb.proto import wandb_server_pb2 as spb +from wandb.sdk.lib import asyncio_manager + +from .mailbox_handle import HandleAbandonedError, MailboxHandle + +_logger = logging.getLogger(__name__) + + +class MailboxResponseHandle(MailboxHandle[spb.ServerResponse]): + """A general handle for any ServerResponse.""" + + def __init__( + self, + address: str, + *, + asyncer: asyncio_manager.AsyncioManager, + cancel: Callable[[str], Awaitable[None]], + ) -> None: + super().__init__(asyncer) + + self._address = address + self._cancel_fn = cancel + + self._abandoned = False + self._response: spb.ServerResponse | None = None + + # Initialized on first use in the asyncio thread. + self._done_event: asyncio.Event | None = None + + async def deliver(self, response: spb.ServerResponse) -> None: + if self._abandoned: + return + + if self._response: + raise ValueError( + f"A response has already been delivered to {self._address}." + ) + + self._response = response + + if not self._done_event: + self._done_event = asyncio.Event() + self._done_event.set() + + @override + def cancel(self) -> None: + # Cancel on a best-effort basis and ignore exceptions. + async def impl() -> None: + try: + await self._cancel_fn(self._address) + except Exception: + _logger.exception("Failed to cancel request %r", self._address) + + try: + self.abandon() + self.asyncer.run_soon(impl) + except Exception: + _logger.exception( + "Failed to abandon and cancel request %r", + self._address, + ) + + def abandon(self) -> None: + """Indicate the handle will not receive a response. + + This causes any code blocked on `wait_or` or `wait_async` to raise + a `HandleAbandonedError` after a short time. + """ + + async def impl() -> None: + self._abandoned = True + + if not self._done_event: + self._done_event = asyncio.Event() + self._done_event.set() + + self.asyncer.run_soon(impl) + + @override + def wait_or(self, *, timeout: float | None) -> spb.ServerResponse: + return self.asyncer.run(lambda: self.wait_async(timeout=timeout)) + + @override + async def wait_async(self, *, timeout: float | None) -> spb.ServerResponse: + if timeout is not None and not math.isfinite(timeout): + raise ValueError("Timeout must be finite or None.") + + if not self._done_event: + self._done_event = asyncio.Event() + + try: + await asyncio.wait_for(self._done_event.wait(), timeout=timeout) + + except (asyncio.TimeoutError, TimeoutError) as e: + if self._response: + return self._response + elif self._abandoned: + raise HandleAbandonedError() + else: + self.cancel() + raise TimeoutError( + f"Timed out waiting for response on {self._address}" + ) from e + + except: + self.cancel() + raise + + else: + if self._response: + return self._response + + assert self._abandoned + raise HandleAbandonedError() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/wait_with_progress.py b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/wait_with_progress.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f0de358a1a4d4b2a12f63b1709a84bac3abbc6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/mailbox/wait_with_progress.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import time +from typing import Any, Callable, Coroutine, List, TypeVar, cast + +from wandb.sdk.lib import asyncio_compat + +from .mailbox_handle import MailboxHandle + +_T = TypeVar("_T") + + +def wait_with_progress( + handle: MailboxHandle[_T], + *, + timeout: float | None, + display_progress: Callable[[], Coroutine[Any, Any, None]], +) -> _T: + """Wait for a handle, possibly displaying progress to the user. + + Equivalent to passing a single handle to `wait_all_with_progress`. + """ + return wait_all_with_progress( + [handle], + timeout=timeout, + display_progress=display_progress, + )[0] + + +def wait_all_with_progress( + handle_list: list[MailboxHandle[_T]], + *, + timeout: float | None, + display_progress: Callable[[], Coroutine[Any, Any, None]], +) -> list[_T]: + """Wait for multiple handles, possibly displaying progress to the user. + + Args: + handle_list: The handles to wait for. + timeout: A number of seconds after which to raise a TimeoutError, + or None if this should never timeout. + display_progress: An asyncio function that displays progress to + the user. This function runs using the handles' AsyncioManager. + + Returns: + A list where the Nth item is the Nth handle's result. + + Raises: + ValueError: If the handles live in different asyncio threads. + TimeoutError: If the overall timeout expires. + HandleAbandonedError: If any handle becomes abandoned. + Exception: Any exception from the display function is propagated. + """ + if not handle_list: + return [] + + asyncer = handle_list[0].asyncer + for handle in handle_list: + if handle.asyncer is not asyncer: + raise ValueError("Handles have different AsyncioManagers.") + + start_time = time.monotonic() + + async def progress_loop_with_timeout() -> list[_T]: + async with asyncio_compat.cancel_on_exit(display_progress()): + if timeout is not None: + elapsed_time = time.monotonic() - start_time + remaining_timeout = timeout - elapsed_time + else: + remaining_timeout = None + + return await _wait_handles_async( + handle_list, + timeout=remaining_timeout, + ) + + return asyncer.run(progress_loop_with_timeout) + + +async def _wait_handles_async( + handle_list: list[MailboxHandle[_T]], + *, + timeout: float | None, +) -> list[_T]: + """Asynchronously wait for multiple mailbox handles. + + Just like _wait_handles. + """ + results: list[_T | None] = [None for _ in handle_list] + + async def wait_single(index: int) -> None: + handle = handle_list[index] + results[index] = await handle.wait_async(timeout=timeout) + + async with asyncio_compat.open_task_group() as task_group: + for index in range(len(handle_list)): + task_group.start_soon(wait_single(index)) + + # NOTE: `list` is not subscriptable until Python 3.10, so we use List. + return cast(List[_T], results) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__init__.py b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2bc82d4ce3bc4b523a4c69d60ca594264f25503 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__pycache__/verify.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__pycache__/verify.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ff2f63e35bd3d09653916cc034f2ed9f30e462f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/__pycache__/verify.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/verify/verify.py b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/verify.py new file mode 100644 index 0000000000000000000000000000000000000000..8cadbf310a9b5c273011fee2c66c78f5c18ed6ea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/verify/verify.py @@ -0,0 +1,555 @@ +"""Utilities for wandb verify.""" + +import contextlib +import getpass +import io +import os +import time +from functools import partial +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import click +import requests +from wandb_gql import gql + +import wandb +from wandb.sdk.artifacts.artifact import Artifact +from wandb.sdk.lib import runid + +from ...apis.internal import Api + +PROJECT_NAME = "verify" +GET_RUN_MAX_TIME = 10 +MIN_RETRYS = 3 +CHECKMARK = "\u2705" +RED_X = "\u274c" +ID_PREFIX = runid.generate_id() + + +def nice_id(name): + return ID_PREFIX + "-" + name + + +def print_results( + failed_test_or_tests: Optional[Union[str, List[str]]], warning: bool +) -> None: + if warning: + color = "yellow" + else: + color = "red" + if isinstance(failed_test_or_tests, str): + print(RED_X) # noqa: T201 + print(click.style(failed_test_or_tests, fg=color, bold=True)) # noqa: T201 + elif isinstance(failed_test_or_tests, list) and len(failed_test_or_tests) > 0: + print(RED_X) # noqa: T201 + print( # noqa: T201 + "\n".join( + [click.style(f, fg=color, bold=True) for f in failed_test_or_tests] + ) + ) + else: + print(CHECKMARK) # noqa: T201 + + +def check_host(host: str) -> bool: + if host in ("api.wandb.ai", "http://api.wandb.ai", "https://api.wandb.ai"): + print_results("Cannot run wandb verify against api.wandb.ai", False) + return False + return True + + +def check_logged_in(api: Api, host: str) -> bool: + print("Checking if logged in".ljust(72, "."), end="") # noqa: T201 + login_doc_url = "https://docs.wandb.ai/ref/cli/wandb-login" + fail_string = None + if api.api_key is None: + fail_string = ( + "Not logged in. Please log in using `wandb login`. See the docs: {}".format( + click.style(login_doc_url, underline=True, fg="blue") + ) + ) + # check that api key is correct + # TODO: Better check for api key is correct + else: + res = api.api.viewer() + if not res: + fail_string = ( + "Could not get viewer with default API key. " + f"Please relogin using `WANDB_BASE_URL={host} wandb login --relogin` and try again" + ) + + print_results(fail_string, False) + return fail_string is None + + +def check_secure_requests(url: str, test_url_string: str, failure_output: str) -> None: + # check if request is over https + print(test_url_string.ljust(72, "."), end="") # noqa: T201 + fail_string = None + if not url.startswith("https"): + fail_string = failure_output + print_results(fail_string, True) + + +def check_cors_configuration(url: str, origin: str) -> None: + print("Checking CORs configuration of the bucket".ljust(72, "."), end="") # noqa: T201 + fail_string = None + res_get = requests.options( + url, headers={"Origin": origin, "Access-Control-Request-Method": "GET"} + ) + + if res_get.headers.get("Access-Control-Allow-Origin") is None: + fail_string = ( + "Your object store does not have a valid CORs configuration, " + f"you must allow GET and PUT to Origin: {origin}" + ) + + print_results(fail_string, True) + + +def check_run(api: Api) -> bool: + print( # noqa: T201 + "Checking logged metrics, saving and downloading a file".ljust(72, "."), end="" + ) + failed_test_strings = [] + + # set up config + n_epochs = 4 + string_test = "A test config" + dict_test = {"config_val": 2, "config_string": "config string"} + list_test = [0, "one", "2"] + config = { + "epochs": n_epochs, + "stringTest": string_test, + "dictTest": dict_test, + "listTest": list_test, + } + # create a file to save + filepath = "./test with_special-characters.txt" + f = open(filepath, "w") + f.write("test") + f.close() + + with wandb.init( + id=nice_id("check_run"), + reinit=True, + config=config, + project=PROJECT_NAME, + ) as run: + run_id = run.id + entity = run.entity + logged = True + try: + for i in range(1, 11): + run.log({"loss": 1.0 / i}, step=i) + log_dict = {"val1": 1.0, "val2": 2} + run.log({"dict": log_dict}, step=i + 1) + except Exception: + logged = False + failed_test_strings.append( + "Failed to log values to run. Contact W&B for support." + ) + + try: + run.log({"HT%3ML ": wandb.Html('Link')}) + except Exception: + failed_test_strings.append( + "Failed to log to media. Contact W&B for support." + ) + + run.save(filepath) + public_api = wandb.Api() + prev_run = public_api.run(f"{entity}/{PROJECT_NAME}/{run_id}") + # raise Exception(prev_run.__dict__) + if prev_run is None: + failed_test_strings.append( + "Failed to access run through API. Contact W&B for support." + ) + print_results(failed_test_strings, False) + return False + for key, value in config.items(): + if prev_run.config.get(key) != value: + failed_test_strings.append( + "Read config values don't match run config. Contact W&B for support." + ) + break + if logged and ( + prev_run.history_keys["keys"]["loss"]["previousValue"] != 0.1 + or prev_run.history_keys["lastStep"] != 11 + or prev_run.history_keys["keys"]["dict.val1"]["previousValue"] != 1.0 + or prev_run.history_keys["keys"]["dict.val2"]["previousValue"] != 2 + ): + failed_test_strings.append( + "History metrics don't match logged values. Check database encoding." + ) + + if logged and prev_run.summary["loss"] != 1.0 / 10: + failed_test_strings.append( + "Read summary values don't match expected value. Check database encoding, or contact W&B for support." + ) + # TODO: (kdg) refactor this so it doesn't rely on an exception handler + try: + read_file = retry_fn(partial(prev_run.file, filepath)) + # There's a race where the file hasn't been processed in the queue, + # we just retry until we get a download + read_file = retry_fn(partial(read_file.download, replace=True)) + except Exception: + failed_test_strings.append( + "Unable to download file. Check SQS configuration, topic configuration and bucket permissions." + ) + + print_results(failed_test_strings, False) + return False + contents = read_file.read() + if contents != "test": + failed_test_strings.append( + "Contents of downloaded file do not match uploaded contents. Contact W&B for support." + ) + print_results(failed_test_strings, False) + return len(failed_test_strings) == 0 + + +def verify_manifest( + downloaded_manifest: Dict[str, Any], + computed_manifest: Dict[str, Any], + fails_list: List[str], +) -> None: + try: + for key in computed_manifest.keys(): + assert ( + computed_manifest[key]["digest"] == downloaded_manifest[key]["digest"] + ) + assert computed_manifest[key]["size"] == downloaded_manifest[key]["size"] + except AssertionError: + fails_list.append( + "Artifact manifest does not appear as expected. Contact W&B for support." + ) + + +def verify_digest( + downloaded: "Artifact", computed: "Artifact", fails_list: List[str] +) -> None: + if downloaded.digest != computed.digest: + fails_list.append( + "Artifact digest does not appear as expected. Contact W&B for support." + ) + + +def artifact_with_path_or_paths( + name: str, verify_dir: Optional[str] = None, singular: bool = False +) -> "Artifact": + art = wandb.Artifact(type="artsy", name=name) + # internal file + with open("verify_int_test.txt", "w") as f: + f.write("test 1") + f.close() + art.add_file(f.name) + if singular: + return art + if verify_dir is None: + verify_dir = "./" + with art.new_file("verify_a.txt") as f: + f.write("test 2") + if not os.path.exists(verify_dir): + os.makedirs(verify_dir) + with open(f"{verify_dir}/verify_1.txt", "w") as f: + f.write("1") + art.add_dir(verify_dir) + file3 = Path(verify_dir) / "verify_3.txt" + file3.write_text("3") + + # reference to local file + art.add_reference(file3.resolve().as_uri()) + + return art + + +def log_use_download_artifact( + artifact: "Artifact", + alias: str, + name: str, + download_dir: str, + failed_test_strings: List[str], + add_extra_file: bool, +) -> Tuple[bool, Optional["Artifact"], List[str]]: + with wandb.init( + id=nice_id("log_artifact"), + reinit=True, + project=PROJECT_NAME, + config={"test": "artifact log"}, + ) as log_art_run: + if add_extra_file: + with open("verify_2.txt", "w") as f: + f.write("2") + f.close() + artifact.add_file(f.name) + + try: + log_art_run.log_artifact(artifact, aliases=alias) + except Exception as e: + failed_test_strings.append(f"Unable to log artifact. {e}") + return False, None, failed_test_strings + + with wandb.init( + id=nice_id("use_artifact"), + project=PROJECT_NAME, + config={"test": "artifact use"}, + ) as use_art_run: + try: + used_art = use_art_run.use_artifact(f"{name}:{alias}") + except Exception as e: + failed_test_strings.append(f"Unable to use artifact. {e}") + return False, None, failed_test_strings + try: + used_art.download(root=download_dir) + except Exception: + failed_test_strings.append( + "Unable to download artifact. Check bucket permissions." + ) + return False, None, failed_test_strings + + return True, used_art, failed_test_strings + + +def check_artifacts() -> bool: + print("Checking artifact save and download workflows".ljust(72, "."), end="") # noqa: T201 + failed_test_strings: List[str] = [] + + # test checksum + sing_art_dir = "./verify_sing_art" + alias = "sing_art1" + name = nice_id("sing-artys") + singular_art = artifact_with_path_or_paths(name, singular=True) + cont_test, download_artifact, failed_test_strings = log_use_download_artifact( + singular_art, alias, name, sing_art_dir, failed_test_strings, False + ) + if not cont_test or download_artifact is None: + print_results(failed_test_strings, False) + return False + try: + download_artifact.verify(root=sing_art_dir) + except ValueError: + failed_test_strings.append( + "Artifact does not contain expected checksum. Contact W&B for support." + ) + + # test manifest and digest + multi_art_dir = "./verify_art" + alias = "art1" + name = nice_id("my-artys") + art1 = artifact_with_path_or_paths(name, "./verify_art_dir", singular=False) + cont_test, download_artifact, failed_test_strings = log_use_download_artifact( + art1, alias, name, multi_art_dir, failed_test_strings, True + ) + if not cont_test or download_artifact is None: + print_results(failed_test_strings, False) + return False + if set(os.listdir(multi_art_dir)) != { + "verify_a.txt", + "verify_2.txt", + "verify_1.txt", + "verify_3.txt", + "verify_int_test.txt", + }: + failed_test_strings.append( + "Artifact directory is missing files. Contact W&B for support." + ) + + computed = wandb.Artifact("computed", type="dataset") + computed.add_dir(multi_art_dir) + verify_digest(download_artifact, computed, failed_test_strings) + + computed_manifest = computed.manifest.to_manifest_json()["contents"] + downloaded_manifest = download_artifact.manifest.to_manifest_json()["contents"] + verify_manifest(downloaded_manifest, computed_manifest, failed_test_strings) + + print_results(failed_test_strings, False) + return len(failed_test_strings) == 0 + + +def check_graphql_put(api: Api, host: str) -> Tuple[bool, Optional[str]]: + # check graphql endpoint using an upload + print("Checking signed URL upload".ljust(72, "."), end="") # noqa: T201 + failed_test_strings = [] + gql_fp = "gql_test_file.txt" + f = open(gql_fp, "w") + f.write("test2") + f.close() + with wandb.init( + id=nice_id("graphql_put"), + reinit=True, + project=PROJECT_NAME, + config={"test": "put to graphql"}, + ) as run: + run.save(gql_fp) + public_api = wandb.Api() + prev_run = public_api.run(f"{run.entity}/{PROJECT_NAME}/{run.id}") + if prev_run is None: + failed_test_strings.append( + "Unable to access previous run through public API. Contact W&B for support." + ) + print_results(failed_test_strings, False) + return False, None + # TODO: (kdg) refactor this so it doesn't rely on an exception handler + try: + read_file = retry_fn(partial(prev_run.file, gql_fp)) + url = read_file.url + read_file = retry_fn(partial(read_file.download, replace=True)) + except Exception: + failed_test_strings.append( + "Unable to read file successfully saved through a put request. Check SQS configurations, bucket permissions and topic configs." + ) + print_results(failed_test_strings, False) + return False, None + contents = read_file.read() + try: + assert contents == "test2" + except AssertionError: + failed_test_strings.append( + "Read file contents do not match saved file contents. Contact W&B for support." + ) + + print_results(failed_test_strings, False) + return len(failed_test_strings) == 0, url + + +def check_large_post() -> bool: + print( # noqa: T201 + "Checking ability to send large payloads through proxy".ljust(72, "."), end="" + ) + descy = "a" * int(10**7) + + username = getpass.getuser() + failed_test_strings = [] + query = gql( + """ + query Project($entity: String!, $name: String!, $runName: String!, $desc: String!){ + project(entityName: $entity, name: $name) { + run(name: $runName, desc: $desc) { + name + summaryMetrics + } + } + } + """ + ) + public_api = wandb.Api() + client = public_api._base_client + + try: + client._get_result( + query, + variable_values={ + "entity": username, + "name": PROJECT_NAME, + "runName": "", + "desc": descy, + }, + timeout=60, + ) + except Exception as e: + if ( + isinstance(e, requests.HTTPError) + and e.response is not None + and e.response.status_code == 413 + ): + failed_test_strings.append( + 'Failed to send a large payload. Check nginx.ingress.kubernetes.io/proxy-body-size is "0".' + ) + else: + failed_test_strings.append( + f"Failed to send a large payload with error: {e}." + ) + print_results(failed_test_strings, False) + return len(failed_test_strings) == 0 + + +def check_wandb_version(api: Api) -> None: + print("Checking wandb package version is up to date".ljust(72, "."), end="") # noqa: T201 + _, server_info = api.viewer_server_info() + fail_string = None + warning = False + max_cli_version = server_info.get("cliVersionInfo", {}).get("max_cli_version", None) + min_cli_version = server_info.get("cliVersionInfo", {}).get( + "min_cli_version", "0.0.1" + ) + + from packaging.version import parse + + if parse(wandb.__version__) < parse(min_cli_version): + fail_string = f"wandb version out of date, please run pip install --upgrade wandb=={max_cli_version}" + elif parse(wandb.__version__) > parse(max_cli_version): + fail_string = ( + "wandb version is not supported by your local installation. This could " + "cause some issues. If you're having problems try: please run `pip " + f"install --upgrade wandb=={max_cli_version}`" + ) + warning = True + + print_results(fail_string, warning) + + +def check_sweeps(api: Api) -> bool: + print("Checking sweep creation and agent execution".ljust(72, "."), end="") # noqa: T201 + failed_test_strings: List[str] = [] + + sweep_config = { + "method": "random", + "metric": {"goal": "minimize", "name": "score"}, + "parameters": { + "x": {"values": [0.01, 0.05, 0.1]}, + "y": {"values": [1, 2, 3]}, + }, + "name": "verify_sweep", + } + + try: + with contextlib.redirect_stdout(io.StringIO()): + sweep_id = wandb.sweep( + sweep=sweep_config, project=PROJECT_NAME, entity=api.default_entity + ) + except Exception as e: + failed_test_strings.append(f"Failed to create sweep: {e}") + print_results(failed_test_strings, False) + return False + + if not sweep_id: + failed_test_strings.append("Sweep creation returned an invalid ID.") + print_results(failed_test_strings, False) + return False + + try: + + def objective(config): + score = config.x**3 + config.y + return score + + def main(): + with wandb.init(project=PROJECT_NAME) as run: + score = objective(run.config) + run.log({"score": score}) + + wandb.agent(sweep_id, function=main, count=10) + except Exception as e: + failed_test_strings.append(f"Failed to run sweep agent: {e}") + print_results(failed_test_strings, False) + return False + + print_results(failed_test_strings, False) + return len(failed_test_strings) == 0 + + +def retry_fn(fn: Callable) -> Any: + ini_time = time.time() + res = None + i = 0 + while i < MIN_RETRYS or time.time() - ini_time < GET_RUN_MAX_TIME: + i += 1 + try: + res = fn() + break + except Exception: + time.sleep(1) + continue + return res diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_alerts.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_alerts.py new file mode 100644 index 0000000000000000000000000000000000000000..072547ab219e907efede3399d660817926074490 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_alerts.py @@ -0,0 +1,12 @@ +# +from enum import Enum + +""" +Call run.alert() to generate an email or Slack notification programmatically. +""" + + +class AlertLevel(Enum): + INFO = "INFO" + WARN = "WARN" + ERROR = "ERROR" diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_config.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_config.py new file mode 100644 index 0000000000000000000000000000000000000000..35a1693bc55cd674a724d1081756e89c20adeedf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_config.py @@ -0,0 +1,323 @@ +"""config.""" + +import logging +from typing import Optional + +import wandb +from wandb.util import ( + _is_artifact_representation, + check_dict_contains_nested_artifact, + json_friendly_val, +) + +from . import wandb_helper +from .lib import config_util + +logger = logging.getLogger("wandb") + + +# TODO(jhr): consider a callback for persisting changes? +# if this is done right we might make sure this is pickle-able +# we might be able to do this on other objects like Run? +class Config: + """Config object. + + Config objects are intended to hold all of the hyperparameters associated + with a wandb run and are saved with the run object when `wandb.init` is + called. + + We recommend setting the config once when initializing your run by passing + the `config` parameter to `init`: + + ``` + wandb.init(config=my_config_dict) + ``` + + You can create a file called `config-defaults.yaml`, and it will + automatically be loaded as each run's config. You can also pass the name + of the file as the `config` parameter to `init`: + + ``` + wandb.init(config="my_config.yaml") + ``` + + See https://docs.wandb.com/guides/track/config#file-based-configs. + + Examples: + Basic usage + ``` + with wandb.init(config={"epochs": 4}) as run: + for x in range(run.config.epochs): + # train + ``` + + Nested values + ``` + with wandb.init(config={"train": {"epochs": 4}}) as run: + for x in range(run.config["train"]["epochs"]): + # train + ``` + + Using absl flags + ``` + flags.DEFINE_string("model", None, "model to run") # name, default, help + with wandb.init() as run: + run.config.update(flags.FLAGS) # adds all absl flags to config + ``` + + Argparse flags + ```python + with wandb.init(config={"epochs": 4}) as run: + parser = argparse.ArgumentParser() + parser.add_argument( + "-b", + "--batch-size", + type=int, + default=8, + metavar="N", + help="input batch size for training (default: 8)", + ) + args = parser.parse_args() + run.config.update(args) + ``` + + Using TensorFlow flags (deprecated in tensorflow v2) + ```python + flags = tf.app.flags + flags.DEFINE_string("data_dir", "/tmp/data") + flags.DEFINE_integer("batch_size", 128, "Batch size.") + + with wandb.init() as run: + run.config.update(flags.FLAGS) + ``` + """ + + def __init__(self): + object.__setattr__(self, "_items", dict()) + object.__setattr__(self, "_locked", dict()) + object.__setattr__(self, "_users", dict()) + object.__setattr__(self, "_users_inv", dict()) + object.__setattr__(self, "_users_cnt", 0) + object.__setattr__(self, "_callback", None) + object.__setattr__(self, "_settings", None) + object.__setattr__(self, "_artifact_callback", None) + + self._load_defaults() + + def _set_callback(self, cb): + object.__setattr__(self, "_callback", cb) + + def _set_artifact_callback(self, cb): + object.__setattr__(self, "_artifact_callback", cb) + + def _set_settings(self, settings): + object.__setattr__(self, "_settings", settings) + + def __repr__(self): + return str(dict(self)) + + def keys(self): + return [k for k in self._items.keys() if not k.startswith("_")] + + def _as_dict(self): + return self._items + + def as_dict(self): + # TODO: add telemetry, deprecate, then remove + return dict(self) + + def __getitem__(self, key): + return self._items[key] + + def __iter__(self): + return iter(self._items) + + def _check_locked(self, key, ignore_locked=False) -> bool: + locked = self._locked.get(key) + if locked is not None: + locked_user = self._users_inv[locked] + if not ignore_locked: + wandb.termwarn( + f"Config item '{key}' was locked by '{locked_user}' (ignored update)." + ) + return True + return False + + def __setitem__(self, key, val): + if self._check_locked(key): + return + with wandb.sdk.lib.telemetry.context() as tel: + tel.feature.set_config_item = True + self._raise_value_error_on_nested_artifact(val, nested=True) + key, val = self._sanitize(key, val) + self._items[key] = val + logger.info("config set %s = %s - %s", key, val, self._callback) + if self._callback: + self._callback(key=key, val=val) + + def items(self): + return [(k, v) for k, v in self._items.items() if not k.startswith("_")] + + __setattr__ = __setitem__ + + def __getattr__(self, key): + try: + return self.__getitem__(key) + except KeyError as ke: + raise AttributeError( + f"{self.__class__!r} object has no attribute {key!r}" + ) from ke + + def __contains__(self, key): + return key in self._items + + def _update(self, d, allow_val_change=None, ignore_locked=None): + parsed_dict = wandb_helper.parse_config(d) + locked_keys = set() + for key in list(parsed_dict): + if self._check_locked(key, ignore_locked=ignore_locked): + locked_keys.add(key) + sanitized = self._sanitize_dict( + parsed_dict, allow_val_change, ignore_keys=locked_keys + ) + self._items.update(sanitized) + return sanitized + + def update(self, d, allow_val_change=None): + sanitized = self._update(d, allow_val_change) + if self._callback: + self._callback(data=sanitized) + + def get(self, *args): + return self._items.get(*args) + + def persist(self): + """Call the callback if it's set.""" + if self._callback: + self._callback(data=self._as_dict()) + + def setdefaults(self, d): + d = wandb_helper.parse_config(d) + # strip out keys already configured + d = {k: v for k, v in d.items() if k not in self._items} + d = self._sanitize_dict(d) + self._items.update(d) + if self._callback: + self._callback(data=d) + + def _get_user_id(self, user) -> int: + if user not in self._users: + self._users[user] = self._users_cnt + self._users_inv[self._users_cnt] = user + object.__setattr__(self, "_users_cnt", self._users_cnt + 1) + + return self._users[user] + + def update_locked(self, d, user=None, _allow_val_change=None): + """Shallow-update config with `d` and lock config updates on d's keys.""" + num = self._get_user_id(user) + + for k, v in d.items(): + k, v = self._sanitize(k, v, allow_val_change=_allow_val_change) + self._locked[k] = num + self._items[k] = v + + if self._callback: + self._callback(data=d) + + def merge_locked(self, d, user=None, _allow_val_change=None): + """Recursively merge-update config with `d` and lock config updates on d's keys.""" + num = self._get_user_id(user) + callback_d = {} + + for k, v in d.items(): + k, v = self._sanitize(k, v, allow_val_change=_allow_val_change) + self._locked[k] = num + + if ( + k in self._items + and isinstance(self._items[k], dict) + and isinstance(v, dict) + ): + self._items[k] = config_util.merge_dicts(self._items[k], v) + else: + self._items[k] = v + + callback_d[k] = self._items[k] + + if self._callback: + self._callback(data=callback_d) + + def _load_defaults(self): + conf_dict = config_util.dict_from_config_file("config-defaults.yaml") + if conf_dict is not None: + self.update(conf_dict) + + def _sanitize_dict( + self, + config_dict, + allow_val_change=None, + ignore_keys: Optional[set] = None, + ): + sanitized = {} + self._raise_value_error_on_nested_artifact(config_dict) + for k, v in config_dict.items(): + if ignore_keys and k in ignore_keys: + continue + k, v = self._sanitize(k, v, allow_val_change) + sanitized[k] = v + return sanitized + + def _sanitize(self, key, val, allow_val_change=None): + # TODO: enable WBValues in the config in the future + # refuse all WBValues which is all Media and Histograms + if isinstance(val, wandb.sdk.data_types.base_types.wb_value.WBValue): + raise TypeError("WBValue objects cannot be added to the run config") + # Let jupyter change config freely by default + if self._settings and self._settings._jupyter and allow_val_change is None: + allow_val_change = True + # We always normalize keys by stripping '-' + key = key.strip("-") + if _is_artifact_representation(val): + val = self._artifact_callback(key, val) + # if the user inserts an artifact into the config + if not isinstance(val, wandb.Artifact): + val = json_friendly_val(val) + if not allow_val_change: + if key in self._items and val != self._items[key]: + raise config_util.ConfigError( + f'Attempted to change value of key "{key}" ' + f"from {self._items[key]} to {val}\n" + "If you really want to do this, pass" + " allow_val_change=True to config.update()" + ) + return key, val + + def _raise_value_error_on_nested_artifact(self, v, nested=False): + # we can't swap nested artifacts because their root key can be locked by other values + # best if we don't allow nested artifacts until we can lock nested keys in the config + if isinstance(v, dict) and check_dict_contains_nested_artifact(v, nested): + raise ValueError( + "Instances of wandb.Artifact can only be top level keys in" + " a run's config" + ) + + +class ConfigStatic: + def __init__(self, config): + object.__setattr__(self, "__dict__", dict(config)) + + def __setattr__(self, name, value): + raise AttributeError("Error: run.config_static is a readonly object") + + def __setitem__(self, key, val): + raise AttributeError("Error: run.config_static is a readonly object") + + def keys(self): + return self.__dict__.keys() + + def __getitem__(self, key): + return self.__dict__[key] + + def __str__(self): + return str(self.__dict__) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_helper.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..5c5e5250f71b2b96db0e1300653ba1f94d62d682 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_helper.py @@ -0,0 +1,54 @@ +import inspect +import types + +from wandb.errors import UsageError + +from .lib import config_util + + +def parse_config(params, exclude=None, include=None): + if exclude and include: + raise UsageError("Expected at most only one of exclude or include") + if isinstance(params, str): + params = config_util.dict_from_config_file(params, must_exist=True) + params = _to_dict(params) + if include: + params = {key: value for key, value in params.items() if key in include} + if exclude: + params = {key: value for key, value in params.items() if key not in exclude} + return params + + +def _to_dict(params): + if isinstance(params, dict): + return params + + # Handle some cases where params is not a dictionary + # by trying to convert it into a dictionary + meta = inspect.getmodule(params) + if meta: + is_tf_flags_module = ( + isinstance(params, types.ModuleType) + and meta.__name__ == "tensorflow.python.platform.flags" + ) + if is_tf_flags_module or meta.__name__ == "absl.flags": + params = params.FLAGS + meta = inspect.getmodule(params) + + # newer tensorflow flags (post 1.4) uses absl.flags + if meta and meta.__name__ == "absl.flags._flagvalues": + params = {name: params[name].value for name in dir(params)} + elif not hasattr(params, "__dict__"): + raise TypeError("config must be a dict or have a __dict__ attribute.") + elif "__flags" in vars(params): + # for older tensorflow flags (pre 1.4) + if not "__parsed" not in vars(params): + params._parse_flags() + params = vars(params)["__flags"] + else: + # params is a Namespace object (argparse) + # or something else + params = vars(params) + + # assume argparse Namespace + return params diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_init.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_init.py new file mode 100644 index 0000000000000000000000000000000000000000..e15da40a3f24fdda6fb876e9117411c888af581c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_init.py @@ -0,0 +1,1595 @@ +"""Defines wandb.init() and associated classes and methods. + +`wandb.init()` indicates the beginning of a new run. In an ML training pipeline, +you could add `wandb.init()` to the beginning of your training script as well as +your evaluation script, and each step would be tracked as a run in W&B. + +For more on using `wandb.init()`, including code snippets, check out our +[guide and FAQs](https://docs.wandb.ai/guides/track/launch). +""" + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import json +import logging +import os +import pathlib +import platform +import sys +import tempfile +import time +from typing import TYPE_CHECKING, Iterable, Iterator, Sequence + +from typing_extensions import Any, Literal, Protocol, Self + +import wandb +import wandb.env +from wandb import env, trigger +from wandb.analytics import get_sentry +from wandb.errors import CommError, Error, UsageError +from wandb.errors.links import url_registry +from wandb.errors.util import ProtobufErrorHandler +from wandb.integration import sagemaker, weave +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk.lib import ipython as wb_ipython +from wandb.sdk.lib import progress, runid, wb_logging +from wandb.sdk.lib.paths import StrPath +from wandb.util import _is_artifact_representation + +from . import wandb_login, wandb_setup +from .backend.backend import Backend +from .lib import SummaryDisabled, filesystem, module, paths, printer, telemetry +from .lib.deprecation import UNSET, DoNotSet, warn_and_record_deprecation +from .mailbox import wait_with_progress +from .wandb_helper import parse_config +from .wandb_run import Run, TeardownHook, TeardownStage +from .wandb_settings import Settings + +if TYPE_CHECKING: + import wandb.jupyter + + +def _huggingface_version() -> str | None: + if "transformers" in sys.modules: + trans = wandb.util.get_module("transformers") + if hasattr(trans, "__version__"): + return str(trans.__version__) + return None + + +def _handle_launch_config(settings: Settings) -> dict[str, Any]: + launch_run_config: dict[str, Any] = {} + if not settings.launch: + return launch_run_config + if os.environ.get("WANDB_CONFIG") is not None: + try: + launch_run_config = json.loads(os.environ.get("WANDB_CONFIG", "{}")) + except (ValueError, SyntaxError): + wandb.termwarn("Malformed WANDB_CONFIG, using original config") + elif settings.launch_config_path and os.path.exists(settings.launch_config_path): + with open(settings.launch_config_path) as fp: + launch_config = json.loads(fp.read()) + launch_run_config = launch_config.get("overrides", {}).get("run_config") + else: + i = 0 + chunks = [] + while True: + key = f"WANDB_CONFIG_{i}" + if key in os.environ: + chunks.append(os.environ[key]) + i += 1 + else: + break + if len(chunks) > 0: + config_string = "".join(chunks) + try: + launch_run_config = json.loads(config_string) + except (ValueError, SyntaxError): + wandb.termwarn("Malformed WANDB_CONFIG, using original config") + + return launch_run_config + + +@dataclasses.dataclass(frozen=True) +class _ConfigParts: + base_no_artifacts: dict[str, Any] + """The run config passed to `init()` minus any artifact-valued keys.""" + + sweep_no_artifacts: dict[str, Any] + """The config loaded as part of a sweep minus any artifact-valued keys.""" + + launch_no_artifacts: dict[str, Any] + """The config loaded as part of Launch minus any artifact-valued keys.""" + + artifacts: dict[str, Any] + """Artifact keys removed from config dictionaries. + + Due to implementation details of how a Run is constructed, + artifacts must be inserted into its config after initialization. + """ + + +class _PrinterCallback(Protocol): + """A callback for displaying messages after a printer is configured. + + This is used for a few messages that may be generated before run settings + are computed, which are necessary for creating a printer. + """ + + def __call__(self, run_printer: printer.Printer) -> None: + """Display information through the given printer.""" + + +def _noop_printer_callback() -> _PrinterCallback: + """A printer callback that does not print anything.""" + return lambda _: None + + +def _concat_printer_callbacks( + cbs: Iterable[_PrinterCallback], +) -> _PrinterCallback: + """Returns a printer callback that runs the given callbacks in order.""" + + def do_callbacks(run_printer: printer.Printer) -> None: + for cb in cbs: + cb(run_printer) + + return do_callbacks + + +class _WandbInit: + def __init__( + self, + wl: wandb_setup._WandbSetup, + telemetry: telemetry.TelemetryRecord, + ) -> None: + self._wl = wl + + self._telemetry = telemetry + """Telemetry gathered before creating a run. + + After the run is created, `telemetry.context()` is used instead. + """ + + self.kwargs = None + self.run: Run | None = None + self.backend: Backend | None = None + + self._teardown_hooks: list[TeardownHook] = [] + self.notebook: wandb.jupyter.Notebook | None = None + + self.deprecated_features_used: list[tuple[Deprecated, str]] = [] + + @property + def _logger(self) -> wandb_setup.Logger: + return self._wl._get_logger() + + def maybe_login(self, init_settings: Settings) -> None: + """Log in if we are not creating an offline or disabled run. + + This may change the W&B singleton settings. + + Args: + init_settings: Settings passed to `wandb.init()` or set via + keyword arguments. + """ + # Allow settings passed to init() to override inferred values. + # + # Calling login() may change settings on the singleton, + # so these may not be the final run settings. + run_settings = self._wl.settings.model_copy() + run_settings.update_from_settings(init_settings) + + # NOTE: _noop or _offline can become true after _login(). + # _noop happens if _login hits a timeout. + # _offline can be selected by the user at the login prompt. + if run_settings._noop or run_settings._offline: + return + + # Only pass an explicit key when the key was provided directly + # to ensure correct messaging in _login(). + explicit_key = init_settings.api_key + + wandb_login._login( + host=run_settings.base_url, + force=run_settings.force, + _silent=run_settings.quiet or run_settings.silent, + key=explicit_key, + update_api_key=explicit_key is None, + ) + + def warn_env_vars_change_after_setup(self) -> _PrinterCallback: + """Warn if environment variables changed after `wandb.setup()`. + + Returns: + A callback to print any generated warnings. + """ + if not self._wl.did_environment_change(): + return _noop_printer_callback() + + def print_warning(run_printer: printer.Printer) -> None: + line = ( + "Changes to your `wandb` environment variables will be ignored " + "because your `wandb` session has already started. " + "For more information on how to modify your settings with " + "`wandb.init()` arguments, please refer to " + f"{run_printer.link(url_registry.url('wandb-init'), 'the W&B docs')}." + ) + run_printer.display(line, level="warn") + + return print_warning + + def clear_run_path_if_sweep_or_launch( + self, + init_settings: Settings, + ) -> _PrinterCallback: + """Clear project/entity/run_id keys if in a Sweep or a Launch context. + + Args: + init_settings: Settings specified in the call to `wandb.init()`. + + Returns: + A callback to print any generated warnings. + """ + when_doing_thing = "" + + if self._wl.settings.sweep_id: + when_doing_thing = "when running a sweep" + elif self._wl.settings.launch: + when_doing_thing = "when running from a wandb launch context" + + if not when_doing_thing: + return _noop_printer_callback() + + warnings = [] + + def warn(key: str, value: str) -> None: + warnings.append(f"Ignoring {key} {value!r} {when_doing_thing}.") + + if init_settings.project is not None: + warn("project", init_settings.project) + init_settings.project = None + if init_settings.entity is not None: + warn("entity", init_settings.entity) + init_settings.entity = None + if init_settings.run_id is not None: + warn("run_id", init_settings.run_id) + init_settings.run_id = None + + def print_warnings(run_printer: printer.Printer) -> None: + for warning in warnings: + run_printer.display(warning, level="warn") + + return print_warnings + + def make_run_settings( + self, + init_settings: Settings, + ) -> tuple[Settings, _PrinterCallback]: + """Returns the run's settings and any warnings. + + Args: + init_settings: Settings passed to `wandb.init()` or set via + keyword arguments. + """ + warning_callbacks: list[_PrinterCallback] = [ + self.warn_env_vars_change_after_setup(), + self.clear_run_path_if_sweep_or_launch(init_settings), + ] + + # Inherit global settings. + settings = self._wl.settings.model_copy() + + # Apply settings from wandb.init() call. + settings.update_from_settings(init_settings) + + # Infer the run ID from SageMaker. + if not settings.sagemaker_disable and sagemaker.is_using_sagemaker(): + if sagemaker.set_run_id(settings): + self._logger.info("set run ID and group based on SageMaker") + self._telemetry.feature.sagemaker = True + + # get status of code saving before applying user settings + save_code_pre_user_settings = settings.save_code + if not settings._offline and not settings._noop: + user_settings = self._wl._load_user_settings() + if user_settings is not None: + settings.update_from_dict(user_settings) + + # ensure that user settings don't set saving to true + # if user explicitly set these to false in UI + if save_code_pre_user_settings is False: + settings.save_code = False + + # TODO: remove this once we refactor the client. This is a temporary + # fix to make sure that we use the same project name for wandb-core. + # The reason this is not going through the settings object is to + # avoid failure cases in other parts of the code that will be + # removed with the switch to wandb-core. + if settings.project is None: + settings.project = wandb.util.auto_project_name(settings.program) + + settings.x_start_time = time.time() + + # In shared mode, generate a unique label if not provided. + # The label is used to distinguish between system metrics and console logs + # from different writers to the same run. + if settings._shared and not settings.x_label: + # TODO: If executed in a known distributed environment (e.g. Ray or SLURM), + # use the env vars to generate a label (e.g. SLURM_JOB_ID or RANK) + prefix = settings.host or "" + label = runid.generate_id() + settings.x_label = f"{prefix}-{label}" if prefix else label + + return settings, _concat_printer_callbacks(warning_callbacks) + + def _load_autoresume_run_id(self, resume_file: pathlib.Path) -> str | None: + """Returns the run_id stored in the auto-resume file, if any. + + Returns `None` if the file does not exist or is not in a valid format. + + Args: + resume_file: The file path to use for resume='auto' mode. + """ + if not resume_file.exists(): + return None + + with resume_file.open() as f: + try: + return json.load(f)["run_id"] + + except json.JSONDecodeError as e: + self._logger.exception( + f"could not decode {resume_file}, ignoring", + exc_info=e, + ) + return None + + except KeyError: + self._logger.exception( + f"resume file at {resume_file} did not store a run_id" + ) + return None + + def _save_autoresume_run_id( + self, + *, + resume_file: pathlib.Path, + run_id: str, + ) -> None: + """Write the run ID to the auto-resume file.""" + resume_file.parent.mkdir(exist_ok=True) + with resume_file.open("w") as f: + json.dump({"run_id": run_id}, f) + + def set_run_id(self, settings: Settings) -> None: + """Set the run ID and possibly save it to the auto-resume file. + + After this, `settings.run_id` is guaranteed to be set. + + If a `resume_from` is provided and `run_id` is not set, initialize + `run_id` with the `resume_from` run's `run_id`. + + Args: + settings: The run's settings derived from the environment + and explicit values passed to `wandb.init()`. + """ + if settings.resume == "auto" and settings.resume_fname: + resume_path = pathlib.Path(settings.resume_fname) + else: + resume_path = None + + if resume_path: + previous_id = self._load_autoresume_run_id(resume_path) + + if not previous_id: + pass + elif settings.run_id is None: + self._logger.info(f"loaded run ID from {resume_path}") + settings.run_id = previous_id + elif settings.run_id != previous_id: + wandb.termwarn( + f"Ignoring ID {previous_id} loaded due to resume='auto'" + f" because the run ID is set to {settings.run_id}.", + ) + + # If no run ID was inferred, explicitly set, or loaded from an + # auto-resume file, then we generate a new ID. + if settings.run_id is None: + # If resume_from is provided and run_id is not already set, + # initialize run_id with the value from resume_from. + if settings.resume_from: + settings.run_id = settings.resume_from.run + else: + settings.run_id = runid.generate_id() + + if resume_path: + self._save_autoresume_run_id( + resume_file=resume_path, + run_id=settings.run_id, + ) + + def set_sync_dir_suffix(self, settings: Settings) -> None: + """Add a suffix to sync_dir if it already exists. + + The sync_dir uses a timestamp with second-level precision which can + result in conflicts if a run with the same ID is initialized within the + same second. This is most likely to happen in tests. + + This can't prevent conflicts from multiple processes attempting + to create a wandb run simultaneously. + + Args: + settings: Fully initialized settings other than the + x_sync_dir_suffix setting which will be modified. + """ + index = 1 + while pathlib.Path(settings.sync_dir).exists(): + settings.x_sync_dir_suffix = f"{index}" + index += 1 + + def make_run_config( + self, + settings: Settings, + config: dict | str | None = None, + config_exclude_keys: list[str] | None = None, + config_include_keys: list[str] | None = None, + ) -> _ConfigParts: + """Construct the run's config. + + Args: + settings: The run's finalized settings. + config: The config passed to `init()`. + config_exclude_keys: Deprecated. Keys to filter out from `config`. + config_include_keys: Deprecated. Keys to include from `config`. + + Returns: + Initial values for the run's config. + """ + if config_exclude_keys: + self.deprecated_features_used.append( + ( + Deprecated(init__config_exclude_keys=True), + "config_exclude_keys is deprecated. Use" + " `config=wandb.helper.parse_config(config_object," + " exclude=('key',))` instead.", + ) + ) + if config_include_keys: + self.deprecated_features_used.append( + ( + Deprecated(init__config_include_keys=True), + "config_include_keys is deprecated. Use" + " `config=wandb.helper.parse_config(config_object," + " include=('key',))` instead.", + ) + ) + config = parse_config( + config or dict(), + include=config_include_keys, + exclude=config_exclude_keys, + ) + + result = _ConfigParts( + base_no_artifacts=dict(), + sweep_no_artifacts=dict(), + launch_no_artifacts=dict(), + artifacts=dict(), + ) + + if not settings.sagemaker_disable and sagemaker.is_using_sagemaker(): + sagemaker_config = sagemaker.parse_sm_config() + self._split_artifacts_from_config( + sagemaker_config, + config_target=result.base_no_artifacts, + artifacts=result.artifacts, + ) + self._telemetry.feature.sagemaker = True + + if self._wl.config: + self._split_artifacts_from_config( + self._wl.config, + config_target=result.base_no_artifacts, + artifacts=result.artifacts, + ) + + if config and isinstance(config, dict): + self._split_artifacts_from_config( + config, + config_target=result.base_no_artifacts, + artifacts=result.artifacts, + ) + + if self._wl._sweep_config: + self._split_artifacts_from_config( + self._wl._sweep_config, + config_target=result.sweep_no_artifacts, + artifacts=result.artifacts, + ) + + if launch_config := _handle_launch_config(settings): + self._split_artifacts_from_config( + launch_config, + config_target=result.launch_no_artifacts, + artifacts=result.artifacts, + ) + + wandb_internal = result.base_no_artifacts.setdefault("_wandb", dict()) + + if settings.save_code and settings.program_relpath: + wandb_internal["code_path"] = paths.LogicalPath( + os.path.join("code", settings.program_relpath) + ) + if settings.fork_from is not None: + wandb_internal["branch_point"] = { + "run_id": settings.fork_from.run, + "step": settings.fork_from.value, + } + if settings.resume_from is not None: + wandb_internal["branch_point"] = { + "run_id": settings.resume_from.run, + "step": settings.resume_from.value, + } + + return result + + def teardown(self) -> None: + # TODO: currently this is only called on failed wandb.init attempts + # normally this happens on the run object + self._logger.info("tearing down wandb.init") + for hook in self._teardown_hooks: + hook.call() + + def _split_artifacts_from_config( + self, + config_source: dict, + config_target: dict, + artifacts: dict, + ) -> None: + for k, v in config_source.items(): + if _is_artifact_representation(v): + artifacts[k] = v + else: + config_target.setdefault(k, v) + + def _safe_symlink( + self, base: str, target: str, name: str, delete: bool = False + ) -> None: + # TODO(jhr): do this with relpaths, but i can't figure it out on no sleep + if not hasattr(os, "symlink"): + return + + pid = os.getpid() + tmp_name = os.path.join(base, f"{name}.{pid}") + + if delete: + try: + os.remove(os.path.join(base, name)) + except OSError: + pass + target = os.path.relpath(target, base) + try: + os.symlink(target, tmp_name) + os.rename(tmp_name, os.path.join(base, name)) + except OSError: + pass + + def _pre_run_cell_hook(self, *args, **kwargs) -> None: + """Hook for the IPython pre_run_cell event. + + This pauses a run, preventing system metrics from being collected + the run's runtime from increasing. It also uploads the notebook's code. + """ + if not self.backend: + return + + if self.notebook and self.notebook.save_ipynb(): + assert self.run is not None + res = self.run.log_code(root=None) + self._logger.info("saved code: %s", res) + + if self.backend.interface is not None: + self._logger.info("pausing backend") + self.backend.interface.publish_pause() + + def _post_run_cell_hook(self, *args, **kwargs) -> None: + """Hook for the IPython post_run_cell event. + + Resumes collection of system metrics and the run's timer. + """ + if self.backend is None or self.backend.interface is None: + return + + self._logger.info("resuming backend") + self.backend.interface.publish_resume() + + def _jupyter_teardown(self) -> None: + """Teardown hooks and display saving, called with wandb.finish.""" + assert self.notebook + ipython = self.notebook.shell + + if self.run: + self.notebook.save_history(self.run) + + if self.notebook.save_ipynb(): + assert self.run is not None + res = self.run.log_code(root=None) + self._logger.info("saved code and history: %s", res) + self._logger.info("cleaning up jupyter logic") + + ipython.events.unregister("pre_run_cell", self._pre_run_cell_hook) + ipython.events.unregister("post_run_cell", self._post_run_cell_hook) + + ipython.display_pub.publish = ipython.display_pub._orig_publish + del ipython.display_pub._orig_publish + + def monkeypatch_ipython(self, settings: Settings) -> None: + """Add hooks, and session history saving.""" + self.notebook = wandb.jupyter.Notebook(settings) + ipython = self.notebook.shell + + # Monkey patch ipython publish to capture displayed outputs + if not hasattr(ipython.display_pub, "_orig_publish"): + self._logger.info("configuring jupyter hooks %s", self) + ipython.display_pub._orig_publish = ipython.display_pub.publish + + ipython.events.register("pre_run_cell", self._pre_run_cell_hook) + ipython.events.register("post_run_cell", self._post_run_cell_hook) + + self._teardown_hooks.append( + TeardownHook(self._jupyter_teardown, TeardownStage.EARLY) + ) + + def publish(data, metadata=None, **kwargs) -> None: + ipython.display_pub._orig_publish(data, metadata=metadata, **kwargs) + assert self.notebook is not None + self.notebook.save_display( + ipython.execution_count, {"data": data, "metadata": metadata} + ) + + ipython.display_pub.publish = publish + + @contextlib.contextmanager + def setup_run_log_directory(self, settings: Settings) -> Iterator[None]: + """Set up the run's log directory. + + This is a context manager that closes and unregisters the log handler + in case of an uncaught exception, so that future logged messages do not + modify this run's log file. + """ + filesystem.mkdir_exists_ok(os.path.dirname(settings.log_user)) + filesystem.mkdir_exists_ok(os.path.dirname(settings.log_internal)) + filesystem.mkdir_exists_ok(os.path.dirname(settings.sync_file)) + filesystem.mkdir_exists_ok(settings.files_dir) + filesystem.mkdir_exists_ok(settings._tmp_code_dir) + + if settings.symlink: + self._safe_symlink( + os.path.dirname(settings.sync_symlink_latest), + os.path.dirname(settings.sync_file), + os.path.basename(settings.sync_symlink_latest), + delete=True, + ) + self._safe_symlink( + os.path.dirname(settings.log_symlink_user), + settings.log_user, + os.path.basename(settings.log_symlink_user), + delete=True, + ) + self._safe_symlink( + os.path.dirname(settings.log_symlink_internal), + settings.log_internal, + os.path.basename(settings.log_symlink_internal), + delete=True, + ) + + assert settings.run_id + handler = wb_logging.add_file_handler( + settings.run_id, + pathlib.Path(settings.log_user), + ) + + if env.is_debug(): + handler.setLevel(logging.DEBUG) + + disposed = False + + def dispose_handler() -> None: + nonlocal disposed + + if not disposed: + disposed = True + logging.getLogger("wandb").removeHandler(handler) + handler.close() + + try: + self._teardown_hooks.append( + TeardownHook( + call=dispose_handler, + stage=TeardownStage.LATE, + ) + ) + + self._wl._early_logger_flush(logging.getLogger("wandb")) + self._logger.info(f"Logging user logs to {settings.log_user}") + self._logger.info(f"Logging internal logs to {settings.log_internal}") + + yield + except Exception: + dispose_handler() + raise + + def make_disabled_run(self, config: _ConfigParts) -> Run: + """Returns a Run-like object where all methods are no-ops. + + This method is used when the `mode` setting is set to "disabled", such as + by wandb.init(mode="disabled") or by setting the WANDB_MODE environment + variable to "disabled". + + It creates a Run object that mimics the behavior of a normal Run but doesn't + communicate with the W&B servers. + + The returned Run object has all expected attributes and methods, but they + are no-op versions that don't perform any actual logging or communication. + """ + run_id = runid.generate_id() + drun = Run( + settings=Settings( + mode="disabled", + root_dir=tempfile.gettempdir(), + run_id=run_id, + run_tags=tuple(), + run_notes=None, + run_group=None, + run_name=f"dummy-{run_id}", + project="dummy", + entity="dummy", + ) + ) + # config, summary, and metadata objects + drun._config = wandb.sdk.wandb_config.Config() + drun._config.update(config.sweep_no_artifacts) + drun._config.update(config.base_no_artifacts) + drun.summary = SummaryDisabled() # type: ignore + + # methods + drun.log = lambda data, *_, **__: drun.summary.update(data) # type: ignore[method-assign] + drun.finish = lambda *_, **__: module.unset_globals() # type: ignore[method-assign] + drun.join = drun.finish # type: ignore[method-assign] + drun.define_metric = lambda *_, **__: wandb.sdk.wandb_metric.Metric("dummy") # type: ignore[method-assign] + drun.save = lambda *_, **__: False # type: ignore[method-assign] + for symbol in ( + "alert", + "finish_artifact", + "get_project_url", + "get_sweep_url", + "get_url", + "link_artifact", + "link_model", + "use_artifact", + "log_code", + "log_model", + "use_model", + "mark_preempting", + "restore", + "status", + "watch", + "unwatch", + "upsert_artifact", + "_finish", + ): + setattr(drun, symbol, lambda *_, **__: None) # type: ignore + + # set properties to None + for attr in ("url", "project_url", "sweep_url"): + setattr(type(drun), attr, property(lambda _: None)) + + class _ChainableNoOp: + """An object that allows chaining arbitrary attributes and method calls.""" + + def __getattr__(self, _: str) -> Self: + return self + + def __call__(self, *_: Any, **__: Any) -> Self: + return self + + class _ChainableNoOpField: + # This is used to chain arbitrary attributes and method calls. + # For example, `run.log_artifact().state` will work in disabled mode. + def __init__(self) -> None: + self._value = None + + def __set__(self, instance: Any, value: Any) -> None: + self._value = value + + def __get__(self, instance: Any, owner: type) -> Any: + return _ChainableNoOp() if (self._value is None) else self._value + + def __call__(self, *args: Any, **kwargs: Any) -> _ChainableNoOp: + return _ChainableNoOp() + + drun.log_artifact = _ChainableNoOpField() # type: ignore + # attributes + drun._start_time = time.time() + drun._starting_step = 0 + drun._step = 0 + drun._attach_id = None + drun._backend = None + + # set the disabled run as the global run + module.set_global( + run=drun, + config=drun.config, + log=drun.log, + summary=drun.summary, + save=drun.save, + use_artifact=drun.use_artifact, + log_artifact=drun.log_artifact, + define_metric=drun.define_metric, + alert=drun.alert, + watch=drun.watch, + unwatch=drun.unwatch, + ) + return drun + + def init( # noqa: C901 + self, + settings: Settings, + config: _ConfigParts, + run_printer: printer.Printer, + ) -> Run: + self._logger.info("calling init triggers") + trigger.call("on_init") + + assert self._wl is not None + + self._logger.info( + f"wandb.init called with sweep_config: {config.sweep_no_artifacts}" + f"\nconfig: {config.base_no_artifacts}" + ) + + if previous_run := self._wl.most_recent_active_run: + if ( + settings.reinit in (True, "finish_previous") + # calling wandb.init() in notebooks finishes previous runs + # by default for user convenience. + or (settings.reinit == "default" and wb_ipython.in_notebook()) + ): + run_printer.display( + "Finishing previous runs because reinit is set" + f" to {settings.reinit!r}." + ) + self._wl.finish_all_active_runs() + + elif settings.reinit == "create_new": + self._logger.info( + "wandb.init() called while a run is active," + " and reinit is set to 'create_new', so continuing" + ) + + elif settings.resume == "must": + raise wandb.Error( + "Cannot resume a run while another run is active." + " You must either finish it using run.finish()," + " or use reinit='create_new' when calling wandb.init()." + ) + + else: + run_printer.display( + "wandb.init() called while a run is active and reinit is" + f" set to {settings.reinit!r}, so returning the previous" + " run." + ) + + with telemetry.context(run=previous_run) as tel: + tel.feature.init_return_run = True + + return previous_run + + self._logger.info("starting backend") + + service = self._wl.ensure_service() + self._logger.info("sending inform_init request") + service.inform_init( + settings=settings.to_proto(), + run_id=settings.run_id, # type: ignore + ) + + backend = Backend(settings=settings, service=service) + backend.ensure_launched() + self._logger.info("backend started and connected") + + run = Run( + config=config.base_no_artifacts, + settings=settings, + sweep_config=config.sweep_no_artifacts, + launch_config=config.launch_no_artifacts, + ) + + # Populate initial telemetry + with telemetry.context(run=run, obj=self._telemetry) as tel: + tel.cli_version = wandb.__version__ + tel.python_version = platform.python_version() + tel.platform = f"{platform.system()}-{platform.machine()}".lower() + hf_version = _huggingface_version() + if hf_version: + tel.huggingface_version = hf_version + if settings._jupyter: + tel.env.jupyter = True + if settings._ipython: + tel.env.ipython = True + if settings._colab: + tel.env.colab = True + if settings._kaggle: + tel.env.kaggle = True + if settings._windows: + tel.env.windows = True + + if settings.launch: + tel.feature.launch = True + + for module_name in telemetry.list_telemetry_imports(only_imported=True): + setattr(tel.imports_init, module_name, True) + + if os.environ.get("PEX"): + tel.env.pex = True + + if settings._aws_lambda: + tel.env.aws_lambda = True + + if settings.x_flow_control_disabled: + tel.feature.flow_control_disabled = True + if settings.x_flow_control_custom: + tel.feature.flow_control_custom = True + if settings._shared: + wandb.termwarn( + "The `shared` mode feature is experimental and may change. " + "Please contact support@wandb.com for guidance and to report any issues." + ) + tel.feature.shared_mode = True + + if settings.x_label: + tel.feature.user_provided_label = True + + if wandb.env.dcgm_profiling_enabled(): + tel.feature.dcgm_profiling_enabled = True + + if not settings.label_disable: + if self.notebook: + run._label_probe_notebook(self.notebook) + else: + run._label_probe_main() + + for deprecated_feature, msg in self.deprecated_features_used: + warn_and_record_deprecation( + feature=deprecated_feature, + message=msg, + run=run, + ) + + self._logger.info("updated telemetry") + + run._set_library(self._wl) + run._set_backend(backend) + run._set_teardown_hooks(self._teardown_hooks) + + assert backend.interface + backend.interface.publish_header() + + # Using GitRepo() blocks & can be slow, depending on user's current git setup. + # We don't want to block run initialization/start request, so populate run's git + # info beforehand. + if not (settings.disable_git or settings.x_disable_machine_info): + run._populate_git_info() + + if settings._offline and settings.resume: + wandb.termwarn( + "`resume` will be ignored since W&B syncing is set to `offline`. " + f"Starting a new run with run id {run.id}." + ) + error: wandb.Error | None = None + + timeout = settings.init_timeout + + self._logger.info( + f"communicating run to backend with {timeout} second timeout", + ) + + run_init_handle = backend.interface.deliver_run(run) + + try: + with progress.progress_printer( + run_printer, + default_text="Waiting for wandb.init()...", + ) as progress_printer: + result = wait_with_progress( + run_init_handle, + timeout=timeout, + display_progress=functools.partial( + progress.loop_printing_operation_stats, + progress_printer, + backend.interface, + ), + ) + + except TimeoutError: + # This may either be an issue with the W&B server (a CommError) + # or a bug in the SDK (an Error). We cannot distinguish between + # the two causes here. + raise CommError( + f"Run initialization has timed out after {timeout} sec." + + " Please try increasing the timeout with the `init_timeout`" + + " setting: `wandb.init(settings=wandb.Settings(init_timeout=120))`." + ) from None + + assert result.run_result + + if error := ProtobufErrorHandler.to_exception(result.run_result.error): + raise error + + if not result.run_result.HasField("run"): + raise Error("Assertion failed: run_result is missing the run field") + + if result.run_result.run.resumed: + self._logger.info("run resumed") + with telemetry.context(run=run) as tel: + tel.feature.resumed = result.run_result.run.resumed + run._set_run_obj(result.run_result.run) + + self._logger.info("starting run threads in backend") + + assert backend.interface + + run_start_handle = backend.interface.deliver_run_start(run) + try: + # TODO: add progress to let user know we are doing something + run_start_handle.wait_or(timeout=30) + except TimeoutError: + pass + + backend.interface.publish_probe_system_info() + + assert self._wl is not None + self.run = run + + run._handle_launch_artifact_overrides() + if ( + settings.launch + and settings.launch_config_path + and os.path.exists(settings.launch_config_path) + ): + run.save(settings.launch_config_path) + # put artifacts in run config here + # since doing so earlier will cause an error + # as the run is not upserted + for k, v in config.artifacts.items(): + run.config.update({k: v}, allow_val_change=True) + job_artifact = run._launch_artifact_mapping.get( + wandb.util.LAUNCH_JOB_ARTIFACT_SLOT_NAME + ) + if job_artifact: + run.use_artifact(job_artifact) + + self.backend = backend + + if settings.reinit != "create_new": + _set_global_run(run) + + run._on_start() + self._logger.info("run started, returning control to user process") + return run + + +def _attach( + attach_id: str | None = None, + run_id: str | None = None, + *, + run: Run | None = None, +) -> Run | None: + """Attach to a run currently executing in another process/thread. + + Args: + attach_id: (str, optional) The id of the run or an attach identifier + that maps to a run. + run_id: (str, optional) The id of the run to attach to. + run: (Run, optional) The run instance to attach + """ + attach_id = attach_id or run_id + if not ((attach_id is None) ^ (run is None)): + raise UsageError("Either (`attach_id` or `run_id`) or `run` must be specified") + + attach_id = attach_id or (run._attach_id if run else None) + + if attach_id is None: + raise UsageError( + "Either `attach_id` or `run_id` must be specified or `run` must have `_attach_id`" + ) + + _wl = wandb_setup.singleton() + logger = _wl._get_logger() + + service = _wl.ensure_service() + + try: + attach_settings = service.inform_attach(attach_id=attach_id) + except Exception as e: + raise UsageError(f"Unable to attach to run {attach_id}") from e + + settings = _wl.settings.model_copy() + settings.update_from_dict( + { + "run_id": attach_id, + "x_start_time": attach_settings.x_start_time.value, + "mode": attach_settings.mode.value, + } + ) + + # TODO: consolidate this codepath with wandb.init() + backend = Backend(settings=settings, service=service) + backend.ensure_launched() + logger.info("attach backend started and connected") + + if run is None: + run = Run(settings=settings) + else: + run._init(settings=settings) + run._set_library(_wl) + run._set_backend(backend) + assert backend.interface + + attach_handle = backend.interface.deliver_attach(attach_id) + try: + # TODO: add progress to let user know we are doing something + attach_result = attach_handle.wait_or(timeout=30) + except TimeoutError: + raise UsageError("Timeout attaching to run") + + attach_response = attach_result.response.attach_response + if attach_response.error and attach_response.error.message: + raise UsageError(f"Failed to attach to run: {attach_response.error.message}") + + run._set_run_obj(attach_response.run) + _set_global_run(run) + run._on_attach() + return run + + +def _set_global_run(run: Run) -> None: + """Set `wandb.run` and point some top-level functions to its methods. + + Args: + run: The run to make global. + """ + module.set_global( + run=run, + config=run.config, + log=run.log, + summary=run.summary, + save=run.save, + use_artifact=run.use_artifact, + log_artifact=run.log_artifact, + define_metric=run.define_metric, + alert=run.alert, + watch=run.watch, + unwatch=run.unwatch, + mark_preempting=run.mark_preempting, + log_model=run.log_model, + use_model=run.use_model, + link_model=run.link_model, + ) + + +def _monkeypatch_openai_gym() -> None: + """Patch OpenAI gym to log to the global `wandb.run`.""" + if len(wandb.patched["gym"]) > 0: + return + + from wandb.integration import gym + + gym.monitor() + + +def _monkeypatch_tensorboard() -> None: + """Patch TensorBoard to log to the global `wandb.run`.""" + if len(wandb.patched["tensorboard"]) > 0: + return + + from wandb.integration import tensorboard as tb_module + + tb_module.patch() + + +def try_create_root_dir(settings: Settings) -> None: + """Try to create the root directory specified in settings. + + If creation fails due to permissions or other errors, + falls back to using the system temp directory. + + Args: + settings: The runs settings containing root_dir configuration. + This function may update the root_dir to a temporary directory + if the parent directory is not writable. + """ + fallback_to_temp_dir = False + + try: + os.makedirs(settings.root_dir, exist_ok=True) + except OSError: + wandb.termwarn( + f"Unable to create root directory {settings.root_dir}", + repeat=False, + ) + fallback_to_temp_dir = True + else: + if not os.access(settings.root_dir, os.W_OK | os.R_OK): + wandb.termwarn( + f"Path {settings.root_dir} wasn't read/writable", + repeat=False, + ) + fallback_to_temp_dir = True + + if not fallback_to_temp_dir: + return + + tmp_dir = tempfile.gettempdir() + if not os.access(tmp_dir, os.W_OK | os.R_OK): + raise ValueError( + f"System temp directory ({tmp_dir}) is not writable/readable, " + "please set the `dir` argument in `wandb.init()` to a writable/readable directory." + ) + + settings.root_dir = tmp_dir + wandb.termwarn( + f"Falling back to temporary directory {tmp_dir}.", + repeat=False, + ) + os.makedirs(settings.root_dir, exist_ok=True) + + +def init( # noqa: C901 + entity: str | None = None, + project: str | None = None, + dir: StrPath | None = None, + id: str | None = None, + name: str | None = None, + notes: str | None = None, + tags: Sequence[str] | None = None, + config: dict[str, Any] | str | None = None, + config_exclude_keys: list[str] | None = None, + config_include_keys: list[str] | None = None, + allow_val_change: bool | None = None, + group: str | None = None, + job_type: str | None = None, + mode: Literal["online", "offline", "disabled", "shared"] | None = None, + force: bool | None = None, + reinit: ( + bool + | Literal[ + None, + "default", + "return_previous", + "finish_previous", + "create_new", + ] + ) = None, + resume: bool | Literal["allow", "never", "must", "auto"] | None = None, + resume_from: str | None = None, + fork_from: str | None = None, + save_code: bool | None = None, + tensorboard: bool | None = None, + sync_tensorboard: bool | None = None, + monitor_gym: bool | None = None, + settings: Settings | dict[str, Any] | None = None, + anonymous: DoNotSet = UNSET, +) -> Run: + r"""Start a new run to track and log to W&B. + + In an ML training pipeline, you could add `wandb.init()` to the beginning of + your training script as well as your evaluation script, and each piece would + be tracked as a run in W&B. + + `wandb.init()` spawns a new background process to log data to a run, and it + also syncs data to https://wandb.ai by default, so you can see your results + in real-time. When you're done logging data, call `wandb.Run.finish()` to end the run. + If you don't call `run.finish()`, the run will end when your script exits. + + Run IDs must not contain any of the following special characters `/ \ # ? % :` + + Args: + entity: The username or team name the runs are logged to. + The entity must already exist, so ensure you create your account + or team in the UI before starting to log runs. If not specified, the + run will default your default entity. To change the default entity, + go to your settings and update the + "Default location to create new projects" under "Default team". + project: The name of the project under which this run will be logged. + If not specified, we use a heuristic to infer the project name based + on the system, such as checking the git root or the current program + file. If we can't infer the project name, the project will default to + `"uncategorized"`. + dir: The absolute path to the directory where experiment logs and + metadata files are stored. If not specified, this defaults + to the `./wandb` directory. Note that this does not affect the + location where artifacts are stored when calling `download()`. + id: A unique identifier for this run, used for resuming. It must be unique + within the project and cannot be reused once a run is deleted. For + a short descriptive name, use the `name` field, + or for saving hyperparameters to compare across runs, use `config`. + name: A short display name for this run, which appears in the UI to help + you identify it. By default, we generate a random two-word name + allowing easy cross-reference runs from table to charts. Keeping these + run names brief enhances readability in chart legends and tables. For + saving hyperparameters, we recommend using the `config` field. + notes: A detailed description of the run, similar to a commit message in + Git. Use this argument to capture any context or details that may + help you recall the purpose or setup of this run in the future. + tags: A list of tags to label this run in the UI. Tags are helpful for + organizing runs or adding temporary identifiers like "baseline" or + "production." You can easily add, remove tags, or filter by tags in + the UI. + If resuming a run, the tags provided here will replace any existing + tags. To add tags to a resumed run without overwriting the current + tags, use `run.tags += ("new_tag",)` after calling `run = wandb.init()`. + config: Sets `wandb.config`, a dictionary-like object for storing input + parameters to your run, such as model hyperparameters or data + preprocessing settings. + The config appears in the UI in an overview page, allowing you to + group, filter, and sort runs based on these parameters. + Keys should not contain periods (`.`), and values should be + smaller than 10 MB. + If a dictionary, `argparse.Namespace`, or `absl.flags.FLAGS` is + provided, the key-value pairs will be loaded directly into + `wandb.config`. + If a string is provided, it is interpreted as a path to a YAML file, + from which configuration values will be loaded into `wandb.config`. + config_exclude_keys: A list of specific keys to exclude from `wandb.config`. + config_include_keys: A list of specific keys to include in `wandb.config`. + allow_val_change: Controls whether config values can be modified after their + initial set. By default, an exception is raised if a config value is + overwritten. For tracking variables that change during training, such as + a learning rate, consider using `wandb.log()` instead. By default, this + is `False` in scripts and `True` in Notebook environments. + group: Specify a group name to organize individual runs as part of a larger + experiment. This is useful for cases like cross-validation or running + multiple jobs that train and evaluate a model on different test sets. + Grouping allows you to manage related runs collectively in the UI, + making it easy to toggle and review results as a unified experiment. + job_type: Specify the type of run, especially helpful when organizing runs + within a group as part of a larger experiment. For example, in a group, + you might label runs with job types such as "train" and "eval". + Defining job types enables you to easily filter and group similar runs + in the UI, facilitating direct comparisons. + mode: Specifies how run data is managed, with the following options: + - `"online"` (default): Enables live syncing with W&B when a network + connection is available, with real-time updates to visualizations. + - `"offline"`: Suitable for air-gapped or offline environments; data + is saved locally and can be synced later. Ensure the run folder + is preserved to enable future syncing. + - `"disabled"`: Disables all W&B functionality, making the run’s methods + no-ops. Typically used in testing to bypass W&B operations. + - `"shared"`: (This is an experimental feature). Allows multiple processes, + possibly on different machines, to simultaneously log to the same run. + In this approach you use a primary node and one or more worker nodes + to log data to the same run. Within the primary node you + initialize a run. For each worker node, initialize a run + using the run ID used by the primary node. + force: Determines if a W&B login is required to run the script. If `True`, + the user must be logged in to W&B; otherwise, the script will not + proceed. If `False` (default), the script can proceed without a login, + switching to offline mode if the user is not logged in. + reinit: Shorthand for the "reinit" setting. Determines the behavior of + `wandb.init()` when a run is active. + resume: Controls the behavior when resuming a run with the specified `id`. + Available options are: + - `"allow"`: If a run with the specified `id` exists, it will resume + from the last step; otherwise, a new run will be created. + - `"never"`: If a run with the specified `id` exists, an error will + be raised. If no such run is found, a new run will be created. + - `"must"`: If a run with the specified `id` exists, it will resume + from the last step. If no run is found, an error will be raised. + - `"auto"`: Automatically resumes the previous run if it crashed on + this machine; otherwise, starts a new run. + - `True`: Deprecated. Use `"auto"` instead. + - `False`: Deprecated. Use the default behavior (leaving `resume` + unset) to always start a new run. + If `resume` is set, `fork_from` and `resume_from` cannot be + used. When `resume` is unset, the system will always start a new run. + resume_from: Specifies a moment in a previous run to resume a run from, + using the format `{run_id}?_step={step}`. This allows users to truncate + the history logged to a run at an intermediate step and resume logging + from that step. The target run must be in the same project. + If an `id` argument is also provided, the `resume_from` argument will + take precedence. + `resume`, `resume_from` and `fork_from` cannot be used together, only + one of them can be used at a time. + Note that this feature is in beta and may change in the future. + fork_from: Specifies a point in a previous run from which to fork a new + run, using the format `{id}?_step={step}`. This creates a new run that + resumes logging from the specified step in the target run’s history. + The target run must be part of the current project. + If an `id` argument is also provided, it must be different from the + `fork_from` argument, an error will be raised if they are the same. + `resume`, `resume_from` and `fork_from` cannot be used together, only + one of them can be used at a time. + Note that this feature is in beta and may change in the future. + save_code: Enables saving the main script or notebook to W&B, aiding in + experiment reproducibility and allowing code comparisons across runs in + the UI. By default, this is disabled, but you can change the default to + enable on your settings page. + tensorboard: Deprecated. Use `sync_tensorboard` instead. + sync_tensorboard: Enables automatic syncing of W&B logs from TensorBoard + or TensorBoardX, saving relevant event files for viewing in + the W&B UI. + monitor_gym: Enables automatic logging of videos of the environment when + using OpenAI Gym. + settings: Specifies a dictionary or `wandb.Settings` object with advanced + settings for the run. + + Returns: + A `Run` object. + + Raises: + Error: If some unknown or internal error happened during the run + initialization. + AuthenticationError: If the user failed to provide valid credentials. + CommError: If there was a problem communicating with the WandB server. + UsageError: If the user provided invalid arguments. + KeyboardInterrupt: If user interrupts the run. + + Examples: + `wandb.init()` returns a `Run` object. Use the run object to log data, + save artifacts, and manage the run lifecycle. + + ```python + import wandb + + config = {"lr": 0.01, "batch_size": 32} + with wandb.init(config=config) as run: + # Log accuracy and loss to the run + acc = 0.95 # Example accuracy + loss = 0.05 # Example loss + run.log({"accuracy": acc, "loss": loss}) + ``` + """ + init_telemetry = telemetry.TelemetryRecord() + + init_settings = Settings() + if isinstance(settings, dict): + init_settings = Settings(**settings) + elif isinstance(settings, Settings): + init_settings = settings + + # Explicit function arguments take precedence over settings + if job_type is not None: + init_settings.run_job_type = job_type + if dir is not None: + init_settings.root_dir = dir # type: ignore + if project is not None: + init_settings.project = project + if entity is not None: + init_settings.entity = entity + if reinit is not None: + init_settings.reinit = reinit + if tags is not None: + init_settings.run_tags = tuple(tags) + if group is not None: + init_settings.run_group = group + if name is not None: + init_settings.run_name = name + if notes is not None: + init_settings.run_notes = notes + if anonymous is not UNSET: + init_settings.anonymous = anonymous + if mode is not None: + init_settings.mode = mode # type: ignore + if resume is not None: + init_settings.resume = resume # type: ignore + if force is not None: + init_settings.force = force + # TODO: deprecate "tensorboard" in favor of "sync_tensorboard" + if tensorboard is not None: + init_settings.sync_tensorboard = tensorboard + if sync_tensorboard is not None: + init_settings.sync_tensorboard = sync_tensorboard + if save_code is not None: + init_settings.save_code = save_code + if id is not None: + init_settings.run_id = id + if fork_from is not None: + init_settings.fork_from = fork_from # type: ignore + if resume_from is not None: + init_settings.resume_from = resume_from # type: ignore + + if config is not None: + init_telemetry.feature.set_init_config = True + + wl: wandb_setup._WandbSetup | None = None + + try: + wl = wandb_setup.singleton() + + wi = _WandbInit(wl, init_telemetry) + + wi.maybe_login(init_settings) + run_settings, show_warnings = wi.make_run_settings(init_settings) + + if isinstance(run_settings.reinit, bool): + wi.deprecated_features_used.append( + ( + Deprecated(run__reinit_bool=True), + "Using a boolean value for 'reinit' is deprecated." + " Use 'return_previous' or 'finish_previous' instead.", + ) + ) + + if run_settings.run_id is not None: + init_telemetry.feature.set_init_id = True + if run_settings.run_name is not None: + init_telemetry.feature.set_init_name = True + if run_settings.run_tags is not None: + init_telemetry.feature.set_init_tags = True + if run_settings._offline: + init_telemetry.feature.offline = True + if run_settings.fork_from is not None: + init_telemetry.feature.fork_mode = True + if run_settings.resume_from is not None: + init_telemetry.feature.rewind_mode = True + + wi.set_run_id(run_settings) + wi.set_sync_dir_suffix(run_settings) + run_printer = printer.new_printer(run_settings) + show_warnings(run_printer) + + with contextlib.ExitStack() as exit_stack: + exit_stack.enter_context(wb_logging.log_to_run(run_settings.run_id)) + + run_config = wi.make_run_config( + settings=run_settings, + config=config, + config_exclude_keys=config_exclude_keys, + config_include_keys=config_include_keys, + ) + + if run_settings._noop: + return wi.make_disabled_run(run_config) + + try_create_root_dir(run_settings) + exit_stack.enter_context(wi.setup_run_log_directory(run_settings)) + + if run_settings._jupyter: + wi.monkeypatch_ipython(run_settings) + + if monitor_gym: + _monkeypatch_openai_gym() + + if wandb.patched["tensorboard"]: + # NOTE: The user may have called the patch function directly. + init_telemetry.feature.tensorboard_patch = True + if run_settings.sync_tensorboard: + _monkeypatch_tensorboard() + init_telemetry.feature.tensorboard_sync = True + + if run_settings.x_server_side_derived_summary: + init_telemetry.feature.server_side_derived_summary = True + + run = wi.init(run_settings, run_config, run_printer) + + # Set up automatic Weave integration if Weave is installed + weave.setup(run_settings.entity, run_settings.project) + + return run + + except KeyboardInterrupt as e: + if wl: + wl._get_logger().warning("interrupted", exc_info=e) + + raise + + except Exception as e: + if wl: + wl._get_logger().exception("error in wandb.init()", exc_info=e) + + get_sentry().reraise(e) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_login.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_login.py new file mode 100644 index 0000000000000000000000000000000000000000..f5bde7ef4da976a2db07b784a013ebe5136873f3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_login.py @@ -0,0 +1,342 @@ +"""Log in to Weights & Biases. + +This authenticates your machine to log data to your account. +""" + +from __future__ import annotations + +import click + +import wandb +from wandb.errors import AuthenticationError, term +from wandb.sdk import wandb_setup +from wandb.sdk.lib import settings_file, wbauth +from wandb.sdk.lib.deprecation import UNSET, DoNotSet + +from ..apis import InternalApi + + +def login( + key: str | None = None, + relogin: bool | None = None, + host: str | None = None, + force: bool | None = None, + timeout: int | None = None, + verify: bool = False, + referrer: str | None = None, + anonymous: DoNotSet = UNSET, +) -> bool: + """Log into W&B. + + You generally don't have to use this because most W&B methods that need + authentication can log in implicitly. This is the programmatic counterpart + to the `wandb login` CLI. + + This updates global credentials for the session (affecting all wandb usage + in the current Python process after this call) and possibly the .netrc file. + + If the identity_token_file setting is set, like through the + WANDB_IDENTITY_TOKEN_FILE environment variable, then this is a no-op. + + Otherwise, if an explicit API key is provided, it is used and written to + the system .netrc file. If no key is provided, but the session is already + authenticated, then the session key is used for verification (if verify + is True) and the .netrc file is not updated. + + If none of the above is true, then this gets the API key from the first of: + + - The WANDB_API_KEY environment variable + - The api_key setting in a system or workspace settings file + - The .netrc file (either ~/.netrc, ~/_netrc or the path specified by the + NETRC environment variable) + - An interactive prompt (if available) + + Args: + key: The API key to use. + relogin: If true, get the API key from an interactive prompt, skipping + reading .netrc, environment variables, etc. + host: The W&B server URL to connect to. + force: If true, disallows selecting offline mode in the interactive + prompt. + timeout: Number of seconds to wait for user input in the interactive + prompt. This can be used as a failsafe if an interactive prompt + is incorrectly shown in a non-interactive environment. + verify: Verify the credentials with the W&B server and raise an + AuthenticationError on failure. + referrer: The referrer to use in the URL login request for analytics. + + Returns: + bool: If `key` is configured. + + Raises: + AuthenticationError: If `api_key` fails verification with the server. + UsageError: If `api_key` cannot be configured and no tty. + """ + if anonymous is not UNSET: + term.termwarn( + "The anonymous parameter to wandb.login() has no effect and will" + + " be removed in future versions.", + repeat=False, + ) + + if wandb.run is not None: + term.termwarn("Calling wandb.login() after wandb.init() has no effect.") + return False + + global_settings = wandb_setup.singleton().settings + if global_settings._noop: + return False + if global_settings._offline and not global_settings.x_cli_only_mode: + term.termwarn("Unable to verify login in offline mode.") + return False + + if host: + host = host.rstrip("/") + + _update_system_settings( + global_settings.read_system_settings(), + host=host, + ) + + logged_in, _ = _login( + key=key, + relogin=relogin, + host=host, + force=force, + timeout=timeout, + verify=verify, + referrer=referrer or "models", + ) + return logged_in + + +def _update_system_settings( + system_settings: settings_file.SettingsFiles, + *, + host: str | None, +) -> None: + """Update the user's system settings files.""" + # 'anonymous' is deprecated; we clear it automatically for now. + system_settings.clear("anonymous", globally=True) + + if host: + if host == "https://api.wandb.ai": + system_settings.clear("base_url", globally=True) + else: + system_settings.set("base_url", host, globally=True) + + try: + system_settings.save() + except settings_file.SaveSettingsError as e: + wandb.termwarn(str(e)) + + +def _login( + *, + key: str | None = None, + relogin: bool | None = None, + host: str | None = None, + force: bool | None = None, + timeout: float | None = None, + verify: bool = False, + referrer: str = "models", + update_api_key: bool = True, + _silent: bool | None = None, +) -> tuple[bool, str | None]: + """Log in to W&B. + + Arguments are the same as for wandb.login() with the following additions: + + Args: + update_api_key: If true and an explicit API key is given, it will be + saved to the .netrc file. + _silent: If true, will not print any messages to the console. + + Returns: + A pair (is_successful, key). + """ + settings = wandb_setup.singleton().settings + + if host is None: + host_url = wbauth.HostUrl(settings.base_url, app_url=settings.app_url) + else: + host_url = wbauth.HostUrl(host) + + if relogin is None: + relogin = settings.relogin + if force is None: + force = settings.force + if timeout is None: + timeout = settings.login_timeout + if _silent is None: + _silent = settings.silent + + if wandb.util._is_kaggle() and not wandb.util._has_internet(): + term.termerror( + "To use W&B in kaggle you must enable internet in the settings" + + " panel on the right." + ) + return False, None + + if key: + auth: wbauth.Auth | None = _use_explicit_key( + key, + host=host_url, + settings=settings, + update_api_key=update_api_key, + silent=_silent, + ) + else: + auth = _find_or_prompt_for_key( + settings, + host=host_url, + force=force, + relogin=relogin, + referrer=referrer, + input_timeout=timeout, + ) + + if verify and isinstance(auth, wbauth.AuthApiKey): + _verify_login(key=auth.api_key, base_url=auth.host.url) + + wandb_setup.singleton().update_user_settings() + if not _silent: + _print_logged_in_message(settings, host=str(host_url)) + + if auth is None: + return False, None + elif isinstance(auth, wbauth.AuthApiKey): + return True, auth.api_key + else: + return True, None + + +def _use_explicit_key( + key: str, + settings: wandb.Settings, + *, + host: wbauth.HostUrl, + update_api_key: bool, + silent: bool, +) -> wbauth.Auth: + """Log in with an explicit key. + + Same arguments as `_login()`. + """ + if settings._notebook and not silent: + term.termwarn( + "If you're specifying your api key in code, ensure this" + + " code is not shared publicly." + + "\nConsider setting the WANDB_API_KEY environment variable," + + " or running `wandb login` from the command line." + ) + + auth = wbauth.AuthApiKey(host=host, api_key=key) + wbauth.use_explicit_auth(auth, source="wandb.login()") + + if update_api_key: + try: + wbauth.write_netrc_auth( + host=auth.host.url, + api_key=auth.api_key, + ) + except wbauth.WriteNetrcError as e: + wandb.termwarn(str(e)) + + return auth + + +def _find_or_prompt_for_key( + settings: wandb.Settings, + *, + host: wbauth.HostUrl, + force: bool, + relogin: bool, + referrer: str, + input_timeout: float | None, +) -> wbauth.Auth | None: + """Log in without an explicit key. + + Same arguments as `_login()`. + """ + timed_out = False + auth: wbauth.Auth | None = None + + try: + auth = wbauth.authenticate_session( + host=host, + source="wandb.login()", + no_offline=force, + no_create=force, + referrer=referrer, + input_timeout=input_timeout, + relogin=relogin, + ) + + except TimeoutError: + timed_out = True + + if not auth: + if timed_out: + term.termwarn("W&B disabled due to login timeout.") + settings.mode = "disabled" + else: + term.termlog("Using W&B in offline mode.") + settings.mode = "offline" + + return auth + + +def _verify_login(key: str, base_url: str) -> None: + from requests.exceptions import ConnectionError + + api = InternalApi( + api_key=key, + default_settings={"base_url": base_url}, + ) + + try: + is_api_key_valid = api.validate_api_key() + except ConnectionError as e: + raise AuthenticationError( + f"Unable to connect to {base_url} to verify API token." + ) from e + except Exception as e: + raise AuthenticationError( + "An error occurred while verifying the API key." + ) from e + + if not is_api_key_valid: + raise AuthenticationError( + f"API key verification failed for host {base_url}." + + " Make sure your API key is valid." + ) + + +def _print_logged_in_message(settings: wandb.Settings, *, host: str) -> None: + """Print a message telling the user they are logged in.""" + singleton = wandb_setup.singleton() + username = singleton._get_username() + + if username: + host_str = f" to {click.style(host, fg='green')}" if host else "" + + # check to see if we got an entity from the setup call or from the user + entity = settings.entity or singleton._get_entity() + + entity_str = "" + # check if entity exist, valid (is part of a certain team) and different from the username + if entity and entity in singleton._get_teams() and entity != username: + entity_str = f" ({click.style(entity, fg='yellow')})" + + login_state_str = f"Currently logged in as: {click.style(username, fg='yellow')}{entity_str}{host_str}" + else: + login_state_str = "W&B API key is configured" + + login_info_str = ( + f"Use {click.style('`wandb login --relogin`', bold=True)} to force relogin" + ) + term.termlog( + f"{login_state_str}. {login_info_str}", + repeat=False, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_metric.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9ffd224bd98eda0aa4f594ad1148c9c01b70cb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_metric.py @@ -0,0 +1,112 @@ +"""metric.""" + +import logging +from typing import Callable, Optional, Sequence, Tuple + +from wandb.proto import wandb_internal_pb2 as pb + +logger = logging.getLogger("wandb") + + +class Metric: + """Metric object.""" + + _callback: Optional[Callable[[pb.MetricRecord], None]] + _name: str + _step_metric: Optional[str] + _step_sync: Optional[bool] + _hidden: Optional[bool] + _summary: Optional[Sequence[str]] + _goal: Optional[str] + _overwrite: Optional[bool] + + def __init__( + self, + name: str, + step_metric: Optional[str] = None, + step_sync: Optional[bool] = None, + hidden: Optional[bool] = None, + summary: Optional[Sequence[str]] = None, + goal: Optional[str] = None, + overwrite: Optional[bool] = None, + ) -> None: + self._callback = None + self._name = name + self._step_metric = step_metric + # default to step_sync=True if step metric is set + step_sync = step_sync if step_sync is not None else step_metric is not None + self._step_sync = step_sync + self._hidden = hidden + self._summary = summary + self._goal = goal + self._overwrite = overwrite + + def _set_callback(self, cb: Callable[[pb.MetricRecord], None]) -> None: + self._callback = cb + + @property + def name(self) -> str: + return self._name + + @property + def step_metric(self) -> Optional[str]: + return self._step_metric + + @property + def step_sync(self) -> Optional[bool]: + return self._step_sync + + @property + def summary(self) -> Optional[Tuple[str, ...]]: + if self._summary is None: + return None + return tuple(self._summary) + + @property + def hidden(self) -> Optional[bool]: + return self._hidden + + @property + def goal(self) -> Optional[str]: + goal_dict = dict(min="minimize", max="maximize") + return goal_dict[self._goal] if self._goal else None + + def _commit(self) -> None: + m = pb.MetricRecord() + m.options.defined = True + if self._name.endswith("*"): + m.glob_name = self._name + else: + m.name = self._name + if self._step_metric: + m.step_metric = self._step_metric + if self._step_sync: + m.options.step_sync = self._step_sync + if self._hidden: + m.options.hidden = self._hidden + if self._summary: + summary_set = set(self._summary) + if "min" in summary_set: + m.summary.min = True + if "max" in summary_set: + m.summary.max = True + if "mean" in summary_set: + m.summary.mean = True + if "last" in summary_set: + m.summary.last = True + if "copy" in summary_set: + m.summary.copy = True + if "none" in summary_set: + m.summary.none = True + if "best" in summary_set: + m.summary.best = True + if "first" in summary_set: + m.summary.first = True + if self._goal == "min": + m.goal = m.GOAL_MINIMIZE + if self._goal == "max": + m.goal = m.GOAL_MAXIMIZE + if self._overwrite: + m._control.overwrite = self._overwrite + if self._callback: + self._callback(m) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_require.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_require.py new file mode 100644 index 0000000000000000000000000000000000000000..b1540e29afe1156b5ed994d662203c68855ac46a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_require.py @@ -0,0 +1,88 @@ +"""Feature Flags Module. + +This module implements a feature flag system for the wandb library to require experimental features +and notify the user when features have been deprecated. + +Example: + import wandb + wandb.require("wandb-service@beta") + wandb.require("incremental-artifacts@beta") +""" + +from __future__ import annotations + +from typing import Iterable + +import wandb +from wandb.errors import UnsupportedError + + +class _Requires: + """Internal feature class.""" + + _features: tuple[str, ...] + + def __init__(self, features: str | Iterable[str]) -> None: + self._features = ( + tuple([features]) if isinstance(features, str) else tuple(features) + ) + + def require_require(self) -> None: + pass + + def require_service(self) -> None: + # Legacy no-op kept solely for backward compatibility: + # some integrations (e.g. PyTorch Lightning) still call + # `wandb.require('service')`, which routes here. + wandb.termwarn( + "`wandb.require('service')` is a no-op as it is now the default behavior." + ) + + def require_core(self) -> None: + # Legacy no-op kept solely for backward compatibility: + # many public codebases still call `wandb.require('core')`. + wandb.termwarn( + "`wandb.require('core')` is a no-op as it is now the default behavior." + ) + + def apply(self) -> None: + """Call require_* method for supported features.""" + last_message: str = "" + for feature_item in self._features: + full_feature = feature_item.split("@", 2)[0] + feature = full_feature.split(":", 2)[0] + func_str = "require_{}".format(feature.replace("-", "_")) + func = getattr(self, func_str, None) + if not func: + last_message = f"require() unsupported requirement: {feature}" + wandb.termwarn(last_message) + continue + func() + + if last_message: + raise UnsupportedError(last_message) + + +def require( + requirement: str | Iterable[str] | None = None, + experiment: str | Iterable[str] | None = None, +) -> None: + """Indicate which experimental features are used by the script. + + This should be called before any other `wandb` functions, ideally right + after importing `wandb`. + + Args: + requirement: The name of a feature to require or an iterable of + feature names. + experiment: An alias for `requirement`. + + Raises: + wandb.errors.UnsupportedError: If a feature name is unknown. + """ + features = requirement or experiment + if not features: + return + + f = _Requires(features=features) + f.apply() diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_require_helpers.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_require_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f14583ac53efd1b2b49b76ea1a0a27e65e88627 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_require_helpers.py @@ -0,0 +1,44 @@ +import os +from functools import wraps +from typing import Any, Callable, Dict, TypeVar, cast + +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + +requirement_env_var_mapping: Dict[str, str] = { + "report-editing:v0": "WANDB_REQUIRE_REPORT_EDITING_V0" +} + + +def requires(requirement: str) -> FuncT: # type: ignore + """Decorate functions to gate features with wandb.require.""" + env_var = requirement_env_var_mapping[requirement] + + def deco(func: FuncT) -> FuncT: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if not os.getenv(env_var): + raise Exception( + f"You need to enable this feature with `wandb.require({requirement!r})`" + ) + return func(*args, **kwargs) + + return cast(FuncT, wrapper) + + return cast(FuncT, deco) + + +class RequiresMixin: + requirement = "" + + def __init__(self) -> None: + self._check_if_requirements_met() + + def __post_init__(self) -> None: + self._check_if_requirements_met() + + def _check_if_requirements_met(self) -> None: + env_var = requirement_env_var_mapping[self.requirement] + if not os.getenv(env_var): + raise Exception( + f'You must explicitly enable this feature with `wandb.require("{self.requirement})"' + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_run.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_run.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eb6ac1823b31b4e8e921cceb2e927a149c566c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_run.py @@ -0,0 +1,4106 @@ +from __future__ import annotations + +import contextlib +import functools +import glob +import json +import logging +import numbers +import os +import pathlib +import re +import sys +import threading +import time +import traceback +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from enum import IntEnum +from types import TracebackType +from typing import TYPE_CHECKING, Callable, Sequence, TextIO, TypeVar + +from typing_extensions import Any, Concatenate, Literal, NamedTuple, ParamSpec + +import wandb +import wandb.env +import wandb.util +from wandb import trigger +from wandb.analytics import get_sentry +from wandb.errors import CommError, UsageError +from wandb.errors.links import url_registry +from wandb.integration.torch import wandb_torch +from wandb.plot import CustomChart, Visualize +from wandb.proto.wandb_internal_pb2 import ( + MetricRecord, + PollExitResponse, + Result, + RunRecord, +) +from wandb.proto.wandb_telemetry_pb2 import Deprecated +from wandb.sdk.lib import wb_logging +from wandb.sdk.lib.filesystem import ( + FilesDict, + GlobStr, + LinkStats, + PolicyName, + link_or_copy_with_policy, + unlink_path, + validate_glob_path, +) +from wandb.sdk.lib.import_hooks import ( + register_post_import_hook, + unregister_post_import_hook, +) +from wandb.sdk.lib.paths import FilePathStr, StrPath +from wandb.util import ( + _is_artifact_object, + _is_artifact_string, + _is_artifact_version_weave_dict, + _is_py_requirements_or_dockerfile, + _resolve_aliases, + add_import_hook, + parse_artifact_string, +) + +from . import wandb_config, wandb_metric, wandb_summary +from .data_types._dtypes import TypeRegistry +from .interface.interface import InterfaceBase +from .interface.summary_record import SummaryRecord +from .lib import ( + config_util, + deprecation, + filenames, + filesystem, + interrupt, + ipython, + module, + printer, + progress, + proto_util, + redirect, + telemetry, +) +from .lib.exit_hooks import ExitHooks +from .mailbox import ( + HandleAbandonedError, + MailboxClosedError, + MailboxHandle, + wait_with_progress, +) +from .wandb_alerts import AlertLevel +from .wandb_setup import _WandbSetup + +if TYPE_CHECKING: + from typing import TypedDict + + import torch # type: ignore [import-not-found] + + import wandb.sdk.backend.backend + import wandb.sdk.interface.interface_queue + from wandb.apis.public import Api as PublicApi + from wandb.proto.wandb_internal_pb2 import ( + GetSummaryResponse, + InternalMessagesResponse, + SampledHistoryResponse, + ) + from wandb.sdk.artifacts.artifact import Artifact + + from .wandb_settings import Settings + + class GitSourceDict(TypedDict): + remote: str + commit: str + entrypoint: list[str] + args: Sequence[str] + + class ArtifactSourceDict(TypedDict): + artifact: str + entrypoint: list[str] + args: Sequence[str] + + class ImageSourceDict(TypedDict): + image: str + args: Sequence[str] + + class JobSourceDict(TypedDict, total=False): + _version: str + source_type: str + source: GitSourceDict | ArtifactSourceDict | ImageSourceDict + input_types: dict[str, Any] + output_types: dict[str, Any] + runtime: str | None + services: dict[str, str] + + +logger = logging.getLogger("wandb") +EXIT_TIMEOUT = 60 +RE_LABEL = re.compile(r"[a-zA-Z0-9_-]+$") + + +class TeardownStage(IntEnum): + EARLY = 1 + LATE = 2 + + +class TeardownHook(NamedTuple): + call: Callable[[], None] + stage: TeardownStage + + +class RunStatusChecker: + """Periodically polls the background process for relevant updates. + + - check if the user has requested a stop. + - check the network status. + - check the run sync status. + """ + + _stop_status_lock: threading.Lock + _stop_status_handle: MailboxHandle[Result] | None + _network_status_lock: threading.Lock + _network_status_handle: MailboxHandle[Result] | None + _internal_messages_lock: threading.Lock + _internal_messages_handle: MailboxHandle[Result] | None + + def __init__( + self, + run_id: str, + interface: InterfaceBase, + settings: Settings, + stop_polling_interval: int = 15, + retry_polling_interval: int = 5, + internal_messages_polling_interval: int = 10, + ) -> None: + self._run_id = run_id + self._interface = interface + self._stop_polling_interval = stop_polling_interval + self._retry_polling_interval = retry_polling_interval + self._internal_messages_polling_interval = internal_messages_polling_interval + self._settings = settings + + self._join_event = threading.Event() + + self._stop_status_lock = threading.Lock() + self._stop_status_handle = None + self._stop_thread = threading.Thread( + target=self.check_stop_status, + name="ChkStopThr", + daemon=True, + ) + + self._network_status_lock = threading.Lock() + self._network_status_handle = None + self._network_status_thread = threading.Thread( + target=self.check_network_status, + name="NetStatThr", + daemon=True, + ) + + self._internal_messages_lock = threading.Lock() + self._internal_messages_handle = None + self._internal_messages_thread = threading.Thread( + target=self.check_internal_messages, + name="IntMsgThr", + daemon=True, + ) + + def start(self) -> None: + self._stop_thread.start() + self._network_status_thread.start() + self._internal_messages_thread.start() + + @staticmethod + def _abandon_status_check( + lock: threading.Lock, + handle: MailboxHandle[Result] | None, + ): + with lock: + if handle: + handle.cancel() + + def _loop_check_status( + self, + *, + lock: threading.Lock, + set_handle: Any, + timeout: int, + request: Any, + process: Any, + ) -> None: + local_handle: MailboxHandle[Result] | None = None + join_requested = False + while not join_requested: + time_probe = time.monotonic() + if not local_handle: + try: + local_handle = request() + except MailboxClosedError: + # This can happen if the service process dies. + break + assert local_handle + + with lock: + if self._join_event.is_set(): + break + set_handle(local_handle) + + try: + result = local_handle.wait_or(timeout=timeout) + except HandleAbandonedError: + # This can happen if the service process dies. + break + except TimeoutError: + result = None + + with lock: + set_handle(None) + + if result: + process(result) + local_handle = None + + time_elapsed = time.monotonic() - time_probe + wait_time = max(timeout - time_elapsed, 0) + join_requested = self._join_event.wait(timeout=wait_time) + + def check_network_status(self) -> None: + def _process_network_status(result: Result) -> None: + network_status = result.response.network_status_response + for hr in network_status.network_responses: + if ( + hr.http_status_code == 200 or hr.http_status_code == 0 + ): # we use 0 for non-http errors (eg wandb errors) + wandb.termlog(f"{hr.http_response_text}") + else: + wandb.termlog( + f"{hr.http_status_code} encountered ({hr.http_response_text.rstrip()}), retrying request" + ) + + with wb_logging.log_to_run(self._run_id): + try: + self._loop_check_status( + lock=self._network_status_lock, + set_handle=lambda x: setattr(self, "_network_status_handle", x), + timeout=self._retry_polling_interval, + request=self._interface.deliver_network_status, + process=_process_network_status, + ) + except BrokenPipeError: + self._abandon_status_check( + self._network_status_lock, + self._network_status_handle, + ) + + def check_stop_status(self) -> None: + def _process_stop_status(result: Result) -> None: + from wandb.agents import pyagent + + stop_status = result.response.stop_status_response + if stop_status.run_should_stop: + # TODO(frz): This check is required + # until WB-3606 is resolved on server side. + if not pyagent.is_running(): # type: ignore + interrupt.interrupt_main() + return + + with wb_logging.log_to_run(self._run_id): + try: + self._loop_check_status( + lock=self._stop_status_lock, + set_handle=lambda x: setattr(self, "_stop_status_handle", x), + timeout=self._stop_polling_interval, + request=self._interface.deliver_stop_status, + process=_process_stop_status, + ) + except BrokenPipeError: + self._abandon_status_check( + self._stop_status_lock, + self._stop_status_handle, + ) + + def check_internal_messages(self) -> None: + def _process_internal_messages(result: Result) -> None: + if ( + not self._settings.show_warnings + or self._settings.quiet + or self._settings.silent + ): + return + internal_messages = result.response.internal_messages_response + for msg in internal_messages.messages.warning: + wandb.termwarn(msg, repeat=False) + + with wb_logging.log_to_run(self._run_id): + try: + self._loop_check_status( + lock=self._internal_messages_lock, + set_handle=lambda x: setattr(self, "_internal_messages_handle", x), + timeout=self._internal_messages_polling_interval, + request=self._interface.deliver_internal_messages, + process=_process_internal_messages, + ) + except BrokenPipeError: + self._abandon_status_check( + self._internal_messages_lock, + self._internal_messages_handle, + ) + + def stop(self) -> None: + self._join_event.set() + self._abandon_status_check( + self._stop_status_lock, + self._stop_status_handle, + ) + self._abandon_status_check( + self._network_status_lock, + self._network_status_handle, + ) + self._abandon_status_check( + self._internal_messages_lock, + self._internal_messages_handle, + ) + + def join(self) -> None: + self.stop() + self._stop_thread.join() + self._network_status_thread.join() + self._internal_messages_thread.join() + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +def _log_to_run( + func: Callable[Concatenate[Run, _P], _T], +) -> Callable[Concatenate[Run, _P], _T]: + """Decorate a Run method to set the run ID in the logging context. + + Any logs during the execution of the method go to the run's log file + and not to other runs' log files. + + This is meant for use on all public methods and some callbacks. Private + methods can be assumed to be called from some public method somewhere. + The general rule is to use it on methods that can be called from a + context that isn't specific to this run (such as all user code or + internal methods that aren't run-specific). + """ + + @functools.wraps(func) + def wrapper(self: Run, *args: _P.args, **kwargs: _P.kwargs) -> _T: + # In "attach" usage, many properties of the Run are not initially + # populated. + if hasattr(self, "_settings"): + run_id = self._settings.run_id + else: + run_id = self._attach_id + + with wb_logging.log_to_run(run_id): + return func(self, *args, **kwargs) + + return wrapper + + +_is_attaching: str = "" + + +def _attach( + func: Callable[Concatenate[Run, _P], _T], +) -> Callable[Concatenate[Run, _P], _T]: + """Decorate a Run method to auto-attach when in a new process. + + When in a forked process or using a pickled Run instance, this automatically + connects to the service process to "attach" to the existing run. + """ + + @functools.wraps(func) + def wrapper(self: Run, *args: _P.args, **kwargs: _P.kwargs) -> _T: + global _is_attaching + + # The _attach_id attribute is only None when running in the "disable + # service" mode. + # + # Since it is set early in `__init__` and included in the run's pickled + # state, the attribute always exists. + is_using_service = self._attach_id is not None + + # The _attach_pid attribute is not pickled, so it might not exist. + # It is set when the run is initialized. + attach_pid = getattr(self, "_attach_pid", None) + + if is_using_service and attach_pid != os.getpid(): + if _is_attaching: + raise RuntimeError( + f"Trying to attach `{func.__name__}`" + + f" while in the middle of attaching `{_is_attaching}`" + ) + + _is_attaching = func.__name__ + try: + wandb._attach(run=self) # type: ignore + finally: + _is_attaching = "" + + return func(self, *args, **kwargs) + + return wrapper + + +def _raise_if_finished( + func: Callable[Concatenate[Run, _P], _T], +) -> Callable[Concatenate[Run, _P], _T]: + """Decorate a Run method to raise an error after the run is finished.""" + + @functools.wraps(func) + def wrapper_fn(self: Run, *args: _P.args, **kwargs: _P.kwargs) -> _T: + if not getattr(self, "_is_finished", False): + return func(self, *args, **kwargs) + + message = ( + f"Run ({self.id}) is finished. The call to" + f" `{func.__name__}` will be ignored." + f" Please make sure that you are using an active run." + ) + + raise UsageError(message) + + return wrapper_fn + + +@dataclass +class RunStatus: + sync_items_total: int = field(default=0) + sync_items_pending: int = field(default=0) + sync_time: datetime | None = field(default=None) + + +class Run: + """A unit of computation logged by W&B. Typically, this is an ML experiment. + + Call [`wandb.init()`](https://docs.wandb.ai/ref/python/init/) to create a + new run. `wandb.init()` starts a new run and returns a `wandb.Run` object. + Each run is associated with a unique ID (run ID). W&B recommends using + a context (`with` statement) manager to automatically finish the run. + + For distributed training experiments, you can either track each process + separately using one run per process or track all processes to a single run. + See [Log distributed training experiments](https://docs.wandb.ai/guides/track/log/distributed-training) + for more information. + + You can log data to a run with `wandb.Run.log()`. Anything you log using + `wandb.Run.log()` is sent to that run. See + [Create an experiment](https://docs.wandb.ai/guides/track/create-an-experiment/) or + [`wandb.init`](https://docs.wandb.ai/ref/python/init/) API reference page + or more information. + + There is a another `Run` object in the + [`wandb.apis.public`](https://docs.wandb.ai/ref/python/public-api/api/) + namespace. Use this object is to interact with runs that have already been + created. + + Attributes: + summary: (Summary) A summary of the run, which is a dictionary-like + object. For more information, see + [Log summary metrics](https://docs.wandb.ai/guides/track/log/log-summary/). + + Examples: + Create a run with `wandb.init()`: + + ```python + import wandb + + # Start a new run and log some data + # Use context manager (`with` statement) to automatically finish the run + with wandb.init(entity="entity", project="project") as run: + run.log({"accuracy": acc, "loss": loss}) + ``` + + + """ + + _telemetry_obj: telemetry.TelemetryRecord + _telemetry_obj_active: bool + _telemetry_obj_dirty: bool + _telemetry_obj_flushed: bytes + + _teardown_hooks: list[TeardownHook] + + _backend: wandb.sdk.backend.backend.Backend | None + _internal_run_interface: wandb.sdk.interface.interface_queue.InterfaceQueue | None + _wl: _WandbSetup | None + + _out_redir: redirect.RedirectBase | None + _err_redir: redirect.RedirectBase | None + _redirect_cb: Callable[[str, str], None] | None + _redirect_raw_cb: Callable[[str, str], None] | None + _output_writer: filesystem.CRDedupedFile | None + + _atexit_cleanup_called: bool + _hooks: ExitHooks | None + _exit_code: int | None + + _run_status_checker: RunStatusChecker | None + + _sampled_history: SampledHistoryResponse | None + _final_summary: GetSummaryResponse | None + _poll_exit_handle: MailboxHandle[Result] | None + _poll_exit_response: PollExitResponse | None + _internal_messages_response: InternalMessagesResponse | None + + _stdout_slave_fd: int | None + _stderr_slave_fd: int | None + _artifact_slots: list[str] + + _init_pid: int + _attach_pid: int + + _attach_id: str | None + _is_attached: bool + _is_finished: bool + _settings: Settings + + _forked: bool + + _launch_artifacts: dict[str, Any] | None + _printer: printer.Printer + + summary: wandb_summary.Summary + + def __init__( + self, + settings: Settings, + config: dict[str, Any] | None = None, + sweep_config: dict[str, Any] | None = None, + launch_config: dict[str, Any] | None = None, + ) -> None: + # pid is set, so we know if this run object was initialized by this process + self._init_pid = os.getpid() + self._attach_id = None + + if settings._noop: + # TODO: properly handle setting for disabled mode + self._settings = settings + return + + self._init( + settings=settings, + config=config, + sweep_config=sweep_config, + launch_config=launch_config, + ) + + def _init( + self, + settings: Settings, + config: dict[str, Any] | None = None, + sweep_config: dict[str, Any] | None = None, + launch_config: dict[str, Any] | None = None, + ) -> None: + self._settings = settings + + self._config = wandb_config.Config() + self._config._set_callback(self._config_callback) + self._config._set_artifact_callback(self._config_artifact_callback) + self._config._set_settings(self._settings) + + # The _wandb key is always expected on the run config. + wandb_key = "_wandb" + self._config._update({wandb_key: dict()}) + + # TODO: perhaps this should be a property that is a noop on a finished run + self.summary = wandb_summary.Summary( + self._summary_get_current_summary_callback, + ) + self.summary._set_update_callback(self._summary_update_callback) + + self._step = 0 + self._starting_step = 0 + self._start_runtime = 0 + # TODO: eventually would be nice to make this configurable using self._settings._start_time + # need to test (jhr): if you set start time to 2 days ago and run a test for 15 minutes, + # does the total time get calculated right (not as 2 days and 15 minutes)? + self._start_time = time.time() + + self._printer = printer.new_printer(settings) + + self._torch_history: wandb_torch.TorchHistory | None = None # type: ignore + + self._backend = None + self._internal_run_interface = None + self._wl = None + # Avoid calling wandb.Api() repeatedly in _public_api() + self._cached_public_api: PublicApi | None = None + self._hooks = None + self._teardown_hooks = [] + + self._output_writer = None + self._out_redir = None + self._err_redir = None + self._stdout_slave_fd = None + self._stderr_slave_fd = None + + self._exit_code = None + self._exit_result = None + + self._used_artifact_slots: dict[str, str] = {} + + # Created when the run "starts". + self._run_status_checker = None + + self._sampled_history = None + self._final_summary = None + self._poll_exit_response = None + self._internal_messages_response = None + self._poll_exit_handle = None + + # Initialize telemetry object + self._telemetry_obj = telemetry.TelemetryRecord() + self._telemetry_obj_active = False + self._telemetry_obj_flushed = b"" + self._telemetry_obj_dirty = False + + self._atexit_cleanup_called = False + + # Initial scope setup for sentry. + # This might get updated when the actual run comes back. + get_sentry().configure_scope( + tags=dict(self._settings), + process_context="user", + ) + + self._launch_artifact_mapping: dict[str, Any] = {} + self._unique_launch_artifact_sequence_names: dict[str, Any] = {} + + # Populate config + config = config or dict() + self._config._update(config, allow_val_change=True, ignore_locked=True) + + if sweep_config: + self._config.merge_locked( + sweep_config, user="sweep", _allow_val_change=True + ) + + if launch_config: + self._config.merge_locked( + launch_config, user="launch", _allow_val_change=True + ) + + # if run is from a launch queue, add queue id to _wandb config + launch_queue_name = wandb.env.get_launch_queue_name() + if launch_queue_name: + self._config[wandb_key]["launch_queue_name"] = launch_queue_name + + launch_queue_entity = wandb.env.get_launch_queue_entity() + if launch_queue_entity: + self._config[wandb_key]["launch_queue_entity"] = launch_queue_entity + + launch_trace_id = wandb.env.get_launch_trace_id() + if launch_trace_id: + self._config[wandb_key]["launch_trace_id"] = launch_trace_id + + self._attach_id = None + self._is_attached = False + self._is_finished = False + + self._attach_pid = os.getpid() + self._forked = False + # for now, use runid as attach id, this could/should be versioned in the future + self._attach_id = self._settings.run_id + + def _handle_launch_artifact_overrides(self) -> None: + if self._settings.launch and (os.environ.get("WANDB_ARTIFACTS") is not None): + try: + artifacts: dict[str, Any] = json.loads( + os.environ.get("WANDB_ARTIFACTS", "{}") + ) + except (ValueError, SyntaxError): + wandb.termwarn("Malformed WANDB_ARTIFACTS, using original artifacts") + else: + self._initialize_launch_artifact_maps(artifacts) + + elif ( + self._settings.launch + and self._settings.launch_config_path + and os.path.exists(self._settings.launch_config_path) + ): + self.save(self._settings.launch_config_path) + with open(self._settings.launch_config_path) as fp: + launch_config = json.loads(fp.read()) + if launch_config.get("overrides", {}).get("artifacts") is not None: + artifacts = launch_config.get("overrides").get("artifacts") + self._initialize_launch_artifact_maps(artifacts) + + def _initialize_launch_artifact_maps(self, artifacts: dict[str, Any]) -> None: + for key, item in artifacts.items(): + self._launch_artifact_mapping[key] = item + artifact_sequence_tuple_or_slot = key.split(":") + + if len(artifact_sequence_tuple_or_slot) == 2: + sequence_name = artifact_sequence_tuple_or_slot[0].split("/")[-1] + if self._unique_launch_artifact_sequence_names.get(sequence_name): + self._unique_launch_artifact_sequence_names.pop(sequence_name) + else: + self._unique_launch_artifact_sequence_names[sequence_name] = item + + def _telemetry_callback(self, telem_obj: telemetry.TelemetryRecord) -> None: + if not hasattr(self, "_telemetry_obj") or self._is_finished: + return + + self._telemetry_obj.MergeFrom(telem_obj) + self._telemetry_obj_dirty = True + self._telemetry_flush() + + def _telemetry_flush(self) -> None: + if not hasattr(self, "_telemetry_obj"): + return + if not self._telemetry_obj_active: + return + if not self._telemetry_obj_dirty: + return + if self._backend and self._backend.interface: + serialized = self._telemetry_obj.SerializeToString() + if serialized == self._telemetry_obj_flushed: + return + self._backend.interface._publish_telemetry(self._telemetry_obj) + self._telemetry_obj_flushed = serialized + self._telemetry_obj_dirty = False + + def _freeze(self) -> None: + self._frozen = True + + def __setattr__(self, attr: str, value: object) -> None: + if getattr(self, "_frozen", None) and not hasattr(self, attr): + raise Exception(f"Attribute {attr} is not supported on Run object.") + super().__setattr__(attr, value) + + def __deepcopy__(self, memo: dict[int, Any]) -> Run: + return self + + def __getstate__(self) -> Any: + """Return run state as a custom pickle.""" + # We only pickle in service mode + if not self._settings: + return + + _attach_id = self._attach_id + if not _attach_id: + return + + return dict( + _attach_id=_attach_id, + _init_pid=self._init_pid, + _is_finished=self._is_finished, + ) + + def __setstate__(self, state: Any) -> None: + """Set run state from a custom pickle.""" + if not state: + return + + _attach_id = state.get("_attach_id") + if not _attach_id: + return + + if state["_init_pid"] == os.getpid(): + raise RuntimeError("attach in the same process is not supported currently") + + self.__dict__.update(state) + + @property + def _torch(self) -> wandb_torch.TorchHistory: # type: ignore + if self._torch_history is None: + self._torch_history = wandb_torch.TorchHistory() # type: ignore + return self._torch_history + + @property + @_log_to_run + @_attach + def settings(self) -> Settings: + """A frozen copy of run's Settings object.""" + return self._settings.model_copy(deep=True) + + @property + @_log_to_run + @_attach + def dir(self) -> str: + """The directory where files associated with the run are saved.""" + return self._settings.files_dir + + @property + @_log_to_run + @_attach + def config(self) -> wandb_config.Config: + """Config object associated with this run.""" + return self._config + + @property + @_log_to_run + @_attach + def config_static(self) -> wandb_config.ConfigStatic: + """Static config object associated with this run.""" + return wandb_config.ConfigStatic(self._config) + + @property + @_log_to_run + @_attach + def name(self) -> str | None: + """Display name of the run. + + Display names are not guaranteed to be unique and may be descriptive. + By default, they are randomly generated. + """ + return self._settings.run_name + + @name.setter + @_log_to_run + @_raise_if_finished + def name(self, name: str) -> None: + with telemetry.context(run=self) as tel: + tel.feature.set_run_name = True + self._settings.run_name = name + if self._backend and self._backend.interface: + self._backend.interface.publish_run(self) + + @property + @_log_to_run + @_attach + def notes(self) -> str | None: + """Notes associated with the run, if there are any. + + Notes can be a multiline string and can also use markdown and latex + equations inside `$$`, like `$x + 3$`. + """ + return self._settings.run_notes + + @notes.setter + @_log_to_run + @_raise_if_finished + def notes(self, notes: str) -> None: + self._settings.run_notes = notes + if self._backend and self._backend.interface: + self._backend.interface.publish_run(self) + + @property + @_log_to_run + @_attach + def tags(self) -> tuple | None: + """Tags associated with the run, if there are any.""" + return self._settings.run_tags or () + + @tags.setter + @_log_to_run + @_raise_if_finished + def tags(self, tags: Sequence) -> None: + with telemetry.context(run=self) as tel: + tel.feature.set_run_tags = True + + try: + self._settings.run_tags = tuple(tags) + except ValueError as e: + # For runtime tag setting, warn instead of crash + # Extract the core error message without the pydantic wrapper + error_msg = str(e) + if "Value error," in error_msg: + # Extract the actual error message after "Value error, " + error_msg = error_msg.split("Value error, ")[1].split(" [type=")[0] + wandb.termwarn(f"Invalid tag detected: {error_msg} Tags not updated.") + return + + if self._backend and self._backend.interface: + self._backend.interface.publish_run(self) + + @property + @_log_to_run + @_attach + def id(self) -> str: + """Identifier for this run.""" + assert self._settings.run_id is not None + return self._settings.run_id + + @property + @_log_to_run + @_attach + def sweep_id(self) -> str | None: + """Identifier for the sweep associated with the run, if there is one.""" + return self._settings.sweep_id + + def _get_path(self) -> str: + return "/".join( + e + for e in [ + self._settings.entity, + self._settings.project, + self._settings.run_id, + ] + if e is not None + ) + + @property + @_log_to_run + @_attach + def path(self) -> str: + """Path to the run. + + Run paths include entity, project, and run ID, in the format + `entity/project/run_id`. + """ + return self._get_path() + + @property + @_log_to_run + @_attach + def start_time(self) -> float: + """Unix timestamp (in seconds) of when the run started.""" + return self._start_time + + @property + @_log_to_run + @_attach + def starting_step(self) -> int: + """The first step of the run. + + + """ + return self._starting_step + + @property + @_log_to_run + @_attach + def resumed(self) -> bool: + """True if the run was resumed, False otherwise.""" + return self._settings.resumed + + @property + @_log_to_run + @_attach + def step(self) -> int: + """Current value of the step. + + This counter is incremented by `wandb.Run.log()`. + + + """ + return self._step + + @property + @_log_to_run + @_attach + def offline(self) -> bool: + """True if the run is offline, False otherwise.""" + return self._settings._offline + + @property + @_log_to_run + @_attach + def disabled(self) -> bool: + """True if the run is disabled, False otherwise.""" + return self._settings._noop + + @property + @_log_to_run + @_attach + def group(self) -> str: + """Returns the name of the group associated with this run. + + Grouping runs together allows related experiments to be organized and + visualized collectively in the W&B UI. This is especially useful for + scenarios such as distributed training or cross-validation, where + multiple runs should be viewed and managed as a unified experiment. + + In shared mode, where all processes share the same run object, + setting a group is usually unnecessary, since there is only one + run and no grouping is required. + """ + return self._settings.run_group or "" + + @property + @_log_to_run + @_attach + def job_type(self) -> str: + """Name of the job type associated with the run. + + View a run's job type in the run's Overview page in the W&B App. + + You can use this to categorize runs by their job type, such as + "training", "evaluation", or "inference". This is useful for organizing + and filtering runs in the W&B UI, especially when you have multiple + runs with different job types in the same project. For more + information, see [Organize runs](https://docs.wandb.ai/guides/runs/#organize-runs). + """ + return self._settings.run_job_type or "" + + def project_name(self) -> str: + """This method is deprecated and will be removed in a future release. Use `run.project` instead. + + Name of the W&B project associated with the run. + + + """ + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__project_name=True), + message=( + "The project_name method is deprecated and will be removed in a" + " future release. Please use `run.project` instead." + ), + ) + return self.project + + @property + @_log_to_run + @_attach + def project(self) -> str: + """Name of the W&B project associated with the run.""" + assert self._settings.project is not None + return self._settings.project + + @_log_to_run + def get_project_url(self) -> str | None: + """This method is deprecated and will be removed in a future release. Use `run.project_url` instead. + + URL of the W&B project associated with the run, if there is one. + Offline runs do not have a project URL. + + + """ + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__get_project_url=True), + message=( + "The get_project_url method is deprecated and will be removed in a" + " future release. Please use `run.project_url` instead." + ), + ) + return self.project_url + + @property + @_log_to_run + @_attach + def project_url(self) -> str | None: + """URL of the W&B project associated with the run, if there is one. + + Offline runs do not have a project URL. + """ + if self._settings._offline: + wandb.termwarn("URL not available in offline run") + return None + return self._settings.project_url + + @_raise_if_finished + @_log_to_run + @_attach + def log_code( + self, + root: str | None = ".", + name: str | None = None, + include_fn: Callable[[str, str], bool] + | Callable[[str], bool] = _is_py_requirements_or_dockerfile, + exclude_fn: Callable[[str, str], bool] + | Callable[[str], bool] = filenames.exclude_wandb_fn, + ) -> Artifact | None: + """Save the current state of your code to a W&B Artifact. + + By default, it walks the current directory and logs all files that end with `.py`. + + Args: + root: The relative (to `os.getcwd()`) or absolute path to recursively find code from. + name: (str, optional) The name of our code artifact. By default, we'll name + the artifact `source-$PROJECT_ID-$ENTRYPOINT_RELPATH`. There may be scenarios where you want + many runs to share the same artifact. Specifying name allows you to achieve that. + include_fn: A callable that accepts a file path and (optionally) root path and + returns True when it should be included and False otherwise. This + defaults to `lambda path, root: path.endswith(".py")`. + exclude_fn: A callable that accepts a file path and (optionally) root path and + returns `True` when it should be excluded and `False` otherwise. This + defaults to a function that excludes all files within `/.wandb/` + and `/wandb/` directories. + + Examples: + Basic usage + + ```python + import wandb + + with wandb.init() as run: + run.log_code() + ``` + + Advanced usage + + ```python + import wandb + + with wandb.init() as run: + run.log_code( + root="../", + include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"), + exclude_fn=lambda path, root: os.path.relpath(path, root).startswith( + "cache/" + ), + ) + ``` + + Returns: + An `Artifact` object if code was logged + """ + from wandb.sdk.artifacts._internal_artifact import InternalArtifact + + if name is None: + if self.settings._jupyter: + notebook_name = None + if self.settings.notebook_name: + notebook_name = self.settings.notebook_name + elif self.settings.x_jupyter_path: + if self.settings.x_jupyter_path.startswith("fileId="): + notebook_name = self.settings.x_jupyter_name + else: + notebook_name = self.settings.x_jupyter_path + name_string = f"{self._settings.project}-{notebook_name}" + else: + name_string = ( + f"{self._settings.project}-{self._settings.program_relpath}" + ) + name = wandb.util.make_artifact_name_safe(f"source-{name_string}") + art = InternalArtifact(name, "code") + files_added = False + if root is not None: + root = os.path.abspath(root) + for file_path in filenames.filtered_dir(root, include_fn, exclude_fn): + files_added = True + save_name = os.path.relpath(file_path, root) + art.add_file(file_path, name=save_name) + # Add any manually staged files such as ipynb notebooks + for dirpath, _, files in os.walk(self._settings._tmp_code_dir): + for fname in files: + file_path = os.path.join(dirpath, fname) + save_name = os.path.relpath(file_path, self._settings._tmp_code_dir) + files_added = True + art.add_file(file_path, name=save_name) + if not files_added: + wandb.termwarn( + "No relevant files were detected in the specified directory. No code will be logged to your run." + ) + return None + + artifact = self._log_artifact(art) + + self._config.update( + {"_wandb": {"code_path": artifact.name}}, + allow_val_change=True, + ) + + return artifact + + @_log_to_run + def get_sweep_url(self) -> str | None: + """This method is deprecated and will be removed in a future release. Use `run.sweep_url` instead. + + The URL of the sweep associated with the run, if there is one. + Offline runs do not have a sweep URL. + + + """ + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__get_sweep_url=True), + message=( + "The get_sweep_url method is deprecated and will be removed in a" + " future release. Please use `run.sweep_url` instead." + ), + ) + return self.sweep_url + + @property + @_attach + def sweep_url(self) -> str | None: + """URL of the sweep associated with the run, if there is one. + + Offline runs do not have a sweep URL. + """ + if self._settings._offline: + wandb.termwarn("URL not available in offline run") + return None + return self._settings.sweep_url + + @_log_to_run + def get_url(self) -> str | None: + """This method is deprecated and will be removed in a future release. Use `run.url` instead. + + URL of the W&B run, if there is one. Offline runs do not have a URL. + + + """ + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__get_url=True), + message=( + "The get_url method is deprecated and will be removed in a" + " future release. Please use `run.url` instead." + ), + ) + return self.url + + @property + @_log_to_run + @_attach + def url(self) -> str | None: + """The url for the W&B run, if there is one. + + Offline runs will not have a url. + """ + if self._settings._offline: + wandb.termwarn("URL not available in offline run") + return None + return self._settings.run_url + + @property + @_log_to_run + @_attach + def entity(self) -> str: + """The name of the W&B entity associated with the run. + + Entity can be a username or the name of a team or organization. + """ + return self._settings.entity or "" + + def _label_internal( + self, + code: str | None = None, + repo: str | None = None, + code_version: str | None = None, + ) -> None: + with telemetry.context(run=self) as tel: + if code and RE_LABEL.match(code): + tel.label.code_string = code + if repo and RE_LABEL.match(repo): + tel.label.repo_string = repo + if code_version and RE_LABEL.match(code_version): + tel.label.code_version = code_version + + def _label( + self, + code: str | None = None, + repo: str | None = None, + code_version: str | None = None, + **kwargs: str, + ) -> None: + if self._settings.label_disable: + return + for k, v in (("code", code), ("repo", repo), ("code_version", code_version)): + if v and not RE_LABEL.match(v): + wandb.termwarn( + f"Label added for '{k}' with invalid identifier '{v}' (ignored).", + repeat=False, + ) + for v in kwargs: + wandb.termwarn( + f"Label added for unsupported key {v!r} (ignored).", + repeat=False, + ) + + self._label_internal(code=code, repo=repo, code_version=code_version) + + # update telemetry in the backend immediately for _label() callers + self._telemetry_flush() + + def _label_probe_lines(self, lines: list[str]) -> None: + if not lines: + return + parsed = telemetry._parse_label_lines(lines) + if not parsed: + return + label_dict = {} + code = parsed.get("code") or parsed.get("c") + if code: + label_dict["code"] = code + repo = parsed.get("repo") or parsed.get("r") + if repo: + label_dict["repo"] = repo + code_ver = parsed.get("version") or parsed.get("v") + if code_ver: + label_dict["code_version"] = code_ver + self._label_internal(**label_dict) + + def _label_probe_main(self) -> None: + m = sys.modules.get("__main__") + if not m: + return + doc = getattr(m, "__doc__", None) + if not doc: + return + + doclines = doc.splitlines() + self._label_probe_lines(doclines) + + # TODO: annotate jupyter Notebook class + def _label_probe_notebook(self, notebook: Any) -> None: + logger.info("probe notebook") + lines = None + try: + data = notebook.probe_ipynb() + cell0 = data.get("cells", [])[0] + lines = cell0.get("source") + # kaggle returns a string instead of a list + if isinstance(lines, str): + lines = lines.split() + except Exception as e: + logger.info(f"Unable to probe notebook: {e}") + return + if lines: + self._label_probe_lines(lines) + + @_log_to_run + @_attach + def display(self, height: int = 420, hidden: bool = False) -> bool: + """Display this run in Jupyter.""" + if self._settings.silent: + return False + + if not ipython.in_jupyter() or ipython.in_vscode_notebook(): + return False + + try: + from IPython import display + except ImportError: + wandb.termwarn(".display() only works in jupyter environments") + return False + + display.display(display.HTML(self.to_html(height, hidden))) + return True + + @_log_to_run + @_attach + def to_html(self, height: int = 420, hidden: bool = False) -> str: + """Generate HTML containing an iframe displaying the current run. + + If the run is being displayed in a VSCode notebook, + the string representation of the run is returned instead. + + + """ + if ipython.in_vscode_notebook(): + import html + + return html.escape(str(self)) + + url = self._settings.run_url + "?jupyter=true" + style = f"border:none;width:100%;height:{height}px;" + prefix = "" + if hidden: + style += "display:none;" + prefix = ipython.toggle_button() + return prefix + f"" + + def _repr_mimebundle_( + self, include: Any | None = None, exclude: Any | None = None + ) -> dict[str, str]: + return {"text/html": self.to_html(hidden=True)} + + @_log_to_run + @_raise_if_finished + def _config_callback( + self, + key: tuple[str, ...] | str | None = None, + val: Any | None = None, + data: dict[str, object] | None = None, + ) -> None: + logger.info(f"config_cb {key} {val} {data}") + if self._backend and self._backend.interface: + self._backend.interface.publish_config(key=key, val=val, data=data) + + @_log_to_run + def _config_artifact_callback( + self, key: str, val: str | Artifact | dict + ) -> Artifact: + from wandb.apis import public + from wandb.sdk.artifacts.artifact import Artifact + + # artifacts can look like dicts as they are passed into the run config + # since the run config stores them on the backend as a dict with fields shown + # in wandb.util.artifact_to_json + if _is_artifact_version_weave_dict(val): + assert isinstance(val, dict) + public_api = self._public_api() + artifact = Artifact._from_id(val["id"], public_api.client) + + assert artifact + return self.use_artifact(artifact) + elif _is_artifact_string(val): + # this will never fail, but is required to make mypy happy + assert isinstance(val, str) + artifact_string, base_url, is_id = parse_artifact_string(val) + overrides = {} + if base_url is not None: + overrides = {"base_url": base_url} + public_api = public.Api(overrides) + else: + public_api = self._public_api() + if is_id: + artifact = Artifact._from_id(artifact_string, public_api._client) + else: + artifact = public_api._artifact(name=artifact_string) + # in the future we'll need to support using artifacts from + # different instances of wandb. + + assert artifact + return self.use_artifact(artifact) + elif _is_artifact_object(val): + return self.use_artifact(val) + else: + raise ValueError( + f"Cannot call _config_artifact_callback on type {type(val)}" + ) + + def _set_config_wandb(self, key: str, val: Any) -> None: + self._config_callback(key=("_wandb", key), val=val) + + @_log_to_run + @_raise_if_finished + def _summary_update_callback(self, summary_record: SummaryRecord) -> None: + with telemetry.context(run=self) as tel: + tel.feature.set_summary = True + if self._backend and self._backend.interface: + self._backend.interface.publish_summary(self, summary_record) + + @_log_to_run + def _summary_get_current_summary_callback(self) -> dict[str, Any]: + if self._is_finished: + # TODO: WB-18420: fetch summary from backend and stage it before run is finished + wandb.termwarn("Summary data not available in finished run") + return {} + if not self._backend or not self._backend.interface: + return {} + handle = self._backend.interface.deliver_get_summary() + + try: + result = handle.wait_or(timeout=self._settings.summary_timeout) + except TimeoutError: + return {} + + get_summary_response = result.response.get_summary_response + return proto_util.dict_from_proto_list(get_summary_response.item) + + @_log_to_run + def _metric_callback(self, metric_record: MetricRecord) -> None: + if self._backend and self._backend.interface: + self._backend.interface._publish_metric(metric_record) + + @_log_to_run + def _publish_file(self, fname: str) -> None: + """Mark a run file to be uploaded with the run. + + This is a W&B-internal function: it can be used by other internal + wandb code. + + Args: + fname: The path to the file in the run's files directory, relative + to the run's files directory. + """ + if not self._backend or not self._backend.interface: + return + files: FilesDict = dict(files=[(GlobStr(fname), "now")]) + self._backend.interface.publish_files(files) + + def _pop_all_charts( + self, + data: dict[str, Any], + key_prefix: str | None = None, + ) -> dict[str, Any]: + """Pops all charts from a dictionary including nested charts. + + This function will return a mapping of the charts and a dot-separated + key for each chart. Indicating the path to the chart in the data dictionary. + """ + keys_to_remove = set() + charts: dict[str, Any] = {} + for k, v in data.items(): + key = f"{key_prefix}.{k}" if key_prefix else k + if isinstance(v, Visualize): + keys_to_remove.add(k) + charts[key] = v + elif isinstance(v, CustomChart): + keys_to_remove.add(k) + charts[key] = v + elif isinstance(v, dict): + nested_charts = self._pop_all_charts(v, key) + charts.update(nested_charts) + + for k in keys_to_remove: + data.pop(k) + + return charts + + def _serialize_custom_charts( + self, + data: dict[str, Any], + ) -> dict[str, Any]: + """Process and replace chart objects with their underlying table values. + + This processes the chart objects passed to `wandb.Run.log()`, replacing their entries + in the given dictionary (which is saved to the run's history) and adding them + to the run's config. + + Args: + data: Dictionary containing data that may include plot objects + Plot objects can be nested in dictionaries, which will be processed recursively. + + Returns: + The processed dictionary with custom charts transformed into tables. + """ + if not data: + return data + + charts = self._pop_all_charts(data) + for k, v in charts.items(): + v.set_key(k) + self._config_callback( + val=v.spec.config_value, + key=v.spec.config_key, + ) + + if isinstance(v, CustomChart): + data[v.spec.table_key] = v.table + elif isinstance(v, Visualize): + data[k] = v.table + + return data + + @_log_to_run + def _partial_history_callback( + self, + data: dict[str, Any], + step: int | None = None, + commit: bool | None = None, + ) -> None: + if not (self._backend and self._backend.interface): + return + + data = data.copy() # avoid modifying the original data + + # Serialize custom charts before publishing + data = self._serialize_custom_charts(data) + + not_using_tensorboard = len(wandb.patched["tensorboard"]) == 0 + self._backend.interface.publish_partial_history( + self, + data, + user_step=self._step, + step=step, + flush=commit, + publish_step=not_using_tensorboard, + ) + + @_log_to_run + def _console_callback(self, name: str, data: str) -> None: + if self._backend and self._backend.interface: + # nowait=True so that this can be called from an asyncio context. + self._backend.interface.publish_output(name, data, nowait=True) + + @_log_to_run + @_raise_if_finished + def _console_raw_callback(self, name: str, data: str) -> None: + # NOTE: console output is only allowed on the process which installed the callback + # this will prevent potential corruption in the socket to the service. Other methods + # are protected by the _attach run decorator, but this callback was installed on the + # write function of stdout and stderr streams. + console_pid = getattr(self, "_attach_pid", 0) + if console_pid != os.getpid(): + return + + if self._backend and self._backend.interface: + # nowait=True so that this can be called from an asyncio context. + self._backend.interface.publish_output_raw(name, data, nowait=True) + + @_log_to_run + def _tensorboard_callback( + self, logdir: str, save: bool = True, root_logdir: str = "" + ) -> None: + logger.info("tensorboard callback: %s, %s", logdir, save) + if self._backend and self._backend.interface: + self._backend.interface.publish_tbdata(logdir, save, root_logdir) + + def _set_library(self, library: _WandbSetup) -> None: + self._wl = library + + def _set_backend(self, backend: wandb.sdk.backend.backend.Backend) -> None: + self._backend = backend + + def _set_internal_run_interface( + self, + interface: wandb.sdk.interface.interface_queue.InterfaceQueue, + ) -> None: + self._internal_run_interface = interface + + def _set_teardown_hooks(self, hooks: list[TeardownHook]) -> None: + self._teardown_hooks = hooks + + def _set_run_obj(self, run_obj: RunRecord) -> None: # noqa: C901 + if run_obj.starting_step: + self._starting_step = run_obj.starting_step + self._step = run_obj.starting_step + + if run_obj.start_time: + self._start_time = run_obj.start_time.ToMicroseconds() / 1e6 + + if run_obj.runtime: + self._start_runtime = run_obj.runtime + + # Grab the config from resuming + if run_obj.config: + c_dict = config_util.dict_no_value_from_proto_list(run_obj.config.update) + # We update the config object here without triggering the callback + self._config._update(c_dict, allow_val_change=True, ignore_locked=True) + # Update the summary, this will trigger an un-needed graphql request :( + if run_obj.summary: + summary_dict = {} + for orig in run_obj.summary.update: + summary_dict[orig.key] = json.loads(orig.value_json) + if summary_dict: + self.summary.update(summary_dict) + + # update settings from run_obj + if run_obj.run_id: + self._settings.run_id = run_obj.run_id + if run_obj.entity: + self._settings.entity = run_obj.entity + if run_obj.project: + self._settings.project = run_obj.project + if run_obj.run_group: + self._settings.run_group = run_obj.run_group + if run_obj.job_type: + self._settings.run_job_type = run_obj.job_type + if run_obj.display_name: + self._settings.run_name = run_obj.display_name + if run_obj.notes: + self._settings.run_notes = run_obj.notes + if run_obj.tags: + self._settings.run_tags = tuple(run_obj.tags) + if run_obj.sweep_id: + self._settings.sweep_id = run_obj.sweep_id + if run_obj.host: + self._settings.host = run_obj.host + if run_obj.resumed: + self._settings.resumed = run_obj.resumed + if run_obj.git: + if run_obj.git.remote_url: + self._settings.git_remote_url = run_obj.git.remote_url + if run_obj.git.commit: + self._settings.git_commit = run_obj.git.commit + + if run_obj.forked: + self._forked = run_obj.forked + + get_sentry().configure_scope( + process_context="user", + tags=dict(self._settings), + ) + + def _populate_git_info(self) -> None: + from .lib.gitlib import GitRepo + + # Use user-provided git info if available, otherwise resolve it from the environment + try: + repo = GitRepo( + root=self._settings.git_root, + remote=self._settings.git_remote, + remote_url=self._settings.git_remote_url, + commit=self._settings.git_commit, + lazy=False, + ) + self._settings.git_remote_url = repo.remote_url + self._settings.git_commit = repo.last_commit + except Exception: + wandb.termwarn("Cannot find valid git repo associated with this directory.") + + def _add_singleton( + self, data_type: str, key: str, value: dict[int | str, str] + ) -> None: + """Store a singleton item to wandb config. + + A singleton in this context is a piece of data that is continually + logged with the same value in each history step, but represented + as a single item in the config. + + We do this to avoid filling up history with a lot of repeated unnecessary data + + Add singleton can be called many times in one run, and it will only be + updated when the value changes. The last value logged will be the one + persisted to the server. + """ + value_extra = {"type": data_type, "key": key, "value": value} + + if data_type not in self._config["_wandb"]: + self._config["_wandb"][data_type] = {} + + if data_type in self._config["_wandb"][data_type]: + old_value = self._config["_wandb"][data_type][key] + else: + old_value = None + + if value_extra != old_value: + self._config["_wandb"][data_type][key] = value_extra + self._config.persist() + + def _log( + self, + data: dict[str, Any], + step: int | None = None, + commit: bool | None = None, + ) -> None: + if not isinstance(data, Mapping): + raise TypeError("wandb.log must be passed a dictionary") + + if any(not isinstance(key, str) for key in data.keys()): + raise TypeError("Key values passed to `wandb.log` must be strings.") + + self._partial_history_callback(data, step, commit) + + if step is not None: + if os.getpid() != self._init_pid or self._is_attached: + wandb.termwarn( + "Note that setting step in multiprocessing can result in data loss. " + "Please use `run.define_metric(...)` to define a custom metric " + "to log your step values.", + repeat=False, + ) + # if step is passed in when tensorboard_sync is used we honor the step passed + # to make decisions about how to close out the history record, but will strip + # this history later on in publish_history() + if len(wandb.patched["tensorboard"]) > 0: + wandb.termwarn( + "Step cannot be set when using tensorboard syncing. " + "Please use `run.define_metric(...)` to define a custom metric " + "to log your step values.", + repeat=False, + ) + if step > self._step: + self._step = step + + if (step is None and commit is None) or commit: + self._step += 1 + + @_log_to_run + @_raise_if_finished + @_attach + def log( + self, + data: dict[str, Any], + step: int | None = None, + commit: bool | None = None, + ) -> None: + """Upload run data. + + Use `log` to log data from runs, such as scalars, images, video, + histograms, plots, and tables. See [Log objects and media](https://docs.wandb.ai/guides/track/log) for + code snippets, best practices, and more. + + Basic usage: + + ```python + import wandb + + with wandb.init() as run: + run.log({"train-loss": 0.5, "accuracy": 0.9}) + ``` + + The previous code snippet saves the loss and accuracy to the run's + history and updates the summary values for these metrics. + + Visualize logged data in a workspace at [wandb.ai](https://wandb.ai), + or locally on a [self-hosted instance](https://docs.wandb.ai/guides/hosting) + of the W&B app, or export data to visualize and explore locally, such as in a + Jupyter notebook, with the [Public API](https://docs.wandb.ai/guides/track/public-api-guide). + + Logged values don't have to be scalars. You can log any + [W&B supported Data Type](https://docs.wandb.ai/ref/python/data-types/) + such as images, audio, video, and more. For example, you can use + `wandb.Table` to log structured data. See + [Log tables, visualize and query data](https://docs.wandb.ai/guides/models/tables/tables-walkthrough) + tutorial for more details. + + W&B organizes metrics with a forward slash (`/`) in their name + into sections named using the text before the final slash. For example, + the following results in two sections named "train" and "validate": + + ```python + with wandb.init() as run: + # Log metrics in the "train" section. + run.log( + { + "train/accuracy": 0.9, + "train/loss": 30, + "validate/accuracy": 0.8, + "validate/loss": 20, + } + ) + ``` + + Only one level of nesting is supported; `run.log({"a/b/c": 1})` + produces a section named "a". + + `run.log()` is not intended to be called more than a few times per second. + For optimal performance, limit your logging to once every N iterations, + or collect data over multiple iterations and log it in a single step. + + By default, each call to `log` creates a new "step". + The step must always increase, and it is not possible to log + to a previous step. You can use any metric as the X axis in charts. + See [Custom log axes](https://docs.wandb.ai/guides/track/log/customize-logging-axes/) + for more details. + + In many cases, it is better to treat the W&B step like + you'd treat a timestamp rather than a training step. + + ```python + with wandb.init() as run: + # Example: log an "epoch" metric for use as an X axis. + run.log({"epoch": 40, "train-loss": 0.5}) + ``` + + It is possible to use multiple `wandb.Run.log()` invocations to log to + the same step with the `step` and `commit` parameters. + The following are all equivalent: + + ```python + with wandb.init() as run: + # Normal usage: + run.log({"train-loss": 0.5, "accuracy": 0.8}) + run.log({"train-loss": 0.4, "accuracy": 0.9}) + + # Implicit step without auto-incrementing: + run.log({"train-loss": 0.5}, commit=False) + run.log({"accuracy": 0.8}) + run.log({"train-loss": 0.4}, commit=False) + run.log({"accuracy": 0.9}) + + # Explicit step: + run.log({"train-loss": 0.5}, step=current_step) + run.log({"accuracy": 0.8}, step=current_step) + current_step += 1 + run.log({"train-loss": 0.4}, step=current_step) + run.log({"accuracy": 0.9}, step=current_step, commit=True) + ``` + + Args: + data: A `dict` with `str` keys and values that are serializable + Python objects including: `int`, `float` and `string`; + any of the `wandb.data_types`; lists, tuples and NumPy arrays + of serializable Python objects; other `dict`s of this + structure. + step: The step number to log. If `None`, then an implicit + auto-incrementing step is used. See the notes in + the description. + commit: If true, finalize and upload the step. If false, then + accumulate data for the step. See the notes in the description. + If `step` is `None`, then the default is `commit=True`; + otherwise, the default is `commit=False`. + + Examples: + For more and more detailed examples, see + [our guides to logging](https://docs.wandb.com/guides/track/log). + + Basic usage + + ```python + import wandb + + with wandb.init() as run: + run.log({"train-loss": 0.5, "accuracy": 0.9 + ``` + + Incremental logging + + ```python + import wandb + + with wandb.init() as run: + run.log({"loss": 0.2}, commit=False) + # Somewhere else when I'm ready to report this step: + run.log({"accuracy": 0.8}) + ``` + + Histogram + + ```python + import numpy as np + import wandb + + # sample gradients at random from normal distribution + gradients = np.random.randn(100, 100) + with wandb.init() as run: + run.log({"gradients": wandb.Histogram(gradients)}) + ``` + + Image from NumPy + + ```python + import numpy as np + import wandb + + with wandb.init() as run: + examples = [] + for i in range(3): + pixels = np.random.randint(low=0, high=256, size=(100, 100, 3)) + image = wandb.Image(pixels, caption=f"random field {i}") + examples.append(image) + run.log({"examples": examples}) + ``` + + Image from PIL + + ```python + import numpy as np + from PIL import Image as PILImage + import wandb + + with wandb.init() as run: + examples = [] + for i in range(3): + pixels = np.random.randint( + low=0, + high=256, + size=(100, 100, 3), + dtype=np.uint8, + ) + pil_image = PILImage.fromarray(pixels, mode="RGB") + image = wandb.Image(pil_image, caption=f"random field {i}") + examples.append(image) + run.log({"examples": examples}) + ``` + + Video from NumPy + + ```python + import numpy as np + import wandb + + with wandb.init() as run: + # axes are (time, channel, height, width) + frames = np.random.randint( + low=0, + high=256, + size=(10, 3, 100, 100), + dtype=np.uint8, + ) + run.log({"video": wandb.Video(frames, fps=4)}) + ``` + + Matplotlib plot + + ```python + from matplotlib import pyplot as plt + import numpy as np + import wandb + + with wandb.init() as run: + fig, ax = plt.subplots() + x = np.linspace(0, 10) + y = x * x + ax.plot(x, y) # plot y = x^2 + run.log({"chart": fig}) + ``` + + PR Curve + + ```python + import wandb + + with wandb.init() as run: + run.log({"pr": wandb.plot.pr_curve(y_test, y_probas, labels)}) + ``` + + 3D Object + + ```python + import wandb + + with wandb.init() as run: + run.log( + { + "generated_samples": [ + wandb.Object3D(open("sample.obj")), + wandb.Object3D(open("sample.gltf")), + wandb.Object3D(open("sample.glb")), + ] + } + ) + ``` + + Raises: + wandb.Error: If called before `wandb.init()`. + ValueError: If invalid data is passed. + + """ + if step is not None: + with telemetry.context(run=self) as tel: + tel.feature.set_step_log = True + + if self._settings._shared and step is not None: + wandb.termwarn( + "In shared mode, the use of `wandb.log` with the step argument is not supported " + f"and will be ignored. Please refer to {url_registry.url('define-metric')} " + "on how to customize your x-axis.", + repeat=False, + ) + self._log(data=data, step=step, commit=commit) + + @_log_to_run + @_raise_if_finished + @_attach + def save( + self, + glob_str: str | os.PathLike, + base_path: str | os.PathLike | None = None, + policy: PolicyName = "live", + ) -> bool | list[str]: + """Sync one or more files to W&B. + + Relative paths are relative to the current working directory. + + A Unix glob, such as "myfiles/*", is expanded at the time `save` is + called regardless of the `policy`. In particular, new files are not + picked up automatically. + + A `base_path` may be provided to control the directory structure of + uploaded files. It should be a prefix of `glob_str`, and the directory + structure beneath it is preserved. + + When given an absolute path or glob and no `base_path`, one + directory level is preserved as in the example above. + + Files are automatically deduplicated: calling `save()` multiple times + on the same file without modifications will not re-upload it. + + Args: + glob_str: A relative or absolute path or Unix glob. + base_path: A path to use to infer a directory structure; see examples. + policy: One of `live`, `now`, or `end`. + - live: upload the file as it changes, overwriting the previous version + - now: upload the file once now + - end: upload file when the run ends + + Returns: + Paths to the symlinks created for the matched files. + + For historical reasons, this may return a boolean in legacy code. + + ```python + import wandb + + run = wandb.init() + + run.save("these/are/myfiles/*") + # => Saves files in a "these/are/myfiles/" folder in the run. + + run.save("these/are/myfiles/*", base_path="these") + # => Saves files in an "are/myfiles/" folder in the run. + + run.save("/Users/username/Documents/run123/*.txt") + # => Saves files in a "run123/" folder in the run. See note below. + + run.save("/Users/username/Documents/run123/*.txt", base_path="/Users") + # => Saves files in a "username/Documents/run123/" folder in the run. + + run.save("files/*/saveme.txt") + # => Saves each "saveme.txt" file in an appropriate subdirectory + # of "files/". + + # Explicitly finish the run since a context manager is not used. + run.finish() + ``` + """ + if isinstance(glob_str, bytes): + # Preserved for backward compatibility: allow bytes inputs. + glob_str = glob_str.decode("utf-8") + if isinstance(glob_str, str) and (glob_str.startswith(("gs://", "s3://"))): + # Provide a better error message for a common misuse. + wandb.termlog(f"{glob_str} is a cloud storage url, can't save file to W&B.") + return [] + # NOTE: We use PurePath instead of Path because WindowsPath doesn't + # like asterisks and errors out in resolve(). It also makes logical + # sense: globs aren't real paths, they're just path-like strings. + glob_path = pathlib.PurePath(glob_str) + resolved_glob_path = pathlib.PurePath(os.path.abspath(glob_path)) + + if base_path is not None: + base_path = pathlib.Path(base_path) + elif not glob_path.is_absolute(): + base_path = pathlib.Path(".") + else: + # Absolute glob paths with no base path get special handling. + wandb.termwarn( + "Saving files without folders. If you want to preserve " + "subdirectories pass base_path to wandb.save, i.e. " + 'wandb.save("/mnt/folder/file.h5", base_path="/mnt")', + repeat=False, + ) + base_path = resolved_glob_path.parent.parent + + if policy not in ("live", "end", "now"): + raise ValueError( + 'Only "live", "end" and "now" policies are currently supported.' + ) + + resolved_base_path = pathlib.PurePath(os.path.abspath(base_path)) + + return self._save( + resolved_glob_path, + resolved_base_path, + policy, + ) + + def _save( + self, + glob_path: pathlib.PurePath, + base_path: pathlib.PurePath, + policy: PolicyName, + ) -> list[str]: + """Materialize matched files into the run's files/ dir for syncing. + + Strategy: + 1) If settings.symlink is True, try symlink. + 2) Else (or if symlink fails), try hardlink (same-volume files). + 3) Else copy and, if requested policy == "live", downgrade those files to "now". + + Args: + glob_path: Absolute path glob pattern for files to save. + base_path: Base path to determine relative directory structure. + policy: Upload policy - "live", "now", or "end". + + Returns: + List of absolute paths to files in the wandb run directory. + + Raises: + ValueError: If glob_path is invalid relative to base_path. + """ + validate_glob_path(glob_path, base_path) + + relative_glob = glob_path.relative_to(base_path) + relative_glob_str = GlobStr(str(relative_glob)) + + with telemetry.context(run=self) as tel: + tel.feature.save = True + + files_root = pathlib.Path(self._settings.files_dir) + preexisting = set(files_root.glob(relative_glob_str)) + + # Expand sources deterministically. + src_paths = [ + pathlib.Path(p).absolute() + for p in sorted(glob.glob(GlobStr(str(base_path / relative_glob_str)))) + ] + + stats = LinkStats() + publish_entries = [] + created_targets = set() + + for src in src_paths: + # Preserve directory structure under base_path. + rel = pathlib.Path(*src.parts[len(base_path.parts) :]) + dst = files_root / rel + created_targets.add(dst) + + # If already the same file, just publish with requested policy. + with contextlib.suppress(OSError): + if dst.exists() and src.samefile(dst): + publish_entries.append( + (GlobStr(str(dst.relative_to(files_root))), policy) + ) + continue + + dst.parent.mkdir(parents=True, exist_ok=True) + unlink_path(dst) + + effective_policy = link_or_copy_with_policy( + self._settings, src, dst, policy, stats + ) + publish_entries.append( + (GlobStr(str(dst.relative_to(files_root))), effective_policy) + ) + + # Include pre-existing matches we didn't touch. + for p in sorted(preexisting): + if p not in created_targets: + publish_entries.append( + (GlobStr(str(p.relative_to(files_root))), policy) + ) + + stats.emit_warnings() + + files_dict: FilesDict = {"files": publish_entries} + if self._backend and self._backend.interface: + self._backend.interface.publish_files(files_dict) + + abs_targets = {files_root / pathlib.Path(g) for (g, _pol) in publish_entries} + return [str(p) for p in sorted(abs_targets)] + + @_log_to_run + @_attach + def restore( + self, + name: str, + run_path: str | None = None, + replace: bool = False, + root: str | None = None, + ) -> None | TextIO: + return restore( + name, + run_path or self._get_path(), + replace, + root or self._settings.files_dir, + ) + + @_log_to_run + @_attach + def finish( + self, + exit_code: int | None = None, + quiet: bool | None = None, + ) -> None: + """Finish a run and upload any remaining data. + + Marks the completion of a W&B run and ensures all data is synced to the server. + The run's final state is determined by its exit conditions and sync status. + + Run States: + - Running: Active run that is logging data and/or sending heartbeats. + - Crashed: Run that stopped sending heartbeats unexpectedly. + - Finished: Run completed successfully (`exit_code=0`) with all data synced. + - Failed: Run completed with errors (`exit_code!=0`). + - Killed: Run was forcibly stopped before it could finish. + + Args: + exit_code: Integer indicating the run's exit status. Use 0 for success, + any other value marks the run as failed. + quiet: Deprecated. Configure logging verbosity using `wandb.Settings(quiet=...)`. + """ + if quiet is not None: + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__finish_quiet=True), + message=( + "The `quiet` argument to `wandb.run.finish()` is deprecated, " + "use `wandb.Settings(quiet=...)` to set this instead." + ), + run=self, + ) + return self._finish(exit_code) + + @_log_to_run + def _finish( + self, + exit_code: int | None = None, + ) -> None: + if self._is_finished: + return + + assert self._wl + + logger.info(f"finishing run {self._get_path()}") + with telemetry.context(run=self) as tel: + tel.feature.finish = True + + # Run hooks that need to happen before the last messages to the + # internal service, like Jupyter hooks. + for hook in self._teardown_hooks: + if hook.stage == TeardownStage.EARLY: + hook.call() + + # Early-stage hooks may use methods that require _is_finished + # to be False, so we set this after running those hooks. + self._is_finished = True + self._wl.remove_active_run(self) + + try: + self._atexit_cleanup(exit_code=exit_code) + + # Run hooks that should happen after the last messages to the + # internal service, like detaching the logger. + for hook in self._teardown_hooks: + if hook.stage == TeardownStage.LATE: + hook.call() + self._teardown_hooks = [] + + # Inform the service that we're done sending messages for this run. + # + # TODO: Why not do this in _atexit_cleanup()? + if self._settings.run_id: + service = self._wl.assert_service() + service.inform_finish(run_id=self._settings.run_id) + + finally: + if wandb.run is self: + module.unset_globals() + get_sentry().end_session() + + @_log_to_run + @_raise_if_finished + @_attach + def status( + self, + ) -> RunStatus: + """Get sync info from the internal backend, about the current run's sync status.""" + if not self._backend or not self._backend.interface: + return RunStatus() + + handle_run_status = self._backend.interface.deliver_request_run_status() + result = handle_run_status.wait_or(timeout=None) + sync_data = result.response.run_status_response + + sync_time = None + if sync_data.sync_time.seconds: + sync_time = datetime.fromtimestamp( + sync_data.sync_time.seconds + sync_data.sync_time.nanos / 1e9 + ) + return RunStatus( + sync_items_total=sync_data.sync_items_total, + sync_items_pending=sync_data.sync_items_pending, + sync_time=sync_time, + ) + + def _add_panel( + self, visualize_key: str, panel_type: str, panel_config: dict + ) -> None: + config = { + "panel_type": panel_type, + "panel_config": panel_config, + } + self._config_callback(val=config, key=("_wandb", "visualize", visualize_key)) + + def _redirect( + self, + stdout_slave_fd: int | None, + stderr_slave_fd: int | None, + console: str | None = None, + ) -> None: + if console is None: + console = self._settings.console + # only use raw for service to minimize potential changes + if console == "wrap": + console = "wrap_raw" + logger.info("redirect: %s", console) + + out_redir: redirect.RedirectBase + err_redir: redirect.RedirectBase + + # raw output handles the output_log writing in the internal process + if console in {"redirect", "wrap_emu"}: + output_log_path = os.path.join( + self._settings.files_dir, filenames.OUTPUT_FNAME + ) + # output writer might have been set up, see wrap_fallback case + if not self._output_writer: + self._output_writer = filesystem.CRDedupedFile( + open(output_log_path, "wb") + ) + + if console == "redirect": + logger.info("Redirecting console.") + out_redir = redirect.Redirect( + src="stdout", + cbs=[ + lambda data: self._console_callback("stdout", data), + self._output_writer.write, # type: ignore + ], + flush_periodically=(self._settings.mode == "online"), + ) + err_redir = redirect.Redirect( + src="stderr", + cbs=[ + lambda data: self._console_callback("stderr", data), + self._output_writer.write, # type: ignore + ], + flush_periodically=(self._settings.mode == "online"), + ) + if os.name == "nt": + + def wrap_fallback() -> None: + if self._out_redir: + self._out_redir.uninstall() + if self._err_redir: + self._err_redir.uninstall() + msg = ( + "Tensorflow detected. Stream redirection is not supported " + "on Windows when tensorflow is imported. Falling back to " + "wrapping stdout/err." + ) + wandb.termlog(msg) + self._redirect(None, None, console="wrap") + + add_import_hook("tensorflow", wrap_fallback) + elif console == "wrap_emu": + logger.info("Wrapping output streams.") + out_redir = redirect.StreamWrapper( + src="stdout", + cbs=[ + lambda data: self._console_callback("stdout", data), + self._output_writer.write, # type: ignore + ], + flush_periodically=(self._settings.mode == "online"), + ) + err_redir = redirect.StreamWrapper( + src="stderr", + cbs=[ + lambda data: self._console_callback("stderr", data), + self._output_writer.write, # type: ignore + ], + flush_periodically=(self._settings.mode == "online"), + ) + elif console == "wrap_raw": + logger.info("Wrapping output streams.") + out_redir = redirect.StreamRawWrapper( + src="stdout", + cbs=[ + lambda data: self._console_raw_callback("stdout", data), + ], + ) + err_redir = redirect.StreamRawWrapper( + src="stderr", + cbs=[ + lambda data: self._console_raw_callback("stderr", data), + ], + ) + elif console == "off": + return + else: + raise ValueError("unhandled console") + try: + # save stdout and stderr before installing new write functions + out_redir.install() + err_redir.install() + self._out_redir = out_redir + self._err_redir = err_redir + logger.info("Redirects installed.") + except Exception as e: + wandb.termwarn(f"Failed to redirect: {e}") + logger.exception("Failed to redirect.") + return + + def _restore(self) -> None: + logger.info("restore") + # TODO(jhr): drain and shutdown all threads + if self._out_redir: + self._out_redir.uninstall() + if self._err_redir: + self._err_redir.uninstall() + logger.info("restore done") + + def _atexit_cleanup(self, exit_code: int | None = None) -> None: + if self._backend is None: + logger.warning("process exited without backend configured") + return + if self._atexit_cleanup_called: + return + self._atexit_cleanup_called = True + + exit_code = exit_code or (self._hooks and self._hooks.exit_code) or 0 + self._exit_code = exit_code + logger.info(f"got exitcode: {exit_code}") + + # Delete this run's "resume" file if the run finished successfully. + # + # This is used by the "auto" resume mode, which resumes from the last + # failed (or unfinished/crashed) run. If we reach this line, then this + # run shouldn't be a candidate for "auto" resume. + if exit_code == 0: + if os.path.exists(self._settings.resume_fname): + os.remove(self._settings.resume_fname) + + try: + self._on_finish() + + except KeyboardInterrupt: + if not wandb.wandb_agent._is_running(): # type: ignore + wandb.termerror("Control-C detected -- Run data was not synced") + raise + + except Exception: + self._console_stop() + logger.exception("Problem finishing run") + wandb.termerror("Problem finishing run") + raise + + Run._footer( + sampled_history=self._sampled_history, + final_summary=self._final_summary, + poll_exit_response=self._poll_exit_response, + internal_messages_response=self._internal_messages_response, + settings=self._settings, + printer=self._printer, + ) + + def _console_start(self) -> None: + logger.info("atexit reg") + self._hooks = ExitHooks() + + self._redirect(self._stdout_slave_fd, self._stderr_slave_fd) + + def _console_stop(self) -> None: + self._restore() + if self._output_writer: + self._output_writer.close() + self._output_writer = None + + def _on_start(self) -> None: + self._header() + + if self._settings.save_code and self._settings.code_dir is not None: + self.log_code(self._settings.code_dir) + + if self._settings.x_save_requirements: + if self._backend and self._backend.interface: + from wandb.util import working_set + + logger.debug( + "Saving list of pip packages installed into the current environment" + ) + self._backend.interface.publish_python_packages(working_set()) + + if self._backend and self._backend.interface and not self._settings._offline: + assert self._settings.run_id + self._run_status_checker = RunStatusChecker( + self._settings.run_id, + interface=self._backend.interface, + settings=self._settings, + ) + self._run_status_checker.start() + + self._console_start() + self._on_ready() + + def _on_attach(self) -> None: + """Event triggered when run is attached to another run.""" + with telemetry.context(run=self) as tel: + tel.feature.attach = True + + self._is_attached = True + self._on_ready() + + def _register_telemetry_import_hooks( + self, + ) -> None: + def _telemetry_import_hook( + run: Run, + module: Any, + ) -> None: + with telemetry.context(run=run) as tel: + try: + name = getattr(module, "__name__", None) + if name is not None: + setattr(tel.imports_finish, name, True) + except AttributeError: + return + + import_telemetry_set = telemetry.list_telemetry_imports() + import_hook_fn = functools.partial(_telemetry_import_hook, self) + if not self._settings.run_id: + return + for module_name in import_telemetry_set: + register_post_import_hook( + import_hook_fn, + self._settings.run_id, + module_name, + ) + + def _on_ready(self) -> None: + """Event triggered when run is ready for the user.""" + assert self._wl + self._wl.add_active_run(self) + + self._register_telemetry_import_hooks() + + # start reporting any telemetry changes + self._telemetry_obj_active = True + self._telemetry_flush() + + try: + self._detect_and_apply_job_inputs() + except Exception: + logger.exception("Problem applying launch job inputs") + + # object is about to be returned to the user, don't let them modify it + self._freeze() + + if not self._settings.resume: + if os.path.exists(self._settings.resume_fname): + os.remove(self._settings.resume_fname) + + def _detect_and_apply_job_inputs(self) -> None: + """If the user has staged launch inputs, apply them to the run.""" + from wandb.sdk.launch.inputs.internal import StagedLaunchInputs + + StagedLaunchInputs().apply(self) + + def _make_job_source_reqs(self) -> tuple[list[str], dict[str, Any], dict[str, Any]]: + from wandb.util import working_set + + installed_packages_list = sorted(f"{d.key}=={d.version}" for d in working_set()) + input_types = TypeRegistry.type_of(self.config.as_dict()).to_json() + output_types = TypeRegistry.type_of(self.summary._as_dict()).to_json() + + return installed_packages_list, input_types, output_types + + def _construct_job_artifact( + self, + name: str, + source_dict: JobSourceDict, + installed_packages_list: list[str], + patch_path: os.PathLike | None = None, + ) -> Artifact: + from wandb.sdk.artifacts._internal_artifact import InternalArtifact + from wandb.sdk.internal import job_builder + + job_artifact = InternalArtifact(name, job_builder.JOB_ARTIFACT_TYPE) + if patch_path and os.path.exists(patch_path): + job_artifact.add_file(FilePathStr(patch_path), "diff.patch") + with job_artifact.new_file("requirements.frozen.txt") as f: + f.write("\n".join(installed_packages_list)) + with job_artifact.new_file("wandb-job.json") as f: + f.write(json.dumps(source_dict)) + + return job_artifact + + def _create_image_job( + self, + input_types: dict[str, Any], + output_types: dict[str, Any], + installed_packages_list: list[str], + docker_image_name: str | None = None, + args: list[str] | None = None, + ) -> Artifact | None: + docker_image_name = docker_image_name or os.getenv("WANDB_DOCKER") + + if not docker_image_name: + return None + + name = wandb.util.make_artifact_name_safe(f"job-{docker_image_name}") + s_args: Sequence[str] = args if args is not None else self._settings._args + source_info: JobSourceDict = { + "_version": "v0", + "source_type": "image", + "source": {"image": docker_image_name, "args": s_args}, + "input_types": input_types, + "output_types": output_types, + "runtime": self._settings._python, + } + job_artifact = self._construct_job_artifact( + name, source_info, installed_packages_list + ) + + return job_artifact + + def _log_job_artifact_with_image( + self, docker_image_name: str, args: list[str] | None = None + ) -> Artifact: + packages, in_types, out_types = self._make_job_source_reqs() + job_artifact = self._create_image_job( + in_types, + out_types, + packages, + args=args, + docker_image_name=docker_image_name, + ) + + assert job_artifact + artifact = self.log_artifact(job_artifact) + + if not artifact: + raise wandb.Error(f"Job Artifact log unsuccessful: {artifact}") + else: + return artifact + + def _on_finish(self) -> None: + trigger.call("on_finished") + + if self._run_status_checker is not None: + self._run_status_checker.stop() + + self._console_stop() # TODO: there's a race here with jupyter console logging + + assert self._backend and self._backend.interface + + if self._settings.x_update_finish_state: + exit_handle = self._backend.interface.deliver_exit(self._exit_code) + else: + exit_handle = self._backend.interface.deliver_finish_without_exit() + + with progress.progress_printer( + self._printer, + default_text="Finishing up...", + ) as progress_printer: + # Wait for the run to complete. + wait_with_progress( + exit_handle, + timeout=None, + display_progress=functools.partial( + progress.loop_printing_operation_stats, + progress_printer, + self._backend.interface, + ), + ) + + poll_exit_handle = self._backend.interface.deliver_poll_exit() + result = poll_exit_handle.wait_or(timeout=None) + self._poll_exit_response = result.response.poll_exit_response + + internal_messages_handle = self._backend.interface.deliver_internal_messages() + result = internal_messages_handle.wait_or(timeout=None) + self._internal_messages_response = result.response.internal_messages_response + + # dispatch all our final requests + + final_summary_handle = self._backend.interface.deliver_get_summary() + sampled_history_handle = ( + self._backend.interface.deliver_request_sampled_history() + ) + + result = sampled_history_handle.wait_or(timeout=None) + self._sampled_history = result.response.sampled_history_response + + result = final_summary_handle.wait_or(timeout=None) + self._final_summary = result.response.get_summary_response + + if self._backend: + self._backend.cleanup() + + if self._run_status_checker: + self._run_status_checker.join() + + if self._settings.run_id: + self._unregister_telemetry_import_hooks(self._settings.run_id) + + @staticmethod + def _unregister_telemetry_import_hooks(run_id: str) -> None: + import_telemetry_set = telemetry.list_telemetry_imports() + for module_name in import_telemetry_set: + unregister_post_import_hook(module_name, run_id) + + @_log_to_run + @_raise_if_finished + @_attach + def define_metric( + self, + name: str, + step_metric: str | wandb_metric.Metric | None = None, + step_sync: bool | None = None, + hidden: bool | None = None, + summary: str | None = None, + goal: str | None = None, + overwrite: bool | None = None, + ) -> wandb_metric.Metric: + """Customize metrics logged with `wandb.Run.log()`. + + Args: + name: The name of the metric to customize. + step_metric: The name of another metric to serve as the X-axis + for this metric in automatically generated charts. + step_sync: Automatically insert the last value of step_metric into + `wandb.Run.log()` if it is not provided explicitly. Defaults to True + if step_metric is specified. + hidden: Hide this metric from automatic plots. + summary: Specify aggregate metrics added to summary. + Supported aggregations include "min", "max", "mean", "last", + "first", "best", "copy" and "none". "none" prevents a summary + from being generated. "best" is used together with the goal + parameter, "best" is deprecated and should not be used, use + "min" or "max" instead. "copy" is deprecated and should not be + used. + goal: Specify how to interpret the "best" summary type. + Supported options are "minimize" and "maximize". "goal" is + deprecated and should not be used, use "min" or "max" instead. + overwrite: If false, then this call is merged with previous + `define_metric` calls for the same metric by using their + values for any unspecified parameters. If true, then + unspecified parameters overwrite values specified by + previous calls. + + Returns: + An object that represents this call but can otherwise be discarded. + """ + if summary and "copy" in summary: + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__define_metric_copy=True), + message="define_metric(summary='copy') is deprecated and will be removed.", + run=self, + ) + + if (summary and "best" in summary) or goal is not None: + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__define_metric_best_goal=True), + message="define_metric(summary='best', goal=...) is deprecated and will be removed. " + "Use define_metric(summary='min') or define_metric(summary='max') instead.", + run=self, + ) + + return self._define_metric( + name, + step_metric, + step_sync, + hidden, + summary, + goal, + overwrite, + ) + + def _define_metric( + self, + name: str, + step_metric: str | wandb_metric.Metric | None = None, + step_sync: bool | None = None, + hidden: bool | None = None, + summary: str | None = None, + goal: str | None = None, + overwrite: bool | None = None, + ) -> wandb_metric.Metric: + if not name: + raise wandb.Error("define_metric() requires non-empty name argument") + if isinstance(step_metric, wandb_metric.Metric): + step_metric = step_metric.name + for arg_name, arg_val, exp_type in ( + ("name", name, str), + ("step_metric", step_metric, str), + ("step_sync", step_sync, bool), + ("hidden", hidden, bool), + ("summary", summary, str), + ("goal", goal, str), + ("overwrite", overwrite, bool), + ): + # NOTE: type checking is broken for isinstance and str + if arg_val is not None and not isinstance(arg_val, exp_type): + arg_type = type(arg_val).__name__ + raise wandb.Error( + f"Unhandled define_metric() arg: {arg_name} type: {arg_type}" + ) + stripped = name[:-1] if name.endswith("*") else name + if "*" in stripped: + raise wandb.Error( + f"Unhandled define_metric() arg: name (glob suffixes only): {name}" + ) + summary_ops: Sequence[str] | None = None + if summary: + summary_items = [s.lower() for s in summary.split(",")] + summary_ops = [] + valid = {"min", "max", "mean", "best", "last", "copy", "none", "first"} + # TODO: deprecate copy and best + for i in summary_items: + if i not in valid: + raise wandb.Error(f"Unhandled define_metric() arg: summary op: {i}") + summary_ops.append(i) + with telemetry.context(run=self) as tel: + tel.feature.metric_summary = True + # TODO: deprecate goal + goal_cleaned: str | None = None + if goal is not None: + goal_cleaned = goal[:3].lower() + valid_goal = {"min", "max"} + if goal_cleaned not in valid_goal: + raise wandb.Error(f"Unhandled define_metric() arg: goal: {goal}") + with telemetry.context(run=self) as tel: + tel.feature.metric_goal = True + if hidden: + with telemetry.context(run=self) as tel: + tel.feature.metric_hidden = True + if step_sync: + with telemetry.context(run=self) as tel: + tel.feature.metric_step_sync = True + + with telemetry.context(run=self) as tel: + tel.feature.metric = True + + m = wandb_metric.Metric( + name=name, + step_metric=step_metric, + step_sync=step_sync, + summary=summary_ops, + hidden=hidden, + goal=goal_cleaned, + overwrite=overwrite, + ) + m._set_callback(self._metric_callback) + m._commit() + return m + + @_log_to_run + @_attach + def watch( + self, + models: torch.nn.Module | Sequence[torch.nn.Module], + criterion: torch.F | None = None, # type: ignore + log: Literal["gradients", "parameters", "all"] | None = "gradients", + log_freq: int = 1000, + idx: int | None = None, + log_graph: bool = False, + ) -> None: + """Hook into given PyTorch model to monitor gradients and the model's computational graph. + + This function can track parameters, gradients, or both during training. + + Args: + models: A single model or a sequence of models to be monitored. + criterion: The loss function being optimized (optional). + log: Specifies whether to log "gradients", "parameters", or "all". + Set to None to disable logging. (default="gradients"). + log_freq: Frequency (in batches) to log gradients and parameters. (default=1000) + idx: Index used when tracking multiple models with `wandb.watch`. (default=None) + log_graph: Whether to log the model's computational graph. (default=False) + + Raises: + ValueError: + If `wandb.init()` has not been called or if any of the models are not instances + of `torch.nn.Module`. + """ + wandb.sdk._watch(self, models, criterion, log, log_freq, idx, log_graph) + + @_log_to_run + @_attach + def unwatch( + self, models: torch.nn.Module | Sequence[torch.nn.Module] | None = None + ) -> None: + """Remove pytorch model topology, gradient and parameter hooks. + + Args: + models: Optional list of pytorch models that have had watch called on them. + """ + wandb.sdk._unwatch(self, models=models) + + @_log_to_run + @_raise_if_finished + @_attach + def link_artifact( + self, + artifact: Artifact, + target_path: str, + aliases: list[str] | None = None, + ) -> Artifact: + """Link the artifact to a collection. + + The term “link” refers to pointers that connect where W&B stores the + artifact and where the artifact is accessible in the registry. W&B + does not duplicate artifacts when you link an artifact to a collection. + + View linked artifacts in the Registry UI for the specified collection. + + Args: + artifact: The artifact object to link to the collection. + target_path: The path of the collection. Path consists of the prefix + "wandb-registry-" along with the registry name and the + collection name `wandb-registry-{REGISTRY_NAME}/{COLLECTION_NAME}`. + aliases: Add one or more aliases to the linked artifact. The + "latest" alias is automatically applied to the most recent artifact + you link. + + Returns: + The linked artifact. + + """ + from .artifacts._validators import ArtifactPath + + if artifact.is_draft() and not artifact._is_draft_save_started(): + artifact = self._log_artifact(artifact) + + if self._settings._offline: + # TODO: implement offline mode + sync + raise NotImplementedError + + # Normalize the target "entity/project/collection" with defaults + # inferred from this run's entity and project, if needed. + # + # HOWEVER, if the target path is a registry collection, avoid setting + # the target entity to the run's entity. Instead, delegate to + # Artifact.link() to resolve the required org entity. + target = ArtifactPath.from_str(target_path) + if not target.is_registry_path(): + target = target.with_defaults(prefix=self.entity, project=self.project) + + return artifact.link(target.to_str(), aliases) + + @_log_to_run + @_raise_if_finished + @_attach + def use_artifact( + self, + artifact_or_name: str | Artifact, + type: str | None = None, + aliases: list[str] | None = None, + use_as: str | None = None, + ) -> Artifact: + """Declare an artifact as an input to a run. + + Call `download` or `file` on the returned object to get the contents locally. + + Args: + artifact_or_name: The name of the artifact to use. May be prefixed + with the name of the project the artifact was logged to + ("entity" or "entity/project"). If no + entity is specified in the name, the Run or API setting's entity is used. + Valid names can be in the following forms + - name:version + - name:alias + type: The type of artifact to use. + aliases: Aliases to apply to this artifact + use_as: This argument is deprecated and does nothing. + + Returns: + An `Artifact` object. + + Examples: + ```python + import wandb + + run = wandb.init(project="") + + # Use an artifact by name and alias + artifact_a = run.use_artifact(artifact_or_name=":") + + # Use an artifact by name and version + artifact_b = run.use_artifact(artifact_or_name=":v") + + # Use an artifact by entity/project/name:alias + artifact_c = run.use_artifact( + artifact_or_name="//:" + ) + + # Use an artifact by entity/project/name:version + artifact_d = run.use_artifact( + artifact_or_name="//:v" + ) + + # Explicitly finish the run since a context manager is not used. + run.finish() + ``` + + """ + from wandb.apis import internal + from wandb.sdk.artifacts.artifact import Artifact + + if self._settings._offline: + raise TypeError("Cannot use artifact when in offline mode.") + + api = internal.Api( + default_settings={ + "entity": self._settings.entity, + "project": self._settings.project, + } + ) + api.set_current_run_id(self._settings.run_id) + + if use_as is not None: + deprecation.warn_and_record_deprecation( + feature=Deprecated(run__use_artifact_use_as=True), + message=( + "`use_as` argument is deprecated and does not affect the behaviour of `run.use_artifact`" + ), + ) + + if isinstance(artifact_or_name, str): + name = artifact_or_name + public_api = self._public_api() + artifact = public_api._artifact(type=type, name=name) + if type is not None and type != artifact.type: + raise ValueError( + f"Supplied type {type} does not match type {artifact.type} of artifact {artifact.name}" + ) + api.use_artifact( + artifact.id, + entity_name=self._settings.entity, + project_name=self._settings.project, + artifact_entity_name=artifact.entity, + artifact_project_name=artifact.project, + ) + else: + artifact = artifact_or_name + if aliases is None: + aliases = [] + elif isinstance(aliases, str): + aliases = [aliases] + if isinstance(artifact_or_name, Artifact) and artifact.is_draft(): + if use_as is not None: + wandb.termwarn( + "Indicating use_as is not supported when using a draft artifact" + ) + self._log_artifact( + artifact, + aliases=aliases, + is_user_created=True, + use_after_commit=True, + ) + artifact.wait() + elif isinstance(artifact, Artifact) and not artifact.is_draft(): + api.use_artifact( + artifact.id, + artifact_entity_name=artifact.entity, + artifact_project_name=artifact.project, + ) + else: + raise ValueError( + 'You must pass an artifact name (e.g. "pedestrian-dataset:v1"), ' + "an instance of `wandb.Artifact`, or `wandb.Api().artifact()` to `use_artifact`" + ) + if self._backend and self._backend.interface: + self._backend.interface.publish_use_artifact(artifact) + return artifact + + @_log_to_run + @_raise_if_finished + @_attach + def log_artifact( + self, + artifact_or_path: Artifact | StrPath, + name: str | None = None, + type: str | None = None, + aliases: list[str] | None = None, + tags: list[str] | None = None, + ) -> Artifact: + """Declare an artifact as an output of a run. + + Args: + artifact_or_path: (str or Artifact) A path to the contents of this artifact, + can be in the following forms: + - `/local/directory` + - `/local/directory/file.txt` + - `s3://bucket/path` + You can also pass an Artifact object created by calling + `wandb.Artifact`. + name: (str, optional) An artifact name. Valid names can be in the following forms: + - name:version + - name:alias + - digest + This will default to the basename of the path prepended with the current + run id if not specified. + type: (str) The type of artifact to log, examples include `dataset`, `model` + aliases: (list, optional) Aliases to apply to this artifact, + defaults to `["latest"]` + tags: (list, optional) Tags to apply to this artifact, if any. + + Returns: + An `Artifact` object. + """ + return self._log_artifact( + artifact_or_path, + name=name, + type=type, + aliases=aliases, + tags=tags, + ) + + @_log_to_run + @_raise_if_finished + @_attach + def upsert_artifact( + self, + artifact_or_path: Artifact | str, + name: str | None = None, + type: str | None = None, + aliases: list[str] | None = None, + distributed_id: str | None = None, + ) -> Artifact: + """Declare (or append to) a non-finalized artifact as output of a run. + + Note that you must call run.finish_artifact() to finalize the artifact. + This is useful when distributed jobs need to all contribute to the same artifact. + + Args: + artifact_or_path: A path to the contents of this artifact, + can be in the following forms: + - `/local/directory` + - `/local/directory/file.txt` + - `s3://bucket/path` + name: An artifact name. May be prefixed with "entity/project". Defaults + to the basename of the path prepended with the current run ID + if not specified. Valid names can be in the following forms: + - name:version + - name:alias + - digest + type: The type of artifact to log. Common examples include `dataset`, `model`. + aliases: Aliases to apply to this artifact, defaults to `["latest"]`. + distributed_id: Unique string that all distributed jobs share. If None, + defaults to the run's group name. + + Returns: + An `Artifact` object. + """ + if self._settings.run_group is None and distributed_id is None: + raise TypeError( + "Cannot upsert artifact unless run is in a group or distributed_id is provided" + ) + if distributed_id is None: + distributed_id = self._settings.run_group or "" + return self._log_artifact( + artifact_or_path, + name=name, + type=type, + aliases=aliases, + distributed_id=distributed_id, + finalize=False, + ) + + @_log_to_run + @_raise_if_finished + @_attach + def finish_artifact( + self, + artifact_or_path: Artifact | str, + name: str | None = None, + type: str | None = None, + aliases: list[str] | None = None, + distributed_id: str | None = None, + ) -> Artifact: + """Finishes a non-finalized artifact as output of a run. + + Subsequent "upserts" with the same distributed ID will result in a new version. + + Args: + artifact_or_path: A path to the contents of this artifact, + can be in the following forms: + - `/local/directory` + - `/local/directory/file.txt` + - `s3://bucket/path` + You can also pass an Artifact object created by calling + `wandb.Artifact`. + name: An artifact name. May be prefixed with entity/project. + Valid names can be in the following forms: + - name:version + - name:alias + - digest + This will default to the basename of the path prepended with the current + run id if not specified. + type: The type of artifact to log, examples include `dataset`, `model` + aliases: Aliases to apply to this artifact, + defaults to `["latest"]` + distributed_id: Unique string that all distributed jobs share. If None, + defaults to the run's group name. + + Returns: + An `Artifact` object. + """ + if self._settings.run_group is None and distributed_id is None: + raise TypeError( + "Cannot finish artifact unless run is in a group or distributed_id is provided" + ) + if distributed_id is None: + distributed_id = self._settings.run_group or "" + + return self._log_artifact( + artifact_or_path, + name, + type, + aliases, + distributed_id=distributed_id, + finalize=True, + ) + + def _log_artifact( + self, + artifact_or_path: Artifact | StrPath, + name: str | None = None, + type: str | None = None, + aliases: list[str] | None = None, + tags: list[str] | None = None, + distributed_id: str | None = None, + finalize: bool = True, + is_user_created: bool = False, + use_after_commit: bool = False, + ) -> Artifact: + from .artifacts._validators import validate_aliases, validate_tags + + if not finalize and distributed_id is None: + raise TypeError("Must provide distributed_id if artifact is not finalize") + + if aliases is not None: + aliases = validate_aliases(aliases) + + # Check if artifact tags are supported + if tags is not None: + tags = validate_tags(tags) + + artifact, aliases = self._prepare_artifact( + artifact_or_path, name, type, aliases + ) + + artifact.metadata = {**artifact.metadata} # triggers validation + + artifact.distributed_id = distributed_id + self._assert_can_log_artifact(artifact) + if self._backend and self._backend.interface: + if not self._settings._offline: + handle = self._backend.interface.deliver_artifact( + self, + artifact, + aliases, + tags, + self.step, + finalize=finalize, + is_user_created=is_user_created, + use_after_commit=use_after_commit, + ) + artifact._set_save_handle(handle, self._public_api().client) + else: + self._backend.interface.publish_artifact( + self, + artifact, + aliases, + tags, + finalize=finalize, + is_user_created=is_user_created, + use_after_commit=use_after_commit, + ) + elif self._internal_run_interface: + self._internal_run_interface.publish_artifact( + self, + artifact, + aliases, + tags, + finalize=finalize, + is_user_created=is_user_created, + use_after_commit=use_after_commit, + ) + return artifact + + def _public_api(self, overrides: dict[str, str] | None = None) -> PublicApi: + if self._cached_public_api is not None: + return self._cached_public_api + + # NOTE: PublicApi is only for type checking, still need to import + from wandb.apis import public + + overrides = {"run": self._settings.run_id} # type: ignore + if not self._settings._offline: + overrides["entity"] = self._settings.entity or "" + overrides["project"] = self._settings.project or "" + overrides["base_url"] = self._settings.base_url + + self._cached_public_api = public.Api(overrides, api_key=self._settings.api_key) + return self._cached_public_api + + # TODO(jhr): annotate this + def _assert_can_log_artifact(self, artifact) -> None: # type: ignore + import requests + + from wandb.sdk.artifacts.artifact import Artifact + + if self._settings._offline: + return + try: + public_api = self._public_api() + entity = public_api.settings["entity"] + project = public_api.settings["project"] + expected_type = Artifact._expected_type( + entity, project, artifact.name, public_api.client + ) + except requests.exceptions.RequestException: + # Just return early if there is a network error. This is + # ok, as this function is intended to help catch an invalid + # type early, but not a hard requirement for valid operation. + return + if expected_type is not None and artifact.type != expected_type: + raise ValueError( + f"Artifact {artifact.name} already exists with type '{expected_type}'; " + f"cannot create another with type '{artifact.type}'" + ) + if entity and artifact._source_entity and entity != artifact._source_entity: + raise ValueError( + f"Artifact {artifact.name} is owned by entity " + f"'{artifact._source_entity}'; it can't be moved to '{entity}'" + ) + if project and artifact._source_project and project != artifact._source_project: + raise ValueError( + f"Artifact {artifact.name} exists in project " + f"'{artifact._source_project}'; it can't be moved to '{project}'" + ) + + def _prepare_artifact( + self, + artifact_or_path: Artifact | StrPath, + name: str | None = None, + type: str | None = None, + aliases: list[str] | None = None, + ) -> tuple[Artifact, list[str]]: + from wandb.sdk.artifacts.artifact import Artifact + + if isinstance(artifact_or_path, (str, os.PathLike)): + name = ( + name + or f"run-{self._settings.run_id}-{os.path.basename(artifact_or_path)}" + ) + artifact = Artifact(name, type or "unspecified") + if os.path.isfile(artifact_or_path): + artifact.add_file(str(artifact_or_path)) + elif os.path.isdir(artifact_or_path): + artifact.add_dir(str(artifact_or_path)) + elif "://" in str(artifact_or_path): + artifact.add_reference(str(artifact_or_path)) + else: + raise ValueError( + "path must be a file, directory or external" + "reference like s3://bucket/path" + ) + else: + artifact = artifact_or_path + if not isinstance(artifact, Artifact): + raise TypeError( + "You must pass an instance of wandb.Artifact or a " + "valid file path to log_artifact" + ) + + artifact.finalize() + return artifact, _resolve_aliases(aliases) + + @_log_to_run + @_raise_if_finished + @_attach + def log_model( + self, + path: StrPath, + name: str | None = None, + aliases: list[str] | None = None, + ) -> None: + """Logs a model artifact containing the contents inside the 'path' to a run and marks it as an output to this run. + + The name of model artifact can only contain alphanumeric characters, + underscores, and hyphens. + + Args: + path: (str) A path to the contents of this model, + can be in the following forms: + - `/local/directory` + - `/local/directory/file.txt` + - `s3://bucket/path` + name: A name to assign to the model artifact that + the file contents will be added to. This will default to the + basename of the path prepended with the current run id if + not specified. + aliases: Aliases to apply to the created model artifact, + defaults to `["latest"]` + + Raises: + ValueError: If name has invalid special characters. + + Returns: + None + """ + self._log_artifact( + artifact_or_path=path, name=name, type="model", aliases=aliases + ) + + @_log_to_run + @_raise_if_finished + @_attach + def use_model(self, name: str) -> FilePathStr: + """Download the files logged in a model artifact 'name'. + + Args: + name: A model artifact name. 'name' must match the name of an existing logged + model artifact. May be prefixed with `entity/project/`. Valid names + can be in the following forms + - model_artifact_name:version + - model_artifact_name:alias + + Returns: + path (str): Path to downloaded model artifact file(s). + + Raises: + AssertionError: If model artifact 'name' is of a type that does + not contain the substring 'model'. + """ + if self._settings._offline: + # Downloading artifacts is not supported when offline. + raise RuntimeError("`use_model` not supported in offline mode.") + + artifact = self.use_artifact(artifact_or_name=name) + if "model" not in str(artifact.type.lower()): + raise AssertionError( + "You can only use this method for 'model' artifacts." + " For an artifact to be a 'model' artifact, its type property" + " must contain the substring 'model'." + ) + + path = artifact.download() + + # If returned directory contains only one file, return path to that file + dir_list = os.listdir(path) + if len(dir_list) == 1: + return FilePathStr(os.path.join(path, dir_list[0])) + return path + + @_log_to_run + @_raise_if_finished + @_attach + def link_model( + self, + path: StrPath, + registered_model_name: str, + name: str | None = None, + aliases: list[str] | None = None, + ) -> Artifact | None: + """Log a model artifact version and link it to a registered model in the model registry. + + Linked model versions are visible in the UI for the specified registered model. + + This method will: + - Check if 'name' model artifact has been logged. If so, use the artifact version that matches the files + located at 'path' or log a new version. Otherwise log files under 'path' as a new model artifact, 'name' + of type 'model'. + - Check if registered model with name 'registered_model_name' exists in the 'model-registry' project. + If not, create a new registered model with name 'registered_model_name'. + - Link version of model artifact 'name' to registered model, 'registered_model_name'. + - Attach aliases from 'aliases' list to the newly linked model artifact version. + + Args: + path: (str) A path to the contents of this model, can be in the + following forms: + - `/local/directory` + - `/local/directory/file.txt` + - `s3://bucket/path` + registered_model_name: The name of the registered model that the + model is to be linked to. A registered model is a collection of + model versions linked to the model registry, typically + representing a team's specific ML Task. The entity that this + registered model belongs to will be derived from the run. + name: The name of the model artifact that files in 'path' will be + logged to. This will default to the basename of the path + prepended with the current run id if not specified. + aliases: Aliases that will only be applied on this linked artifact + inside the registered model. The alias "latest" will always be + applied to the latest version of an artifact that is linked. + + Raises: + AssertionError: If registered_model_name is a path or + if model artifact 'name' is of a type that does not contain + the substring 'model'. + ValueError: If name has invalid special characters. + + Returns: + The linked artifact if linking was successful, otherwise `None`. + """ + name_parts = registered_model_name.split("/") + if len(name_parts) != 1: + raise AssertionError( + "Please provide only the name of the registered model." + " Do not append the entity or project name." + ) + + project = "model-registry" + target_path = self.entity + "/" + project + "/" + registered_model_name + + public_api = self._public_api() + try: + artifact = public_api._artifact(name=f"{name}:latest") + if "model" not in str(artifact.type.lower()): + raise AssertionError( + "You can only use this method for 'model' artifacts." + " For an artifact to be a 'model' artifact, its type" + " property must contain the substring 'model'." + ) + + artifact = self._log_artifact( + artifact_or_path=path, name=name, type=artifact.type + ) + except (ValueError, CommError): + artifact = self._log_artifact( + artifact_or_path=path, name=name, type="model" + ) + return self.link_artifact( + artifact=artifact, target_path=target_path, aliases=aliases + ) + + @_log_to_run + @_raise_if_finished + @_attach + def alert( + self, + title: str, + text: str, + level: str | AlertLevel | None = None, + wait_duration: int | float | timedelta | None = None, + ) -> None: + """Create an alert with the given title and text. + + Args: + title: The title of the alert, must be less than 64 characters long. + text: The text body of the alert. + level: The alert level to use, either: `INFO`, `WARN`, or `ERROR`. + wait_duration: The time to wait (in seconds) before sending another + alert with this title. + """ + level = level or AlertLevel.INFO + level_str: str = level.value if isinstance(level, AlertLevel) else level + if level_str not in {lev.value for lev in AlertLevel}: + raise ValueError("level must be one of 'INFO', 'WARN', or 'ERROR'") + + wait_duration = wait_duration or timedelta(minutes=1) + if isinstance(wait_duration, int) or isinstance(wait_duration, float): + wait_duration = timedelta(seconds=wait_duration) + elif not callable(getattr(wait_duration, "total_seconds", None)): + raise TypeError( + "wait_duration must be an int, float, or datetime.timedelta" + ) + wait_duration = int(wait_duration.total_seconds() * 1000) + + if self._backend and self._backend.interface: + self._backend.interface.publish_alert(title, text, level_str, wait_duration) + + def __enter__(self) -> Run: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool: + exception_raised = exc_type is not None + if exception_raised: + traceback.print_exception(exc_type, exc_val, exc_tb) + exit_code = 1 if exception_raised else 0 + self._finish(exit_code=exit_code) + return not exception_raised + + @_log_to_run + @_raise_if_finished + @_attach + def mark_preempting(self) -> None: + """Mark this run as preempting. + + Also tells the internal process to immediately report this to server. + """ + if self._backend and self._backend.interface: + self._backend.interface.publish_preempting() + + @property + @_log_to_run + @_raise_if_finished + @_attach + def _system_metrics(self) -> dict[str, list[tuple[datetime, float]]]: + """Returns a dictionary of system metrics. + + Returns: + A dictionary of system metrics. + """ + from wandb.proto import wandb_internal_pb2 + + def pb_to_dict( + system_metrics_pb: wandb_internal_pb2.GetSystemMetricsResponse, + ) -> dict[str, list[tuple[datetime, float]]]: + res = {} + + for metric, records in system_metrics_pb.system_metrics.items(): + measurements = [] + for record in records.record: + # Convert timestamp to datetime + dt = datetime.fromtimestamp( + record.timestamp.seconds, tz=timezone.utc + ) + dt = dt.replace(microsecond=record.timestamp.nanos // 1000) + + measurements.append((dt, record.value)) + + res[metric] = measurements + + return res + + if not self._backend or not self._backend.interface: + return {} + + handle = self._backend.interface.deliver_get_system_metrics() + + try: + result = handle.wait_or(timeout=1) + except TimeoutError: + return {} + else: + try: + response = result.response.get_system_metrics_response + return pb_to_dict(response) if response else {} + except Exception: + logger.exception("Error getting system metrics.") + return {} + + # ------------------------------------------------------------------------------ + # HEADER + # ------------------------------------------------------------------------------ + def _header(self) -> None: + self._header_wandb_version_info() + self._header_sync_info() + self._header_run_info() + + def _header_wandb_version_info(self) -> None: + if self._settings.quiet or self._settings.silent: + return + + # TODO: add this to a higher verbosity level + self._printer.display(f"Tracking run with wandb version {wandb.__version__}") + + def _header_sync_info(self) -> None: + sync_location_msg = f"Run data is saved locally in {self._printer.files(self._settings.sync_dir)}" + + if self._settings._offline: + offline_warning = ( + f"W&B syncing is set to {self._printer.code('`offline`')} " + f"in this directory. Run {self._printer.code('`wandb online`')} " + f"or set {self._printer.code('WANDB_MODE=online')} " + "to enable cloud syncing." + ) + self._printer.display([offline_warning, sync_location_msg]) + else: + messages = [sync_location_msg] + + if not self._printer.supports_html: + disable_sync_msg = ( + f"Run {self._printer.code('`wandb offline`')} to turn off syncing." + ) + messages.append(disable_sync_msg) + + if not self._settings.quiet and not self._settings.silent: + self._printer.display(messages) + + def _header_run_info(self) -> None: + settings, printer = self._settings, self._printer + + if settings._offline or settings.silent: + return + + run_url = settings.run_url + project_url = settings.project_url + sweep_url = settings.sweep_url + + run_state_str = ( + "Resuming run" + if settings.resumed or settings.resume_from + else "Syncing run" + ) + run_name = settings.run_name + if not run_name: + return + + if printer.supports_html: + import wandb.jupyter + + if not wandb.jupyter.display_if_magic_is_used(self): + run_line = f"{printer.link(run_url, run_name)}" + project_line, sweep_line = "", "" + + if not settings.quiet: + doc_html = printer.link(url_registry.url("developer-guide"), "docs") + + project_html = printer.link(project_url, "Weights & Biases") + project_line = f"to {project_html} ({doc_html})" + + if sweep_url: + sweep_line = f"Sweep page: {printer.link(sweep_url, sweep_url)}" + + printer.display( + [f"{run_state_str} {run_line} {project_line}", sweep_line], + ) + + elif run_name: + printer.display(f"{run_state_str} {printer.name(run_name)}") + + if not settings.quiet: + # TODO: add verbosity levels and add this to higher levels + printer.display( + f"{printer.emoji('star')} View project at {printer.link(project_url)}" + ) + if sweep_url: + printer.display( + f"{printer.emoji('broom')} View sweep at {printer.link(sweep_url)}" + ) + printer.display( + f"{printer.emoji('rocket')} View run at {printer.link(run_url)}", + ) + + # ------------------------------------------------------------------------------ + # FOOTER + # ------------------------------------------------------------------------------ + # Note: All the footer methods are static methods since we want to share the printing logic + # with the service execution path that doesn't have access to the run instance + @staticmethod + def _footer( + sampled_history: SampledHistoryResponse | None = None, + final_summary: GetSummaryResponse | None = None, + poll_exit_response: PollExitResponse | None = None, + internal_messages_response: InternalMessagesResponse | None = None, + *, + settings: Settings, + printer: printer.Printer, + ) -> None: + Run._footer_history_summary_info( + history=sampled_history, + summary=final_summary, + settings=settings, + printer=printer, + ) + + Run._footer_sync_info( + poll_exit_response=poll_exit_response, + settings=settings, + printer=printer, + ) + Run._footer_log_dir_info(settings=settings, printer=printer) + Run._footer_internal_messages( + internal_messages_response=internal_messages_response, + settings=settings, + printer=printer, + ) + + @staticmethod + def _footer_sync_info( + poll_exit_response: PollExitResponse | None = None, + *, + settings: Settings, + printer: printer.Printer, + ) -> None: + if settings.silent: + return + + if settings._offline: + if not settings.quiet: + printer.display( + [ + "You can sync this run to the cloud by running:", + printer.code(f"wandb sync {settings.sync_dir}"), + ], + ) + return + + info = [] + if settings.run_name and settings.run_url: + info.append( + f"{printer.emoji('rocket')} View run {printer.name(settings.run_name)} at: {printer.link(settings.run_url)}" + ) + if settings.project_url: + info.append( + f"{printer.emoji('star')} View project at: {printer.link(settings.project_url)}" + ) + if poll_exit_response and poll_exit_response.file_counts: + logger.info("logging synced files") + file_counts = poll_exit_response.file_counts + info.append( + f"Synced {file_counts.wandb_count} W&B file(s), {file_counts.media_count} media file(s), " + f"{file_counts.artifact_count} artifact file(s) and {file_counts.other_count} other file(s)", + ) + printer.display(info) + + @staticmethod + def _footer_log_dir_info( + *, + settings: Settings, + printer: printer.Printer, + ) -> None: + if settings.quiet or settings.silent: + return + + log_dir = settings.log_user or settings.log_internal + if log_dir: + log_dir = os.path.dirname(log_dir.replace(os.getcwd(), ".")) + printer.display( + f"Find logs at: {printer.files(log_dir)}", + ) + + @staticmethod + def _footer_history_summary_info( + history: SampledHistoryResponse | None = None, + summary: GetSummaryResponse | None = None, + *, + settings: Settings, + printer: printer.Printer, + ) -> None: + if settings.quiet or settings.silent: + return + + panel: list[str] = [] + + if history and ( + history_grid := Run._footer_history(history, printer, settings) + ): + panel.append(history_grid) + + if summary and ( + summary_grid := Run._footer_summary(summary, printer, settings) + ): + panel.append(summary_grid) + + if panel: + printer.display(printer.panel(panel)) + + @staticmethod + def _footer_history( + history: SampledHistoryResponse, + printer: printer.Printer, + settings: Settings, + ) -> str | None: + """Returns the run history formatted for printing to the console.""" + sorted_history_items = sorted( + (item for item in history.item if not item.key.startswith("_")), + key=lambda item: item.key, + ) + + history_rows: list[list[str]] = [] + for item in sorted_history_items: + if len(history_rows) >= settings.max_end_of_run_history_metrics: + break + + values = wandb.util.downsample( + item.values_float or item.values_int, + 40, + ) + + if sparkline := printer.sparklines(values): + history_rows.append([item.key, sparkline]) + + if not history_rows: + return None + + if len(history_rows) < len(sorted_history_items): + remaining = len(sorted_history_items) - len(history_rows) + history_rows.append([f"+{remaining:,d}", "..."]) + + return printer.grid(history_rows, "Run history:") + + @staticmethod + def _footer_summary( + summary: GetSummaryResponse, + printer: printer.Printer, + settings: Settings, + ) -> str | None: + """Returns the run summary formatted for printing to the console.""" + sorted_summary_items = sorted( + ( + item + for item in summary.item + if not item.key.startswith("_") and not item.nested_key + ), + key=lambda item: item.key, + ) + + summary_rows: list[list[str]] = [] + skipped = 0 + for item in sorted_summary_items: + if len(summary_rows) >= settings.max_end_of_run_summary_metrics: + break + + try: + value = json.loads(item.value_json) + except json.JSONDecodeError: + logger.exception(f"Error decoding summary[{item.key!r}]") + skipped += 1 + continue + + if isinstance(value, str): + value = value[:20] + "..." * (len(value) >= 20) + summary_rows.append([item.key, value]) + elif isinstance(value, numbers.Number): + value = round(value, 5) if isinstance(value, float) else value + summary_rows.append([item.key, str(value)]) + else: + skipped += 1 + + if not summary_rows: + return None + + if len(summary_rows) < len(sorted_summary_items) - skipped: + remaining = len(sorted_summary_items) - len(summary_rows) - skipped + summary_rows.append([f"+{remaining:,d}", "..."]) + + return printer.grid(summary_rows, "Run summary:") + + @staticmethod + def _footer_internal_messages( + internal_messages_response: InternalMessagesResponse | None = None, + *, + settings: Settings, + printer: printer.Printer, + ) -> None: + if settings.quiet or settings.silent: + return + + if not internal_messages_response: + return + + for message in internal_messages_response.messages.warning: + printer.display(message, level="warn") + + +# We define this outside of the run context to support restoring before init +def restore( + name: str, + run_path: str | None = None, + replace: bool = False, + root: str | None = None, +) -> None | TextIO: + """Download the specified file from cloud storage. + + File is placed into the current directory or run directory. + By default, will only download the file if it doesn't already exist. + + Args: + name: The name of the file. + run_path: Optional path to a run to pull files from, i.e. `username/project_name/run_id` + if wandb.init has not been called, this is required. + replace: Whether to download the file even if it already exists locally + root: The directory to download the file to. Defaults to the current + directory or the run directory if wandb.init was called. + + Returns: + None if it can't find the file, otherwise a file object open for reading. + + Raises: + CommError: If W&B can't connect to the W&B backend. + ValueError: If the file is not found or can't find run_path. + """ + from wandb.apis import public + + is_disabled = wandb.run is not None and wandb.run.disabled + run = None if is_disabled else wandb.run + if run_path is None: + if run is not None: + run_path = run.path + else: + raise ValueError( + "run_path required when calling wandb.restore before wandb.init" + ) + if root is None: + if run is not None: + root = run.dir + api = public.Api() + api_run = api.run(run_path) + if root is None: + root = os.getcwd() + path = os.path.join(root, name) + if os.path.exists(path) and replace is False: + return open(path) + if is_disabled: + return None + files = api_run.files([name]) + if len(files) == 0: + return None + # if the file does not exist, the file has an md5 of 0 + if files[0].md5 == "0": + raise ValueError(f"File {name} not found in {run_path or root}.") + return files[0].download(root=root, replace=True) + + +# propagate our doc string to the runs restore method +try: + Run.restore.__doc__ = restore.__doc__ +except AttributeError: + pass + + +def finish( + exit_code: int | None = None, + quiet: bool | None = None, +) -> None: + """Finish a run and upload any remaining data. + + Marks the completion of a W&B run and ensures all data is synced to the server. + The run's final state is determined by its exit conditions and sync status. + + Run States: + - Running: Active run that is logging data and/or sending heartbeats. + - Crashed: Run that stopped sending heartbeats unexpectedly. + - Finished: Run completed successfully (`exit_code=0`) with all data synced. + - Failed: Run completed with errors (`exit_code!=0`). + + Args: + exit_code: Integer indicating the run's exit status. Use 0 for success, + any other value marks the run as failed. + quiet: Deprecated. Configure logging verbosity using `wandb.Settings(quiet=...)`. + """ + if wandb.run: + wandb.run.finish(exit_code=exit_code, quiet=quiet) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_settings.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..85bdad6ea279200a7d5bec080db7cf1147922d7a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_settings.py @@ -0,0 +1,2289 @@ +from __future__ import annotations + +import json +import logging +import os +import pathlib +import platform +import re +import shutil +import socket +import sys +import traceback +from datetime import datetime + +# Optional and Union are used for type hinting instead of | because +# the latter is not supported in pydantic<2.6 and Python<3.10. +# Dict, List, and Tuple are used for backwards compatibility +# with pydantic v1 and Python<3.9. +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +from urllib.parse import quote, unquote + +from google.protobuf.wrappers_pb2 import BoolValue, DoubleValue, Int32Value, StringValue +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Self + +import wandb +from wandb import env, util +from wandb._pydantic import ( + IS_PYDANTIC_V2, + AliasChoices, + ValidationError, + computed_field, + field_validator, + model_validator, +) +from wandb.errors import UsageError +from wandb.proto import wandb_settings_pb2 +from wandb.sdk.lib import deprecation, settings_file, urls + +from .lib import credentials, filesystem, ipython +from .lib.run_moment import RunMoment + +if not IS_PYDANTIC_V2: + from pydantic import root_validator + + +def _path_convert(*args: str) -> str: + """Join path and apply os.path.expanduser to it.""" + return os.path.expanduser(os.path.join(*args)) + + +CLIENT_ONLY_SETTINGS = ( + "anonymous", + "app_url_override", + "files_dir", + "max_end_of_run_history_metrics", + "max_end_of_run_summary_metrics", + "reinit", + "x_files_dir", + "x_sync_dir_suffix", +) +"""Python-only keys that are not fields on the settings proto.""" + + +class Settings(BaseModel, validate_assignment=True): + """Settings for the W&B SDK. + + This class manages configuration settings for the W&B SDK, + ensuring type safety and validation of all settings. Settings are accessible + as attributes and can be initialized programmatically, through environment + variables (`WANDB_ prefix`), and with configuration files. + + The settings are organized into three categories: + 1. Public settings: Core configuration options that users can safely modify to customize + W&B's behavior for their specific needs. + 2. Internal settings: Settings prefixed with 'x_' that handle low-level SDK behavior. + These settings are primarily for internal use and debugging. While they can be modified, + they are not considered part of the public API and may change without notice in future + versions. + 3. Computed settings: Read-only settings that are automatically derived from other settings or + the environment. + """ + + # Pydantic Model configuration. + model_config = ConfigDict( + extra="forbid", # throw an error if extra fields are provided + validate_default=True, # validate default values + use_attribute_docstrings=True, # for field descriptions + revalidate_instances="always", + ) + + # Public settings. + + allow_offline_artifacts: bool = True + """Flag to allow table artifacts to be synced in offline mode. + + To revert to the old behavior, set this to False. + """ + + allow_val_change: bool = False + """Flag to allow modification of `Config` values after they've been set.""" + + anonymous: deprecation.DoNotSet = Field( + default=deprecation.UNSET, + exclude=True, + ) + """Deprecated and will be removed.""" + + api_key: Optional[str] = None + """The W&B API key.""" + + azure_account_url_to_access_key: Optional[Dict[str, str]] = None + """Mapping of Azure account URLs to their corresponding access keys for Azure integration.""" + + app_url_override: Optional[str] = None + """Override for the 'app' URL for the W&B UI. + + The `app_url` is normally computed based on `base_url`, but this can be + used to set it explicitly. + + WANDB_APP_URL is the corresponding environment variable. + """ + + base_url: str = "https://api.wandb.ai" + """The URL of the W&B backend for data synchronization.""" + + code_dir: Optional[str] = None + """Directory containing the code to be tracked by W&B.""" + + config_paths: Optional[Sequence[str]] = None + """Paths to files to load configuration from into the `Config` object.""" + + console: Literal["auto", "off", "wrap", "redirect", "wrap_raw", "wrap_emu"] = Field( + default="auto", + validate_default=True, + ) + """The type of console capture to be applied. + + Possible values are: + - "auto" - Automatically selects the console capture method based on the + system environment and settings. + - "off" - Disables console capture. + - "redirect" - Redirects low-level file descriptors for capturing output. + - "wrap" - Overrides the write methods of sys.stdout/sys.stderr. Will be + mapped to either "wrap_raw" or "wrap_emu" based on the state of the system. + - "wrap_raw" - Same as "wrap" but captures raw output directly instead of + through an emulator. Derived from the `wrap` setting and should not be set manually. + - "wrap_emu" - Same as "wrap" but captures output through an emulator. + Derived from the `wrap` setting and should not be set manually. + """ + + console_multipart: bool = False + """Enable multipart console logging. + + When True, the SDK writes console output to timestamped files + under the `logs/` directory instead of a single `output.log`. + + Each part is uploaded as soon as it is closed, giving users live + access to logs while the run is active. Rollover cadence is + controlled by `console_chunk_max_bytes` and/or `console_chunk_max_seconds`. + If both limits are `0`, all logs are uploaded once at run finish. + + Note: Uploaded chunks are immutable; terminal control sequences + that modify previous lines (e.g., progress bars using carriage returns) + only affect the current chunk. + """ + + console_chunk_max_bytes: int = 0 + """Size-based rollover threshold for multipart console logs, in bytes. + + Starts a new console log file when the current part reaches this + size. Has an effect only when `console_multipart` is `True`. + Can be combined with `console_chunk_max_seconds`; whichever limit is + hit first triggers the rollover. A value of `0` disables the + size-based limit. + """ + + console_chunk_max_seconds: int = 0 + """Time-based rollover threshold for multipart console logs, in seconds. + + Starts a new console log file after this many seconds have elapsed + since the current part began. Requires `console_multipart` to be + `True`. May be used with `console_chunk_max_bytes`; the first limit + reached closes the part. A value of `0` disables the time-based + limit. + """ + + credentials_file: str = Field( + default_factory=lambda: str(credentials.DEFAULT_WANDB_CREDENTIALS_FILE) + ) + """Path to file for writing temporary access tokens.""" + + disable_code: bool = False + """Whether to disable capturing the code.""" + + disable_git: bool = False + """Whether to disable capturing the git state.""" + + disable_job_creation: bool = True + """Whether to disable the creation of a job artifact for W&B Launch.""" + + docker: Optional[str] = None + """The Docker image used to execute the script.""" + + email: Optional[str] = None + """The email address of the user.""" + + entity: Optional[str] = None + """The W&B entity, such as a user or a team.""" + + organization: Optional[str] = None + """The W&B organization.""" + + force: bool = False + """Whether to pass the `force` flag to `wandb.login()`.""" + + fork_from: Optional[RunMoment] = None + """Specifies a point in a previous execution of a run to fork from. + + The point is defined by the run ID, a metric, and its value. + Currently, only the metric '_step' is supported. + """ + + git_commit: Optional[str] = None + """The git commit hash to associate with the run.""" + + git_remote: str = "origin" + """The git remote to associate with the run.""" + + git_remote_url: Optional[str] = None + """The URL of the git remote repository.""" + + git_root: Optional[str] = None + """Root directory of the git repository.""" + + heartbeat_seconds: int = 30 + """Interval in seconds between heartbeat signals sent to the W&B servers. + + + """ + + host: Optional[str] = None + """Hostname of the machine running the script.""" + + http_proxy: Optional[str] = None + """Custom proxy servers for http requests to W&B.""" + + https_proxy: Optional[str] = None + """Custom proxy servers for https requests to W&B.""" + + identity_token_file: Optional[str] = None + """Path to file containing an identity token (JWT) for authentication.""" + + ignore_globs: Sequence[str] = () + """Unix glob patterns relative to `files_dir` specifying files to exclude from upload.""" + + init_timeout: float = 90.0 + """Time in seconds to wait for the `wandb.init` call to complete before timing out.""" + + insecure_disable_ssl: bool = False + """Whether to insecurely disable SSL verification.""" + + job_name: Optional[str] = None + """Name of the Launch job running the script.""" + + job_source: Optional[Literal["repo", "artifact", "image"]] = None + """Source type for Launch.""" + + label_disable: bool = False + """Whether to disable automatic labeling features.""" + + launch: bool = False + """Flag to indicate if the run is being launched through W&B Launch. + + + """ + + launch_config_path: Optional[str] = None + """Path to the launch configuration file.""" + + login_timeout: Optional[float] = None + """Time in seconds to wait for login operations before timing out.""" + + mode: Literal["online", "offline", "shared", "disabled", "dryrun", "run"] = Field( + default="online", + validate_default=True, + ) + """The operating mode for W&B logging and synchronization.""" + + notebook_name: Optional[str] = None + """Name of the notebook if running in a Jupyter-like environment.""" + + program: Optional[str] = None + """Path to the script that created the run, if available.""" + + program_abspath: Optional[str] = None + """The absolute path from the root repository directory to the script that + created the run. + + Root repository directory is defined as the directory containing the + .git directory, if it exists. Otherwise, it's the current working directory. + """ + + program_relpath: Optional[str] = None + """The relative path to the script that created the run.""" + + project: Optional[str] = None + """The W&B project ID.""" + + quiet: bool = False + """Flag to suppress non-essential output.""" + + reinit: Union[ + Literal[ + "default", + "return_previous", + "finish_previous", + "create_new", + ], + bool, + ] = "default" + """What to do when `wandb.init()` is called while a run is active. + + Options: + - "default": Use "finish_previous" in notebooks and "return_previous" + otherwise. + - "return_previous": Return the most recently created run + that is not yet finished. This does not update `wandb.run`; see + the "create_new" option. + - "finish_previous": Finish all active runs, then return a new run. + - "create_new": Create a new run without modifying other active runs. + Does not update `wandb.run` and top-level functions like `wandb.log`. + Because of this, some older integrations that rely on the global run + will not work. + + Can also be a boolean, but this is deprecated. False is the same as + "return_previous", and True is the same as "finish_previous". + """ + + relogin: bool = False + """Flag to force a new login attempt.""" + + resume: Optional[Literal["allow", "must", "never", "auto"]] = None + """Specifies the resume behavior for the run. + + Options: + - "must": Resumes from an existing run with the same ID. If no such run exists, + it will result in failure. + - "allow": Attempts to resume from an existing run with the same ID. If none is + found, a new run will be created. + - "never": Always starts a new run. If a run with the same ID already exists, + it will result in failure. + - "auto": Automatically resumes from the most recent failed run on the same + machine. + """ + + resume_from: Optional[RunMoment] = None + """Specifies a point in a previous execution of a run to resume from. + + The point is defined by the run ID, a metric, and its value. + Currently, only the metric '_step' is supported. + """ + + resumed: bool = False + """Indication from the server about the state of the run. + + This is different from resume, a user provided flag. + + """ + + root_dir: str = Field(default_factory=lambda: os.path.abspath(os.getcwd())) + """The root directory to use as the base for all run-related paths. + + In particular, this is used to derive the wandb directory and the run directory. + """ + + run_group: Optional[str] = None + """Group identifier for related runs. + + Used for grouping runs in the UI. + """ + + run_id: Optional[str] = None + """The ID of the run.""" + + run_job_type: Optional[str] = None + """Type of job being run (e.g., training, evaluation).""" + + run_name: Optional[str] = None + """Human-readable name for the run.""" + + run_notes: Optional[str] = None + """Additional notes or description for the run.""" + + run_tags: Optional[Tuple[str, ...]] = None + """Tags to associate with the run for organization and filtering.""" + + sagemaker_disable: bool = False + """Flag to disable SageMaker-specific functionality.""" + + save_code: Optional[bool] = None + """Whether to save the code associated with the run.""" + + settings_system: Optional[str] = None + """Path to the system-wide settings file.""" + + max_end_of_run_history_metrics: int = 10 + """Maximum number of history sparklines to display at the end of a run.""" + + max_end_of_run_summary_metrics: int = 10 + """Maximum number of summary metrics to display at the end of a run.""" + + show_colors: Optional[bool] = None + """Whether to use colored output in the console. + + + """ + + show_emoji: Optional[bool] = None + """Whether to show emoji in the console output. + + + """ + + show_errors: bool = True + """Whether to display error messages.""" + + show_info: bool = True + """Whether to display informational messages.""" + + show_warnings: bool = True + """Whether to display warning messages.""" + + silent: bool = False + """Flag to suppress all output.""" + + start_method: Optional[str] = None + """Method to use for starting subprocesses. + + This is deprecated and will be removed in a future release. + + """ + + strict: Optional[bool] = None + """Whether to enable strict mode for validation and error checking.""" + + summary_timeout: int = 60 + """Time in seconds to wait for summary operations before timing out.""" + + summary_warnings: int = 5 + """Maximum number of summary warnings to display. + + + """ + + sweep_id: Optional[str] = None + """Identifier of the sweep this run belongs to.""" + + sweep_param_path: Optional[str] = None + """Path to the sweep parameters configuration.""" + + symlink: bool = Field( + default_factory=lambda: False if platform.system() == "Windows" else True + ) + """Whether to use symlinks (True by default except on Windows).""" + + sync_tensorboard: Optional[bool] = None + """Whether to synchronize TensorBoard logs with W&B.""" + + table_raise_on_max_row_limit_exceeded: bool = False + """Whether to raise an exception when table row limits are exceeded.""" + + use_dot_wandb: Optional[bool] = None + """Whether to use a hidden `.wandb` or visible `wandb` directory for run data. + + If True, the SDK uses `.wandb`. If False, `wandb`. + If not set, defaults to `.wandb` if it already exists, otherwise `wandb`. + """ + + username: Optional[str] = None + """Username.""" + + # Internal settings. + # + # These are typically not meant to be set by the user and should not be considered + # a part of the public API as they may change or be removed in future versions. + + x_cli_only_mode: bool = False + """Flag to indicate that the SDK is running in CLI-only mode. + + + """ + + x_disable_meta: bool = False + """Flag to disable the collection of system metadata.""" + + x_disable_stats: bool = False + """Flag to disable the collection of system metrics.""" + + x_disable_viewer: bool = False + """Flag to disable the early viewer query. + + + """ + + x_disable_machine_info: bool = False + """Flag to disable automatic machine info collection. + + + """ + + x_executable: Optional[str] = None + """Path to the Python executable. + + + """ + + x_extra_http_headers: Optional[Dict[str, str]] = None + """Additional headers to add to all outgoing HTTP requests.""" + + x_file_stream_max_bytes: Optional[int] = None + """An approximate maximum request size for the filestream API. + + Its purpose is to prevent HTTP requests from failing due to + containing too much data. This number is approximate: + requests will be slightly larger. + + """ + + x_file_stream_max_line_bytes: Optional[int] = None + """Maximum line length for filestream JSONL files. + + + """ + + x_file_stream_transmit_interval: Optional[float] = None + """Interval in seconds between filestream transmissions. + + + """ + + # Filestream retry client configuration. + + x_file_stream_retry_max: Optional[int] = None + """Max number of retries for filestream operations. + + + """ + + x_file_stream_retry_wait_min_seconds: Optional[float] = None + """Minimum wait time between retries for filestream operations. + + + """ + + x_file_stream_retry_wait_max_seconds: Optional[float] = None + """Maximum wait time between retries for filestream operations. + + + """ + + x_file_stream_timeout_seconds: Optional[float] = None + """Timeout in seconds for individual filestream HTTP requests. + + + """ + + # file transfer retry client configuration + + x_file_transfer_retry_max: Optional[int] = None + """Max number of retries for file transfer operations. + + + """ + + x_file_transfer_retry_wait_min_seconds: Optional[float] = None + """Minimum wait time between retries for file transfer operations. + + + """ + + x_file_transfer_retry_wait_max_seconds: Optional[float] = None + """Maximum wait time between retries for file transfer operations. + + + """ + + x_file_transfer_timeout_seconds: Optional[float] = None + """Timeout in seconds for individual file transfer HTTP requests. + + + """ + + x_files_dir: Optional[str] = None + """Override setting for the computed files_dir. + + DEPRECATED, DO NOT USE. This private setting is not respected by wandb-core + but will continue to work for some legacy Python code. + + + """ + + x_flow_control_custom: Optional[bool] = None + """Flag indicating custom flow control for filestream. + + TODO: Not implemented in wandb-core. + + """ + + x_flow_control_disabled: Optional[bool] = None + """Flag indicating flow control is disabled for filestream. + + TODO: Not implemented in wandb-core. + + """ + + # graphql retry client configuration + + x_graphql_retry_max: Optional[int] = None + """Max number of retries for GraphQL operations. + + + """ + + x_graphql_retry_wait_min_seconds: Optional[float] = None + """Minimum wait time between retries for GraphQL operations. + + + """ + + x_graphql_retry_wait_max_seconds: Optional[float] = None + """Maximum wait time between retries for GraphQL operations. + + + """ + + x_graphql_timeout_seconds: Optional[float] = None + """Timeout in seconds for individual GraphQL requests. + + + """ + + x_internal_check_process: float = 8.0 + """Interval for internal process health checks in seconds. + + + """ + + x_jupyter_name: Optional[str] = None + """Name of the Jupyter notebook. + + + """ + + x_jupyter_path: Optional[str] = None + """Path to the Jupyter notebook. + + + """ + + x_jupyter_root: Optional[str] = None + """Root directory of the Jupyter notebook. + + + """ + + x_label: Optional[str] = None + """Label to assign to system metrics and console logs collected for the run. + + This is used to group data by on the frontend and can be used to distinguish data + from different processes in a distributed training job. + """ + + x_live_policy_rate_limit: Optional[int] = None + """Rate limit for live policy updates in seconds. + + + """ + + x_live_policy_wait_time: Optional[int] = None + """Wait time between live policy updates in seconds. + + + """ + + x_log_level: int = logging.INFO + """Logging level for internal operations. + + + """ + + x_network_buffer: Optional[int] = None + """Size of the network buffer used in flow control. + + TODO: Not implemented in wandb-core. + + """ + + x_primary: bool = Field( + default=True, validation_alias=AliasChoices("x_primary", "x_primary_node") + ) + """Determines whether to save internal wandb files and metadata. + + In a distributed setting, this is useful for avoiding file overwrites + from secondary processes when only system metrics and logs are needed, + as the primary process handles the main logging. + """ + + x_proxies: Optional[Dict[str, str]] = None + """Custom proxy servers for requests to W&B. + + This is deprecated and will be removed in a future release. + Please use `http_proxy` and `https_proxy` instead. + + """ + + x_runqueue_item_id: Optional[str] = None + """ID of the Launch run queue item being processed. + + + """ + + x_save_requirements: bool = True + """Flag to save the requirements file.""" + + x_server_side_derived_summary: bool = False + """Flag to delegate automatic computation of summary from history to the server. + + This does not disable user-provided summary updates. + """ + + x_server_side_expand_glob_metrics: bool = True + """Flag to delegate glob matching of metrics in define_metric to the server. + + If the server does not support this, the client will perform the glob matching. + + """ + + x_service_transport: Optional[str] = None + """Transport method for communication with the wandb service. + + + """ + + x_service_wait: float = 30.0 + """Time in seconds to wait for the wandb-core internal service to start.""" + + x_skip_transaction_log: bool = False + """Whether to skip saving the run events to the transaction log. + + This is only relevant for online runs. Can be used to reduce the amount of + data written to disk. + + Should be used with caution, as it removes the gurantees about + recoverability. + """ + + x_start_time: Optional[float] = None + """The start time of the run in seconds since the Unix epoch. + + + """ + + x_stats_pid: int = os.getpid() + """PID of the process that started the wandb-core process to collect system stats for. + + + """ + + x_stats_sampling_interval: float = Field(default=15.0) + """Sampling interval for the system monitor in seconds.""" + + x_stats_neuron_monitor_config_path: Optional[str] = None + """Path to the default config file for the neuron-monitor tool. + + This is used to monitor AWS Trainium devices. + + """ + + x_stats_dcgm_exporter: Optional[str] = None + """Endpoint to extract Nvidia DCGM metrics from. + + Options: + - Extract DCGM-related metrics from a query to the Prometheus `/api/v1/query` endpoint. + It is a common practice to aggregate metrics reported by the instances of the DCGM Exporter + running on different nodes in a cluster using Prometheus. + - TODO: Parse metrics directly from the `/metrics` endpoint of the DCGM Exporter. + + Examples: + - `http://localhost:9400/api/v1/query?query=DCGM_FI_DEV_GPU_TEMP{node="l1337", cluster="globular"}`. + - TODO: `http://192.168.0.1:9400/metrics`. + + """ + + x_stats_open_metrics_endpoints: Optional[Dict[str, str]] = None + """OpenMetrics `/metrics` endpoints to monitor for system metrics.""" + + x_stats_open_metrics_filters: Union[ + Dict[str, Dict[str, str]], Sequence[str], None + ] = None + """Filter to apply to metrics collected from OpenMetrics `/metrics` endpoints. + + Supports two formats: + - `{"metric regex pattern, including endpoint name as prefix": {"label": "label value regex pattern"}}` + - `("metric regex pattern 1", "metric regex pattern 2", ...)` + """ + + x_stats_open_metrics_http_headers: Optional[Dict[str, str]] = None + """HTTP headers to add to OpenMetrics requests.""" + + x_stats_disk_paths: Optional[Sequence[str]] = ("/",) + """System paths to monitor for disk usage.""" + + x_stats_cpu_count: Optional[int] = None + """System CPU count. + + If set, overrides the auto-detected value in the run metadata. + """ + + x_stats_cpu_logical_count: Optional[int] = None + """Logical CPU count. + + If set, overrides the auto-detected value in the run metadata. + """ + + x_stats_gpu_count: Optional[int] = None + """GPU device count. + + If set, overrides the auto-detected value in the run metadata. + """ + + x_stats_gpu_type: Optional[str] = None + """GPU device type. + + If set, overrides the auto-detected value in the run metadata. + """ + + x_stats_gpu_device_ids: Optional[Sequence[int]] = None + """GPU device indices to monitor. + + If not set, the system monitor captures metrics for all GPUs. + Assumes 0-based indexing matching CUDA/ROCm device enumeration. + """ + + x_stats_buffer_size: int = 0 + """Number of system metric samples to buffer in memory in the wandb-core process. + + Can be accessed via run._system_metrics. + + """ + + x_stats_coreweave_metadata_base_url: str = "http://169.254.169.254" + """The scheme and hostname for contacting the CoreWeave metadata server. + + Only accessible from within a CoreWeave cluster. + + """ + + x_stats_coreweave_metadata_endpoint: str = "/api/v2/cloud-init/meta-data" + """The relative path on the CoreWeave metadata server to which to make requests. + + This must not include the schema and hostname prefix. + Only accessible from within a CoreWeave cluster. + + """ + + x_stats_track_process_tree: bool = False + """Monitor the entire process tree for resource usage, starting from `x_stats_pid`. + + When `True`, the system monitor aggregates the RSS, CPU%, and thread count + from the process with PID `x_stats_pid` and all of its descendants. + This can have a performance overhead and is disabled by default. + """ + + x_sync: bool = False + """Flag to indicate whether we are syncing a run from the transaction log. + + + """ + + x_sync_dir_suffix: str = "" + """Suffix to add to the run's directory name (sync_dir). + + This is set in wandb.init() to avoid naming conflicts. + If set, it is joined to the default name with a dash. + """ + + x_update_finish_state: bool = True + """Flag to indicate whether this process can update the run's final state on the server. + + Set to False in distributed training when only the main process should determine the final state. + """ + + # Model validator to catch legacy settings. + @model_validator(mode="before") + @classmethod + def catch_private_settings(cls, values): + """Check if a private field is provided and assign to the corresponding public one. + + This is a compatibility layer to handle previous versions of the settings. + + + """ + new_values = {} + for key in values: + # Internal settings are prefixed with "x_" instead of "_" + # as Pydantic does not allow "_" in field names. + if key.startswith("_"): + new_values["x" + key] = values[key] + else: + new_values[key] = values[key] + return new_values + + if IS_PYDANTIC_V2: + + @model_validator(mode="after") + def validate_mutual_exclusion_of_branching_args(self) -> Self: + """Check if `fork_from`, `resume`, and `resume_from` are mutually exclusive. + + + """ + if ( + sum( + o is not None + for o in [self.fork_from, self.resume, self.resume_from] + ) + > 1 + ): + raise ValueError( + "`fork_from`, `resume`, or `resume_from` are mutually exclusive. " + "Please specify only one of them." + ) + return self + + @model_validator(mode="after") + def validate_skip_transaction_log(self): + """Validate x_skip_transaction_log. + + + """ + if self._offline and self.x_skip_transaction_log: + raise ValueError("Cannot skip transaction log in offline mode") + return self + else: + + @root_validator(pre=False) # type: ignore [call-overload] + @classmethod + def validate_mutual_exclusion_of_branching_args(cls, values): + if ( + sum( + values.get(o) is not None + for o in ["fork_from", "resume", "resume_from"] + ) + > 1 + ): + raise ValueError( + "`fork_from`, `resume`, or `resume_from` are mutually exclusive. " + "Please specify only one of them." + ) + return values + + @root_validator(pre=False) # type: ignore [call-overload] + @classmethod + def validate_skip_transaction_log(cls, values): + if values.get("_offline") and values.get("x_skip_transaction_log"): + raise ValueError("Cannot skip transaction log in offline mode") + return values + + # Field validators. + @field_validator("anonymous", mode="after") + @classmethod + def validate_anonymous(cls, value: object) -> object: + if value is not deprecation.UNSET: + wandb.termwarn( + "The anonymous setting has no effect and will be removed" + + " in a future version.", + repeat=False, + ) + + return value + + @field_validator("api_key", mode="after") + @classmethod + def validate_api_key(cls, value): + """Validate the API key. + + + """ + if value is not None and (len(value) > len(value.strip())): + raise UsageError("API key cannot start or end with whitespace") + return value + + @field_validator("base_url", mode="after") + @classmethod + def validate_base_url(cls, value): + """Validate the base URL. + + + """ + urls.validate_url(value) + # wandb.ai-specific checks + if re.match(r".*wandb\.ai[^\.]*$", value) and "api." not in value: + # user might guess app.wandb.ai or wandb.ai is the default cloud server + raise ValueError( + f"{value} is not a valid server address, did you mean https://api.wandb.ai?" + ) + elif re.match(r".*wandb\.ai[^\.]*$", value) and not value.startswith("https"): + raise ValueError("http is not secure, please use https://api.wandb.ai") + return value.rstrip("/") + + @field_validator("code_dir", mode="before") + @classmethod + def validate_code_dir(cls, value): + """Validate the code directory. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("console", mode="after") + @classmethod + def validate_console(cls, value, values): + """Validate the console capture method. + + + """ + if value != "auto": + return value + + return "wrap" + + @field_validator("console_chunk_max_bytes", mode="after") + @classmethod + def validate_console_chunk_max_bytes(cls, value): + """Validate the console_chunk_max_bytes value. + + + """ + if value < 0: + raise ValueError("console_chunk_max_bytes must be non-negative") + + return value + + @field_validator("console_chunk_max_seconds", mode="after") + @classmethod + def validate_console_chunk_max_seconds(cls, value): + """Validate the console_chunk_max_seconds value. + + + """ + if value < 0: + raise ValueError("console_chunk_max_seconds must be non-negative") + + return value + + @field_validator("x_executable", mode="before") + @classmethod + def validate_x_executable(cls, value): + """Validate the Python executable path. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("x_extra_http_headers", mode="before") + @classmethod + def validate_x_extra_http_headers(cls, value): + if isinstance(value, str): + return json.loads(value) + return value + + @field_validator("x_file_stream_max_line_bytes", mode="after") + @classmethod + def validate_file_stream_max_line_bytes(cls, value): + """Validate the maximum line length for filestream JSONL files. + + + """ + if value is not None and value < 1: + raise ValueError("File stream max line bytes must be greater than 0") + return value + + @field_validator("x_files_dir", mode="before") + @classmethod + def validate_x_files_dir(cls, value): + """Validate the files directory. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("fork_from", mode="before") + @classmethod + def validate_fork_from(cls, value, values) -> Optional[RunMoment]: + """Validate the fork_from field. + + + """ + run_moment = cls._runmoment_preprocessor(value) + + if hasattr(values, "data"): + # pydantic v2 + values = values.data + else: + # pydantic v1 + values = values + + if ( + run_moment + and values.get("run_id") is not None + and values.get("run_id") == run_moment.run + ): + raise ValueError( + "Provided `run_id` is the same as the run to `fork_from`. " + "Please provide a different `run_id` or remove the `run_id` argument. " + "If you want to rewind the current run, please use `resume_from` instead." + ) + return run_moment + + @field_validator("http_proxy", mode="after") + @classmethod + def validate_http_proxy(cls, value): + """Validate the HTTP proxy. + + + """ + if value is None: + return None + urls.validate_url(value) + return value.rstrip("/") + + @field_validator("https_proxy", mode="after") + @classmethod + def validate_https_proxy(cls, value): + """Validate the HTTPS proxy. + + + """ + if value is None: + return None + urls.validate_url(value) + return value.rstrip("/") + + @field_validator("ignore_globs", mode="after") + @classmethod + def validate_ignore_globs(cls, value): + """Validate the ignore globs. + + + """ + return tuple(value) if not isinstance(value, tuple) else value + + @field_validator("program", mode="before") + @classmethod + def validate_program(cls, value): + """Validate the program path. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("program_abspath", mode="before") + @classmethod + def validate_program_abspath(cls, value): + """Validate the absolute program path. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("program_relpath", mode="before") + @classmethod + def validate_program_relpath(cls, value): + """Validate the relative program path. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("project", mode="after") + @classmethod + def validate_project(cls, value, values): + """Validate the project name. + + + """ + if value is None: + return None + invalid_chars_list = list("/\\#?%:") + if len(value) > 128: + raise UsageError(f"Invalid project name {value!r}: exceeded 128 characters") + invalid_chars = {char for char in invalid_chars_list if char in value} + if invalid_chars: + raise UsageError( + f"Invalid project name {value!r}: " + f"cannot contain characters {','.join(invalid_chars_list)!r}, " + f"found {','.join(invalid_chars)!r}" + ) + return value + + @field_validator("resume", mode="before") + @classmethod + def validate_resume(cls, value): + """Validate the resume behavior. + + + """ + if value is False: + return None + if value is True: + return "auto" + return value + + @field_validator("resume_from", mode="before") + @classmethod + def validate_resume_from(cls, value, values) -> Optional[RunMoment]: + """Validate the resume_from field. + + + """ + run_moment = cls._runmoment_preprocessor(value) + + if hasattr(values, "data"): + # pydantic v2 + values = values.data + else: + # pydantic v1 + values = values + + if ( + run_moment + and values.get("run_id") is not None + and values.get("run_id") != run_moment.run + ): + raise ValueError( + "Both `run_id` and `resume_from` have been specified with different ids." + ) + return run_moment + + @field_validator("root_dir", mode="before") + @classmethod + def validate_root_dir(cls, value): + """Validate the root directory. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("run_id", mode="after") + @classmethod + def validate_run_id(cls, value, values): + """Validate the run ID. + + + """ + if value is None: + return None + + if len(value) == 0: + raise UsageError("Run ID cannot be empty") + if len(value) > len(value.strip()): + raise UsageError("Run ID cannot start or end with whitespace") + if not bool(value.strip()): + raise UsageError("Run ID cannot contain only whitespace") + + # check if the run id contains any reserved characters + reserved_chars = ":;,#?/'" + if any(char in reserved_chars for char in value): + raise UsageError(f"Run ID cannot contain the characters: {reserved_chars}") + return value + + @field_validator("settings_system", mode="after") + @classmethod + def validate_settings_system(cls, value): + """Validate the system settings file path. + + + """ + if value is None: + return None + elif isinstance(value, pathlib.Path): + return str(_path_convert(value)) + else: + return _path_convert(value) + + @field_validator("x_service_wait", mode="after") + @classmethod + def validate_service_wait(cls, value): + """Validate the service wait time. + + + """ + if value < 0: + raise UsageError("Service wait time cannot be negative") + return value + + @field_validator("start_method", mode="after") + @classmethod + def validate_start_method(cls, value): + """Validate the start method for subprocesses. + + + """ + if value is None: + return value + wandb.termwarn( + "`start_method` is deprecated and will be removed in a future version " + "of wandb. This setting is currently non-functional and safely ignored.", + repeat=False, + ) + return value + + @field_validator("x_stats_coreweave_metadata_base_url", mode="after") + @classmethod + def validate_x_stats_coreweave_metadata_base_url(cls, value): + urls.validate_url(value) + return value.rstrip("/") + + @field_validator("x_stats_gpu_device_ids", mode="before") + @classmethod + def validate_x_stats_gpu_device_ids(cls, value): + """Validate the GPU device IDs. + + + """ + if isinstance(value, str): + return json.loads(value) + return value + + @field_validator("x_stats_neuron_monitor_config_path", mode="before") + @classmethod + def validate_x_stats_neuron_monitor_config_path(cls, value): + """Validate the path to the neuron-monitor config file. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + @field_validator("x_stats_open_metrics_endpoints", mode="before") + @classmethod + def validate_stats_open_metrics_endpoints(cls, value): + """Validate the OpenMetrics endpoints. + + + """ + if isinstance(value, str): + return json.loads(value) + return value + + @field_validator("x_stats_open_metrics_filters", mode="before") + @classmethod + def validate_stats_open_metrics_filters(cls, value): + """Validate the OpenMetrics filters. + + + """ + if isinstance(value, str): + return json.loads(value) + return value + + @field_validator("x_stats_open_metrics_http_headers", mode="before") + @classmethod + def validate_stats_open_metrics_http_headers(cls, value): + """Validate the OpenMetrics HTTP headers. + + + """ + if isinstance(value, str): + return json.loads(value) + return value + + @field_validator("x_stats_sampling_interval", mode="after") + @classmethod + def validate_stats_sampling_interval(cls, value): + """Validate the stats sampling interval. + + + """ + if value < 0.1: + raise UsageError("Stats sampling interval cannot be less than 0.1 seconds") + return value + + @field_validator("sweep_id", mode="after") + @classmethod + def validate_sweep_id(cls, value): + """Validate the sweep ID. + + + """ + if value is None: + return None + if len(value) == 0: + raise UsageError("Sweep ID cannot be empty") + if len(value) > len(value.strip()): + raise UsageError("Sweep ID cannot start or end with whitespace") + if not bool(value.strip()): + raise UsageError("Sweep ID cannot contain only whitespace") + return value + + @field_validator("run_tags", mode="before") + @classmethod + def validate_run_tags(cls, value): + """Validate run tags. + + Validates that each tag: + - Is between 1 and 64 characters in length (inclusive) + - Converts single string values to tuple format + - Preserves None values + + + + Args: + value: A string, list, tuple, or None representing tags + + Returns: + tuple: A tuple of validated tags, or None + + Raises: + ValueError: If any tag is empty or exceeds 64 characters + """ + if value is None: + return None + + # Convert to tuple if needed + if isinstance(value, str): + tags = (value,) + else: + tags = tuple(value) + + # Validate each tag and accumulate errors + errors = [] + for i, tag in enumerate(tags): + tag_str = str(tag) + if len(tag_str) == 0: + errors.append( + f"Tag at index {i} is empty. Tags must be between 1 and 64 characters" + ) + elif len(tag_str) > 64: + # Truncate long tags for display + display_tag = ( + f"{tag_str[:20]}...{tag_str[-20:]}" + if len(tag_str) > 43 + else tag_str + ) + errors.append( + f"Tag '{display_tag}' is {len(tag_str)} characters. Tags must be between 1 and 64 characters" + ) + + # Raise combined error if any validation issues were found + if errors: + raise ValueError("; ".join(errors)) + + return tags + + @field_validator("sweep_param_path", mode="before") + @classmethod + def validate_sweep_param_path(cls, value): + """Validate the sweep parameter path. + + + """ + # TODO: add native support for pathlib.Path + if isinstance(value, pathlib.Path): + return str(value) + return value + + # Computed fields. + + @computed_field # type: ignore[prop-decorator] + @property + def _args(self) -> List[str]: + if not self._jupyter: + return sys.argv[1:] + return [] + + @computed_field # type: ignore[prop-decorator] + @property + def _aws_lambda(self) -> bool: + """Check if we are running in a lambda environment.""" + from sentry_sdk.integrations.aws_lambda import ( # type: ignore[import-not-found] + get_lambda_bootstrap, + ) + + lambda_bootstrap = get_lambda_bootstrap() + if not lambda_bootstrap or not hasattr( + lambda_bootstrap, "handle_event_request" + ): + return False + return True + + @computed_field # type: ignore[prop-decorator] + @property + def _code_path_local(self) -> Optional[str]: + """The relative path from the current working directory to the code path. + + For example, if the code path is /home/user/project/example.py, and the + current working directory is /home/user/project, then the code path local + is example.py. + + If couldn't find the relative path, this will be an empty string. + """ + return self._get_program_relpath(self.program) if self.program else None + + @computed_field # type: ignore[prop-decorator] + @property + def _colab(self) -> bool: + return "google.colab" in sys.modules + + @computed_field # type: ignore[prop-decorator] + @property + def _ipython(self) -> bool: + return ipython.in_ipython() + + @computed_field # type: ignore[prop-decorator] + @property + def _jupyter(self) -> bool: + return ipython.in_jupyter() + + @computed_field # type: ignore[prop-decorator] + @property + def _kaggle(self) -> bool: + return util._is_likely_kaggle() + + @computed_field # type: ignore[prop-decorator] + @property + def _noop(self) -> bool: + return self.mode == "disabled" + + @computed_field # type: ignore[prop-decorator] + @property + def _notebook(self) -> bool: + return self._ipython or self._jupyter or self._colab or self._kaggle + + @computed_field # type: ignore[prop-decorator] + @property + def _offline(self) -> bool: + return self.mode in ("offline", "dryrun") + + @computed_field # type: ignore[prop-decorator] + @property + def _os(self) -> str: + """The operating system of the machine running the script.""" + return platform.platform(aliased=True) + + @computed_field # type: ignore[prop-decorator] + @property + def _platform(self) -> str: + return f"{platform.system()}-{platform.machine()}".lower() + + @computed_field # type: ignore[prop-decorator] + @property + def _python(self) -> str: + return f"{platform.python_implementation()} {platform.python_version()}" + + @computed_field # type: ignore[prop-decorator] + @property + def _shared(self) -> bool: + """Whether we are in shared mode. + + In "shared" mode, multiple processes can write to the same run, + for example from different machines. + """ + return self.mode == "shared" + + @computed_field # type: ignore[prop-decorator] + @property + def _start_datetime(self) -> str: + if self.x_start_time is None: + return "" + datetime_now = datetime.fromtimestamp(self.x_start_time) + return datetime_now.strftime("%Y%m%d_%H%M%S") + + @computed_field # type: ignore[prop-decorator] + @property + def _tmp_code_dir(self) -> str: + return _path_convert(self.sync_dir, "tmp", "code") + + @computed_field # type: ignore[prop-decorator] + @property + def _windows(self) -> bool: + return platform.system() == "Windows" + + @computed_field # type: ignore[prop-decorator] + @property + def app_url(self) -> str: + """The URL for the W&B UI, usually https://wandb.ai. + + This is different from `base_url` (like https://api.wandb.ai) which + is used to access W&B APIs programmatically. + """ + return self.app_url_override or util.api_to_app_url(self.base_url) + + @computed_field # type: ignore[prop-decorator] + @property + def colab_url(self) -> Optional[str]: + """The URL to the Colab notebook, if running in Colab.""" + if not self._colab: + return None + if self.x_jupyter_path and self.x_jupyter_path.startswith("fileId="): + unescaped = unquote(self.x_jupyter_path) + return "https://colab.research.google.com/notebook#" + unescaped + return None + + @computed_field # type: ignore[prop-decorator] + @property + def deployment(self) -> Literal["local", "cloud"]: + return "local" if self.is_local else "cloud" + + @computed_field # type: ignore[prop-decorator] + @property + def files_dir(self) -> str: + """Absolute path to the local directory where the run's files are stored.""" + # Must match the logic in settings.go in the service process. + return self.x_files_dir or _path_convert(self.sync_dir, "files") + + @computed_field # type: ignore[prop-decorator] + @property + def is_local(self) -> bool: + return str(self.base_url) != "https://api.wandb.ai" + + @computed_field # type: ignore[prop-decorator] + @property + def log_dir(self) -> str: + """The directory for storing log files.""" + return _path_convert(self.sync_dir, "logs") + + @computed_field # type: ignore[prop-decorator] + @property + def log_internal(self) -> str: + """The path to the file to use for internal logs.""" + return _path_convert(self.log_dir, "debug-internal.log") + + @computed_field # type: ignore[prop-decorator] + @property + def log_symlink_internal(self) -> str: + """The path to the symlink to the internal log file of the most recent run.""" + return _path_convert(self.wandb_dir, "debug-internal.log") + + @computed_field # type: ignore[prop-decorator] + @property + def log_symlink_user(self) -> str: + """The path to the symlink to the user-process log file of the most recent run.""" + return _path_convert(self.wandb_dir, "debug.log") + + @computed_field # type: ignore[prop-decorator] + @property + def log_user(self) -> str: + """The path to the file to use for user-process logs.""" + return _path_convert(self.log_dir, "debug.log") + + @computed_field # type: ignore[prop-decorator] + @property + def project_url(self) -> str: + """The W&B URL where the project can be viewed.""" + project_url = self._project_url_base() + if not project_url: + return "" + + return project_url + + @computed_field # type: ignore[prop-decorator] + @property + def resume_fname(self) -> str: + """The path to the resume file.""" + return _path_convert(self.wandb_dir, "wandb-resume.json") + + @computed_field # type: ignore[prop-decorator] + @property + def run_mode(self) -> Literal["run", "offline-run"]: + """The mode of the run. Can be either "run" or "offline-run".""" + return "run" if not self._offline else "offline-run" + + @computed_field # type: ignore[prop-decorator] + @property + def run_url(self) -> str: + """The W&B URL where the run can be viewed.""" + project_url = self._project_url_base() + if not all([project_url, self.run_id]): + return "" + + # Exclude specific safe characters from URL encoding to prevent 404 errors + safe_chars = "=+&$@" + return f"{project_url}/runs/{quote(self.run_id or '', safe=safe_chars)}" + + @computed_field # type: ignore[prop-decorator] + @property + def settings_workspace(self) -> str: + """The path to the workspace settings file.""" + return _path_convert(self.wandb_dir, "settings") + + @computed_field # type: ignore[prop-decorator] + @property + def sweep_url(self) -> str: + """The W&B URL where the sweep can be viewed.""" + project_url = self._project_url_base() + if not all([project_url, self.sweep_id]): + return "" + + return f"{project_url}/sweeps/{quote(self.sweep_id or '')}" + + @computed_field # type: ignore[prop-decorator] + @property + def sync_dir(self) -> str: + """The directory for storing the run's files.""" + name = f"{self.run_mode}-{self.timespec}-{self.run_id}" + + if self.x_sync_dir_suffix: + name += f"-{self.x_sync_dir_suffix}" + + return _path_convert(self.wandb_dir, name) + + @computed_field # type: ignore[prop-decorator] + @property + def sync_file(self) -> str: + """Path to the append-only binary transaction log file.""" + return _path_convert(self.sync_dir, f"run-{self.run_id}.wandb") + + @computed_field # type: ignore[prop-decorator] + @property + def sync_symlink_latest(self) -> str: + """Path to the symlink to the most recent run's transaction log file.""" + return _path_convert(self.wandb_dir, "latest-run") + + @computed_field # type: ignore[prop-decorator] + @property + def timespec(self) -> str: + """The time specification for the run.""" + return self._start_datetime + + @computed_field # type: ignore[prop-decorator] + @property + def wandb_dir(self) -> str: + """Full path to the wandb directory.""" + if self.use_dot_wandb is None: + use_dot = pathlib.Path(self.root_dir, ".wandb").exists() + else: + use_dot = self.use_dot_wandb + + dirname = ".wandb" if use_dot else "wandb" + return str(pathlib.Path(self.root_dir, dirname).expanduser()) + + # Methods to collect and update settings from different sources. + # + # The Settings class does not track the source of the settings, + # so it is up to the developer to ensure that the settings are applied + # in the correct order. Most of the updates are done in + # wandb/sdk/wandb_setup.py::_WandbSetup._settings_setup. + + def read_system_settings(self) -> settings_file.SettingsFiles: + """Read settings from the workspace and global settings files. + + The files are determined by the settings_system and settings_workspace + settings. + + The resulting object is a snapshot of the system settings at the time + this function is used and does not reflect the settings on this Settings + object. It can be used to update the files, and it should be short-lived + since it does not reflect external changes to the files. + + Updating the settings files does not update this Settings instance + and vice versa. + + + """ + local_settings = pathlib.Path(self.settings_workspace) + + if self.settings_system: + global_settings = pathlib.Path(self.settings_system) + else: + global_settings = None + + return settings_file.SettingsFiles( + global_settings=global_settings, + local_settings=local_settings, + ) + + def update_from_system_settings(self) -> None: + """Load settings from the settings files. + + If settings files contain invalid settings, prints and suppresses + the error. + + + """ + system_settings = self.read_system_settings() + + if len(system_settings.sources) == 0: + return + elif len(system_settings.sources) == 1: + source_string = str(system_settings.sources[0]) + else: + source_string = "\n" + "\n".join( + f" {source}" for source in system_settings.sources + ) + + # Print at the start so that users can diagnose uncaught exceptions. + if not self.quiet: + printed_sources = True + wandb.termlog(f"Loading settings from {source_string}") + else: + printed_sources = False + + try: + parsed_settings = _parse_system_settings(system_settings) + except Exception as e: + if not printed_sources: + wandb.termerror(f"Failed to load settings from {source_string}") + + if isinstance(e, ValidationError): + # Pydantic ValidationErrors have detailed messages that we can + # print without a stack trace. + wandb.termerror(str(e)) + + else: + # For all other errors, we need to dump a stack trace to make + # sure they're debuggable. + tb = traceback.format_exception(type(e), e, e.__traceback__) + wandb.termerror("".join(tb)) + + return + + # We parse and set in different steps so that we do not partially + # apply a broken settings file. + # + # Note that this runs validation functions a second time, but we expect + # them to succeed. + self.update_from_settings(parsed_settings) + + def update_from_env_vars(self, environ: Dict[str, Any]): + """Update settings from environment variables. + + + """ + env_prefix: str = "WANDB_" + private_env_prefix: str = env_prefix + "_" + special_env_var_names = { + env.APP_URL: "app_url_override", + "WANDB_SERVICE_TRANSPORT": "x_service_transport", + env.DIR: "root_dir", + env.NAME: "run_name", + env.NOTES: "run_notes", + env.TAGS: "run_tags", + env.JOB_TYPE: "run_job_type", + env.HTTP_TIMEOUT: "x_graphql_timeout_seconds", + env.FILE_PUSHER_TIMEOUT: "x_file_transfer_timeout_seconds", + env.USER_EMAIL: "email", + } + + for setting, value in environ.items(): + if not setting.startswith(env_prefix): + continue + + if setting in special_env_var_names: + key = special_env_var_names[setting] + elif setting.startswith(private_env_prefix): + key = "x_" + setting[len(private_env_prefix) :].lower() + else: + # otherwise, strip the prefix and convert to lowercase + key = setting[len(env_prefix) :].lower() + + if key not in self.__dict__: + continue + + if key in ("ignore_globs", "run_tags"): + value = value.split(",") + + if value is None: + continue + + setattr(self, key, value) + + def update_from_system_environment(self): + """Update settings from the system environment. + + + """ + # For code saving, only allow env var override if value from server is true, or + # if no preference was specified. + if (self.save_code is True or self.save_code is None) and ( + os.getenv(env.SAVE_CODE) is not None + or os.getenv(env.DISABLE_CODE) is not None + ): + self.save_code = env.should_save_code() + + if os.getenv(env.DISABLE_GIT) is not None: + self.disable_git = env.disable_git() + + # Attempt to get notebook information if not already set by the user + if self._jupyter and (self.notebook_name is None or self.notebook_name == ""): + meta = wandb.jupyter.notebook_metadata(self.silent) # type: ignore + self.x_jupyter_path = meta.get("path") + self.x_jupyter_name = meta.get("name") + self.x_jupyter_root = meta.get("root") + elif ( + self._jupyter + and self.notebook_name is not None + and os.path.exists(self.notebook_name) + ): + self.x_jupyter_path = self.notebook_name + self.x_jupyter_name = self.notebook_name + self.x_jupyter_root = os.getcwd() + elif self._jupyter: + wandb.termwarn( + "WANDB_NOTEBOOK_NAME should be a path to a notebook file, " + f"couldn't find {self.notebook_name}.", + ) + + # host is populated by update_from_env_vars if the corresponding env + # vars exist -- but if they don't, we'll fill them in here. + if self.host is None: + self.host = socket.gethostname() # type: ignore + + _executable = ( + self.x_executable + or os.environ.get(env._EXECUTABLE) + or sys.executable + or shutil.which("python3") + or "python3" + ) + self.x_executable = _executable + + if self.docker is None: + self.docker = env.get_docker(util.image_id_from_k8s()) + + # proceed if not in CLI mode + if self.x_cli_only_mode: + return + + program = self.program or self._get_program() + + if program is not None: + self._setup_code_paths(program) + else: + program = "" + + self.program = program + + def update_from_dict(self, settings: Dict[str, Any]) -> None: + """Update settings from a dictionary. + + + """ + for key, value in dict(settings).items(): + if value is not None: + setattr(self, key, value) + + def update_from_settings(self, settings: Settings) -> None: + """Update settings from another instance of `Settings`. + + + """ + d = {field: getattr(settings, field) for field in settings.model_fields_set} + if d: + self.update_from_dict(d) + + # Helper methods. + + def to_proto(self) -> wandb_settings_pb2.Settings: + """Generate a protobuf representation of the settings. + + + """ + settings_proto = wandb_settings_pb2.Settings() + for k, v in self.model_dump(exclude_none=True).items(): + if k in CLIENT_ONLY_SETTINGS: + continue + + # Special case for x_stats_open_metrics_filters. + if k == "x_stats_open_metrics_filters": + if isinstance(v, (list, set, tuple)): + setting = getattr(settings_proto, k) + setting.sequence.value.extend(v) + elif isinstance(v, dict): + setting = getattr(settings_proto, k) + for key, value in v.items(): + for kk, vv in value.items(): + setting.mapping.value[key].value[kk] = vv + else: + raise TypeError(f"Unsupported type {type(v)} for setting {k}") + continue + + # Special case for RunMoment fields. + if k in ("fork_from", "resume_from"): + run_moment = ( + v + if isinstance(v, RunMoment) + else RunMoment( + run=v.get("run"), + value=v.get("value"), + metric=v.get("metric"), + ) + ) + getattr(settings_proto, k).CopyFrom( + wandb_settings_pb2.RunMoment( + run=run_moment.run, + value=run_moment.value, + metric=run_moment.metric, + ) + ) + continue + + if isinstance(v, bool): + getattr(settings_proto, k).CopyFrom(BoolValue(value=v)) + elif isinstance(v, int): + getattr(settings_proto, k).CopyFrom(Int32Value(value=v)) + elif isinstance(v, float): + getattr(settings_proto, k).CopyFrom(DoubleValue(value=v)) + elif isinstance(v, str): + getattr(settings_proto, k).CopyFrom(StringValue(value=v)) + elif isinstance(v, (list, set, tuple)): + # we only support sequences of strings for now + sequence = getattr(settings_proto, k) + sequence.value.extend(v) + elif isinstance(v, dict): + mapping = getattr(settings_proto, k) + for key, value in v.items(): + # we only support dicts with string values for now + mapping.value[key] = value + elif v is None: + # None means that the setting value was not set. + pass + else: + raise TypeError(f"Unsupported type {type(v)} for setting {k}") + + return settings_proto + + def _get_program(self) -> Optional[str]: + """Get the program that started the current process.""" + if self._jupyter: + # If in a notebook, try to get the program from the notebook metadata. + if self.notebook_name: + return self.notebook_name + + if not self.x_jupyter_path: + return self.program + + if self.x_jupyter_path.startswith("fileId="): + return self.x_jupyter_name + + return self.x_jupyter_path + + # If not in a notebook, try to get the program from the environment + # or the __main__ module for scripts run as `python -m ...`. + program = os.getenv(env.PROGRAM) + if program is not None: + return program + + try: + import __main__ + except ImportError: + return None + + try: + if __main__.__spec__ is None: + python_args = __main__.__file__ + else: + python_args = f"-m {__main__.__spec__.name}" + except AttributeError: + return None + + return python_args + + @staticmethod + def _get_program_relpath(program: str, root: Optional[str] = None) -> Optional[str]: + """Get the relative path to the program from the root directory.""" + if not program: + return None + + root = root or os.getcwd() + if not root: + return None + + # For windows, if the root and program are on different drives, + # os.path.relpath will raise a ValueError. + if not filesystem.are_paths_on_same_drive( + pathlib.Path(root), pathlib.Path(program) + ): + return None + + full_path_to_program = os.path.join( + root, os.path.relpath(os.getcwd(), root), program + ) + if os.path.exists(full_path_to_program): + relative_path = os.path.relpath(full_path_to_program, start=root) + if "../" in relative_path: + return None + return relative_path + + return None + + def _project_url_base(self) -> str: + """Construct the base URL for the project.""" + if not all([self.entity, self.project]): + return "" + + return f"{self.app_url}/{quote(self.entity or '')}/{quote(self.project or '')}" + + @staticmethod + def _runmoment_preprocessor( + val: Union[RunMoment, str, None], + ) -> Optional[RunMoment]: + """Preprocess the setting for forking or resuming a run.""" + if isinstance(val, RunMoment) or val is None: + return val + elif isinstance(val, str): + return RunMoment.from_uri(val) + + if not IS_PYDANTIC_V2: + + def model_copy(self, *args, **kwargs): + return self.copy(*args, **kwargs) + + def model_dump(self, **kwargs): + """Compatibility method for Pydantic v1 to mimic v2's model_dump. + + In v1, this is equivalent to dict() but also includes computed properties. + + Args: + **kwargs: Options passed to the dict method + - exclude_none: Whether to exclude fields with None values + + Returns: + A dictionary of the model's fields and computed properties + """ + # Handle exclude_none separately since it's named differently in v1 + exclude_none = kwargs.pop("exclude_none", False) + + # Start with regular fields from dict() + result = self.dict(**kwargs) + + # Get all computed properties + for name in dir(self.__class__): + attr = getattr(self.__class__, name, None) + if isinstance(attr, property): + try: + # Only include properties that don't raise errors + value = getattr(self, name) + result[name] = value + except (AttributeError, NotImplementedError, TypeError, ValueError): + # Skip properties that can't be accessed or raise errors + pass + elif isinstance(attr, RunMoment): + value = getattr(self, name) + result[name] = value + + # Special Pydantic attributes that should always be excluded + exclude_fields = { + "model_config", + "model_fields", + "model_fields_set", + "__fields__", + "__model_fields_set", + "__pydantic_self__", + "__pydantic_initialised__", + } + + # Remove special Pydantic attributes + for field in exclude_fields: + if field in result: + del result[field] + + if exclude_none: + # Remove None values from the result + return {k: v for k, v in result.items() if v is not None} + + return result + + @property + def model_fields_set(self) -> set: + """Return a set of fields that have been explicitly set. + + This is a compatibility property for Pydantic v1 to mimic v2's model_fields_set. + """ + return getattr(self, "__fields_set__", set()) + + def _setup_code_paths(self, program: str): + """Sets the program_abspath and program_relpath settings.""" + if self._jupyter and self.x_jupyter_root: + self._infer_code_paths_for_jupyter(program) + else: + self._infer_code_path_for_program(program) + + def _infer_code_path_for_program(self, program: str): + """Finds the program's absolute and relative paths.""" + from .lib.gitlib import GitRepo + + try: + root = ( + GitRepo().root or os.getcwd() if not self.disable_git else os.getcwd() + ) + except Exception: + # if the git command fails, fall back to the current working directory + root = os.getcwd() + + self.program_relpath = self.program_relpath or self._get_program_relpath( + program, root + ) + + program_abspath = os.path.abspath( + os.path.join(root, os.path.relpath(os.getcwd(), root), program) + ) + + if os.path.exists(program_abspath): + self.program_abspath = program_abspath + + def _infer_code_paths_for_jupyter(self, program: str): + """Find the notebook's absolute and relative paths. + + Since the notebook's execution environment + is not the same as the current working directory. + We utilize the metadata provided by the jupyter server. + """ + if not self.x_jupyter_root or not program: + return None + + self.program_abspath = os.path.abspath( + os.path.join(self.x_jupyter_root, program) + ) + self.program_relpath = program + + +def _parse_system_settings( + system_settings: settings_file.SettingsFiles, +) -> Settings: + """Validate settings from a settings file. + + Returns: + A validated Settings object. + + Raises: + ValidationError: on invalid data. + Exception: arbitrary errors can occur when constructing Settings. + """ + fields: dict[str, Any] = dict() + + value: object # Can be transformed arbitrarily. + for key, value in system_settings.all().items(): + if key == "ignore_globs": + fields[key] = value.split(",") + + elif key == "anonymous": + wandb.termwarn( + "Deprecated setting 'anonymous' has no effect and will be" + + " removed in a future version of wandb." + + " Please delete it manually or by running `wandb login`" + + " to avoid errors.", + repeat=False, + ) + fields[key] = deprecation.UNSET + + elif key in ("settings_system", "root_dir"): + wandb.termwarn( + f"Ignoring setting {key!r} which is not allowed in a settings file." + + " Please delete it manually to avoid errors in the future." + ) + + else: + fields[key] = value + + # NOTE: Field validators must raise ValueError for Pydantic to wrap them + # in a ValidationError. Other kinds of errors will bubble up unaltered. + # + # Unfortunately, some validators return a UsageError, which has special + # handling in the CLI and may require care to change. + return Settings(**fields) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_setup.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5d05069802b9524baebe785bca7a31120adc24 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_setup.py @@ -0,0 +1,591 @@ +"""Global W&B library state. + +This module manages global state, which for wandb includes: + +- Settings configured through `wandb.setup()` +- The list of active runs +- A subprocess ("the internal service") that asynchronously uploads metrics + +This module is fork-aware: in a forked process such as that spawned by the +`multiprocessing` module, `wandb.singleton()` returns a new object, not the +one inherited from the parent process. This requirement comes from backward +compatibility with old design choices: the hardest one to fix is that wandb +was originally designed to have a single run for the entire process that +`wandb.init()` was meant to return. Back then, the only way to create +multiple simultaneous runs in a single script was to run subprocesses, and since +the built-in `multiprocessing` module forks by default, this required a PID +check to make `wandb.init()` ignore the inherited global run. + +Another reason for fork-awareness is that the process that starts up +the internal service owns it and is responsible for shutting it down, +and child processes shouldn't also try to do that. This is easier to +redesign. +""" + +from __future__ import annotations + +import logging +import os +import pathlib +import sys +import threading +from typing import TYPE_CHECKING, Any, Union + +import wandb +import wandb.integration.sagemaker as sagemaker +from wandb.env import CONFIG_DIR +from wandb.errors import UsageError +from wandb.sdk.lib import asyncio_manager, import_hooks, wb_logging + +from .lib import config_util, server + +if TYPE_CHECKING: + from wandb.sdk import wandb_run + from wandb.sdk.lib.service.service_connection import ServiceConnection + from wandb.sdk.wandb_settings import Settings + + +class _EarlyLogger: + """Early logger which captures logs in memory until logging can be configured.""" + + def __init__(self) -> None: + self._log: list[tuple] = [] + self._exception: list[tuple] = [] + # support old warn() as alias of warning() + self.warn = self.warning + + def debug(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._log.append((logging.DEBUG, msg, args, kwargs)) + + def info(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._log.append((logging.INFO, msg, args, kwargs)) + + def warning(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._log.append((logging.WARNING, msg, args, kwargs)) + + def error(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._log.append((logging.ERROR, msg, args, kwargs)) + + def critical(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._log.append((logging.CRITICAL, msg, args, kwargs)) + + def exception(self, msg: str, *args: Any, **kwargs: Any) -> None: + self._exception.append((msg, args, kwargs)) + + def log(self, level: str, msg: str, *args: Any, **kwargs: Any) -> None: + self._log.append((level, msg, args, kwargs)) + + def _flush(self, new_logger: Logger) -> None: + assert self is not new_logger + for level, msg, args, kwargs in self._log: + new_logger.log(level, msg, *args, **kwargs) + for msg, args, kwargs in self._exception: + new_logger.exception(msg, *args, **kwargs) + + +Logger = Union[logging.Logger, _EarlyLogger] + + +class _WandbSetup: + """W&B library singleton.""" + + def __init__(self, pid: int) -> None: + self._asyncer = asyncio_manager.AsyncioManager() + self._asyncer.start() + + self._connection: ServiceConnection | None = None + + self._active_runs: list[wandb_run.Run] = [] + self._active_runs_lock = threading.Lock() + + self._sweep_config: dict | None = None + self._server: server.Server | None = None + self._pid = pid + + # TODO(jhr): defer strict checks until settings are fully initialized + # and logging is ready + self._logger: Logger = _EarlyLogger() + + self._settings: Settings | None = None + self._settings_environ: dict[str, str] | None = None + + @property + def asyncer(self) -> asyncio_manager.AsyncioManager: + """The internal asyncio thread used by wandb.""" + return self._asyncer + + def add_active_run(self, run: wandb_run.Run) -> None: + """Append a run to the active runs list. + + This must be called when a run is initialized. + + Args: + run: A newly initialized run. + """ + with self._active_runs_lock: + if run not in self._active_runs: + self._active_runs.append(run) + + def remove_active_run(self, run: wandb_run.Run) -> None: + """Remove the run from the active runs list. + + This must be called when a run is finished. + + Args: + run: A run that is finished or crashed. + """ + try: + with self._active_runs_lock: + self._active_runs.remove(run) + except ValueError: + pass # Removing a run multiple times is not an error. + + @property + def most_recent_active_run(self) -> wandb_run.Run | None: + """The most recently initialized run that is not yet finished.""" + with self._active_runs_lock: + if not self._active_runs: + return None + + return self._active_runs[-1] + + def finish_all_active_runs(self) -> None: + """Finish all unfinished runs. + + NOTE: This is slightly inefficient as it finishes runs one at a time. + This only exists to support using the `reinit="finish_previous"` + setting together with `reinit="create_new"` which does not seem to be a + useful pattern. Since `"create_new"` should eventually become the + default and only behavior, it does not seem worth optimizing. + """ + # Take a snapshot as each call to `finish()` modifies `_active_runs`. + with self._active_runs_lock: + runs_copy = list(self._active_runs) + + for run in runs_copy: + run.finish() + + def did_environment_change(self) -> bool: + """Check if os.environ has changed since settings were initialized.""" + if not self._settings_environ: + return False + + exclude_env_vars = {"WANDB_SERVICE", "WANDB_KUBEFLOW_URL"} + singleton_env = { + k: v + for k, v in self._settings_environ.items() + if k.startswith("WANDB_") and k not in exclude_env_vars + } + os_env = { + k: v + for k, v in os.environ.items() + if k.startswith("WANDB_") and k not in exclude_env_vars + } + + return ( + set(singleton_env.keys()) != set(os_env.keys()) # + or set(singleton_env.values()) != set(os_env.values()) + ) + + def _load_settings( + self, + *, + system_settings_path: str | None, + disable_sagemaker: bool, + overrides: Settings | None = None, + ) -> None: + """Load settings from environment variables, config files, etc. + + Args: + system_settings_path: Location of system settings file to use. + If not provided, reads the WANDB_CONFIG_DIR environment + variable or uses the default location. + disable_sagemaker: If true, skips modifying settings based on + SageMaker. + overrides: Additional settings to apply to the global settings. + """ + from wandb.sdk.wandb_settings import Settings + + self._settings = Settings() + + # the pid of the process to monitor for system stats + pid = os.getpid() + self._logger.info(f"Current SDK version is {wandb.__version__}") + self._logger.info(f"Configure stats pid to {pid}") + self._settings.x_stats_pid = pid + + if system_settings_path: + self._settings.settings_system = system_settings_path + elif config_dir_str := os.getenv(CONFIG_DIR, None): + config_dir = pathlib.Path(config_dir_str).expanduser() + self._settings.settings_system = str(config_dir / "settings") + else: + self._settings.settings_system = str( + pathlib.Path("~", ".config", "wandb", "settings").expanduser() + ) + self._settings.update_from_system_settings() + + # load settings from the environment variables + self._logger.info("Loading settings from environment variables") + self._settings_environ = os.environ.copy() + self._settings.update_from_env_vars(self._settings_environ) + + # infer settings from the system environment + self._settings.update_from_system_environment() + + # load SageMaker settings + if ( + not self._settings.sagemaker_disable + and not disable_sagemaker + and sagemaker.is_using_sagemaker() + ): + self._logger.info("Loading SageMaker settings") + sagemaker.set_global_settings(self._settings) + + # load settings from the passed init/setup settings + if overrides: + self._settings.update_from_settings(overrides) + + wandb.termsetup(self._settings, None) + + def _update(self, settings: Settings | None) -> None: + """Update settings, initializing them if necessary. + + Args: + settings: Overrides to apply, if any. + """ + if not self._settings: + system_settings_path = settings.settings_system if settings else None + disable_sagemaker = settings.sagemaker_disable if settings else False + self._load_settings( + system_settings_path=system_settings_path, + disable_sagemaker=disable_sagemaker, + overrides=settings, + ) + + # This is 'elif' because load_settings already applies overrides. + elif settings: + self._settings.update_from_settings(settings) + + def update_user_settings(self) -> None: + # Get rid of cached results to force a refresh. + self._server = None + + def _early_logger_flush(self, new_logger: Logger) -> None: + if self._logger is new_logger: + return + + if isinstance(self._logger, _EarlyLogger): + self._logger._flush(new_logger) + self._logger = new_logger + + def _get_logger(self) -> Logger: + return self._logger + + @property + def settings(self) -> Settings: + """The global wandb settings. + + Initializes settings if they have not yet been loaded. + """ + if not self._settings: + self._load_settings( + system_settings_path=None, + disable_sagemaker=False, + ) + assert self._settings + + return self._settings + + @property + def settings_if_loaded(self) -> Settings | None: + """The global wandb settings, or None if not yet loaded.""" + return self._settings + + def _get_entity(self) -> str | None: + if self._settings and self._settings._offline: + return None + entity = self.viewer.get("entity") + return entity + + def _get_username(self) -> str | None: + if self._settings and self._settings._offline: + return None + return self.viewer.get("username") + + def _get_teams(self) -> list[str]: + if self._settings and self._settings._offline: + return [] + teams = self.viewer.get("teams") + if teams: + teams = [team["node"]["name"] for team in teams["edges"]] + return teams or [] + + @property + def viewer(self) -> dict[str, Any]: + if self._server is None: + self._server = server.Server(settings=self.settings) + + return self._server.viewer + + def _load_user_settings(self) -> dict[str, Any] | None: + # offline? + if self._server is None: + return None + + flags = self._server._flags + user_settings = dict() + if "code_saving_enabled" in flags: + user_settings["save_code"] = flags["code_saving_enabled"] + + email = self.viewer.get("email", None) + if email: + user_settings["email"] = email + + return user_settings + + @property + def config(self) -> dict: + sweep_path = self.settings.sweep_param_path + if sweep_path: + self._sweep_config = config_util.dict_from_config_file( + sweep_path, must_exist=True + ) + + config = {} + + # if config_paths was set, read in config dict + if self.settings.config_paths: + # TODO(jhr): handle load errors, handle list of files + for config_path in self.settings.config_paths: + config_dict = config_util.dict_from_config_file(config_path) + if config_dict: + config.update(config_dict) + + return config + + def _teardown(self, exit_code: int | None = None) -> None: + import_hooks.unregister_all_post_import_hooks() + + if self._connection: + internal_exit_code = self._connection.teardown(exit_code or 0) + else: + internal_exit_code = None + + self._asyncer.join() + + if internal_exit_code not in (None, 0): + sys.exit(internal_exit_code) + + def ensure_service(self) -> ServiceConnection: + """Returns a connection to the service process creating it if needed.""" + if self._connection: + return self._connection + + from wandb.sdk.lib.service import service_connection + + self._connection = service_connection.connect_to_service( + self._asyncer, + self.settings, + ) + return self._connection + + def assert_service(self) -> ServiceConnection: + """Returns a connection to the service process, asserting it exists. + + Unlike ensure_service(), this will not start up a service process + if it didn't already exist. + """ + if not self._connection: + raise AssertionError("Expected service process to exist.") + + return self._connection + + +_singleton: _WandbSetup | None = None +"""The W&B library singleton, or None if not yet set up. + +The value is invalid and must not be used if `os.getpid() != _singleton._pid`. +""" + +_singleton_lock = threading.Lock() + + +def singleton() -> _WandbSetup: + """The W&B singleton for the current process. + + The first call to this in this process (which may be a fork of another + process) creates the singleton, and all subsequent calls return it + until teardown(). This does not start the service process. + """ + return _setup(start_service=False, load_settings=False) + + +@wb_logging.log_to_all_runs() +def _setup( + settings: Settings | None = None, + start_service: bool = True, + load_settings: bool = True, +) -> _WandbSetup: + """Set up library context. + + Args: + settings: Global settings to set, or updates to the global settings + if the singleton has already been initialized. + start_service: Whether to start up the service process. + NOTE: A service process will only be started if allowed by the + global settings (after the given updates). The service will not + start up if the mode resolves to "disabled". + load_settings: Whether to load settings from the environment + if creating a new singleton. If False, then settings and + start_service must be None. + """ + global _singleton + + if not load_settings and settings: + raise ValueError("Cannot pass settings if load_settings is False.") + if not load_settings and start_service: + raise ValueError("Cannot use start_service if load_settings is False.") + + pid = os.getpid() + with _singleton_lock: + if _singleton and _singleton._pid == pid: + current_singleton = _singleton + else: + current_singleton = _WandbSetup(pid=pid) + + if load_settings: + current_singleton._update(settings) + + if start_service and not current_singleton.settings._noop: + current_singleton.ensure_service() + + _singleton = current_singleton + + # Update after configuring the _singleton. + # + # Do not hold the lock while updating credentials, as it writes back + # to settings. + if settings: + _maybe_update_credentials(settings) + + return current_singleton + + +def _maybe_update_credentials(settings: Settings) -> None: + """Update session credentials if they're set on settings. + + This is a refactoring step for moving credentials into a separate module + and out of settings. If a user calls `wandb.setup()` explicitly with an + api_key or other credential, this overwrites credentials that might have + been set by a call to `wandb.login()`. + """ + if settings.api_key and settings.identity_token_file: + raise UsageError( + "The api_key and identity_token_file settings cannot be used together." + ) + + from wandb.sdk.lib import wbauth + + if settings.api_key: + wbauth.use_explicit_auth( + wbauth.AuthApiKey( + host=wbauth.HostUrl(settings.base_url, app_url=settings.app_url), + api_key=settings.api_key, + ), + source="wandb.setup()", + ) + + elif settings.identity_token_file: + wbauth.use_explicit_auth( + wbauth.AuthIdentityTokenFile( + host=wbauth.HostUrl(settings.base_url, app_url=settings.app_url), + path=settings.identity_token_file, + ), + source="wandb.setup()", + ) + + +def setup(settings: Settings | None = None) -> _WandbSetup: + """Prepares W&B for use in the current process and its children. + + You can usually ignore this as it is implicitly called by `wandb.init()`. + + When using wandb in multiple processes, calling `wandb.setup()` + in the parent process before starting child processes may improve + performance and resource utilization. + + Note that `wandb.setup()` modifies `os.environ`, and it is important + that child processes inherit the modified environment variables. + + See also `wandb.teardown()`. + + Args: + settings: Configuration settings to apply globally. These can be + overridden by subsequent `wandb.init()` calls. + + Example: + ```python + import multiprocessing + + import wandb + + + def run_experiment(params): + with wandb.init(config=params): + # Run experiment + pass + + + if __name__ == "__main__": + # Start backend and set global config + wandb.setup(settings={"project": "my_project"}) + + # Define experiment parameters + experiment_params = [ + {"learning_rate": 0.01, "epochs": 10}, + {"learning_rate": 0.001, "epochs": 20}, + ] + + # Start multiple processes, each running a separate experiment + processes = [] + for params in experiment_params: + p = multiprocessing.Process(target=run_experiment, args=(params,)) + p.start() + processes.append(p) + + # Wait for all processes to complete + for p in processes: + p.join() + + # Optional: Explicitly shut down the backend + wandb.teardown() + ``` + """ + return _setup(settings=settings) + + +@wb_logging.log_to_all_runs() +def teardown(exit_code: int | None = None) -> None: + """Waits for W&B to finish and frees resources. + + Completes any runs that were not explicitly finished + using `run.finish()` and waits for all data to be uploaded. + + It is recommended to call this at the end of a session + that used `wandb.setup()`. It is invoked automatically + in an `atexit` hook, but this is not reliable in certain setups + such as when using Python's `multiprocessing` module. + """ + global _singleton + + from wandb.sdk.lib import wbauth + + with _singleton_lock: + orig_singleton = _singleton + _singleton = None + + if orig_singleton: + orig_singleton._teardown(exit_code=exit_code) + + wbauth.unauthenticate_session(update_settings=False) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_summary.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..938bb9820b4bb992f2b1120e591fc7b0303b103f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_summary.py @@ -0,0 +1,149 @@ +import abc +import typing as t + +from .interface.summary_record import SummaryItem, SummaryRecord + + +def _get_dict(d): + if isinstance(d, dict): + return d + # assume argparse Namespace + return vars(d) + + +class SummaryDict(metaclass=abc.ABCMeta): + """dict-like wrapper for the nested dictionaries in a SummarySubDict. + + Triggers self._root._callback on property changes. + """ + + @abc.abstractmethod + def _as_dict(self): + raise NotImplementedError + + @abc.abstractmethod + def _update(self, record: SummaryRecord): + raise NotImplementedError + + def keys(self): + return [k for k in self._as_dict().keys() if k != "_wandb"] + + def get(self, key, default=None): + return self._as_dict().get(key, default) + + def __getitem__(self, key): + item = self._as_dict()[key] + + if isinstance(item, dict): + # this nested dict needs to be wrapped: + wrapped_item = SummarySubDict() + object.__setattr__(wrapped_item, "_items", item) + object.__setattr__(wrapped_item, "_parent", self) + object.__setattr__(wrapped_item, "_parent_key", key) + + return wrapped_item + + # this item isn't a nested dict + return item + + __getattr__ = __getitem__ + + def __setitem__(self, key, val): + self.update({key: val}) + + __setattr__ = __setitem__ + + def __delattr__(self, key): + record = SummaryRecord() + item = SummaryItem() + item.key = (key,) + record.remove = (item,) + self._update(record) + + __delitem__ = __delattr__ + + def update(self, d: t.Dict): + record = SummaryRecord() + for key, value in d.items(): + item = SummaryItem() + item.key = (key,) + item.value = value + record.update.append(item) + + self._update(record) + + +class Summary(SummaryDict): + """Track single values for each metric for each run. + + By default, a metric's summary is the last value of its History. + + For example, `wandb.log({'accuracy': 0.9})` will add a new step to History and + update Summary to the latest value. In some cases, it's more useful to have + the maximum or minimum of a metric instead of the final value. You can set + history manually `(wandb.summary['accuracy'] = best_acc)`. + + In the UI, summary metrics appear in the table to compare across runs. + Summary metrics are also used in visualizations like the scatter plot and + parallel coordinates chart. + + After training has completed, you may want to save evaluation metrics to a + run. Summary can handle numpy arrays and PyTorch/TensorFlow tensors. When + you save one of these types to Summary, we persist the entire tensor in a + binary file and store high level metrics in the summary object, such as min, + mean, variance, and 95th percentile. + + Examples: + ```python + wandb.init(config=args) + + best_accuracy = 0 + for epoch in range(1, args.epochs + 1): + test_loss, test_accuracy = test() + if test_accuracy > best_accuracy: + wandb.run.summary["best_accuracy"] = test_accuracy + best_accuracy = test_accuracy + ``` + """ + + _update_callback: t.Callable + _get_current_summary_callback: t.Callable + + def __init__(self, get_current_summary_callback: t.Callable): + super().__init__() + object.__setattr__(self, "_update_callback", None) + object.__setattr__( + self, "_get_current_summary_callback", get_current_summary_callback + ) + + def _set_update_callback(self, update_callback: t.Callable): + object.__setattr__(self, "_update_callback", update_callback) + + def _as_dict(self): + return self._get_current_summary_callback() + + def _update(self, record: SummaryRecord): + if self._update_callback: # type: ignore + self._update_callback(record) + + +class SummarySubDict(SummaryDict): + """Non-root node of the summary data structure. + + Contains a path to itself from the root. + """ + + _items: t.Dict + _parent: SummaryDict + _parent_key: str + + def __init__(self): + object.__setattr__(self, "_items", dict()) + object.__setattr__(self, "_parent", None) + object.__setattr__(self, "_parent_key", None) + + def _as_dict(self): + return self._items + + def _update(self, record: SummaryRecord): + return self._parent._update(record._add_next_parent(self._parent_key)) diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_sweep.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_sweep.py new file mode 100644 index 0000000000000000000000000000000000000000..1b29848e21dcc9712fdc61f81ef75744b028801a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_sweep.py @@ -0,0 +1,121 @@ +import urllib.parse +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union + +import wandb +from wandb import env + +from . import wandb_login + +if TYPE_CHECKING: + from wandb.wandb_controller import _WandbController + + +def _get_sweep_url(api, sweep_id): + """Return sweep url if we can figure it out.""" + if api.api_key: + if api.settings("entity") is None: + viewer = api.viewer() + if viewer.get("entity"): + api.set_setting("entity", viewer["entity"]) + project = api.settings("project") + if not project: + return + if api.settings("entity"): + return "{base}/{entity}/{project}/sweeps/{sweepid}".format( + base=api.app_url, + entity=urllib.parse.quote(api.settings("entity")), + project=urllib.parse.quote(project), + sweepid=urllib.parse.quote(sweep_id), + ) + + +def sweep( + sweep: Union[dict, Callable], + entity: Optional[str] = None, + project: Optional[str] = None, + prior_runs: Optional[List[str]] = None, +) -> str: + """Initialize a hyperparameter sweep. + + Search for hyperparameters that optimizes a cost function + of a machine learning model by testing various combinations. + + Make note the unique identifier, `sweep_id`, that is returned. + At a later step provide the `sweep_id` to a sweep agent. + + See [Sweep configuration structure](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration) + for information on how to define your sweep. + + Args: + sweep: The configuration of a hyperparameter search. + (or configuration generator). + If you provide a callable, ensure that the callable does + not take arguments and that it returns a dictionary that + conforms to the W&B sweep config spec. + entity: The username or team name where you want to send W&B + runs created by the sweep to. Ensure that the entity you + specify already exists. If you don't specify an entity, + the run will be sent to your default entity, + which is usually your username. + project: The name of the project where W&B runs created from + the sweep are sent to. If the project is not specified, the + run is sent to a project labeled 'Uncategorized'. + prior_runs: The run IDs of existing runs to add to this sweep. + + Returns: + str: A unique identifier for the sweep. + """ + from wandb.apis import InternalApi + from wandb.sdk.launch.sweeps.utils import handle_sweep_config_violations + + if callable(sweep): + sweep = sweep() + """Sweep create for controller api and jupyter (eventually for cli).""" + + # Project may be only found in the sweep config. + if project is None and isinstance(sweep, dict): + project = sweep.get("project", None) + + if entity: + env.set_entity(entity) + if project: + env.set_project(project) + + # Make sure we are logged in + if wandb.run is None: + wandb_login._login(_silent=True) + api = InternalApi() + sweep_id, warnings = api.upsert_sweep(sweep, prior_runs=prior_runs) + handle_sweep_config_violations(warnings) + print("Create sweep with ID:", sweep_id) # noqa: T201 + sweep_url = _get_sweep_url(api, sweep_id) + if sweep_url: + print("Sweep URL:", sweep_url) # noqa: T201 + return sweep_id + + +def controller( + sweep_id_or_config: Optional[Union[str, Dict]] = None, + entity: Optional[str] = None, + project: Optional[str] = None, +) -> "_WandbController": + """Public sweep controller constructor. + + Examples: + ```python + import wandb + + tuner = wandb.controller(...) + print(tuner.sweep_config) + print(tuner.sweep_id) + tuner.configure_search(...) + tuner.configure_stopping(...) + ``` + + """ + from ..wandb_controller import _WandbController + + c = _WandbController( + sweep_id_or_config=sweep_id_or_config, entity=entity, project=project + ) + return c diff --git a/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_watch.py b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_watch.py new file mode 100644 index 0000000000000000000000000000000000000000..170a5c55ffec85d2c6cb052028183104a2cbb132 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/sdk/wandb_watch.py @@ -0,0 +1,146 @@ +"""watch.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Sequence + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + +import wandb + +from .lib import telemetry + +if TYPE_CHECKING: + import torch # type: ignore [import-not-found] + +logger = logging.getLogger("wandb") + +_global_watch_idx = 0 + + +def _watch( + run: wandb.Run, + models: torch.nn.Module | Sequence[torch.nn.Module], + criterion: torch.F | None = None, + log: Literal["gradients", "parameters", "all"] | None = "gradients", + log_freq: int = 1000, + idx: int | None = None, + log_graph: bool = False, +): + """Hooks into the given PyTorch model(s) to monitor gradients and the model's computational graph. + + This function can track parameters, gradients, or both during training. It should be + extended to support arbitrary machine learning models in the future. + + Args: + run (wandb.Run): The run object to log to. + models (Union[torch.nn.Module, Sequence[torch.nn.Module]]): + A single model or a sequence of models to be monitored. + criterion (Optional[torch.F]): + The loss function being optimized (optional). + log (Optional[Literal["gradients", "parameters", "all"]]): + Specifies whether to log "gradients", "parameters", or "all". + Set to None to disable logging. (default="gradients") + log_freq (int): + Frequency (in batches) to log gradients and parameters. (default=1000) + idx (Optional[int]): + Index used when tracking multiple models with `wandb.watch`. (default=None) + log_graph (bool): + Whether to log the model's computational graph. (default=False) + + Returns: + wandb.Graph: + The graph object, which will be populated after the first backward pass. + + Raises: + ValueError: If `wandb.init` has not been called. + TypeError: If any of the models are not instances of `torch.nn.Module`. + """ + global _global_watch_idx + + with telemetry.context() as tel: + tel.feature.watch = True + + logger.info("Watching") + + if log not in {"gradients", "parameters", "all", None}: + raise ValueError("log must be one of 'gradients', 'parameters', 'all', or None") + + log_parameters = log in {"parameters", "all"} + log_gradients = log in {"gradients", "all"} + + if not isinstance(models, (tuple, list)): + models = (models,) + + torch = wandb.util.get_module( + "torch", required="wandb.watch only works with pytorch, couldn't import torch." + ) + + for model in models: + if not isinstance(model, torch.nn.Module): + raise TypeError( + f"Expected a pytorch model (torch.nn.Module). Received {type(model)}" + ) + + graphs = [] + prefix = "" + + if idx is None: + idx = _global_watch_idx + for local_idx, model in enumerate(models): + global_idx = idx + local_idx + _global_watch_idx += 1 + if global_idx > 0: + # TODO: this makes ugly chart names like gradients/graph_1conv1d.bias + prefix = f"graph_{global_idx}" + + if log_parameters: + run._torch.add_log_parameters_hook( + model, + prefix=prefix, + log_freq=log_freq, + ) + + if log_gradients: + run._torch.add_log_gradients_hook( + model, + prefix=prefix, + log_freq=log_freq, + ) + + if log_graph: + graph = run._torch.hook_torch(model, criterion, graph_idx=global_idx) + graphs.append(graph) + # NOTE: the graph is set in run.summary by hook_torch on the backward pass + return graphs + + +def _unwatch( + run: wandb.Run, models: torch.nn.Module | Sequence[torch.nn.Module] | None = None +) -> None: + """Remove pytorch model topology, gradient and parameter hooks. + + Args: + run (wandb.Run): + The run object to log to. + models (torch.nn.Module | Sequence[torch.nn.Module]): + Optional list of pytorch models that have had watch called on them + """ + if models: + if not isinstance(models, (tuple, list)): + models = (models,) + for model in models: + if not hasattr(model, "_wandb_hook_names"): + wandb.termwarn(f"{model} model has not been watched") + else: + for name in model._wandb_hook_names: + run._torch.unhook(name) + delattr(model, "_wandb_hook_names") + # TODO: we should also remove recursively model._wandb_watch_called + + else: + run._torch.unhook_all() diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/setup.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..92482c03536227f47dcab140914feba539ff4d70 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup, find_packages + +setup( + name='gql', + version='0.2.0', + description='GraphQL client for Python', + long_description=open('README.md').read(), + long_description_content_type="text/markdown", + url='https://github.com/graphql-python/gql', + author='Syrus Akbary', + author_email='me@syrusakbary.com', + license='MIT', + classifiers=[ + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Topic :: Software Development :: Libraries', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: Implementation :: PyPy', + ], + keywords='api graphql protocol rest relay gql client', + packages=find_packages(include=["gql*"]), + install_requires=[ + 'six>=1.10.0', + 'graphql-core>=0.5.0,<2', + 'promise>=2.0,<3', + 'requests>=2.12,<3' + ], + tests_require=[ + 'pytest>=3,<4', + 'pytest-cov>=2.8,<3', + 'mock>=3,<4', + 'vcrpy>=2.1,<3' + ], +) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py new file mode 100644 index 0000000000000000000000000000000000000000..51f29a59248cdeee13a7fc34b1c92154a0977e81 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/fixtures.py @@ -0,0 +1,96 @@ +from collections import namedtuple + +Human = namedtuple('Human', 'id name friends appearsIn homePlanet') + +luke = Human( + id='1000', + name='Luke Skywalker', + friends=['1002', '1003', '2000', '2001'], + appearsIn=[4, 5, 6], + homePlanet='Tatooine', +) + +vader = Human( + id='1001', + name='Darth Vader', + friends=['1004'], + appearsIn=[4, 5, 6], + homePlanet='Tatooine', +) + +han = Human( + id='1002', + name='Han Solo', + friends=['1000', '1003', '2001'], + appearsIn=[4, 5, 6], + homePlanet=None, +) + +leia = Human( + id='1003', + name='Leia Organa', + friends=['1000', '1002', '2000', '2001'], + appearsIn=[4, 5, 6], + homePlanet='Alderaan', +) + +tarkin = Human( + id='1004', + name='Wilhuff Tarkin', + friends=['1001'], + appearsIn=[4], + homePlanet=None, +) + +humanData = { + '1000': luke, + '1001': vader, + '1002': han, + '1003': leia, + '1004': tarkin, +} + +Droid = namedtuple('Droid', 'id name friends appearsIn primaryFunction') + +threepio = Droid( + id='2000', + name='C-3PO', + friends=['1000', '1002', '1003', '2001'], + appearsIn=[4, 5, 6], + primaryFunction='Protocol', +) + +artoo = Droid( + id='2001', + name='R2-D2', + friends=['1000', '1002', '1003'], + appearsIn=[4, 5, 6], + primaryFunction='Astromech', +) + +droidData = { + '2000': threepio, + '2001': artoo, +} + + +def getCharacter(id): + return humanData.get(id) or droidData.get(id) + + +def getFriends(character): + return map(getCharacter, character.friends) + + +def getHero(episode): + if episode == 5: + return luke + return artoo + + +def getHuman(id): + return humanData.get(id) + + +def getDroid(id): + return droidData.get(id) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/schema.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..da2d88d8cb674ddde76b593420e2405d39c53c4d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/schema.py @@ -0,0 +1,146 @@ +from graphql.type import (GraphQLArgument, GraphQLEnumType, GraphQLEnumValue, + GraphQLField, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLSchema, + GraphQLString) + +from .fixtures import getDroid, getFriends, getHero, getHuman + +episodeEnum = GraphQLEnumType( + 'Episode', + description='One of the films in the Star Wars Trilogy', + values={ + 'NEWHOPE': GraphQLEnumValue( + 4, + description='Released in 1977.', + ), + 'EMPIRE': GraphQLEnumValue( + 5, + description='Released in 1980.', + ), + 'JEDI': GraphQLEnumValue( + 6, + description='Released in 1983.', + ) + } +) + +characterInterface = GraphQLInterfaceType( + 'Character', + description='A character in the Star Wars Trilogy', + fields=lambda: { + 'id': GraphQLField( + GraphQLNonNull(GraphQLString), + description='The id of the character.' + ), + 'name': GraphQLField( + GraphQLString, + description='The name of the character.' + ), + 'friends': GraphQLField( + GraphQLList(characterInterface), + description='The friends of the character, or an empty list if they have none.' + ), + 'appearsIn': GraphQLField( + GraphQLList(episodeEnum), + description='Which movies they appear in.' + ), + }, + resolve_type=lambda character, *_: humanType if getHuman(character.id) else droidType, +) + +humanType = GraphQLObjectType( + 'Human', + description='A humanoid creature in the Star Wars universe.', + fields=lambda: { + 'id': GraphQLField( + GraphQLNonNull(GraphQLString), + description='The id of the human.', + ), + 'name': GraphQLField( + GraphQLString, + description='The name of the human.', + ), + 'friends': GraphQLField( + GraphQLList(characterInterface), + description='The friends of the human, or an empty list if they have none.', + resolver=lambda human, *_: getFriends(human), + ), + 'appearsIn': GraphQLField( + GraphQLList(episodeEnum), + description='Which movies they appear in.', + ), + 'homePlanet': GraphQLField( + GraphQLString, + description='The home planet of the human, or null if unknown.', + ) + }, + interfaces=[characterInterface] +) + +droidType = GraphQLObjectType( + 'Droid', + description='A mechanical creature in the Star Wars universe.', + fields=lambda: { + 'id': GraphQLField( + GraphQLNonNull(GraphQLString), + description='The id of the droid.', + ), + 'name': GraphQLField( + GraphQLString, + description='The name of the droid.', + ), + 'friends': GraphQLField( + GraphQLList(characterInterface), + description='The friends of the droid, or an empty list if they have none.', + resolver=lambda droid, *_: getFriends(droid), + ), + 'appearsIn': GraphQLField( + GraphQLList(episodeEnum), + description='Which movies they appear in.', + ), + 'primaryFunction': GraphQLField( + GraphQLString, + description='The primary function of the droid.', + ) + }, + interfaces=[characterInterface] +) + +queryType = GraphQLObjectType( + 'Query', + fields=lambda: { + 'hero': GraphQLField( + characterInterface, + args={ + 'episode': GraphQLArgument( + description='If omitted, returns the hero of the whole saga. If ' + 'provided, returns the hero of that particular episode.', + type=episodeEnum, + ) + }, + resolver=lambda root, args, *_: getHero(args.get('episode')), + ), + 'human': GraphQLField( + humanType, + args={ + 'id': GraphQLArgument( + description='id of the human', + type=GraphQLNonNull(GraphQLString), + ) + }, + resolver=lambda root, args, *_: getHuman(args['id']), + ), + 'droid': GraphQLField( + droidType, + args={ + 'id': GraphQLArgument( + description='id of the droid', + type=GraphQLNonNull(GraphQLString), + ) + }, + resolver=lambda root, args, *_: getDroid(args['id']), + ), + } +) + +StarWarsSchema = GraphQLSchema(query=queryType, types=[humanType, droidType]) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fa1b6196f8eb8f35cf202b4017841ab63e8a74 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_dsl.py @@ -0,0 +1,293 @@ +import pytest + +from gql import Client +from gql.dsl import DSLSchema + +from .schema import StarWarsSchema + + +@pytest.fixture +def ds(): + client = Client(schema=StarWarsSchema) + ds = DSLSchema(client) + return ds + + +def test_hero_name_query(ds): + query = ''' +hero { + name +} + '''.strip() + query_dsl = ds.Query.hero.select( + ds.Character.name + ) + assert query == str(query_dsl) + + +def test_hero_name_and_friends_query(ds): + query = ''' +hero { + id + name + friends { + name + } +} + '''.strip() + query_dsl = ds.Query.hero.select( + ds.Character.id, + ds.Character.name, + ds.Character.friends.select( + ds.Character.name, + ) + ) + assert query == str(query_dsl) + + +def test_nested_query(ds): + query = ''' +hero { + name + friends { + name + appearsIn + friends { + name + } + } +} + '''.strip() + query_dsl = ds.Query.hero.select( + ds.Character.name, + ds.Character.friends.select( + ds.Character.name, + ds.Character.appears_in, + ds.Character.friends.select( + ds.Character.name + ) + ) + ) + assert query == str(query_dsl) + + +def test_fetch_luke_query(ds): + query = ''' +human(id: "1000") { + name +} + '''.strip() + query_dsl = ds.Query.human(id="1000").select( + ds.Human.name, + ) + + assert query == str(query_dsl) + + +# def test_fetch_some_id_query(): +# query = ''' +# query FetchSomeIDQuery($someId: String!) { +# human(id: $someId) { +# name +# } +# } +# ''' +# params = { +# 'someId': '1000', +# } +# expected = { +# 'human': { +# 'name': 'Luke Skywalker', +# } +# } +# result = schema.execute(query, None, params) +# assert not result.errors +# assert result.data == expected + + +# def test_fetch_some_id_query2(): +# query = ''' +# query FetchSomeIDQuery($someId: String!) { +# human(id: $someId) { +# name +# } +# } +# ''' +# params = { +# 'someId': '1002', +# } +# expected = { +# 'human': { +# 'name': 'Han Solo', +# } +# } +# result = schema.execute(query, None, params) +# assert not result.errors +# assert result.data == expected + + +# def test_invalid_id_query(): +# query = ''' +# query humanQuery($id: String!) { +# human(id: $id) { +# name +# } +# } +# ''' +# params = { +# 'id': 'not a valid id', +# } +# expected = { +# 'human': None +# } +# result = schema.execute(query, None, params) +# assert not result.errors +# assert result.data == expected + + +def test_fetch_luke_aliased(ds): + query = ''' +luke: human(id: "1000") { + name +} + '''.strip() + query_dsl = ds.Query.human.args(id=1000).alias('luke').select( + ds.Character.name, + ) + assert query == str(query_dsl) + + +# def test_fetch_luke_and_leia_aliased(): +# query = ''' +# query FetchLukeAndLeiaAliased { +# luke: human(id: "1000") { +# name +# } +# leia: human(id: "1003") { +# name +# } +# } +# ''' +# expected = { +# 'luke': { +# 'name': 'Luke Skywalker', +# }, +# 'leia': { +# 'name': 'Leia Organa', +# } +# } +# result = schema.execute(query) +# assert not result.errors +# assert result.data == expected + + +# def test_duplicate_fields(): +# query = ''' +# query DuplicateFields { +# luke: human(id: "1000") { +# name +# homePlanet +# } +# leia: human(id: "1003") { +# name +# homePlanet +# } +# } +# ''' +# expected = { +# 'luke': { +# 'name': 'Luke Skywalker', +# 'homePlanet': 'Tatooine', +# }, +# 'leia': { +# 'name': 'Leia Organa', +# 'homePlanet': 'Alderaan', +# } +# } +# result = schema.execute(query) +# assert not result.errors +# assert result.data == expected + + +# def test_use_fragment(): +# query = ''' +# query UseFragment { +# luke: human(id: "1000") { +# ...HumanFragment +# } +# leia: human(id: "1003") { +# ...HumanFragment +# } +# } +# fragment HumanFragment on Human { +# name +# homePlanet +# } +# ''' +# expected = { +# 'luke': { +# 'name': 'Luke Skywalker', +# 'homePlanet': 'Tatooine', +# }, +# 'leia': { +# 'name': 'Leia Organa', +# 'homePlanet': 'Alderaan', +# } +# } +# result = schema.execute(query) +# assert not result.errors +# assert result.data == expected + + +# def test_check_type_of_r2(): +# query = ''' +# query CheckTypeOfR2 { +# hero { +# __typename +# name +# } +# } +# ''' +# expected = { +# 'hero': { +# '__typename': 'Droid', +# 'name': 'R2-D2', +# } +# } +# result = schema.execute(query) +# assert not result.errors +# assert result.data == expected + + +# def test_check_type_of_luke(): +# query = ''' +# query CheckTypeOfLuke { +# hero(episode: EMPIRE) { +# __typename +# name +# } +# } +# ''' +# expected = { +# 'hero': { +# '__typename': 'Human', +# 'name': 'Luke Skywalker', +# } +# } +# result = schema.execute(query) +# assert not result.errors +# assert result.data == expected + + +def test_hero_name_query_result(ds): + result = ds.query( + ds.Query.hero.select( + ds.Character.name + ) + ) + expected = { + 'hero': { + 'name': 'R2-D2' + } + } + assert result == expected diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_query.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_query.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3ebaf379aac98e2691cfd0ac1e438c25d7ed1d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_query.py @@ -0,0 +1,355 @@ +import pytest +from graphql.error import format_error + +from gql import Client, gql + +from .schema import StarWarsSchema + + +@pytest.fixture +def client(): + return Client(schema=StarWarsSchema) + + +def test_hero_name_query(client): + query = gql(''' + query HeroNameQuery { + hero { + name + } + } + ''') + expected = { + 'hero': { + 'name': 'R2-D2' + } + } + result = client.execute(query) + assert result == expected + + +def test_hero_name_and_friends_query(client): + query = gql(''' + query HeroNameAndFriendsQuery { + hero { + id + name + friends { + name + } + } + } + ''') + expected = { + 'hero': { + 'id': '2001', + 'name': 'R2-D2', + 'friends': [ + {'name': 'Luke Skywalker'}, + {'name': 'Han Solo'}, + {'name': 'Leia Organa'}, + ] + } + } + result = client.execute(query) + assert result == expected + + +def test_nested_query(client): + query = gql(''' + query NestedQuery { + hero { + name + friends { + name + appearsIn + friends { + name + } + } + } + } + ''') + expected = { + 'hero': { + 'name': 'R2-D2', + 'friends': [ + { + 'name': 'Luke Skywalker', + 'appearsIn': ['NEWHOPE', 'EMPIRE', 'JEDI'], + 'friends': [ + { + 'name': 'Han Solo', + }, + { + 'name': 'Leia Organa', + }, + { + 'name': 'C-3PO', + }, + { + 'name': 'R2-D2', + }, + ] + }, + { + 'name': 'Han Solo', + 'appearsIn': ['NEWHOPE', 'EMPIRE', 'JEDI'], + 'friends': [ + { + 'name': 'Luke Skywalker', + }, + { + 'name': 'Leia Organa', + }, + { + 'name': 'R2-D2', + }, + ] + }, + { + 'name': 'Leia Organa', + 'appearsIn': ['NEWHOPE', 'EMPIRE', 'JEDI'], + 'friends': [ + { + 'name': 'Luke Skywalker', + }, + { + 'name': 'Han Solo', + }, + { + 'name': 'C-3PO', + }, + { + 'name': 'R2-D2', + }, + ] + }, + ] + } + } + result = client.execute(query) + assert result == expected + + +def test_fetch_luke_query(client): + query = gql(''' + query FetchLukeQuery { + human(id: "1000") { + name + } + } + ''') + expected = { + 'human': { + 'name': 'Luke Skywalker', + } + } + result = client.execute(query) + assert result == expected + + +def test_fetch_some_id_query(client): + query = gql(''' + query FetchSomeIDQuery($someId: String!) { + human(id: $someId) { + name + } + } + ''') + params = { + 'someId': '1000', + } + expected = { + 'human': { + 'name': 'Luke Skywalker', + } + } + result = client.execute(query, variable_values=params) + assert result == expected + + +def test_fetch_some_id_query2(client): + query = gql(''' + query FetchSomeIDQuery($someId: String!) { + human(id: $someId) { + name + } + } + ''') + params = { + 'someId': '1002', + } + expected = { + 'human': { + 'name': 'Han Solo', + } + } + result = client.execute(query, variable_values=params) + assert result == expected + + +def test_invalid_id_query(client): + query = gql(''' + query humanQuery($id: String!) { + human(id: $id) { + name + } + } + ''') + params = { + 'id': 'not a valid id', + } + expected = { + 'human': None + } + result = client.execute(query, variable_values=params) + assert result == expected + + +def test_fetch_luke_aliased(client): + query = gql(''' + query FetchLukeAliased { + luke: human(id: "1000") { + name + } + } + ''') + expected = { + 'luke': { + 'name': 'Luke Skywalker', + } + } + result = client.execute(query) + assert result == expected + + +def test_fetch_luke_and_leia_aliased(client): + query = gql(''' + query FetchLukeAndLeiaAliased { + luke: human(id: "1000") { + name + } + leia: human(id: "1003") { + name + } + } + ''') + expected = { + 'luke': { + 'name': 'Luke Skywalker', + }, + 'leia': { + 'name': 'Leia Organa', + } + } + result = client.execute(query) + assert result == expected + + +def test_duplicate_fields(client): + query = gql(''' + query DuplicateFields { + luke: human(id: "1000") { + name + homePlanet + } + leia: human(id: "1003") { + name + homePlanet + } + } + ''') + expected = { + 'luke': { + 'name': 'Luke Skywalker', + 'homePlanet': 'Tatooine', + }, + 'leia': { + 'name': 'Leia Organa', + 'homePlanet': 'Alderaan', + } + } + result = client.execute(query) + assert result == expected + + +def test_use_fragment(client): + query = gql(''' + query UseFragment { + luke: human(id: "1000") { + ...HumanFragment + } + leia: human(id: "1003") { + ...HumanFragment + } + } + fragment HumanFragment on Human { + name + homePlanet + } + ''') + expected = { + 'luke': { + 'name': 'Luke Skywalker', + 'homePlanet': 'Tatooine', + }, + 'leia': { + 'name': 'Leia Organa', + 'homePlanet': 'Alderaan', + } + } + result = client.execute(query) + assert result == expected + + +def test_check_type_of_r2(client): + query = gql(''' + query CheckTypeOfR2 { + hero { + __typename + name + } + } + ''') + expected = { + 'hero': { + '__typename': 'Droid', + 'name': 'R2-D2', + } + } + result = client.execute(query) + assert result == expected + + +def test_check_type_of_luke(client): + query = gql(''' + query CheckTypeOfLuke { + hero(episode: EMPIRE) { + __typename + name + } + } + ''') + expected = { + 'hero': { + '__typename': 'Human', + 'name': 'Luke Skywalker', + } + } + result = client.execute(query) + assert result == expected + + +def test_parse_error(client): + result = None + with pytest.raises(Exception) as excinfo: + query = gql(''' + qeury + ''') + result = client.execute(query) + error = excinfo.value + formatted_error = format_error(error) + assert formatted_error['locations'] == [{'column': 13, 'line': 2}] + assert 'Syntax Error GraphQL request (2:13) Unexpected Name "qeury"' in formatted_error['message'] + assert not result diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..e2699e95f16072e579a54572fdc7f391de1f43bb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/starwars/test_validation.py @@ -0,0 +1,171 @@ +import pytest +from graphql import graphql +from graphql.utils.introspection_query import introspection_query + +from gql import Client, gql + +from .schema import StarWarsSchema + +introspection = graphql(StarWarsSchema, introspection_query).data + + +@pytest.fixture +def local_schema(): + return Client(schema=StarWarsSchema) + + +@pytest.fixture +def typedef_schema(): + return Client(type_def=''' +schema { + query: Query +} + +interface Character { + appearsIn: [Episode] + friends: [Character] + id: String! + name: String +} + +type Droid implements Character { + appearsIn: [Episode] + friends: [Character] + id: String! + name: String + primaryFunction: String +} + +enum Episode { + EMPIRE + JEDI + NEWHOPE +} + +type Human implements Character { + appearsIn: [Episode] + friends: [Character] + homePlanet: String + id: String! + name: String +} + +type Query { + droid(id: String!): Droid + hero(episode: Episode): Character + human(id: String!): Human +}''') + + +@pytest.fixture +def introspection_schema(): + return Client(introspection=introspection) + + +@pytest.fixture(params=['local_schema', 'typedef_schema', 'introspection_schema']) +def client(request): + return request.getfixturevalue(request.param) + + +def validation_errors(client, query): + query = gql(query) + try: + client.validate(query) + return False + except Exception: + return True + + +def test_nested_query_with_fragment(client): + query = ''' + query NestedQueryWithFragment { + hero { + ...NameAndAppearances + friends { + ...NameAndAppearances + friends { + ...NameAndAppearances + } + } + } + } + fragment NameAndAppearances on Character { + name + appearsIn + } + ''' + assert not validation_errors(client, query) + + +def test_non_existent_fields(client): + query = ''' + query HeroSpaceshipQuery { + hero { + favoriteSpaceship + } + } + ''' + assert validation_errors(client, query) + + +def test_require_fields_on_object(client): + query = ''' + query HeroNoFieldsQuery { + hero + } + ''' + assert validation_errors(client, query) + + +def test_disallows_fields_on_scalars(client): + query = ''' + query HeroFieldsOnScalarQuery { + hero { + name { + firstCharacterOfName + } + } + } + ''' + assert validation_errors(client, query) + + +def test_disallows_object_fields_on_interfaces(client): + query = ''' + query DroidFieldOnCharacter { + hero { + name + primaryFunction + } + } + ''' + assert validation_errors(client, query) + + +def test_allows_object_fields_in_fragments(client): + query = ''' + query DroidFieldInFragment { + hero { + name + ...DroidFields + } + } + fragment DroidFields on Droid { + primaryFunction + } + ''' + assert not validation_errors(client, query) + + +def test_allows_object_fields_in_inline_fragments(client): + query = ''' + query DroidFieldInFragment { + hero { + name + ... on Droid { + primaryFunction + } + } + } + ''' + assert not validation_errors(client, query) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/test_client.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/test_client.py new file mode 100644 index 0000000000000000000000000000000000000000..14f06c431aed8a4744fb3a0e941b6abe484817dd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/test_client.py @@ -0,0 +1,31 @@ +import pytest +import mock + +from gql import Client, gql +from gql.transport.requests import RequestsHTTPTransport + + +@mock.patch('gql.transport.requests.RequestsHTTPTransport.execute') +def test_retries(execute_mock): + expected_retries = 3 + execute_mock.side_effect = Exception("fail") + + client = Client( + retries=expected_retries, + transport=RequestsHTTPTransport(url='http://swapi.graphene-python.org/graphql') + ) + + query = gql(''' + { + myFavoriteFilm: film(id:"RmlsbToz") { + id + title + episodeId + } + } + ''') + + with pytest.raises(Exception): + client.execute(query) + + assert execute_mock.call_count == expected_retries diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/test_transport.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/test_transport.py new file mode 100644 index 0000000000000000000000000000000000000000..93334ba091334175a79a4cf9fb37a2f7b21371fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/tests/test_transport.py @@ -0,0 +1,89 @@ +import pytest +import requests +import vcr + +from gql import Client, gql +from gql.transport.requests import RequestsHTTPTransport + +# https://github.com/graphql-python/swapi-graphene +URL = 'http://127.0.0.1:8000/graphql' + + +@pytest.fixture +def client(): + with vcr.use_cassette('tests/fixtures/vcr_cassettes/client.yaml'): + request = requests.get( + URL, + headers={ + 'Host': 'swapi.graphene-python.org', + 'Accept': 'text/html', + } + ) + request.raise_for_status() + csrf = request.cookies['csrftoken'] + + return Client( + transport=RequestsHTTPTransport( + url=URL, + cookies={"csrftoken": csrf}, + headers={'x-csrftoken': csrf}), + fetch_schema_from_transport=True + ) + + +def test_hero_name_query(client): + query = gql(''' + { + myFavoriteFilm: film(id:"RmlsbToz") { + id + title + episodeId + characters(first:5) { + edges { + node { + name + } + } + } + } + } + ''') + expected = { + "myFavoriteFilm": { + "id": "RmlsbToz", + "title": "Return of the Jedi", + "episodeId": 6, + "characters": { + "edges": [ + { + "node": { + "name": "Luke Skywalker" + } + }, + { + "node": { + "name": "C-3PO" + } + }, + { + "node": { + "name": "R2-D2" + } + }, + { + "node": { + "name": "Darth Vader" + } + }, + { + "node": { + "name": "Leia Organa" + } + } + ] + } + } + } + with vcr.use_cassette('tests/fixtures/vcr_cassettes/execute.yaml'): + result = client.execute(query) + assert result == expected diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f24fce1295c27bf2575c3a41c8569ce257dd7ffa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__init__.py @@ -0,0 +1,4 @@ +from .gql import gql +from .client import Client + +__all__ = ['gql', 'Client'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edd43fa1cae5726511dd30126340d90dfcbd19da Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/client.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/client.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7e7bca6d353be9abcf500d00da20118703d3438 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/client.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/gql.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/gql.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3375d252b59bc1ecf4e6714d397c54486925c324 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/__pycache__/gql.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3ed7f0e121a908e892ed0537f215f0aa942f93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/client.py @@ -0,0 +1,75 @@ +import logging + +from wandb_graphql import parse, introspection_query, build_ast_schema, build_client_schema +from wandb_graphql.validation import validate + +from .transport.local_schema import LocalSchemaTransport + +log = logging.getLogger(__name__) + + +class RetryError(Exception): + """Custom exception thrown when retry logic fails""" + def __init__(self, retries_count, last_exception): + message = "Failed %s retries: %s" % (retries_count, last_exception) + super(RetryError, self).__init__(message) + self.last_exception = last_exception + + +class Client(object): + def __init__(self, schema=None, introspection=None, type_def=None, transport=None, + fetch_schema_from_transport=False, retries=0): + assert not(type_def and introspection), 'Cant provide introspection type definition at the same time' + if transport and fetch_schema_from_transport: + assert not schema, 'Cant fetch the schema from transport if is already provided' + introspection = transport.execute(parse(introspection_query)).data + if introspection: + assert not schema, 'Cant provide introspection and schema at the same time' + schema = build_client_schema(introspection) + elif type_def: + assert not schema, 'Cant provide Type definition and schema at the same time' + type_def_ast = parse(type_def) + schema = build_ast_schema(type_def_ast) + elif schema and not transport: + transport = LocalSchemaTransport(schema) + + self.schema = schema + self.introspection = introspection + self.transport = transport + self.retries = retries + + def validate(self, document): + if not self.schema: + raise Exception("Cannot validate locally the document, you need to pass a schema.") + validation_errors = validate(self.schema, document) + if validation_errors: + raise validation_errors[0] + + def execute(self, document, *args, **kwargs): + if self.schema: + self.validate(document) + + result = self._get_result(document, *args, **kwargs) + if result.errors: + raise Exception(str(result.errors[0])) + + return result.data + + def _get_result(self, document, *args, **kwargs): + if not self.retries: + return self.transport.execute(document, *args, **kwargs) + + last_exception = None + retries_count = 0 + while retries_count < self.retries: + try: + result = self.transport.execute(document, *args, **kwargs) + return result + except Exception as e: + last_exception = e + log.warning("Request failed with exception %s. Retrying for the %s time...", + e, retries_count + 1, exc_info=True) + finally: + retries_count += 1 + + raise RetryError(retries_count, last_exception) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/dsl.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/dsl.py new file mode 100644 index 0000000000000000000000000000000000000000..052e1ebd277b98cf80dec0d8e2313593fa04c3d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/dsl.py @@ -0,0 +1,152 @@ +from collections.abc import Iterable +import decimal +from functools import partial + +from wandb_graphql.language import ast +from wandb_graphql.language.printer import print_ast +from wandb_graphql.type import (GraphQLField, GraphQLList, + GraphQLNonNull, GraphQLEnumType) + +from .utils import to_camel_case + + +class DSLSchema(object): + def __init__(self, client): + self.client = client + + @property + def schema(self): + return self.client.schema + + def __getattr__(self, name): + type_def = self.schema.get_type(name) + return DSLType(type_def) + + def query(self, *args, **kwargs): + return self.execute(query(*args, **kwargs)) + + def mutate(self, *args, **kwargs): + return self.query(*args, operation='mutate', **kwargs) + + def execute(self, document): + return self.client.execute(document) + + +class DSLType(object): + def __init__(self, type): + self.type = type + + def __getattr__(self, name): + formatted_name, field_def = self.get_field(name) + return DSLField(formatted_name, field_def) + + def get_field(self, name): + camel_cased_name = to_camel_case(name) + + if name in self.type.fields: + return name, self.type.fields[name] + + if camel_cased_name in self.type.fields: + return camel_cased_name, self.type.fields[camel_cased_name] + + raise KeyError('Field {} doesnt exist in type {}.'.format(name, self.type.name)) + + +def selections(*fields): + for _field in fields: + yield field(_field).ast + + +def get_ast_value(value): + if isinstance(value, ast.Node): + return value + if isinstance(value, str): + return ast.StringValue(value=value) + elif isinstance(value, bool): + return ast.BooleanValue(value=value) + elif isinstance(value, (float, decimal.Decimal)): + return ast.FloatValue(value=value) + elif isinstance(value, int): + return ast.IntValue(value=value) + return None + + +class DSLField(object): + + def __init__(self, name, field): + self.field = field + self.ast_field = ast.Field(name=ast.Name(value=name), arguments=[]) + self.selection_set = None + + def select(self, *fields): + if not self.ast_field.selection_set: + self.ast_field.selection_set = ast.SelectionSet(selections=[]) + self.ast_field.selection_set.selections.extend(selections(*fields)) + return self + + def __call__(self, *args, **kwargs): + return self.args(*args, **kwargs) + + def alias(self, alias): + self.ast_field.alias = ast.Name(value=alias) + return self + + def args(self, **args): + for name, value in args.items(): + arg = self.field.args.get(name) + arg_type_serializer = get_arg_serializer(arg.type) + value = arg_type_serializer(value) + self.ast_field.arguments.append( + ast.Argument( + name=ast.Name(value=name), + value=get_ast_value(value) + ) + ) + return self + + @property + def ast(self): + return self.ast_field + + def __str__(self): + return print_ast(self.ast_field) + + +def field(field, **args): + if isinstance(field, GraphQLField): + return DSLField(field).args(**args) + elif isinstance(field, DSLField): + return field + + raise Exception('Received incompatible query field: "{}".'.format(field)) + + +def query(*fields): + return ast.Document( + definitions=[ast.OperationDefinition( + operation='query', + selection_set=ast.SelectionSet( + selections=list(selections(*fields)) + ) + )] + ) + + +def serialize_list(serializer, values): + assert isinstance(values, Iterable), 'Expected iterable, received "{}"'.format(repr(values)) + return [serializer(v) for v in values] + + +def get_arg_serializer(arg_type): + if isinstance(arg_type, GraphQLNonNull): + return get_arg_serializer(arg_type.of_type) + if isinstance(arg_type, GraphQLList): + inner_serializer = get_arg_serializer(arg_type.of_type) + return partial(serialize_list, inner_serializer) + if isinstance(arg_type, GraphQLEnumType): + return lambda value: ast.EnumValue(value=arg_type.serialize(value)) + return arg_type.serialize + + +def var(name): + return ast.Variable(name=name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/gql.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/gql.py new file mode 100644 index 0000000000000000000000000000000000000000..21edd394ad7dec411676982f285a977cb4d07752 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/gql.py @@ -0,0 +1,10 @@ +from wandb_graphql.language.parser import parse +from wandb_graphql.language.source import Source + + +def gql(request_string): + if isinstance(request_string, str): + source = Source(request_string, 'GraphQL request') + return parse(source) + else: + raise Exception('Received incompatible request "{}".'.format(request_string)) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b93e5e7e053de2b2581a19d81f1ccc4a6537508 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/http.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/http.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8fede917c7714efc9d35f813db1d0638e77b675 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/http.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/local_schema.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/local_schema.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38e65cb3be1ead0fbf45aa535d92a5a9b5eb4ff1 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/__pycache__/local_schema.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py new file mode 100644 index 0000000000000000000000000000000000000000..bbd0e04758f0db007b316041518b0db14c53eb12 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/http.py @@ -0,0 +1,6 @@ +class HTTPTransport(object): + + def __init__(self, url, headers=None, cookies=None): + self.url = url + self.headers = headers + self.cookies = cookies diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc7d33dc589a1cd1a0410bfd9465f9cb4a7a50a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/local_schema.py @@ -0,0 +1,15 @@ +from wandb_graphql.execution import execute + + +class LocalSchemaTransport(object): + + def __init__(self, schema): + self.schema = schema + + def execute(self, document, *args, **kwargs): + return execute( + self.schema, + document, + *args, + **kwargs + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py new file mode 100644 index 0000000000000000000000000000000000000000..305ca8af968f65e172e6ba262febaef02d70cf0e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import + +import requests +from wandb_graphql.execution import ExecutionResult +from wandb_graphql.language.printer import print_ast + +from .http import HTTPTransport + + +class RequestsHTTPTransport(HTTPTransport): + def __init__(self, url, auth=None, use_json=False, timeout=None, **kwargs): + """ + :param url: The GraphQL URL + :param auth: Auth tuple or callable to enable Basic/Digest/Custom HTTP Auth + :param use_json: Send request body as JSON instead of form-urlencoded + :param timeout: Specifies a default timeout for requests (Default: None) + """ + super(RequestsHTTPTransport, self).__init__(url, **kwargs) + self.auth = auth + self.default_timeout = timeout + self.use_json = use_json + + def execute(self, document, variable_values=None, timeout=None): + query_str = print_ast(document) + payload = { + 'query': query_str, + 'variables': variable_values or {} + } + + data_key = 'json' if self.use_json else 'data' + post_args = { + 'headers': self.headers, + 'auth': self.auth, + 'cookies': self.cookies, + 'timeout': timeout or self.default_timeout, + data_key: payload + } + request = requests.post(self.url, **post_args) + request.raise_for_status() + + result = request.json() + assert 'errors' in result or 'data' in result, 'Received non-compatible response "{}"'.format(result) + return ExecutionResult( + errors=result.get('errors'), + data=result.get('data') + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/utils.py b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d59964977fda9da9c877841f5cfd17b219e6bbb4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/gql-0.2.0/wandb_gql/utils.py @@ -0,0 +1,21 @@ +import re + + +# From this response in Stackoverflow +# http://stackoverflow.com/a/19053800/1072990 +def to_camel_case(snake_str): + components = snake_str.split('_') + # We capitalize the first letter of each component except the first one + # with the 'title' method and join them together. + return components[0] + "".join(x.title() if x else '_' for x in components[1:]) + + +# From this response in Stackoverflow +# http://stackoverflow.com/a/1176023/1072990 +def to_snake_case(name): + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + +def to_const(string): + return re.sub(r'[\W|^]+', '_', string).upper() diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/setup.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..16ecb33816b535c767a272e7e06eb9799e011e8b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/setup.py @@ -0,0 +1,86 @@ +from setuptools import setup, find_packages +from setuptools.command.test import test as TestCommand +import sys + +if sys.version_info[0] < 3: + import __builtin__ as builtins +else: + import builtins + +# This is a bit (!) hackish: we are setting a global variable so that the main +# graphql __init__ can detect if it is being loaded by the setup routine, to +# avoid attempting to load components that aren't built yet: +# the numpy distutils extensions that are used by scikit-learn to recursively +# build the compiled extensions in sub-packages is based on the Python import +# machinery. +if 'test' not in sys.argv: + builtins.__GRAPHQL_SETUP__ = True + +version = __import__('graphql').get_version() + +install_requires = [ + 'six>=1.10.0', + 'promise>=2.0' +] + +tests_requires = [ + 'pytest==3.0.2', + 'pytest-django==2.9.1', + 'pytest-cov==2.3.1', + 'coveralls', + 'gevent==1.1rc1', + 'six>=1.10.0', + 'pytest-benchmark==3.0.0', + 'pytest-mock==1.2', +] + +class PyTest(TestCommand): + def finalize_options(self): + TestCommand.finalize_options(self) + self.test_args = ['graphql', '-vrsx'] + self.test_suite = True + + def run_tests(self): + #import here, cause outside the eggs aren't loaded + import pytest + errno = pytest.main(self.test_args) + sys.exit(errno) + + +setup( + name='graphql-core', + version=version, + description='GraphQL implementation for Python', + url='https://github.com/graphql-python/graphql-core', + download_url='https://github.com/graphql-python/graphql-core/releases', + author='Syrus Akbary, Jake Heinz, Taeho Kim', + author_email='Syrus Akbary , Jake Heinz , Taeho Kim ', + license='MIT', + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'Topic :: Software Development :: Libraries', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: Implementation :: PyPy', + 'License :: OSI Approved :: MIT License', + 'Topic :: Database :: Front-Ends', + 'Topic :: Internet :: WWW/HTTP', + ], + + keywords='api graphql protocol rest', + packages=find_packages(exclude=['tests', 'tests_py35']), + install_requires=install_requires, + tests_require=tests_requires, + cmdclass = {'test': PyTest}, + extras_require={ + 'gevent': [ + 'gevent==1.1rc1' + ], + 'test': tests_requires + } +) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d53457d40101cf306e08a2918db023290081fa28 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__init__.py @@ -0,0 +1,287 @@ +''' +GraphQL.js provides a reference implementation for the GraphQL specification +but is also a useful utility for operating on GraphQL files and building +sophisticated tools. + +This primary module exports a general purpose function for fulfilling all +steps of the GraphQL specification in a single operation, but also includes +utilities for every part of the GraphQL specification: + + - Parsing the GraphQL language. + - Building a GraphQL type schema. + - Validating a GraphQL request against a type schema. + - Executing a GraphQL request against a type schema. + +This also includes utility functions for operating on GraphQL types and +GraphQL documents to facilitate building tools. + +You may also import from each sub-directory directly. For example, the +following two import statements are equivalent: + + from graphql import parse + from graphql.language.base import parse +''' +from .pyutils.version import get_version + + +try: + # This variable is injected in the __builtins__ by the build + # process. It used to enable importing subpackages when + # the required packages are not installed + __GRAPHQL_SETUP__ +except NameError: + __GRAPHQL_SETUP__ = False + + +VERSION = (1, 1, 0, 'final', 0) + +__version__ = get_version(VERSION) + + +if not __GRAPHQL_SETUP__: + # The primary entry point into fulfilling a GraphQL request. + from .graphql import ( + graphql + ) + + # Create and operate on GraphQL type definitions and schema. + from .type import ( # no import order + GraphQLSchema, + + # Definitions + GraphQLScalarType, + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLEnumType, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + GraphQLField, + GraphQLInputObjectField, + GraphQLArgument, + + # "Enum" of Type Kinds + TypeKind, + + # "Enum" of Directive locations + DirectiveLocation, + + # Scalars + GraphQLInt, + GraphQLFloat, + GraphQLString, + GraphQLBoolean, + GraphQLID, + + # Directive definition + GraphQLDirective, + + # Built-in directives defined by the Spec + specified_directives, + GraphQLSkipDirective, + GraphQLIncludeDirective, + GraphQLDeprecatedDirective, + + # Constant Deprecation Reason + DEFAULT_DEPRECATION_REASON, + + # GraphQL Types for introspection. + __Schema, + __Directive, + __DirectiveLocation, + __Type, + __Field, + __InputValue, + __EnumValue, + __TypeKind, + + # Meta-field definitions. + SchemaMetaFieldDef, + TypeMetaFieldDef, + TypeNameMetaFieldDef, + + # Predicates + is_type, + is_input_type, + is_output_type, + is_leaf_type, + is_composite_type, + is_abstract_type, + + # Un-modifiers + get_nullable_type, + get_named_type, + ) + + # Parse and operate on GraphQL language source files. + from .language.base import ( # no import order + Source, + get_location, + + # Parse + parse, + parse_value, + + # Print + print_ast, + + # Visit + visit, + ParallelVisitor, + TypeInfoVisitor, + BREAK, + ) + + # Execute GraphQL queries. + from .execution import ( # no import order + execute, + MiddlewareManager, + middlewares + ) + + # Validate GraphQL queries. + from .validation import ( # no import order + validate, + specified_rules, + ) + + # Create and format GraphQL errors. + from .error import ( + GraphQLError, + format_error, + ) + + # Utilities for operating on GraphQL type schema and parsed sources. + from .utils.base import ( + # The GraphQL query recommended for a full schema introspection. + introspection_query, + + # Gets the target Operation from a Document + get_operation_ast, + + # Build a GraphQLSchema from an introspection result. + build_client_schema, + + # Build a GraphQLSchema from a parsed GraphQL Schema language AST. + build_ast_schema, + + # Extends an existing GraphQLSchema from a parsed GraphQL Schema + # language AST. + extend_schema, + + # Print a GraphQLSchema to GraphQL Schema language. + print_schema, + + # Create a GraphQLType from a GraphQL language AST. + type_from_ast, + + # Create a JavaScript value from a GraphQL language AST. + value_from_ast, + + # Create a GraphQL language AST from a JavaScript value. + ast_from_value, + + # A helper to use within recursive-descent visitors which need to be aware of + # the GraphQL type system. + TypeInfo, + + # Determine if JavaScript values adhere to a GraphQL type. + is_valid_value, + + # Determine if AST values adhere to a GraphQL type. + is_valid_literal_value, + + # Concatenates multiple AST together. + concat_ast, + + # Comparators for types + is_equal_type, + is_type_sub_type_of, + do_types_overlap, + + # Asserts a string is a valid GraphQL name. + assert_valid_name, + ) + + __all__ = ( + 'graphql', + 'GraphQLBoolean', + 'GraphQLEnumType', + 'GraphQLFloat', + 'GraphQLID', + 'GraphQLInputObjectType', + 'GraphQLInt', + 'GraphQLInterfaceType', + 'GraphQLList', + 'GraphQLNonNull', + 'GraphQLField', + 'GraphQLInputObjectField', + 'GraphQLArgument', + 'GraphQLObjectType', + 'GraphQLScalarType', + 'GraphQLSchema', + 'GraphQLString', + 'GraphQLUnionType', + 'GraphQLDirective', + 'specified_directives', + 'GraphQLSkipDirective', + 'GraphQLIncludeDirective', + 'GraphQLDeprecatedDirective', + 'DEFAULT_DEPRECATION_REASON', + 'TypeKind', + 'DirectiveLocation', + '__Schema', + '__Directive', + '__DirectiveLocation', + '__Type', + '__Field', + '__InputValue', + '__EnumValue', + '__TypeKind', + 'SchemaMetaFieldDef', + 'TypeMetaFieldDef', + 'TypeNameMetaFieldDef', + 'get_named_type', + 'get_nullable_type', + 'is_abstract_type', + 'is_composite_type', + 'is_input_type', + 'is_leaf_type', + 'is_output_type', + 'is_type', + 'BREAK', + 'ParallelVisitor', + 'Source', + 'TypeInfoVisitor', + 'get_location', + 'parse', + 'parse_value', + 'print_ast', + 'visit', + 'execute', + 'MiddlewareManager', + 'middlewares', + 'specified_rules', + 'validate', + 'GraphQLError', + 'format_error', + 'TypeInfo', + 'assert_valid_name', + 'ast_from_value', + 'build_ast_schema', + 'build_client_schema', + 'concat_ast', + 'do_types_overlap', + 'extend_schema', + 'get_operation_ast', + 'introspection_query', + 'is_equal_type', + 'is_type_sub_type_of', + 'is_valid_literal_value', + 'is_valid_value', + 'print_schema', + 'type_from_ast', + 'value_from_ast', + 'get_version', + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b7d5629b11539fd616307d9460f73777c0f50b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__pycache__/graphql.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__pycache__/graphql.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f7cc288da998c18f0bf9d4c029c243a0f5da238 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/__pycache__/graphql.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdcf81493f1248b2c8bd503d1f0cf6ec83df7599 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__init__.py @@ -0,0 +1,6 @@ +from .base import GraphQLError +from .located_error import GraphQLLocatedError +from .syntax_error import GraphQLSyntaxError +from .format_error import format_error + +__all__ = ['GraphQLError', 'GraphQLLocatedError', 'GraphQLSyntaxError', 'format_error'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f061ed64c8e64de1d93281b36c3cc7a4cd9c92af Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/base.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cd9a426ffdf4ab208203cbc75dce4f172162ff7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/base.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/format_error.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/format_error.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbcbf2c07f9e1b79ce3421190f8728477580c731 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/format_error.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/located_error.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/located_error.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37a44dd564fc2203225befb6ec1a303115783a6e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/located_error.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/syntax_error.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/syntax_error.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aecc9c6798416ac657281904c90ba594e456691d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/__pycache__/syntax_error.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py new file mode 100644 index 0000000000000000000000000000000000000000..74d644705fbe732c0688f6acec6042c07f84f59e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/base.py @@ -0,0 +1,42 @@ +from ..language.location import get_location + + +class GraphQLError(Exception): + __slots__ = 'message', 'nodes', 'stack', 'original_error', '_source', '_positions' + + def __init__(self, message, nodes=None, stack=None, source=None, positions=None): + super(GraphQLError, self).__init__(message) + self.message = message + self.nodes = nodes + self.stack = stack + self._source = source + self._positions = positions + + @property + def source(self): + if self._source: + return self._source + if self.nodes: + node = self.nodes[0] + return node and node.loc and node.loc.source + + @property + def positions(self): + if self._positions: + return self._positions + if self.nodes is not None: + node_positions = [node.loc and node.loc.start for node in self.nodes] + if any(node_positions): + return node_positions + + def reraise(self): + if self.stack: + raise self.with_traceback(self.stack) + else: + raise self + + @property + def locations(self): + source = self.source + if self.positions and source: + return [get_location(source, pos) for pos in self.positions] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py new file mode 100644 index 0000000000000000000000000000000000000000..040955030421d78f887682e8251ab8f64b3426d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/format_error.py @@ -0,0 +1,11 @@ +def format_error(error): + formatted_error = { + 'message': error.message, + } + if error.locations is not None: + formatted_error['locations'] = [ + {'line': loc.line, 'column': loc.column} + for loc in error.locations + ] + + return formatted_error diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py new file mode 100644 index 0000000000000000000000000000000000000000..9111f764ddc03487eb656a9c761a4453acd8d803 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/located_error.py @@ -0,0 +1,29 @@ +import sys + +from .base import GraphQLError + +__all__ = ['GraphQLLocatedError'] + + +class GraphQLLocatedError(GraphQLError): + + def __init__(self, nodes, original_error=None): + if original_error: + try: + message = str(original_error) + except UnicodeEncodeError: + message = original_error.message.encode('utf-8') + else: + message = 'An unknown error occurred.' + + if hasattr(original_error, 'stack'): + stack = original_error.stack + else: + stack = sys.exc_info()[2] + + super(GraphQLLocatedError, self).__init__( + message=message, + nodes=nodes, + stack=stack + ) + self.original_error = original_error diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py new file mode 100644 index 0000000000000000000000000000000000000000..16eb487fb2dd9f29532ec30299f9ea486a1a90d9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/error/syntax_error.py @@ -0,0 +1,36 @@ +from ..language.location import get_location +from .base import GraphQLError + +__all__ = ['GraphQLSyntaxError'] + + +class GraphQLSyntaxError(GraphQLError): + + def __init__(self, source, position, description): + location = get_location(source, position) + super(GraphQLSyntaxError, self).__init__( + message=u'Syntax Error {} ({}:{}) {}\n\n{}'.format( + source.name, + location.line, + location.column, + description, + highlight_source_at_location(source, location), + ), + source=source, + positions=[position], + ) + + +def highlight_source_at_location(source, location): + line = location.line + lines = source.body.splitlines() + pad_len = len(str(line + 1)) + result = u'' + format = (u'{:>' + str(pad_len) + '}: {}\n').format + if line >= 2: + result += format(line - 1, lines[line - 2]) + result += format(line, lines[line - 1]) + result += ' ' * (1 + pad_len + location.column) + '^\n' + if line < len(lines): + result += format(line + 1, lines[line]) + return result diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..546f425ce95a4c8b129dbd27d53afb56e2c37579 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +""" +Terminology + +"Definitions" are the generic name for top-level statements in the document. +Examples of this include: +1) Operations (such as a query) +2) Fragments + +"Operations" are a generic name for requests in the document. +Examples of this include: +1) query, +2) mutation + +"Selections" are the statements that can appear legally and at +single level of the query. These include: +1) field references e.g "a" +2) fragment "spreads" e.g. "...c" +3) inline fragment "spreads" e.g. "...on Type { a }" +""" +from .executor import execute +from .base import ExecutionResult +from .middleware import middlewares, MiddlewareManager + + +__all__ = ['execute', 'ExecutionResult', 'MiddlewareManager', 'middlewares'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61181bdeac95b15db1caa4abc5552d020232364a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/base.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..535c2c17fd78fe608238be9b0939522ddfa8e706 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/base.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/executor.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/executor.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeb37f5c7f40e38381da08076f6c522331bcbca2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/executor.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/middleware.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/middleware.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c78e149b8769264231999e73f3f1098895f0d711 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/middleware.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/values.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/values.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..262d3f749d68357a04dc4cfbdee96bb58d79b0f6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/__pycache__/values.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7929cd5a9b151cc2e65758e62cb2f8dfabae4a7f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/base.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +from ..error import GraphQLError +from ..language import ast +from ..pyutils.default_ordered_dict import DefaultOrderedDict +from ..type.definition import GraphQLInterfaceType, GraphQLUnionType +from ..type.directives import GraphQLIncludeDirective, GraphQLSkipDirective +from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, + TypeNameMetaFieldDef) +from ..utils.type_from_ast import type_from_ast +from .values import get_argument_values, get_variable_values + +Undefined = object() + + +class ExecutionContext(object): + """Data that must be available at all points during query execution. + + Namely, schema of the type system that is currently executing, + and the fragments defined in the query document""" + + __slots__ = 'schema', 'fragments', 'root_value', 'operation', 'variable_values', 'errors', 'context_value', \ + 'argument_values_cache', 'executor', 'middleware', '_subfields_cache' + + def __init__(self, schema, document_ast, root_value, context_value, variable_values, operation_name, executor, middleware): + """Constructs a ExecutionContext object from the arguments passed + to execute, which we will pass throughout the other execution + methods.""" + errors = [] + operation = None + fragments = {} + + for definition in document_ast.definitions: + if isinstance(definition, ast.OperationDefinition): + if not operation_name and operation: + raise GraphQLError('Must provide operation name if query contains multiple operations.') + + if not operation_name or definition.name and definition.name.value == operation_name: + operation = definition + + elif isinstance(definition, ast.FragmentDefinition): + fragments[definition.name.value] = definition + + else: + raise GraphQLError( + u'GraphQL cannot execute a request containing a {}.'.format(definition.__class__.__name__), + definition + ) + + if not operation: + if operation_name: + raise GraphQLError(u'Unknown operation named "{}".'.format(operation_name)) + + else: + raise GraphQLError('Must provide an operation.') + + variable_values = get_variable_values(schema, operation.variable_definitions or [], variable_values) + + self.schema = schema + self.fragments = fragments + self.root_value = root_value + self.operation = operation + self.variable_values = variable_values + self.errors = errors + self.context_value = context_value + self.argument_values_cache = {} + self.executor = executor + self.middleware = middleware + self._subfields_cache = {} + + def get_field_resolver(self, field_resolver): + if not self.middleware: + return field_resolver + return self.middleware.get_field_resolver(field_resolver) + + def get_argument_values(self, field_def, field_ast): + k = field_def, field_ast + result = self.argument_values_cache.get(k) + + if not result: + result = self.argument_values_cache[k] = get_argument_values(field_def.args, field_ast.arguments, + self.variable_values) + + return result + + def get_sub_fields(self, return_type, field_asts): + k = return_type, tuple(field_asts) + if k not in self._subfields_cache: + subfield_asts = DefaultOrderedDict(list) + visited_fragment_names = set() + for field_ast in field_asts: + selection_set = field_ast.selection_set + if selection_set: + subfield_asts = collect_fields( + self, return_type, selection_set, + subfield_asts, visited_fragment_names + ) + self._subfields_cache[k] = subfield_asts + return self._subfields_cache[k] + + +class ExecutionResult(object): + """The result of execution. `data` is the result of executing the + query, `errors` is null if no errors occurred, and is a + non-empty array if an error occurred.""" + + __slots__ = 'data', 'errors', 'invalid' + + def __init__(self, data=None, errors=None, invalid=False): + self.data = data + self.errors = errors + + if invalid: + assert data is None + + self.invalid = invalid + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ExecutionResult) and + self.data == other.data and + self.errors == other.errors and + self.invalid == other.invalid + ) + ) + + +def get_operation_root_type(schema, operation): + op = operation.operation + if op == 'query': + return schema.get_query_type() + + elif op == 'mutation': + mutation_type = schema.get_mutation_type() + + if not mutation_type: + raise GraphQLError( + 'Schema is not configured for mutations', + [operation] + ) + + return mutation_type + + elif op == 'subscription': + subscription_type = schema.get_subscription_type() + + if not subscription_type: + raise GraphQLError( + 'Schema is not configured for subscriptions', + [operation] + ) + + return subscription_type + + raise GraphQLError( + 'Can only execute queries, mutations and subscriptions', + [operation] + ) + + +def collect_fields(ctx, runtime_type, selection_set, fields, prev_fragment_names): + """ + Given a selectionSet, adds all of the fields in that selection to + the passed in map of fields, and returns it at the end. + + collect_fields requires the "runtime type" of an object. For a field which + returns and Interface or Union type, the "runtime type" will be the actual + Object type returned by that field. + """ + for selection in selection_set.selections: + directives = selection.directives + + if isinstance(selection, ast.Field): + if not should_include_node(ctx, directives): + continue + + name = get_field_entry_key(selection) + fields[name].append(selection) + + elif isinstance(selection, ast.InlineFragment): + if not should_include_node( + ctx, directives) or not does_fragment_condition_match( + ctx, selection, runtime_type): + continue + + collect_fields(ctx, runtime_type, selection.selection_set, fields, prev_fragment_names) + + elif isinstance(selection, ast.FragmentSpread): + frag_name = selection.name.value + + if frag_name in prev_fragment_names or not should_include_node(ctx, directives): + continue + + prev_fragment_names.add(frag_name) + fragment = ctx.fragments.get(frag_name) + frag_directives = fragment.directives + if not fragment or not \ + should_include_node(ctx, frag_directives) or not \ + does_fragment_condition_match(ctx, fragment, runtime_type): + continue + + collect_fields(ctx, runtime_type, fragment.selection_set, fields, prev_fragment_names) + + return fields + + +def should_include_node(ctx, directives): + """Determines if a field should be included based on the @include and + @skip directives, where @skip has higher precidence than @include.""" + # TODO: Refactor based on latest code + if directives: + skip_ast = None + + for directive in directives: + if directive.name.value == GraphQLSkipDirective.name: + skip_ast = directive + break + + if skip_ast: + args = get_argument_values( + GraphQLSkipDirective.args, + skip_ast.arguments, + ctx.variable_values, + ) + if args.get('if') is True: + return False + + include_ast = None + + for directive in directives: + if directive.name.value == GraphQLIncludeDirective.name: + include_ast = directive + break + + if include_ast: + args = get_argument_values( + GraphQLIncludeDirective.args, + include_ast.arguments, + ctx.variable_values, + ) + + if args.get('if') is False: + return False + + return True + + +def does_fragment_condition_match(ctx, fragment, type_): + type_condition_ast = fragment.type_condition + if not type_condition_ast: + return True + + conditional_type = type_from_ast(ctx.schema, type_condition_ast) + if conditional_type.is_same_type(type_): + return True + + if isinstance(conditional_type, (GraphQLInterfaceType, GraphQLUnionType)): + return ctx.schema.is_possible_type(conditional_type, type_) + + return False + + +def get_field_entry_key(node): + """Implements the logic to compute the key of a given field's entry""" + if node.alias: + return node.alias.value + return node.name.value + + +class ResolveInfo(object): + __slots__ = ('field_name', 'field_asts', 'return_type', 'parent_type', + 'schema', 'fragments', 'root_value', 'operation', 'variable_values') + + def __init__(self, field_name, field_asts, return_type, parent_type, + schema, fragments, root_value, operation, variable_values): + self.field_name = field_name + self.field_asts = field_asts + self.return_type = return_type + self.parent_type = parent_type + self.schema = schema + self.fragments = fragments + self.root_value = root_value + self.operation = operation + self.variable_values = variable_values + + +def default_resolve_fn(source, args, context, info): + """If a resolve function is not given, then a default resolve behavior is used which takes the property of the source object + of the same name as the field and returns it as the result, or if it's a function, returns the result of calling that function.""" + name = info.field_name + property = getattr(source, name, None) + if callable(property): + return property() + return property + + +def get_field_def(schema, parent_type, field_name): + """This method looks up the field on the given type defintion. + It has special casing for the two introspection fields, __schema + and __typename. __typename is special because it can always be + queried as a field, even in situations where no other fields + are allowed, like on a Union. __schema could get automatically + added to the query type, but that would require mutating type + definitions, which would cause issues.""" + if field_name == '__schema' and schema.get_query_type() == parent_type: + return SchemaMetaFieldDef + elif field_name == '__type' and schema.get_query_type() == parent_type: + return TypeMetaFieldDef + elif field_name == '__typename': + return TypeNameMetaFieldDef + return parent_type.fields.get(field_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e9fd293cd375cb50c158d9a8e2a1fd9d4cb566 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executor.py @@ -0,0 +1,398 @@ +import collections +from collections.abc import Iterable +import functools +import logging +import sys + +from wandb_promise import Promise, promise_for_dict, is_thenable + +from ..error import GraphQLError, GraphQLLocatedError +from ..pyutils.default_ordered_dict import DefaultOrderedDict +from ..pyutils.ordereddict import OrderedDict +from ..type import (GraphQLEnumType, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLUnionType) +from .base import (ExecutionContext, ExecutionResult, ResolveInfo, Undefined, + collect_fields, default_resolve_fn, get_field_def, + get_operation_root_type) +from .executors.sync import SyncExecutor +from .experimental.executor import execute as experimental_execute +from .middleware import MiddlewareManager + +logger = logging.getLogger(__name__) + + +use_experimental_executor = False + + +def execute(schema, document_ast, root_value=None, context_value=None, + variable_values=None, operation_name=None, executor=None, + return_promise=False, middleware=None): + if use_experimental_executor: + return experimental_execute( + schema, document_ast, root_value, context_value, + variable_values, operation_name, executor, + return_promise, middleware + ) + + assert schema, 'Must provide schema' + assert isinstance(schema, GraphQLSchema), ( + 'Schema must be an instance of GraphQLSchema. Also ensure that there are ' + + 'not multiple versions of GraphQL installed in your node_modules directory.' + ) + if middleware: + if not isinstance(middleware, MiddlewareManager): + middleware = MiddlewareManager(*middleware) + + assert isinstance(middleware, MiddlewareManager), ( + 'middlewares have to be an instance' + ' of MiddlewareManager. Received "{}".'.format(middleware) + ) + + if executor is None: + executor = SyncExecutor() + + context = ExecutionContext( + schema, + document_ast, + root_value, + context_value, + variable_values, + operation_name, + executor, + middleware + ) + + def executor(resolve, reject): + return resolve(execute_operation(context, context.operation, root_value)) + + def on_rejected(error): + context.errors.append(error) + return None + + def on_resolve(data): + if not context.errors: + return ExecutionResult(data=data) + return ExecutionResult(data=data, errors=context.errors) + + promise = Promise(executor).catch(on_rejected).then(on_resolve) + if return_promise: + return promise + context.executor.wait_until_finished() + return promise.get() + + +def execute_operation(exe_context, operation, root_value): + type = get_operation_root_type(exe_context.schema, operation) + fields = collect_fields( + exe_context, + type, + operation.selection_set, + DefaultOrderedDict(list), + set() + ) + + if operation.operation == 'mutation': + return execute_fields_serially(exe_context, type, root_value, fields) + + return execute_fields(exe_context, type, root_value, fields) + + +def execute_fields_serially(exe_context, parent_type, source_value, fields): + def execute_field_callback(results, response_name): + field_asts = fields[response_name] + result = resolve_field( + exe_context, + parent_type, + source_value, + field_asts + ) + if result is Undefined: + return results + + if is_thenable(result): + def collect_result(resolved_result): + results[response_name] = resolved_result + return results + + return result.then(collect_result, None) + + results[response_name] = result + return results + + def execute_field(prev_promise, response_name): + return prev_promise.then(lambda results: execute_field_callback(results, response_name)) + + return functools.reduce(execute_field, fields.keys(), Promise.resolve(collections.OrderedDict())) + + +def execute_fields(exe_context, parent_type, source_value, fields): + contains_promise = False + + final_results = OrderedDict() + + for response_name, field_asts in fields.items(): + result = resolve_field(exe_context, parent_type, source_value, field_asts) + if result is Undefined: + continue + + final_results[response_name] = result + if is_thenable(result): + contains_promise = True + + if not contains_promise: + return final_results + + return promise_for_dict(final_results) + + +def resolve_field(exe_context, parent_type, source, field_asts): + field_ast = field_asts[0] + field_name = field_ast.name.value + + field_def = get_field_def(exe_context.schema, parent_type, field_name) + if not field_def: + return Undefined + + return_type = field_def.type + resolve_fn = field_def.resolver or default_resolve_fn + + # We wrap the resolve_fn from the middleware + resolve_fn_middleware = exe_context.get_field_resolver(resolve_fn) + + # Build a dict of arguments from the field.arguments AST, using the variables scope to + # fulfill any variable references. + args = exe_context.get_argument_values(field_def, field_ast) + + # The resolve function's optional third argument is a context value that + # is provided to every resolve function within an execution. It is commonly + # used to represent an authenticated user, or request-specific caches. + context = exe_context.context_value + + # The resolve function's optional third argument is a collection of + # information about the current execution state. + info = ResolveInfo( + field_name, + field_asts, + return_type, + parent_type, + schema=exe_context.schema, + fragments=exe_context.fragments, + root_value=exe_context.root_value, + operation=exe_context.operation, + variable_values=exe_context.variable_values, + ) + + executor = exe_context.executor + result = resolve_or_error(resolve_fn_middleware, source, args, context, info, executor) + + return complete_value_catching_error( + exe_context, + return_type, + field_asts, + info, + result + ) + + +def resolve_or_error(resolve_fn, source, args, context, info, executor): + try: + return executor.execute(resolve_fn, source, args, context, info) + except Exception as e: + logger.exception("An error occurred while resolving field {}.{}".format( + info.parent_type.name, info.field_name + )) + e.stack = sys.exc_info()[2] + return e + + +def complete_value_catching_error(exe_context, return_type, field_asts, info, result): + # If the field type is non-nullable, then it is resolved without any + # protection from errors. + if isinstance(return_type, GraphQLNonNull): + return complete_value(exe_context, return_type, field_asts, info, result) + + # Otherwise, error protection is applied, logging the error and + # resolving a null value for this field if one is encountered. + try: + completed = complete_value(exe_context, return_type, field_asts, info, result) + if is_thenable(completed): + def handle_error(error): + exe_context.errors.append(error) + return None + + return completed.catch(handle_error) + + return completed + except Exception as e: + exe_context.errors.append(e) + return None + + +def complete_value(exe_context, return_type, field_asts, info, result): + """ + Implements the instructions for completeValue as defined in the + "Field entries" section of the spec. + + If the field type is Non-Null, then this recursively completes the value for the inner type. It throws a field + error if that completion returns null, as per the "Nullability" section of the spec. + + If the field type is a List, then this recursively completes the value for the inner type on each item in the + list. + + If the field type is a Scalar or Enum, ensures the completed value is a legal value of the type by calling the + `serialize` method of GraphQL type definition. + + If the field is an abstract type, determine the runtime type of the value and then complete based on that type. + + Otherwise, the field type expects a sub-selection set, and will complete the value by evaluating all + sub-selections. + """ + # If field type is NonNull, complete for inner type, and throw field error if result is null. + + if is_thenable(result): + return Promise.resolve(result).then( + lambda resolved: complete_value( + exe_context, + return_type, + field_asts, + info, + resolved + ), + lambda error: Promise.rejected(GraphQLLocatedError(field_asts, original_error=error)) + ) + + # print return_type, type(result) + if isinstance(result, Exception): + raise GraphQLLocatedError(field_asts, original_error=result) + + if isinstance(return_type, GraphQLNonNull): + return complete_nonnull_value(exe_context, return_type, field_asts, info, result) + + # If result is null-like, return null. + if result is None: + return None + + # If field type is List, complete each item in the list with the inner type + if isinstance(return_type, GraphQLList): + return complete_list_value(exe_context, return_type, field_asts, info, result) + + # If field type is Scalar or Enum, serialize to a valid value, returning null if coercion is not possible. + if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)): + return complete_leaf_value(return_type, result) + + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + return complete_abstract_value(exe_context, return_type, field_asts, info, result) + + if isinstance(return_type, GraphQLObjectType): + return complete_object_value(exe_context, return_type, field_asts, info, result) + + assert False, u'Cannot complete value of unexpected type "{}".'.format(return_type) + + +def complete_list_value(exe_context, return_type, field_asts, info, result): + """ + Complete a list value by completing each item in the list with the inner type + """ + assert isinstance(result, Iterable), \ + ('User Error: expected iterable, but did not find one ' + + 'for field {}.{}.').format(info.parent_type, info.field_name) + + item_type = return_type.of_type + completed_results = [] + contains_promise = False + for item in result: + completed_item = complete_value_catching_error(exe_context, item_type, field_asts, info, item) + if not contains_promise and is_thenable(completed_item): + contains_promise = True + + completed_results.append(completed_item) + + return Promise.all(completed_results) if contains_promise else completed_results + + +def complete_leaf_value(return_type, result): + """ + Complete a Scalar or Enum by serializing to a valid value, returning null if serialization is not possible. + """ + # serialize = getattr(return_type, 'serialize', None) + # assert serialize, 'Missing serialize method on type' + + return return_type.serialize(result) + + +def complete_abstract_value(exe_context, return_type, field_asts, info, result): + """ + Complete an value of an abstract type by determining the runtime type of that value, then completing based + on that type. + """ + runtime_type = None + + # Field type must be Object, Interface or Union and expect sub-selections. + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + if return_type.resolve_type: + runtime_type = return_type.resolve_type(result, exe_context.context_value, info) + else: + runtime_type = get_default_resolve_type_fn(result, exe_context.context_value, info, return_type) + + if isinstance(runtime_type, str): + runtime_type = info.schema.get_type(runtime_type) + + if not isinstance(runtime_type, GraphQLObjectType): + raise GraphQLError( + ('Abstract type {} must resolve to an Object type at runtime ' + + 'for field {}.{} with value "{}", received "{}".').format( + return_type, + info.parent_type, + info.field_name, + result, + runtime_type, + ), + field_asts + ) + + if not exe_context.schema.is_possible_type(return_type, runtime_type): + raise GraphQLError( + u'Runtime Object type "{}" is not a possible type for "{}".'.format(runtime_type, return_type), + field_asts + ) + + return complete_object_value(exe_context, runtime_type, field_asts, info, result) + + +def get_default_resolve_type_fn(value, context, info, abstract_type): + possible_types = info.schema.get_possible_types(abstract_type) + for type in possible_types: + if callable(type.is_type_of) and type.is_type_of(value, context, info): + return type + + +def complete_object_value(exe_context, return_type, field_asts, info, result): + """ + Complete an Object value by evaluating all sub-selections. + """ + if return_type.is_type_of and not return_type.is_type_of(result, exe_context.context_value, info): + raise GraphQLError( + u'Expected value of type "{}" but got: {}.'.format(return_type, type(result).__name__), + field_asts + ) + + # Collect sub-fields to execute to complete this value. + subfield_asts = exe_context.get_sub_fields(return_type, field_asts) + return execute_fields(exe_context, return_type, result, subfield_asts) + + +def complete_nonnull_value(exe_context, return_type, field_asts, info, result): + """ + Complete a NonNull value by completing the inner type + """ + completed = complete_value( + exe_context, return_type.of_type, field_asts, info, result + ) + if completed is None: + raise GraphQLError( + 'Cannot return null for non-nullable field {}.{}.'.format(info.parent_type, info.field_name), + field_asts + ) + + return completed diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05d9c0b644996c2493dde093d14311d2acc38784 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__pycache__/sync.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__pycache__/sync.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e58a776c6bde9e324b1d345a08318841a51ca7c Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/__pycache__/sync.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py new file mode 100644 index 0000000000000000000000000000000000000000..eabeca1461d486cffacbc2d83bfe68315146c19d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/asyncio.py @@ -0,0 +1,53 @@ +from __future__ import absolute_import + +from asyncio import Future, get_event_loop, iscoroutine, wait + +from wandb_promise import Promise + +try: + from asyncio import ensure_future +except ImportError: + # ensure_future is only implemented in Python 3.4.4+ + def ensure_future(coro_or_future, loop=None): + """Wrap a coroutine or an awaitable in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + if loop is None: + loop = get_event_loop() + task = loop.create_task(coro_or_future) + if task._source_traceback: + del task._source_traceback[-1] + return task + else: + raise TypeError('A Future, a coroutine or an awaitable is required') + + +class AsyncioExecutor(object): + + def __init__(self, loop=None): + if loop is None: + loop = get_event_loop() + self.loop = loop + self.futures = [] + + def wait_until_finished(self): + # if there are futures to wait for + while self.futures: + # wait for the futures to finish + futures = self.futures + self.futures = [] + self.loop.run_until_complete(wait(futures)) + + def execute(self, fn, *args, **kwargs): + result = fn(*args, **kwargs) + if isinstance(result, Future) or iscoroutine(result): + future = ensure_future(result, loop=self.loop) + self.futures.append(future) + return Promise.resolve(future) + return result diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py new file mode 100644 index 0000000000000000000000000000000000000000..67395d3f7f359c14899c229ded73f0e6883c79b3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/gevent.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import + +import gevent +from wandb_promise import Promise + +from .utils import process + + +class GeventExecutor(object): + + def __init__(self): + self.jobs = [] + + def wait_until_finished(self): + [j.join() for j in self.jobs] + # gevent.joinall(self.jobs) + + def execute(self, fn, *args, **kwargs): + promise = Promise() + job = gevent.spawn(process, promise, fn, args, kwargs) + self.jobs.append(job) + return promise diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py new file mode 100644 index 0000000000000000000000000000000000000000..51301c864a5579c6a93df07da92e2ef520264c0c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/process.py @@ -0,0 +1,32 @@ +from multiprocessing import Process, Queue + +from wandb_promise import Promise + +from .utils import process + + +def queue_process(q): + promise, fn, args, kwargs = q.get() + process(promise, fn, args, kwargs) + + +class ProcessExecutor(object): + + def __init__(self): + self.processes = [] + self.q = Queue() + + def wait_until_finished(self): + for _process in self.processes: + _process.join() + self.q.close() + self.q.join_thread() + + def execute(self, fn, *args, **kwargs): + promise = Promise() + + self.q.put([promise, fn, args, kwargs], False) + _process = Process(target=queue_process, args=(self.q)) + _process.start() + self.processes.append(_process) + return promise diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py new file mode 100644 index 0000000000000000000000000000000000000000..85f8471b6a148bf22184b5b0d9ea06bff78c0b2e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/sync.py @@ -0,0 +1,7 @@ +class SyncExecutor(object): + + def wait_until_finished(self): + pass + + def execute(self, fn, *args, **kwargs): + return fn(*args, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py new file mode 100644 index 0000000000000000000000000000000000000000..28e8883608849e6012e936086f3f6131c0407678 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/thread.py @@ -0,0 +1,35 @@ +from multiprocessing.pool import ThreadPool +from threading import Thread + +from wandb_promise import Promise + +from .utils import process + + +class ThreadExecutor(object): + + pool = None + + def __init__(self, pool=False): + self.threads = [] + if pool: + self.execute = self.execute_in_pool + self.pool = ThreadPool(processes=pool) + else: + self.execute = self.execute_in_thread + + def wait_until_finished(self): + for thread in self.threads: + thread.join() + + def execute_in_thread(self, fn, *args, **kwargs): + promise = Promise() + thread = Thread(target=process, args=(promise, fn, args, kwargs)) + thread.start() + self.threads.append(thread) + return promise + + def execute_in_pool(self, fn, *args, **kwargs): + promise = Promise() + self.pool.map(lambda input: process(*input), [(promise, fn, args, kwargs)]) + return promise diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc44875b8e42a4e08bce82c5826c2979572e389 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/executors/utils.py @@ -0,0 +1,6 @@ +def process(p, f, args, kwargs): + try: + val = f(*args, **kwargs) + p.do_resolve(val) + except Exception as e: + p.do_reject(e) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..312ecc01d7983742af64f528e14adb366a671750 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/executor.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/executor.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e677ff805aac79b6a5acac79c87369021870b18d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/executor.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/fragment.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/fragment.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7631b98f0a6d13af8ee3a47548e16114a41895fa Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/__pycache__/fragment.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d9d99c4e89e576caa51f1070f133a2bcebc832 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/executor.py @@ -0,0 +1,66 @@ +from wandb_promise import Promise + +from ...type import GraphQLSchema +from ..base import ExecutionContext, ExecutionResult, get_operation_root_type +from ..executors.sync import SyncExecutor +from ..middleware import MiddlewareManager +from .fragment import Fragment + + +def execute(schema, document_ast, root_value=None, context_value=None, + variable_values=None, operation_name=None, executor=None, + return_promise=False, middleware=None): + assert schema, 'Must provide schema' + assert isinstance(schema, GraphQLSchema), ( + 'Schema must be an instance of GraphQLSchema. Also ensure that there are ' + + 'not multiple versions of GraphQL installed in your node_modules directory.' + ) + if middleware: + if not isinstance(middleware, MiddlewareManager): + middleware = MiddlewareManager(*middleware) + assert isinstance(middleware, MiddlewareManager), ( + 'middlewares have to be an instance' + ' of MiddlewareManager. Received "{}".'.format(middleware) + ) + + if executor is None: + executor = SyncExecutor() + + context = ExecutionContext( + schema, + document_ast, + root_value, + context_value, + variable_values, + operation_name, + executor, + middleware + ) + + def executor(resolve, reject): + return resolve(execute_operation(context, context.operation, root_value)) + + def on_rejected(error): + context.errors.append(error) + return None + + def on_resolve(data): + if not context.errors: + return ExecutionResult(data=data) + return ExecutionResult(data=data, errors=context.errors) + + promise = Promise(executor).catch(on_rejected).then(on_resolve) + if return_promise: + return promise + context.executor.wait_until_finished() + return promise.get() + + +def execute_operation(exe_context, operation, root_value): + type = get_operation_root_type(exe_context.schema, operation) + execute_serially = operation.operation == 'mutation' + + fragment = Fragment(type=type, field_asts=[operation], context=exe_context) + if execute_serially: + return fragment.resolve_serially(root_value) + return fragment.resolve(root_value) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py new file mode 100644 index 0000000000000000000000000000000000000000..ccef4b405066a3e77a923152e46b9cad4c2dcee1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/fragment.py @@ -0,0 +1,252 @@ +import functools + +from wandb_promise import Promise, is_thenable, promise_for_dict + +from ...pyutils.cached_property import cached_property +from ...pyutils.default_ordered_dict import DefaultOrderedDict +from ...type import (GraphQLInterfaceType, GraphQLList, GraphQLNonNull, + GraphQLObjectType, GraphQLUnionType) +from ..base import ResolveInfo, Undefined, collect_fields, get_field_def +from ..values import get_argument_values +from ...error import GraphQLError +try: + from itertools import izip as zip +except: + pass + + +def get_base_type(type): + if isinstance(type, (GraphQLList, GraphQLNonNull)): + return get_base_type(type.of_type) + return type + + +def get_subfield_asts(context, return_type, field_asts): + subfield_asts = DefaultOrderedDict(list) + visited_fragment_names = set() + for field_ast in field_asts: + selection_set = field_ast.selection_set + if selection_set: + subfield_asts = collect_fields( + context, return_type, selection_set, + subfield_asts, visited_fragment_names + ) + return subfield_asts + + +def get_resolvers(context, type, field_asts): + from .resolver import field_resolver + subfield_asts = get_subfield_asts(context, type, field_asts) + + for response_name, field_asts in subfield_asts.items(): + field_ast = field_asts[0] + field_name = field_ast.name.value + field_def = get_field_def(context and context.schema, type, field_name) + if not field_def: + continue + field_base_type = get_base_type(field_def.type) + field_fragment = None + info = ResolveInfo( + field_name, + field_asts, + field_base_type, + parent_type=type, + schema=context and context.schema, + fragments=context and context.fragments, + root_value=context and context.root_value, + operation=context and context.operation, + variable_values=context and context.variable_values, + ) + if isinstance(field_base_type, GraphQLObjectType): + field_fragment = Fragment( + type=field_base_type, + field_asts=field_asts, + info=info, + context=context + ) + elif isinstance(field_base_type, (GraphQLInterfaceType, GraphQLUnionType)): + field_fragment = AbstractFragment( + abstract_type=field_base_type, + field_asts=field_asts, + info=info, + context=context + ) + resolver = field_resolver(field_def, exe_context=context, info=info, fragment=field_fragment) + args = get_argument_values( + field_def.args, + field_ast.arguments, + context and context.variable_values + ) + yield (response_name, Field(resolver, args, context and context.context_value, info)) + + +class Field(object): + __slots__ = ('fn', 'args', 'context', 'info') + + def __init__(self, fn, args, context, info): + self.fn = fn + self.args = args + self.context = context + self.info = info + + def execute(self, root): + return self.fn(root, self.args, self.context, self.info) + + +class Fragment(object): + + def __init__(self, type, field_asts, context=None, info=None): + self.type = type + self.field_asts = field_asts + self.context = context + self.info = info + + @cached_property + def partial_resolvers(self): + return list(get_resolvers( + self.context, + self.type, + self.field_asts + )) + + @cached_property + def fragment_container(self): + try: + fields = next(zip(*self.partial_resolvers)) + except StopIteration: + fields = tuple() + + class FragmentInstance(dict): + # def __init__(self): + # self.fields = fields + # _fields = ('c','b','a') + set = dict.__setitem__ + # def set(self, name, value): + # self[name] = value + + def __iter__(self): + return iter(fields) + + return FragmentInstance + + def have_type(self, root): + return not self.type.is_type_of or self.type.is_type_of(root, self.context.context_value, self.info) + + def resolve(self, root): + if root and not self.have_type(root): + raise GraphQLError( + u'Expected value of type "{}" but got: {}.'.format(self.type, type(root).__name__), + self.info.field_asts + ) + + contains_promise = False + + final_results = self.fragment_container() + # return OrderedDict( + # ((field_name, field_resolver(root, field_args, context, info)) + # for field_name, field_resolver, field_args, context, info in self.partial_resolvers) + # ) + for response_name, field_resolver in self.partial_resolvers: + + result = field_resolver.execute(root) + if result is Undefined: + continue + + if not contains_promise and is_thenable(result): + contains_promise = True + + final_results[response_name] = result + + if not contains_promise: + return final_results + + return promise_for_dict(final_results) + # return { + # field_name: field_resolver(root, field_args, context, info) + # for field_name, field_resolver, field_args, context, info in self.partial_resolvers + # } + + def resolve_serially(self, root): + def execute_field_callback(results, resolver): + response_name, field_resolver = resolver + + result = field_resolver.execute(root) + + if result is Undefined: + return results + + if is_thenable(result): + def collect_result(resolved_result): + results[response_name] = resolved_result + return results + + return result.then(collect_result) + + results[response_name] = result + return results + + def execute_field(prev_promise, resolver): + return prev_promise.then(lambda results: execute_field_callback(results, resolver)) + + return functools.reduce(execute_field, self.partial_resolvers, Promise.resolve(self.fragment_container())) + + def __eq__(self, other): + return isinstance(other, Fragment) and ( + other.type == self.type and + other.field_asts == self.field_asts and + other.context == self.context and + other.info == self.info + ) + + +class AbstractFragment(object): + + def __init__(self, abstract_type, field_asts, context=None, info=None): + self.abstract_type = abstract_type + self.field_asts = field_asts + self.context = context + self.info = info + self._fragments = {} + + @cached_property + def possible_types(self): + return self.context.schema.get_possible_types(self.abstract_type) + + @cached_property + def possible_types_with_is_type_of(self): + return [ + (type, type.is_type_of) for type in self.possible_types if callable(type.is_type_of) + ] + + def get_fragment(self, type): + if isinstance(type, str): + type = self.context.schema.get_type(type) + + if type not in self._fragments: + assert type in self.possible_types, ( + 'Runtime Object type "{}" is not a possible type for "{}".' + ).format(type, self.abstract_type) + self._fragments[type] = Fragment( + type, + self.field_asts, + self.context, + self.info + ) + + return self._fragments[type] + + def resolve_type(self, result): + return_type = self.abstract_type + context = self.context.context_value + + if return_type.resolve_type: + return return_type.resolve_type(result, context, self.info) + + for type, is_type_of in self.possible_types_with_is_type_of: + if is_type_of(result, context, self.info): + return type + + def resolve(self, root): + _type = self.resolve_type(root) + fragment = self.get_fragment(_type) + return fragment.resolve(root) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..4d39cacd1dc810ec6a4de2fa709077c248760983 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/resolver.py @@ -0,0 +1,151 @@ +import sys +from collections.abc import Iterable +from functools import partial + +from wandb_promise import Promise, is_thenable + +from ...error import GraphQLError, GraphQLLocatedError +from ...type import (GraphQLEnumType, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLUnionType) +from ..base import default_resolve_fn +from ...execution import executor +from .utils import imap, normal_map + + +def on_complete_resolver(on_error, __func, exe_context, info, __resolver, *args, **kwargs): + try: + result = __resolver(*args, **kwargs) + if isinstance(result, Exception): + return on_error(result) + # return Promise.resolve(result).then(__func).catch(on_error) + if is_thenable(result): + # TODO: Remove this, if a promise is resolved with an Exception, + # it should raise by default. This is fixing an old behavior + # in the Promise package + def on_resolve(value): + if isinstance(value, Exception): + return on_error(value) + return value + return result.then(on_resolve).then(__func).catch(on_error) + return __func(result) + except Exception as e: + return on_error(e) + + +def complete_list_value(inner_resolver, exe_context, info, on_error, result): + if result is None: + return None + + assert isinstance(result, Iterable), \ + ('User Error: expected iterable, but did not find one ' + + 'for field {}.{}.').format(info.parent_type, info.field_name) + + completed_results = normal_map(inner_resolver, result) + + if not any(imap(is_thenable, completed_results)): + return completed_results + + return Promise.all(completed_results).catch(on_error) + + +def complete_nonnull_value(exe_context, info, result): + if result is None: + raise GraphQLError( + 'Cannot return null for non-nullable field {}.{}.'.format(info.parent_type, info.field_name), + info.field_asts + ) + return result + + +def complete_leaf_value(serialize, result): + if result is None: + return None + return serialize(result) + + +def complete_object_value(fragment_resolve, exe_context, on_error, result): + if result is None: + return None + + result = fragment_resolve(result) + if is_thenable(result): + return result.catch(on_error) + return result + + +def field_resolver(field, fragment=None, exe_context=None, info=None): + # resolver = exe_context.get_field_resolver(field.resolver or default_resolve_fn) + resolver = field.resolver or default_resolve_fn + if exe_context: + # We decorate the resolver with the middleware + resolver = exe_context.get_field_resolver(resolver) + return type_resolver(field.type, resolver, + fragment, exe_context, info, catch_error=True) + + +def type_resolver(return_type, resolver, fragment=None, exe_context=None, info=None, catch_error=False): + if isinstance(return_type, GraphQLNonNull): + return type_resolver_non_null(return_type, resolver, fragment, exe_context, info) + + if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)): + return type_resolver_leaf(return_type, resolver, exe_context, info, catch_error) + + if isinstance(return_type, (GraphQLList)): + return type_resolver_list(return_type, resolver, fragment, exe_context, info, catch_error) + + if isinstance(return_type, (GraphQLObjectType)): + assert fragment and fragment.type == return_type, 'Fragment and return_type dont match' + return type_resolver_fragment(return_type, resolver, fragment, exe_context, info, catch_error) + + if isinstance(return_type, (GraphQLInterfaceType, GraphQLUnionType)): + assert fragment, 'You need to pass a fragment to resolve a Interface or Union' + return type_resolver_fragment(return_type, resolver, fragment, exe_context, info, catch_error) + + raise Exception("The resolver have to be created for a fragment") + + +def on_error(exe_context, info, catch_error, e): + error = e + if not isinstance(e, (GraphQLLocatedError, GraphQLError)): + error = GraphQLLocatedError(info.field_asts, original_error=e) + if catch_error: + exe_context.errors.append(error) + executor.logger.exception("An error occurred while resolving field {}.{}".format( + info.parent_type.name, info.field_name + )) + error.stack = sys.exc_info()[2] + return None + raise error + + +def type_resolver_fragment(return_type, resolver, fragment, exe_context, info, catch_error): + on_complete_type_error = partial(on_error, exe_context, info, catch_error) + complete_object_value_resolve = partial( + complete_object_value, + fragment.resolve, + exe_context, + on_complete_type_error) + on_resolve_error = partial(on_error, exe_context, info, catch_error) + return partial(on_complete_resolver, on_resolve_error, complete_object_value_resolve, exe_context, info, resolver) + + +def type_resolver_non_null(return_type, resolver, fragment, exe_context, info): # no catch_error + resolver = type_resolver(return_type.of_type, resolver, fragment, exe_context, info) + nonnull_complete = partial(complete_nonnull_value, exe_context, info) + on_resolve_error = partial(on_error, exe_context, info, False) + return partial(on_complete_resolver, on_resolve_error, nonnull_complete, exe_context, info, resolver) + + +def type_resolver_leaf(return_type, resolver, exe_context, info, catch_error): + leaf_complete = partial(complete_leaf_value, return_type.serialize) + on_resolve_error = partial(on_error, exe_context, info, catch_error) + return partial(on_complete_resolver, on_resolve_error, leaf_complete, exe_context, info, resolver) + + +def type_resolver_list(return_type, resolver, fragment, exe_context, info, catch_error): + item_type = return_type.of_type + inner_resolver = type_resolver(item_type, lambda item: item, fragment, exe_context, info, catch_error=True) + on_resolve_error = partial(on_error, exe_context, info, catch_error) + list_complete = partial(complete_list_value, inner_resolver, exe_context, info, on_resolve_error) + return partial(on_complete_resolver, on_resolve_error, list_complete, exe_context, info, resolver) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8aef421d9b89f3d7853e68603768af4e1bfd5271 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/experimental/utils.py @@ -0,0 +1,7 @@ +try: + from itertools import imap + normal_map = map +except: + def normal_map(func, iter): + return list(map(func, iter)) + imap = map diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py new file mode 100644 index 0000000000000000000000000000000000000000..7a18f2988857c62e23b66fb17e53cadfa70a6460 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/middleware.py @@ -0,0 +1,57 @@ +import inspect +from functools import partial +from itertools import chain + +from wandb_promise import Promise + +MIDDLEWARE_RESOLVER_FUNCTION = 'resolve' + + +class MiddlewareManager(object): + + def __init__(self, *middlewares, **kwargs): + self.middlewares = middlewares + self.wrap_in_promise = kwargs.get('wrap_in_promise', True) + self._middleware_resolvers = list(get_middleware_resolvers(middlewares)) + self._cached_resolvers = {} + + def get_field_resolver(self, field_resolver): + if field_resolver not in self._cached_resolvers: + self._cached_resolvers[field_resolver] = middleware_chain( + field_resolver, + self._middleware_resolvers, + wrap_in_promise=self.wrap_in_promise, + ) + + return self._cached_resolvers[field_resolver] + + +middlewares = MiddlewareManager + + +def get_middleware_resolvers(middlewares): + for middleware in middlewares: + # If the middleware is a function instead of a class + if inspect.isfunction(middleware): + yield middleware + if not hasattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION): + continue + yield getattr(middleware, MIDDLEWARE_RESOLVER_FUNCTION) + + +def middleware_chain(func, middlewares, wrap_in_promise): + if not middlewares: + return func + if wrap_in_promise: + middlewares = chain((func, make_it_promise), middlewares) + else: + middlewares = chain((func,), middlewares) + last_func = None + for middleware in middlewares: + last_func = partial(middleware, last_func) if last_func else middleware + + return last_func + + +def make_it_promise(next, *a, **b): + return Promise.resolve(next(*a, **b)) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py new file mode 100644 index 0000000000000000000000000000000000000000..473ada13c2d8f9ab9029cad419cd0f195380969b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/execution/values.py @@ -0,0 +1,145 @@ +from collections.abc import Iterable +import json + +from ..error import GraphQLError +from ..language.printer import print_ast +from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, + GraphQLNonNull, GraphQLScalarType, is_input_type) +from ..utils.is_valid_value import is_valid_value +from ..utils.type_from_ast import type_from_ast +from ..utils.value_from_ast import value_from_ast + +__all__ = ['get_variable_values', 'get_argument_values'] + + +def get_variable_values(schema, definition_asts, inputs): + """Prepares an object map of variables of the correct type based on the provided variable definitions and arbitrary input. + If the input cannot be parsed to match the variable definitions, a GraphQLError will be thrown.""" + if inputs is None: + inputs = {} + + values = {} + for def_ast in definition_asts: + var_name = def_ast.variable.name.value + value = get_variable_value(schema, def_ast, inputs.get(var_name)) + values[var_name] = value + + return values + + +def get_argument_values(arg_defs, arg_asts, variables=None): + """Prepares an object map of argument values given a list of argument + definitions and list of argument AST nodes.""" + if not arg_defs: + return {} + + if arg_asts: + arg_ast_map = {arg.name.value: arg for arg in arg_asts} + else: + arg_ast_map = {} + + result = {} + for name, arg_def in arg_defs.items(): + value_ast = arg_ast_map.get(name) + if value_ast: + value_ast = value_ast.value + + value = value_from_ast( + value_ast, + arg_def.type, + variables + ) + + if value is None: + value = arg_def.default_value + + if value is not None: + # We use out_name as the output name for the + # dict if exists + result[arg_def.out_name or name] = value + + return result + + +def get_variable_value(schema, definition_ast, input): + """Given a variable definition, and any value of input, return a value which adheres to the variable definition, + or throw an error.""" + type = type_from_ast(schema, definition_ast.type) + variable = definition_ast.variable + + if not type or not is_input_type(type): + raise GraphQLError( + 'Variable "${}" expected value of type "{}" which cannot be used as an input type.'.format( + variable.name.value, + print_ast(definition_ast.type), + ), + [definition_ast] + ) + + input_type = type + errors = is_valid_value(input, input_type) + if not errors: + if input is None: + default_value = definition_ast.default_value + if default_value: + return value_from_ast(default_value, input_type) + + return coerce_value(input_type, input) + + if input is None: + raise GraphQLError( + 'Variable "${}" of required type "{}" was not provided.'.format( + variable.name.value, + print_ast(definition_ast.type) + ), + [definition_ast] + ) + + message = (u'\n' + u'\n'.join(errors)) if errors else u'' + raise GraphQLError( + 'Variable "${}" got invalid value {}.{}'.format( + variable.name.value, + json.dumps(input, sort_keys=True), + message + ), + [definition_ast] + ) + + +def coerce_value(type, value): + """Given a type and any value, return a runtime value coerced to match the type.""" + if isinstance(type, GraphQLNonNull): + # Note: we're not checking that the result of coerceValue is + # non-null. + # We only call this function after calling isValidValue. + return coerce_value(type.of_type, value) + + if value is None: + return None + + if isinstance(type, GraphQLList): + item_type = type.of_type + if not isinstance(value, str) and isinstance(value, Iterable): + return [coerce_value(item_type, item) for item in value] + else: + return [coerce_value(item_type, value)] + + if isinstance(type, GraphQLInputObjectType): + fields = type.fields + obj = {} + for field_name, field in fields.items(): + field_value = coerce_value(field.type, value.get(field_name)) + if field_value is None: + field_value = field.default_value + + if field_value is not None: + # We use out_name as the output name for the + # dict if exists + obj[field.out_name or field_name] = field_value + + return obj + + assert isinstance(type, (GraphQLScalarType, GraphQLEnumType)), \ + 'Must be input type' + + return type.parse_value(value) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d311c94adec29f221ae2af13efa4a0e84816e2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/graphql.py @@ -0,0 +1,60 @@ +from .execution import ExecutionResult, execute +from .language.ast import Document +from .language.parser import parse +from .language.source import Source +from .validation import validate + + +# This is the primary entry point function for fulfilling GraphQL operations +# by parsing, validating, and executing a GraphQL document along side a +# GraphQL schema. + +# More sophisticated GraphQL servers, such as those which persist queries, +# may wish to separate the validation and execution phases to a static time +# tooling step, and a server runtime step. + +# schema: +# The GraphQL type system to use when validating and executing a query. +# requestString: +# A GraphQL language formatted string representing the requested operation. +# rootValue: +# The value provided as the first argument to resolver functions on the top +# level type (e.g. the query object type). +# variableValues: +# A mapping of variable name to runtime value to use for all variables +# defined in the requestString. +# operationName: +# The name of the operation to use if requestString contains multiple +# possible operations. Can be omitted if requestString contains only +# one operation. +def graphql(schema, request_string='', root_value=None, context_value=None, + variable_values=None, operation_name=None, executor=None, + return_promise=False, middleware=None): + try: + if isinstance(request_string, Document): + ast = request_string + else: + source = Source(request_string, 'GraphQL request') + ast = parse(source) + validation_errors = validate(schema, ast) + if validation_errors: + return ExecutionResult( + errors=validation_errors, + invalid=True, + ) + return execute( + schema, + ast, + root_value, + context_value, + operation_name=operation_name, + variable_values=variable_values or {}, + executor=executor, + return_promise=return_promise, + middleware=middleware, + ) + except Exception as e: + return ExecutionResult( + errors=[e], + invalid=True, + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b79e4a9aa019934ad14009e5895b0746a3c4dd9 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/ast.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/ast.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..857ce4f2ca5ae79250d88fe5aad83ca3f35c5fd2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/ast.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/base.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f75c955806d9f0e37b63c4c64cf4fcb18fab5993 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/base.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/lexer.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/lexer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54af600e9260506284f5fde8664e47275e8d2b3f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/lexer.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/location.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/location.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9474009eea49289a53048d6fab4904f4d042985e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/location.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/parser.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/parser.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64de06b7d697d7b20f8ca76a8a8e101e63b8122c Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/parser.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/printer.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/printer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0868086cb3d821ef234e70d0f97dad62a873ab95 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/printer.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/source.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/source.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d43e5fe54d6aeb557490f38034089cb9eb0df9e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/source.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/visitor.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/visitor.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a476e1d71b54e67a5c1a4bd5bf54b4cb6441cf7a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/visitor.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/visitor_meta.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/visitor_meta.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72e15f88dca80225fa69c461fd183a28e4145fb7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/__pycache__/visitor_meta.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py new file mode 100644 index 0000000000000000000000000000000000000000..6fffae84b8835e5d6a71e0b15534de730b607c91 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/ast.py @@ -0,0 +1,1349 @@ +# This is autogenerated code. DO NOT change this manually. +# Run scripts/generate_ast.py to generate this file. + + +class Node(object): + __slots__ = () + + +class Definition(Node): + __slots__ = () + + +class Document(Node): + __slots__ = ('loc', 'definitions',) + _fields = ('definitions',) + + def __init__(self, definitions, loc=None): + self.loc = loc + self.definitions = definitions + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Document) and + # self.loc == other.loc and + self.definitions == other.definitions + ) + ) + + def __repr__(self): + return ('Document(' + 'definitions={self.definitions!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.definitions, + self.loc + ) + + def __hash__(self): + return id(self) + + +class OperationDefinition(Definition): + __slots__ = ('loc', 'operation', 'name', 'variable_definitions', 'directives', 'selection_set',) + _fields = ('operation', 'name', 'variable_definitions', 'directives', 'selection_set',) + + def __init__(self, operation, selection_set, name=None, variable_definitions=None, directives=None, loc=None): + self.loc = loc + self.operation = operation + self.name = name + self.variable_definitions = variable_definitions + self.directives = directives + self.selection_set = selection_set + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, OperationDefinition) and + # self.loc == other.loc and + self.operation == other.operation and + self.name == other.name and + self.variable_definitions == other.variable_definitions and + self.directives == other.directives and + self.selection_set == other.selection_set + ) + ) + + def __repr__(self): + return ('OperationDefinition(' + 'operation={self.operation!r}' + ', name={self.name!r}' + ', variable_definitions={self.variable_definitions!r}' + ', directives={self.directives!r}' + ', selection_set={self.selection_set!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.operation, + self.selection_set, + self.name, + self.variable_definitions, + self.directives, + self.loc + ) + + def __hash__(self): + return id(self) + + +class VariableDefinition(Node): + __slots__ = ('loc', 'variable', 'type', 'default_value',) + _fields = ('variable', 'type', 'default_value',) + + def __init__(self, variable, type, default_value=None, loc=None): + self.loc = loc + self.variable = variable + self.type = type + self.default_value = default_value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, VariableDefinition) and + # self.loc == other.loc and + self.variable == other.variable and + self.type == other.type and + self.default_value == other.default_value + ) + ) + + def __repr__(self): + return ('VariableDefinition(' + 'variable={self.variable!r}' + ', type={self.type!r}' + ', default_value={self.default_value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.variable, + self.type, + self.default_value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class SelectionSet(Node): + __slots__ = ('loc', 'selections',) + _fields = ('selections',) + + def __init__(self, selections, loc=None): + self.loc = loc + self.selections = selections + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, SelectionSet) and + # self.loc == other.loc and + self.selections == other.selections + ) + ) + + def __repr__(self): + return ('SelectionSet(' + 'selections={self.selections!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.selections, + self.loc + ) + + def __hash__(self): + return id(self) + + +class Selection(Node): + __slots__ = () + + +class Field(Selection): + __slots__ = ('loc', 'alias', 'name', 'arguments', 'directives', 'selection_set',) + _fields = ('alias', 'name', 'arguments', 'directives', 'selection_set',) + + def __init__(self, name, alias=None, arguments=None, directives=None, selection_set=None, loc=None): + self.loc = loc + self.alias = alias + self.name = name + self.arguments = arguments + self.directives = directives + self.selection_set = selection_set + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Field) and + # self.loc == other.loc and + self.alias == other.alias and + self.name == other.name and + self.arguments == other.arguments and + self.directives == other.directives and + self.selection_set == other.selection_set + ) + ) + + def __repr__(self): + return ('Field(' + 'alias={self.alias!r}' + ', name={self.name!r}' + ', arguments={self.arguments!r}' + ', directives={self.directives!r}' + ', selection_set={self.selection_set!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.alias, + self.arguments, + self.directives, + self.selection_set, + self.loc + ) + + def __hash__(self): + return id(self) + + +class Argument(Node): + __slots__ = ('loc', 'name', 'value',) + _fields = ('name', 'value',) + + def __init__(self, name, value, loc=None): + self.loc = loc + self.name = name + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Argument) and + # self.loc == other.loc and + self.name == other.name and + self.value == other.value + ) + ) + + def __repr__(self): + return ('Argument(' + 'name={self.name!r}' + ', value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class FragmentSpread(Selection): + __slots__ = ('loc', 'name', 'directives',) + _fields = ('name', 'directives',) + + def __init__(self, name, directives=None, loc=None): + self.loc = loc + self.name = name + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, FragmentSpread) and + # self.loc == other.loc and + self.name == other.name and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('FragmentSpread(' + 'name={self.name!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.directives, + self.loc + ) + + def __hash__(self): + return id(self) + + +class InlineFragment(Selection): + __slots__ = ('loc', 'type_condition', 'directives', 'selection_set',) + _fields = ('type_condition', 'directives', 'selection_set',) + + def __init__(self, type_condition, selection_set, directives=None, loc=None): + self.loc = loc + self.type_condition = type_condition + self.directives = directives + self.selection_set = selection_set + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, InlineFragment) and + # self.loc == other.loc and + self.type_condition == other.type_condition and + self.directives == other.directives and + self.selection_set == other.selection_set + ) + ) + + def __repr__(self): + return ('InlineFragment(' + 'type_condition={self.type_condition!r}' + ', directives={self.directives!r}' + ', selection_set={self.selection_set!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.type_condition, + self.selection_set, + self.directives, + self.loc + ) + + def __hash__(self): + return id(self) + + +class FragmentDefinition(Definition): + __slots__ = ('loc', 'name', 'type_condition', 'directives', 'selection_set',) + _fields = ('name', 'type_condition', 'directives', 'selection_set',) + + def __init__(self, name, type_condition, selection_set, directives=None, loc=None): + self.loc = loc + self.name = name + self.type_condition = type_condition + self.directives = directives + self.selection_set = selection_set + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, FragmentDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.type_condition == other.type_condition and + self.directives == other.directives and + self.selection_set == other.selection_set + ) + ) + + def __repr__(self): + return ('FragmentDefinition(' + 'name={self.name!r}' + ', type_condition={self.type_condition!r}' + ', directives={self.directives!r}' + ', selection_set={self.selection_set!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.type_condition, + self.selection_set, + self.directives, + self.loc + ) + + def __hash__(self): + return id(self) + + +class Value(Node): + __slots__ = () + + +class Variable(Value): + __slots__ = ('loc', 'name',) + _fields = ('name',) + + def __init__(self, name, loc=None): + self.loc = loc + self.name = name + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Variable) and + # self.loc == other.loc and + self.name == other.name + ) + ) + + def __repr__(self): + return ('Variable(' + 'name={self.name!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.loc + ) + + def __hash__(self): + return id(self) + + +class IntValue(Value): + __slots__ = ('loc', 'value',) + _fields = ('value',) + + def __init__(self, value, loc=None): + self.loc = loc + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, IntValue) and + # self.loc == other.loc and + self.value == other.value + ) + ) + + def __repr__(self): + return ('IntValue(' + 'value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class FloatValue(Value): + __slots__ = ('loc', 'value',) + _fields = ('value',) + + def __init__(self, value, loc=None): + self.loc = loc + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, FloatValue) and + # self.loc == other.loc and + self.value == other.value + ) + ) + + def __repr__(self): + return ('FloatValue(' + 'value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class StringValue(Value): + __slots__ = ('loc', 'value',) + _fields = ('value',) + + def __init__(self, value, loc=None): + self.loc = loc + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, StringValue) and + # self.loc == other.loc and + self.value == other.value + ) + ) + + def __repr__(self): + return ('StringValue(' + 'value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class BooleanValue(Value): + __slots__ = ('loc', 'value',) + _fields = ('value',) + + def __init__(self, value, loc=None): + self.loc = loc + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, BooleanValue) and + # self.loc == other.loc and + self.value == other.value + ) + ) + + def __repr__(self): + return ('BooleanValue(' + 'value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class EnumValue(Value): + __slots__ = ('loc', 'value',) + _fields = ('value',) + + def __init__(self, value, loc=None): + self.loc = loc + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, EnumValue) and + # self.loc == other.loc and + self.value == other.value + ) + ) + + def __repr__(self): + return ('EnumValue(' + 'value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class ListValue(Value): + __slots__ = ('loc', 'values',) + _fields = ('values',) + + def __init__(self, values, loc=None): + self.loc = loc + self.values = values + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ListValue) and + # self.loc == other.loc and + self.values == other.values + ) + ) + + def __repr__(self): + return ('ListValue(' + 'values={self.values!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.values, + self.loc + ) + + def __hash__(self): + return id(self) + + +class ObjectValue(Value): + __slots__ = ('loc', 'fields',) + _fields = ('fields',) + + def __init__(self, fields, loc=None): + self.loc = loc + self.fields = fields + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ObjectValue) and + # self.loc == other.loc and + self.fields == other.fields + ) + ) + + def __repr__(self): + return ('ObjectValue(' + 'fields={self.fields!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.fields, + self.loc + ) + + def __hash__(self): + return id(self) + + +class ObjectField(Node): + __slots__ = ('loc', 'name', 'value',) + _fields = ('name', 'value',) + + def __init__(self, name, value, loc=None): + self.loc = loc + self.name = name + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ObjectField) and + # self.loc == other.loc and + self.name == other.name and + self.value == other.value + ) + ) + + def __repr__(self): + return ('ObjectField(' + 'name={self.name!r}' + ', value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +class Directive(Node): + __slots__ = ('loc', 'name', 'arguments',) + _fields = ('name', 'arguments',) + + def __init__(self, name, arguments=None, loc=None): + self.loc = loc + self.name = name + self.arguments = arguments + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Directive) and + # self.loc == other.loc and + self.name == other.name and + self.arguments == other.arguments + ) + ) + + def __repr__(self): + return ('Directive(' + 'name={self.name!r}' + ', arguments={self.arguments!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.arguments, + self.loc + ) + + def __hash__(self): + return id(self) + + +class Type(Node): + __slots__ = () + + +class NamedType(Type): + __slots__ = ('loc', 'name',) + _fields = ('name',) + + def __init__(self, name, loc=None): + self.loc = loc + self.name = name + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, NamedType) and + # self.loc == other.loc and + self.name == other.name + ) + ) + + def __repr__(self): + return ('NamedType(' + 'name={self.name!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.loc + ) + + def __hash__(self): + return id(self) + + +class ListType(Type): + __slots__ = ('loc', 'type',) + _fields = ('type',) + + def __init__(self, type, loc=None): + self.loc = loc + self.type = type + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ListType) and + # self.loc == other.loc and + self.type == other.type + ) + ) + + def __repr__(self): + return ('ListType(' + 'type={self.type!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.type, + self.loc + ) + + def __hash__(self): + return id(self) + + +class NonNullType(Type): + __slots__ = ('loc', 'type',) + _fields = ('type',) + + def __init__(self, type, loc=None): + self.loc = loc + self.type = type + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, NonNullType) and + # self.loc == other.loc and + self.type == other.type + ) + ) + + def __repr__(self): + return ('NonNullType(' + 'type={self.type!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.type, + self.loc + ) + + def __hash__(self): + return id(self) + + +class Name(Node): + __slots__ = ('loc', 'value',) + _fields = ('value',) + + def __init__(self, value, loc=None): + self.loc = loc + self.value = value + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Name) and + # self.loc == other.loc and + self.value == other.value + ) + ) + + def __repr__(self): + return ('Name(' + 'value={self.value!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.value, + self.loc + ) + + def __hash__(self): + return id(self) + + +# Type System Definition + +class TypeDefinition(Node): + pass + + +class TypeSystemDefinition(TypeDefinition): + pass + + +class SchemaDefinition(TypeSystemDefinition): + __slots__ = ('loc', 'directives', 'operation_types',) + _fields = ('operation_types',) + + def __init__(self, operation_types, loc=None, directives=None): + self.operation_types = operation_types + self.loc = loc + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, SchemaDefinition) and + self.operation_types == other.operation_types and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('SchemaDefinition(' + 'operation_types={self.operation_types!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.operation_types, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class OperationTypeDefinition(Node): + __slots__ = ('loc', 'operation', 'type',) + _fields = ('operation', 'type',) + + def __init__(self, operation, type, loc=None): + self.operation = operation + self.type = type + self.loc = loc + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, OperationTypeDefinition) and + self.operation == other.operation and + self.type == other.type + ) + ) + + def __repr__(self): + return ('OperationTypeDefinition(' + 'operation={self.operation!r}' + ', type={self.type!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.operation, + self.type, + self.loc + ) + + def __hash__(self): + return id(self) + + +class ObjectTypeDefinition(TypeDefinition): + __slots__ = ('loc', 'name', 'interfaces', 'directives', 'fields',) + _fields = ('name', 'interfaces', 'fields',) + + def __init__(self, name, fields, interfaces=None, loc=None, directives=None): + self.loc = loc + self.name = name + self.interfaces = interfaces + self.fields = fields + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ObjectTypeDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.interfaces == other.interfaces and + self.fields == other.fields and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('ObjectTypeDefinition(' + 'name={self.name!r}' + ', interfaces={self.interfaces!r}' + ', fields={self.fields!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.fields, + self.interfaces, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class FieldDefinition(Node): + __slots__ = ('loc', 'name', 'arguments', 'type', 'directives',) + _fields = ('name', 'arguments', 'type',) + + def __init__(self, name, arguments, type, loc=None, directives=None): + self.loc = loc + self.name = name + self.arguments = arguments + self.type = type + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, FieldDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.arguments == other.arguments and + self.type == other.type and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('FieldDefinition(' + 'name={self.name!r}' + ', arguments={self.arguments!r}' + ', type={self.type!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.arguments, + self.type, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class InputValueDefinition(Node): + __slots__ = ('loc', 'name', 'type', 'default_value', 'directives') + _fields = ('name', 'type', 'default_value',) + + def __init__(self, name, type, default_value=None, loc=None, + directives=None): + self.loc = loc + self.name = name + self.type = type + self.default_value = default_value + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, InputValueDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.type == other.type and + self.default_value == other.default_value and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('InputValueDefinition(' + 'name={self.name!r}' + ', type={self.type!r}' + ', default_value={self.default_value!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.type, + self.default_value, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class InterfaceTypeDefinition(TypeDefinition): + __slots__ = ('loc', 'name', 'fields', 'directives',) + _fields = ('name', 'fields',) + + def __init__(self, name, fields, loc=None, directives=None): + self.loc = loc + self.name = name + self.fields = fields + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, InterfaceTypeDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.fields == other.fields and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('InterfaceTypeDefinition(' + 'name={self.name!r}' + ', fields={self.fields!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.fields, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class UnionTypeDefinition(TypeDefinition): + __slots__ = ('loc', 'name', 'types', 'directives',) + _fields = ('name', 'types',) + + def __init__(self, name, types, loc=None, directives=None): + self.loc = loc + self.name = name + self.types = types + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, UnionTypeDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.types == other.types and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('UnionTypeDefinition(' + 'name={self.name!r}' + ', types={self.types!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.types, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class ScalarTypeDefinition(TypeDefinition): + __slots__ = ('loc', 'name', 'directives',) + _fields = ('name',) + + def __init__(self, name, loc=None, directives=None): + self.loc = loc + self.name = name + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, ScalarTypeDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('ScalarTypeDefinition(' + 'name={self.name!r}' + 'directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.loc, + self.directives + ) + + def __hash__(self): + return id(self) + + +class EnumTypeDefinition(TypeDefinition): + __slots__ = ('loc', 'name', 'values', 'directives',) + _fields = ('name', 'values',) + + def __init__(self, name, values, loc=None, directives=None): + self.loc = loc + self.name = name + self.values = values + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, EnumTypeDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.values == other.values and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('EnumTypeDefinition(' + 'name={self.name!r}' + ', values={self.values!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.values, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class EnumValueDefinition(Node): + __slots__ = ('loc', 'name', 'directives',) + _fields = ('name',) + + def __init__(self, name, loc=None, directives=None): + self.loc = loc + self.name = name + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, EnumValueDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('EnumValueDefinition(' + 'name={self.name!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class InputObjectTypeDefinition(TypeDefinition): + __slots__ = ('loc', 'name', 'fields', 'directives',) + _fields = ('name', 'fields',) + + def __init__(self, name, fields, loc=None, directives=None): + self.loc = loc + self.name = name + self.fields = fields + self.directives = directives + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, InputObjectTypeDefinition) and + # self.loc == other.loc and + self.name == other.name and + self.fields == other.fields and + self.directives == other.directives + ) + ) + + def __repr__(self): + return ('InputObjectTypeDefinition(' + 'name={self.name!r}' + ', fields={self.fields!r}' + ', directives={self.directives!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.fields, + self.loc, + self.directives, + ) + + def __hash__(self): + return id(self) + + +class TypeExtensionDefinition(TypeSystemDefinition): + __slots__ = ('loc', 'definition',) + _fields = ('definition',) + + def __init__(self, definition, loc=None): + self.loc = loc + self.definition = definition + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, TypeExtensionDefinition) and + # self.loc == other.loc and + self.definition == other.definition + ) + ) + + def __repr__(self): + return ('TypeExtensionDefinition(' + 'definition={self.definition!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.definition, + self.loc + ) + + def __hash__(self): + return id(self) + + +class DirectiveDefinition(TypeSystemDefinition): + __slots__ = ('loc', 'name', 'arguments', 'locations') + _fields = ('name', 'locations') + + def __init__(self, name, locations, arguments=None, loc=None): + self.name = name + self.locations = locations + self.loc = loc + self.arguments = arguments + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, DirectiveDefinition) and + self.name == other.name and + self.locations == other.locations and + # self.loc == other.loc and + self.arguments == other.arguments + ) + ) + + def __repr__(self): + return ('DirectiveDefinition(' + 'name={self.name!r}, ' + 'locations={self.locations!r}' + ')').format(self=self) + + def __copy__(self): + return type(self)( + self.name, + self.locations, + self.arguments, + self.loc, + ) + + def __hash__(self): + return id(self) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d9d91b0b66764daeabb1bade9200bf8142b701 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/base.py @@ -0,0 +1,19 @@ +from .lexer import Lexer +from .location import get_location +from .parser import parse, parse_value +from .printer import print_ast +from .source import Source +from .visitor import BREAK, ParallelVisitor, TypeInfoVisitor, visit + +__all__ = [ + 'Lexer', + 'get_location', + 'parse', + 'parse_value', + 'print_ast', + 'Source', + 'BREAK', + 'ParallelVisitor', + 'TypeInfoVisitor', + 'visit', +] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py new file mode 100644 index 0000000000000000000000000000000000000000..3bdad3c586195bce43ae5d5a94a2a3c74f94a376 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/lexer.py @@ -0,0 +1,435 @@ +import json + +from ..error import GraphQLSyntaxError + +__all__ = ['Token', 'Lexer', 'TokenKind', + 'get_token_desc', 'get_token_kind_desc'] + + +class Token(object): + __slots__ = 'kind', 'start', 'end', 'value' + + def __init__(self, kind, start, end, value=None): + self.kind = kind + self.start = start + self.end = end + self.value = value + + def __repr__(self): + return u''.format( + get_token_kind_desc(self.kind), + self.start, + self.end, + repr(self.value) + ) + + def __eq__(self, other): + return (self.kind == other.kind and + self.start == other.start and + self.end == other.end and + self.value == other.value) + + +class Lexer(object): + __slots__ = 'source', 'prev_position' + + def __init__(self, source): + self.source = source + self.prev_position = 0 + + def next_token(self, reset_position=None): + if reset_position is None: + reset_position = self.prev_position + token = read_token(self.source, reset_position) + self.prev_position = token.end + return token + + +class TokenKind(object): + EOF = 1 + BANG = 2 + DOLLAR = 3 + PAREN_L = 4 + PAREN_R = 5 + SPREAD = 6 + COLON = 7 + EQUALS = 8 + AT = 9 + BRACKET_L = 10 + BRACKET_R = 11 + BRACE_L = 12 + PIPE = 13 + BRACE_R = 14 + NAME = 15 + VARIABLE = 16 + INT = 17 + FLOAT = 18 + STRING = 19 + + +def get_token_desc(token): + if token.value: + return u'{} "{}"'.format( + get_token_kind_desc(token.kind), + token.value + ) + else: + return get_token_kind_desc(token.kind) + + +def get_token_kind_desc(kind): + return TOKEN_DESCRIPTION[kind] + + +TOKEN_DESCRIPTION = { + TokenKind.EOF: 'EOF', + TokenKind.BANG: '!', + TokenKind.DOLLAR: '$', + TokenKind.PAREN_L: '(', + TokenKind.PAREN_R: ')', + TokenKind.SPREAD: '...', + TokenKind.COLON: ':', + TokenKind.EQUALS: '=', + TokenKind.AT: '@', + TokenKind.BRACKET_L: '[', + TokenKind.BRACKET_R: ']', + TokenKind.BRACE_L: '{', + TokenKind.PIPE: '|', + TokenKind.BRACE_R: '}', + TokenKind.NAME: 'Name', + TokenKind.VARIABLE: 'Variable', + TokenKind.INT: 'Int', + TokenKind.FLOAT: 'Float', + TokenKind.STRING: 'String', +} + + +def char_code_at(s, pos): + if 0 <= pos < len(s): + return ord(s[pos]) + + return None + + +PUNCT_CODE_TO_KIND = { + ord('!'): TokenKind.BANG, + ord('$'): TokenKind.DOLLAR, + ord('('): TokenKind.PAREN_L, + ord(')'): TokenKind.PAREN_R, + ord(':'): TokenKind.COLON, + ord('='): TokenKind.EQUALS, + ord('@'): TokenKind.AT, + ord('['): TokenKind.BRACKET_L, + ord(']'): TokenKind.BRACKET_R, + ord('{'): TokenKind.BRACE_L, + ord('|'): TokenKind.PIPE, + ord('}'): TokenKind.BRACE_R, +} + + +def print_char_code(code): + if code is None: + return '' + + if code < 0x007F: + return json.dumps(chr(code)) + + return '"\\u%04X"' % code + + +def read_token(source, from_position): + """Gets the next token from the source starting at the given position. + + This skips over whitespace and comments until it finds the next lexable + token, then lexes punctuators immediately or calls the appropriate + helper fucntion for more complicated tokens.""" + body = source.body + body_length = len(body) + + position = position_after_whitespace(body, from_position) + + if position >= body_length: + return Token(TokenKind.EOF, position, position) + + code = char_code_at(body, position) + + if code < 0x0020 and code not in (0x0009, 0x000A, 0x000D): + raise GraphQLSyntaxError( + source, position, + u'Invalid character {}.'.format(print_char_code(code)) + ) + + kind = PUNCT_CODE_TO_KIND.get(code) + if kind is not None: + return Token(kind, position, position + 1) + + if code == 46: # . + if char_code_at(body, position + 1) == char_code_at(body, position + 2) == 46: + return Token(TokenKind.SPREAD, position, position + 3) + + elif 65 <= code <= 90 or code == 95 or 97 <= code <= 122: + # A-Z, _, a-z + return read_name(source, position) + + elif code == 45 or 48 <= code <= 57: # -, 0-9 + return read_number(source, position, code) + + elif code == 34: # " + return read_string(source, position) + + raise GraphQLSyntaxError( + source, position, + u'Unexpected character {}.'.format(print_char_code(code))) + + +ignored_whitespace_characters = frozenset([ + # BOM + 0xFEFF, + # White Space + 0x0009, # tab + 0x0020, # space + # Line Terminator + 0x000A, # new line + 0x000D, # carriage return + # Comma + 0x002C +]) + + +def position_after_whitespace(body, start_position): + """Reads from body starting at start_position until it finds a + non-whitespace or commented character, then returns the position of + that character for lexing.""" + body_length = len(body) + position = start_position + while position < body_length: + code = char_code_at(body, position) + if code in ignored_whitespace_characters: + position += 1 + + elif code == 35: # #, skip comments + position += 1 + while position < body_length: + code = char_code_at(body, position) + if not (code is not None and (code > 0x001F or code == 0x0009) and code not in (0x000A, 0x000D)): + break + + position += 1 + else: + break + return position + + +def read_number(source, start, first_code): + """Reads a number token from the source file, either a float + or an int depending on whether a decimal point appears. + """ + code = first_code + body = source.body + position = start + is_float = False + + if code == 45: # - + position += 1 + code = char_code_at(body, position) + + if code == 48: # 0 + position += 1 + code = char_code_at(body, position) + + if code is not None and 48 <= code <= 57: + raise GraphQLSyntaxError( + source, + position, + u'Invalid number, unexpected digit after 0: {}.'.format(print_char_code(code)) + ) + else: + position = read_digits(source, position, code) + code = char_code_at(body, position) + + if code == 46: # . + is_float = True + + position += 1 + code = char_code_at(body, position) + position = read_digits(source, position, code) + code = char_code_at(body, position) + + if code in (69, 101): # E e + is_float = True + position += 1 + code = char_code_at(body, position) + if code in (43, 45): # + - + position += 1 + code = char_code_at(body, position) + + position = read_digits(source, position, code) + + return Token( + TokenKind.FLOAT if is_float else TokenKind.INT, + start, + position, + body[start:position] + ) + + +def read_digits(source, start, first_code): + body = source.body + position = start + code = first_code + + if code is not None and 48 <= code <= 57: # 0 - 9 + while True: + position += 1 + code = char_code_at(body, position) + + if not (code is not None and 48 <= code <= 57): + break + + return position + + raise GraphQLSyntaxError( + source, + position, + u'Invalid number, expected digit but got: {}.'.format(print_char_code(code)) + ) + + +ESCAPED_CHAR_CODES = { + 34: '"', + 47: '/', + 92: '\\', + 98: '\b', + 102: '\f', + 110: '\n', + 114: '\r', + 116: '\t', +} + + +def read_string(source, start): + """Reads a string token from the source file. + + "([^"\\\u000A\u000D\u2028\u2029]|(\\(u[0-9a-fA-F]{4}|["\\/bfnrt])))*" + """ + body = source.body + body_length = len(body) + + position = start + 1 + chunk_start = position + code = 0 + value = [] + append = value.append + + while position < body_length: + code = char_code_at(body, position) + if not ( + code is not None and + code not in ( + # LineTerminator + 0x000A, 0x000D, + # Quote + 34 + ) + ): + break + + if code < 0x0020 and code != 0x0009: + raise GraphQLSyntaxError( + source, + position, + u'Invalid character within String: {}.'.format(print_char_code(code)) + ) + + position += 1 + if code == 92: # \ + append(body[chunk_start:position - 1]) + + code = char_code_at(body, position) + escaped = ESCAPED_CHAR_CODES.get(code) + if escaped is not None: + append(escaped) + + elif code == 117: # u + char_code = uni_char_code( + char_code_at(body, position + 1) or 0, + char_code_at(body, position + 2) or 0, + char_code_at(body, position + 3) or 0, + char_code_at(body, position + 4) or 0, + ) + + if char_code < 0: + raise GraphQLSyntaxError( + source, position, + u'Invalid character escape sequence: \\u{}.'.format(body[position + 1: position + 5]) + ) + + append(chr(char_code)) + position += 4 + else: + raise GraphQLSyntaxError( + source, position, + u'Invalid character escape sequence: \\{}.'.format(chr(code)) + ) + + position += 1 + chunk_start = position + + if code != 34: # Quote (") + raise GraphQLSyntaxError(source, position, 'Unterminated string') + + append(body[chunk_start:position]) + return Token(TokenKind.STRING, start, position + 1, u''.join(value)) + + +def uni_char_code(a, b, c, d): + """Converts four hexidecimal chars to the integer that the + string represents. For example, uniCharCode('0','0','0','f') + will return 15, and uniCharCode('0','0','f','f') returns 255. + + Returns a negative number on error, if a char was invalid. + + This is implemented by noting that char2hex() returns -1 on error, + which means the result of ORing the char2hex() will also be negative. + """ + return (char2hex(a) << 12 | char2hex(b) << 8 | + char2hex(c) << 4 | char2hex(d)) + + +def char2hex(a): + """Converts a hex character to its integer value. + '0' becomes 0, '9' becomes 9 + 'A' becomes 10, 'F' becomes 15 + 'a' becomes 10, 'f' becomes 15 + + Returns -1 on error.""" + if 48 <= a <= 57: # 0-9 + return a - 48 + elif 65 <= a <= 70: # A-F + return a - 55 + elif 97 <= a <= 102: # a-f + return a - 87 + return -1 + + +def read_name(source, position): + """Reads an alphanumeric + underscore name from the source. + + [_A-Za-z][_0-9A-Za-z]*""" + body = source.body + body_length = len(body) + end = position + 1 + + while end != body_length: + code = char_code_at(body, end) + if not (code is not None and ( + code == 95 or # _ + 48 <= code <= 57 or # 0-9 + 65 <= code <= 90 or # A-Z + 97 <= code <= 122 # a-z + )): + break + + end += 1 + + return Token(TokenKind.NAME, position, end, body[position:end]) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py new file mode 100644 index 0000000000000000000000000000000000000000..c478dcbd94d1db7d25f936a3784063e8b9c273b0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/location.py @@ -0,0 +1,30 @@ +__all__ = ['get_location', 'SourceLocation'] + + +class SourceLocation(object): + __slots__ = 'line', 'column' + + def __init__(self, line, column): + self.line = line + self.column = column + + def __repr__(self): + return ''.format(self.line, self.column) + + def __eq__(self, other): + return ( + isinstance(other, SourceLocation) and + self.line == other.line and + self.column == other.column + ) + + +def get_location(source, position): + lines = source.body[:position].splitlines() + if lines: + line = len(lines) + column = len(lines[-1]) + 1 + else: + line = 1 + column = 1 + return SourceLocation(line, column) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..d2854a303095d4fcbbaed9f29a9e4dcca5e06db4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/parser.py @@ -0,0 +1,779 @@ +from . import ast +from ..error import GraphQLSyntaxError +from .lexer import Lexer, TokenKind, get_token_desc, get_token_kind_desc +from .source import Source + +__all__ = ['parse'] + + +def parse(source, **kwargs): + """Given a GraphQL source, parses it into a Document.""" + options = {'no_location': False, 'no_source': False} + options.update(kwargs) + source_obj = source + + if isinstance(source, str): + source_obj = Source(source) + + parser = Parser(source_obj, options) + return parse_document(parser) + + +def parse_value(source, **kwargs): + options = {'no_location': False, 'no_source': False} + options.update(kwargs) + source_obj = source + + if isinstance(source, str): + source_obj = Source(source) + + parser = Parser(source_obj, options) + return parse_value_literal(parser, False) + + +class Parser(object): + __slots__ = 'lexer', 'source', 'options', 'prev_end', 'token' + + def __init__(self, source, options): + self.lexer = Lexer(source) + self.source = source + self.options = options + self.prev_end = 0 + self.token = self.lexer.next_token() + + +class Loc(object): + __slots__ = 'start', 'end', 'source' + + def __init__(self, start, end, source=None): + self.start = start + self.end = end + self.source = source + + def __repr__(self): + source = ' source={}'.format(self.source) if self.source else '' + return ''.format(self.start, self.end, source) + + def __eq__(self, other): + return ( + isinstance(other, Loc) and + self.start == other.start and + self.end == other.end and + self.source == other.source + ) + + +def loc(parser, start): + """Returns a location object, used to identify the place in + the source that created a given parsed object.""" + if parser.options['no_location']: + return None + + if parser.options['no_source']: + return Loc(start, parser.prev_end) + + return Loc(start, parser.prev_end, parser.source) + + +def advance(parser): + """Moves the internal parser object to the next lexed token.""" + prev_end = parser.token.end + parser.prev_end = prev_end + parser.token = parser.lexer.next_token(prev_end) + + +def peek(parser, kind): + """Determines if the next token is of a given kind""" + return parser.token.kind == kind + + +def skip(parser, kind): + """If the next token is of the given kind, return true after advancing + the parser. Otherwise, do not change the parser state + and throw an error.""" + match = parser.token.kind == kind + if match: + advance(parser) + + return match + + +def expect(parser, kind): + """If the next token is of the given kind, return that token after + advancing the parser. Otherwise, do not change the parser state and + return False.""" + token = parser.token + if token.kind == kind: + advance(parser) + return token + + raise GraphQLSyntaxError( + parser.source, + token.start, + u'Expected {}, found {}'.format( + get_token_kind_desc(kind), + get_token_desc(token) + ) + ) + + +def expect_keyword(parser, value): + """If the next token is a keyword with the given value, return that + token after advancing the parser. Otherwise, do not change the parser + state and return False.""" + token = parser.token + if token.kind == TokenKind.NAME and token.value == value: + advance(parser) + return token + + raise GraphQLSyntaxError( + parser.source, + token.start, + u'Expected "{}", found {}'.format(value, get_token_desc(token)) + ) + + +def unexpected(parser, at_token=None): + """Helper function for creating an error when an unexpected lexed token + is encountered.""" + token = at_token or parser.token + return GraphQLSyntaxError( + parser.source, + token.start, + u'Unexpected {}'.format(get_token_desc(token)) + ) + + +def any(parser, open_kind, parse_fn, close_kind): + """Returns a possibly empty list of parse nodes, determined by + the parse_fn. This list begins with a lex token of openKind + and ends with a lex token of closeKind. Advances the parser + to the next lex token after the closing token.""" + expect(parser, open_kind) + nodes = [] + while not skip(parser, close_kind): + nodes.append(parse_fn(parser)) + + return nodes + + +def many(parser, open_kind, parse_fn, close_kind): + """Returns a non-empty list of parse nodes, determined by + the parse_fn. This list begins with a lex token of openKind + and ends with a lex token of closeKind. Advances the parser + to the next lex token after the closing token.""" + expect(parser, open_kind) + nodes = [parse_fn(parser)] + while not skip(parser, close_kind): + nodes.append(parse_fn(parser)) + + return nodes + + +def parse_name(parser): + """Converts a name lex token into a name parse node.""" + token = expect(parser, TokenKind.NAME) + return ast.Name( + value=token.value, + loc=loc(parser, token.start) + ) + + +# Implements the parsing rules in the Document section. + +def parse_document(parser): + start = parser.token.start + definitions = [] + while True: + definitions.append(parse_definition(parser)) + + if skip(parser, TokenKind.EOF): + break + + return ast.Document( + definitions=definitions, + loc=loc(parser, start) + ) + + +def parse_definition(parser): + if peek(parser, TokenKind.BRACE_L): + return parse_operation_definition(parser) + + if peek(parser, TokenKind.NAME): + name = parser.token.value + + if name in ('query', 'mutation', 'subscription'): + return parse_operation_definition(parser) + elif name == 'fragment': + return parse_fragment_definition(parser) + elif name in ('schema', 'scalar', 'type', 'interface', 'union', 'enum', 'input', 'extend', 'directive'): + return parse_type_system_definition(parser) + + raise unexpected(parser) + + +# Implements the parsing rules in the Operations section. +def parse_operation_definition(parser): + start = parser.token.start + if peek(parser, TokenKind.BRACE_L): + return ast.OperationDefinition( + operation='query', + name=None, + variable_definitions=None, + directives=[], + selection_set=parse_selection_set(parser), + loc=loc(parser, start) + ) + + operation = parse_operation_type(parser) + + name = None + if peek(parser, TokenKind.NAME): + name = parse_name(parser) + + return ast.OperationDefinition( + operation=operation, + name=name, + variable_definitions=parse_variable_definitions(parser), + directives=parse_directives(parser), + selection_set=parse_selection_set(parser), + loc=loc(parser, start) + ) + + +def parse_operation_type(parser): + operation_token = expect(parser, TokenKind.NAME) + operation = operation_token.value + if operation == 'query': + return 'query' + elif operation == 'mutation': + return 'mutation' + elif operation == 'subscription': + return 'subscription' + + raise unexpected(parser, operation_token) + + +def parse_variable_definitions(parser): + if peek(parser, TokenKind.PAREN_L): + return many( + parser, + TokenKind.PAREN_L, + parse_variable_definition, + TokenKind.PAREN_R + ) + + return [] + + +def parse_variable_definition(parser): + start = parser.token.start + + return ast.VariableDefinition( + variable=parse_variable(parser), + type=expect(parser, TokenKind.COLON) and parse_type(parser), + default_value=parse_value_literal(parser, True) if skip(parser, TokenKind.EQUALS) else None, + loc=loc(parser, start) + ) + + +def parse_variable(parser): + start = parser.token.start + expect(parser, TokenKind.DOLLAR) + + return ast.Variable( + name=parse_name(parser), + loc=loc(parser, start) + ) + + +def parse_selection_set(parser): + start = parser.token.start + return ast.SelectionSet( + selections=many(parser, TokenKind.BRACE_L, parse_selection, TokenKind.BRACE_R), + loc=loc(parser, start) + ) + + +def parse_selection(parser): + if peek(parser, TokenKind.SPREAD): + return parse_fragment(parser) + else: + return parse_field(parser) + + +def parse_field(parser): + # Corresponds to both Field and Alias in the spec + start = parser.token.start + + name_or_alias = parse_name(parser) + if skip(parser, TokenKind.COLON): + alias = name_or_alias + name = parse_name(parser) + else: + alias = None + name = name_or_alias + + return ast.Field( + alias=alias, + name=name, + arguments=parse_arguments(parser), + directives=parse_directives(parser), + selection_set=parse_selection_set(parser) if peek(parser, TokenKind.BRACE_L) else None, + loc=loc(parser, start) + ) + + +def parse_arguments(parser): + if peek(parser, TokenKind.PAREN_L): + return many( + parser, TokenKind.PAREN_L, + parse_argument, TokenKind.PAREN_R) + + return [] + + +def parse_argument(parser): + start = parser.token.start + + return ast.Argument( + name=parse_name(parser), + value=expect(parser, TokenKind.COLON) and parse_value_literal(parser, False), + loc=loc(parser, start) + ) + + +# Implements the parsing rules in the Fragments section. + +def parse_fragment(parser): + # Corresponds to both FragmentSpread and InlineFragment in the spec + start = parser.token.start + expect(parser, TokenKind.SPREAD) + + if peek(parser, TokenKind.NAME) and parser.token.value != 'on': + return ast.FragmentSpread( + name=parse_fragment_name(parser), + directives=parse_directives(parser), + loc=loc(parser, start) + ) + + type_condition = None + if parser.token.value == 'on': + advance(parser) + type_condition = parse_named_type(parser) + + return ast.InlineFragment( + type_condition=type_condition, + directives=parse_directives(parser), + selection_set=parse_selection_set(parser), + loc=loc(parser, start) + ) + + +def parse_fragment_definition(parser): + start = parser.token.start + expect_keyword(parser, 'fragment') + + return ast.FragmentDefinition( + name=parse_fragment_name(parser), + type_condition=expect_keyword(parser, 'on') and parse_named_type(parser), + directives=parse_directives(parser), + selection_set=parse_selection_set(parser), + loc=loc(parser, start) + ) + + +def parse_fragment_name(parser): + if parser.token.value == 'on': + raise unexpected(parser) + + return parse_name(parser) + + +def parse_value_literal(parser, is_const): + token = parser.token + if token.kind == TokenKind.BRACKET_L: + return parse_list(parser, is_const) + + elif token.kind == TokenKind.BRACE_L: + return parse_object(parser, is_const) + + elif token.kind == TokenKind.INT: + advance(parser) + return ast.IntValue(value=token.value, loc=loc(parser, token.start)) + + elif token.kind == TokenKind.FLOAT: + advance(parser) + return ast.FloatValue(value=token.value, loc=loc(parser, token.start)) + + elif token.kind == TokenKind.STRING: + advance(parser) + return ast.StringValue(value=token.value, loc=loc(parser, token.start)) + + elif token.kind == TokenKind.NAME: + if token.value in ('true', 'false'): + advance(parser) + return ast.BooleanValue(value=token.value == 'true', loc=loc(parser, token.start)) + + if token.value != 'null': + advance(parser) + return ast.EnumValue(value=token.value, loc=loc(parser, token.start)) + + elif token.kind == TokenKind.DOLLAR: + if not is_const: + return parse_variable(parser) + + raise unexpected(parser) + + +# Implements the parsing rules in the Values section. +def parse_variable_value(parser): + return parse_value_literal(parser, False) + + +def parse_const_value(parser): + return parse_value_literal(parser, True) + + +def parse_list(parser, is_const): + start = parser.token.start + item = parse_const_value if is_const else parse_variable_value + + return ast.ListValue( + values=any( + parser, TokenKind.BRACKET_L, + item, TokenKind.BRACKET_R), + loc=loc(parser, start) + ) + + +def parse_object(parser, is_const): + start = parser.token.start + expect(parser, TokenKind.BRACE_L) + fields = [] + + while not skip(parser, TokenKind.BRACE_R): + fields.append(parse_object_field(parser, is_const)) + + return ast.ObjectValue(fields=fields, loc=loc(parser, start)) + + +def parse_object_field(parser, is_const): + start = parser.token.start + return ast.ObjectField( + name=parse_name(parser), + value=expect(parser, TokenKind.COLON) and parse_value_literal(parser, is_const), + loc=loc(parser, start) + ) + + +# Implements the parsing rules in the Directives section. + +def parse_directives(parser): + directives = [] + while peek(parser, TokenKind.AT): + directives.append(parse_directive(parser)) + return directives + + +def parse_directive(parser): + start = parser.token.start + expect(parser, TokenKind.AT) + + return ast.Directive( + name=parse_name(parser), + arguments=parse_arguments(parser), + loc=loc(parser, start), + ) + + +# Implements the parsing rules in the Types section. +def parse_type(parser): + """Handles the 'Type': TypeName, ListType, and NonNullType + parsing rules.""" + start = parser.token.start + if skip(parser, TokenKind.BRACKET_L): + ast_type = parse_type(parser) + expect(parser, TokenKind.BRACKET_R) + ast_type = ast.ListType(type=ast_type, loc=loc(parser, start)) + + else: + ast_type = parse_named_type(parser) + + if skip(parser, TokenKind.BANG): + return ast.NonNullType(type=ast_type, loc=loc(parser, start)) + + return ast_type + + +def parse_named_type(parser): + start = parser.token.start + return ast.NamedType( + name=parse_name(parser), + loc=loc(parser, start), + ) + + +def parse_type_system_definition(parser): + ''' + TypeSystemDefinition : + - SchemaDefinition + - TypeDefinition + - TypeExtensionDefinition + - DirectiveDefinition + + TypeDefinition : + - ScalarTypeDefinition + - ObjectTypeDefinition + - InterfaceTypeDefinition + - UnionTypeDefinition + - EnumTypeDefinition + - InputObjectTypeDefinition + ''' + if not peek(parser, TokenKind.NAME): + raise unexpected(parser) + + name = parser.token.value + + if name == 'schema': + return parse_schema_definition(parser) + + elif name == 'scalar': + return parse_scalar_type_definition(parser) + + elif name == 'type': + return parse_object_type_definition(parser) + + elif name == 'interface': + return parse_interface_type_definition(parser) + + elif name == 'union': + return parse_union_type_definition(parser) + + elif name == 'enum': + return parse_enum_type_definition(parser) + + elif name == 'input': + return parse_input_object_type_definition(parser) + + elif name == 'extend': + return parse_type_extension_definition(parser) + + elif name == 'directive': + return parse_directive_definition(parser) + + raise unexpected(parser) + + +def parse_schema_definition(parser): + start = parser.token.start + expect_keyword(parser, 'schema') + directives = parse_directives(parser) + operation_types = many( + parser, + TokenKind.BRACE_L, + parse_operation_type_definition, + TokenKind.BRACE_R + ) + + return ast.SchemaDefinition( + directives=directives, + operation_types=operation_types, + loc=loc(parser, start) + ) + + +def parse_operation_type_definition(parser): + start = parser.token.start + operation = parse_operation_type(parser) + expect(parser, TokenKind.COLON) + + return ast.OperationTypeDefinition( + operation=operation, + type=parse_named_type(parser), + loc=loc(parser, start) + ) + + +def parse_scalar_type_definition(parser): + start = parser.token.start + expect_keyword(parser, 'scalar') + + return ast.ScalarTypeDefinition( + name=parse_name(parser), + directives=parse_directives(parser), + loc=loc(parser, start), + ) + + +def parse_object_type_definition(parser): + start = parser.token.start + expect_keyword(parser, 'type') + return ast.ObjectTypeDefinition( + name=parse_name(parser), + interfaces=parse_implements_interfaces(parser), + directives=parse_directives(parser), + fields=any( + parser, + TokenKind.BRACE_L, + parse_field_definition, + TokenKind.BRACE_R + ), + loc=loc(parser, start), + ) + + +def parse_implements_interfaces(parser): + types = [] + if parser.token.value == 'implements': + advance(parser) + + while True: + types.append(parse_named_type(parser)) + + if not peek(parser, TokenKind.NAME): + break + + return types + + +def parse_field_definition(parser): + start = parser.token.start + + return ast.FieldDefinition( + name=parse_name(parser), + arguments=parse_argument_defs(parser), + type=expect(parser, TokenKind.COLON) and parse_type(parser), + directives=parse_directives(parser), + loc=loc(parser, start), + ) + + +def parse_argument_defs(parser): + if not peek(parser, TokenKind.PAREN_L): + return [] + + return many(parser, TokenKind.PAREN_L, parse_input_value_def, TokenKind.PAREN_R) + + +def parse_input_value_def(parser): + start = parser.token.start + + return ast.InputValueDefinition( + name=parse_name(parser), + type=expect(parser, TokenKind.COLON) and parse_type(parser), + default_value=parse_const_value(parser) if skip(parser, TokenKind.EQUALS) else None, + directives=parse_directives(parser), + loc=loc(parser, start), + ) + + +def parse_interface_type_definition(parser): + start = parser.token.start + expect_keyword(parser, 'interface') + + return ast.InterfaceTypeDefinition( + name=parse_name(parser), + directives=parse_directives(parser), + fields=any(parser, TokenKind.BRACE_L, parse_field_definition, TokenKind.BRACE_R), + loc=loc(parser, start), + ) + + +def parse_union_type_definition(parser): + start = parser.token.start + expect_keyword(parser, 'union') + + return ast.UnionTypeDefinition( + name=parse_name(parser), + directives=parse_directives(parser), + types=expect(parser, TokenKind.EQUALS) and parse_union_members(parser), + loc=loc(parser, start), + ) + + +def parse_union_members(parser): + members = [] + + while True: + members.append(parse_named_type(parser)) + + if not skip(parser, TokenKind.PIPE): + break + + return members + + +def parse_enum_type_definition(parser): + start = parser.token.start + expect_keyword(parser, 'enum') + + return ast.EnumTypeDefinition( + name=parse_name(parser), + directives=parse_directives(parser), + values=many(parser, TokenKind.BRACE_L, parse_enum_value_definition, TokenKind.BRACE_R), + loc=loc(parser, start), + ) + + +def parse_enum_value_definition(parser): + start = parser.token.start + + return ast.EnumValueDefinition( + name=parse_name(parser), + directives=parse_directives(parser), + loc=loc(parser, start), + ) + + +def parse_input_object_type_definition(parser): + start = parser.token.start + expect_keyword(parser, 'input') + + return ast.InputObjectTypeDefinition( + name=parse_name(parser), + directives=parse_directives(parser), + fields=any(parser, TokenKind.BRACE_L, parse_input_value_def, TokenKind.BRACE_R), + loc=loc(parser, start), + ) + + +def parse_type_extension_definition(parser): + start = parser.token.start + expect_keyword(parser, 'extend') + + return ast.TypeExtensionDefinition( + definition=parse_object_type_definition(parser), + loc=loc(parser, start) + ) + + +def parse_directive_definition(parser): + start = parser.token.start + expect_keyword(parser, 'directive') + expect(parser, TokenKind.AT) + + name = parse_name(parser) + args = parse_argument_defs(parser) + expect_keyword(parser, 'on') + + locations = parse_directive_locations(parser) + return ast.DirectiveDefinition( + name=name, + locations=locations, + arguments=args, + loc=loc(parser, start) + ) + + +def parse_directive_locations(parser): + locations = [] + + while True: + locations.append(parse_name(parser)) + + if not skip(parser, TokenKind.PIPE): + break + + return locations diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a6dd360d51afff285d7f2822611dbbd3b5a510 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/printer.py @@ -0,0 +1,193 @@ +import json + +from .visitor import Visitor, visit + +__all__ = ['print_ast'] + + +def print_ast(ast): + return visit(ast, PrintingVisitor()) + + +class PrintingVisitor(Visitor): + __slots__ = () + + def leave_Name(self, node, *args): + return node.value + + def leave_Variable(self, node, *args): + return '$' + node.name + + def leave_Document(self, node, *args): + return join(node.definitions, '\n\n') + '\n' + + def leave_OperationDefinition(self, node, *args): + name = node.name + selection_set = node.selection_set + op = node.operation + var_defs = wrap('(', join(node.variable_definitions, ', '), ')') + directives = join(node.directives, ' ') + + if not name and not directives and not var_defs and op == 'query': + return selection_set + + return join([op, join([name, var_defs]), directives, selection_set], ' ') + + def leave_VariableDefinition(self, node, *args): + return node.variable + ': ' + node.type + wrap(' = ', node.default_value) + + def leave_SelectionSet(self, node, *args): + return block(node.selections) + + def leave_Field(self, node, *args): + return join([ + wrap('', node.alias, ': ') + node.name + wrap('(', join(node.arguments, ', '), ')'), + join(node.directives, ' '), + node.selection_set + ], ' ') + + def leave_Argument(self, node, *args): + return node.name + ': ' + node.value + + # Fragments + + def leave_FragmentSpread(self, node, *args): + return '...' + node.name + wrap(' ', join(node.directives, ' ')) + + def leave_InlineFragment(self, node, *args): + return join([ + '...', + wrap('on ', node.type_condition), + join(node.directives, ''), + node.selection_set + ], ' ') + + def leave_FragmentDefinition(self, node, *args): + return ('fragment {} on {} '.format(node.name, node.type_condition) + + wrap('', join(node.directives, ' '), ' ') + + node.selection_set) + + # Value + + def leave_IntValue(self, node, *args): + return node.value + + def leave_FloatValue(self, node, *args): + return node.value + + def leave_StringValue(self, node, *args): + return json.dumps(node.value) + + def leave_BooleanValue(self, node, *args): + return json.dumps(node.value) + + def leave_EnumValue(self, node, *args): + return node.value + + def leave_ListValue(self, node, *args): + return '[' + join(node.values, ', ') + ']' + + def leave_ObjectValue(self, node, *args): + return '{' + join(node.fields, ', ') + '}' + + def leave_ObjectField(self, node, *args): + return node.name + ': ' + node.value + + # Directive + + def leave_Directive(self, node, *args): + return '@' + node.name + wrap('(', join(node.arguments, ', '), ')') + + # Type + + def leave_NamedType(self, node, *args): + return node.name + + def leave_ListType(self, node, *args): + return '[' + node.type + ']' + + def leave_NonNullType(self, node, *args): + return node.type + '!' + + # Type Definitions: + + def leave_SchemaDefinition(self, node, *args): + return join([ + 'schema', + join(node.directives, ' '), + block(node.operation_types), + ], ' ') + + def leave_OperationTypeDefinition(self, node, *args): + return '{}: {}'.format(node.operation, node.type) + + def leave_ScalarTypeDefinition(self, node, *args): + return 'scalar ' + node.name + wrap(' ', join(node.directives, ' ')) + + def leave_ObjectTypeDefinition(self, node, *args): + return join([ + 'type', + node.name, + wrap('implements ', join(node.interfaces, ', ')), + join(node.directives, ' '), + block(node.fields) + ], ' ') + + def leave_FieldDefinition(self, node, *args): + return ( + node.name + + wrap('(', join(node.arguments, ', '), ')') + + ': ' + + node.type + + wrap(' ', join(node.directives, ' ')) + ) + + def leave_InputValueDefinition(self, node, *args): + return node.name + ': ' + node.type + wrap(' = ', node.default_value) + wrap(' ', join(node.directives, ' ')) + + def leave_InterfaceTypeDefinition(self, node, *args): + return 'interface ' + node.name + wrap(' ', join(node.directives, ' ')) + ' ' + block(node.fields) + + def leave_UnionTypeDefinition(self, node, *args): + return 'union ' + node.name + wrap(' ', join(node.directives, ' ')) + ' = ' + join(node.types, ' | ') + + def leave_EnumTypeDefinition(self, node, *args): + return 'enum ' + node.name + wrap(' ', join(node.directives, ' ')) + ' ' + block(node.values) + + def leave_EnumValueDefinition(self, node, *args): + return node.name + wrap(' ', join(node.directives, ' ')) + + def leave_InputObjectTypeDefinition(self, node, *args): + return 'input ' + node.name + wrap(' ', join(node.directives, ' ')) + ' ' + block(node.fields) + + def leave_TypeExtensionDefinition(self, node, *args): + return 'extend ' + node.definition + + def leave_DirectiveDefinition(self, node, *args): + return 'directive @{}{} on {}'.format(node.name, wrap( + '(', join(node.arguments, ', '), ')'), ' | '.join(node.locations)) + + +def join(maybe_list, separator=''): + if maybe_list: + return separator.join(filter(None, maybe_list)) + return '' + + +def block(_list): + '''Given a list, print each item on its own line, wrapped in an indented "{ }" block.''' + if _list: + return indent('{\n' + join(_list, '\n')) + '\n}' + return '{}' + + +def wrap(start, maybe_str, end=''): + if maybe_str: + return start + maybe_str + end + return '' + + +def indent(maybe_str): + if maybe_str: + return maybe_str.replace('\n', '\n ') + return maybe_str diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py new file mode 100644 index 0000000000000000000000000000000000000000..14e22fac7cc24237418e7845c3ca2ad6b76e1eb3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/source.py @@ -0,0 +1,18 @@ +__all__ = ['Source'] + + +class Source(object): + __slots__ = 'body', 'name' + + def __init__(self, body, name='GraphQL'): + self.body = body + self.name = name + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, Source) and + self.body == other.body and + self.name == other.name + ) + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py new file mode 100644 index 0000000000000000000000000000000000000000..95cd69a04af1abb7161e6b3926fbe16bb0f1ac3c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor.py @@ -0,0 +1,222 @@ +from copy import copy + +from . import ast +from .visitor_meta import QUERY_DOCUMENT_KEYS, VisitorMeta + + +class Falsey(object): + + def __nonzero__(self): + return False + + def __bool__(self): + return False + + +BREAK = object() +REMOVE = Falsey() + + +class Stack(object): + __slots__ = 'in_array', 'index', 'keys', 'edits', 'prev' + + def __init__(self, in_array, index, keys, edits, prev): + self.in_array = in_array + self.index = index + self.keys = keys + self.edits = edits + self.prev = prev + + +def visit(root, visitor, key_map=None): + visitor_keys = key_map or QUERY_DOCUMENT_KEYS + + stack = None + in_array = isinstance(root, list) + keys = [root] + index = -1 + edits = [] + parent = None + path = [] + ancestors = [] + new_root = root + leave = visitor.leave + enter = visitor.enter + path_pop = path.pop + ancestors_pop = ancestors.pop + path_append = path.append + ancestors_append = ancestors.append + + while True: + index += 1 + is_leaving = index == len(keys) + is_edited = is_leaving and edits + if is_leaving: + key = path_pop() if ancestors else None + node = parent + parent = ancestors_pop() if ancestors else None + + if is_edited: + if in_array: + node = list(node) + + else: + node = copy(node) + edit_offset = 0 + for edit_key, edit_value in edits: + if in_array: + edit_key -= edit_offset + + if in_array and edit_value is REMOVE: + node.pop(edit_key) + edit_offset += 1 + + else: + if isinstance(node, list): + node[edit_key] = edit_value + + else: + setattr(node, edit_key, edit_value) + + index = stack.index + keys = stack.keys + edits = stack.edits + in_array = stack.in_array + stack = stack.prev + + else: + if parent: + key = index if in_array else keys[index] + if isinstance(parent, list): + node = parent[key] + + else: + node = getattr(parent, key, None) + + else: + key = None + node = new_root + + if node is REMOVE or node is None: + continue + + if parent: + path_append(key) + + result = None + + if not isinstance(node, list): + assert isinstance(node, ast.Node), 'Invalid AST Node: ' + repr(node) + + if is_leaving: + result = leave(node, key, parent, path, ancestors) + + else: + result = enter(node, key, parent, path, ancestors) + + if result is BREAK: + break + + if result is False: + if not is_leaving: + path_pop() + continue + + elif result is not None: + edits.append((key, result)) + if not is_leaving: + if isinstance(result, ast.Node): + node = result + + else: + path_pop() + continue + + if result is None and is_edited: + edits.append((key, node)) + + if not is_leaving: + stack = Stack(in_array, index, keys, edits, stack) + in_array = isinstance(node, list) + keys = node if in_array else visitor_keys.get(type(node), None) or [] + index = -1 + edits = [] + + if parent: + ancestors_append(parent) + + parent = node + + if not stack: + break + + if edits: + new_root = edits[-1][1] + + return new_root + + +class Visitor(metaclass=VisitorMeta): + __slots__ = () + + def enter(self, node, key, parent, path, ancestors): + method = self._get_enter_handler(type(node)) + if method: + return method(self, node, key, parent, path, ancestors) + + def leave(self, node, key, parent, path, ancestors): + method = self._get_leave_handler(type(node)) + if method: + return method(self, node, key, parent, path, ancestors) + + +class ParallelVisitor(Visitor): + __slots__ = 'skipping', 'visitors' + + def __init__(self, visitors): + self.visitors = visitors + self.skipping = [None] * len(visitors) + + def enter(self, node, key, parent, path, ancestors): + for i, visitor in enumerate(self.visitors): + if not self.skipping[i]: + result = visitor.enter(node, key, parent, path, ancestors) + if result is False: + self.skipping[i] = node + elif result is BREAK: + self.skipping[i] = BREAK + elif result is not None: + return result + + def leave(self, node, key, parent, path, ancestors): + for i, visitor in enumerate(self.visitors): + if not self.skipping[i]: + result = visitor.leave(node, key, parent, path, ancestors) + if result is BREAK: + self.skipping[i] = BREAK + elif result is not None and result is not False: + return result + elif self.skipping[i] == node: + self.skipping[i] = REMOVE + + +class TypeInfoVisitor(Visitor): + __slots__ = 'visitor', 'type_info' + + def __init__(self, type_info, visitor): + self.type_info = type_info + self.visitor = visitor + + def enter(self, node, key, parent, path, ancestors): + self.type_info.enter(node) + result = self.visitor.enter(node, key, parent, path, ancestors) + if result is not None: + self.type_info.leave(node) + if isinstance(result, ast.Node): + self.type_info.enter(result) + return result + + def leave(self, node, key, parent, path, ancestors): + result = self.visitor.leave(node, key, parent, path, ancestors) + self.type_info.leave(node) + return result diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..db2e640931abe178a2f0f58d11f0d96442fc1249 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/language/visitor_meta.py @@ -0,0 +1,82 @@ +from . import ast + +QUERY_DOCUMENT_KEYS = { + ast.Name: (), + + ast.Document: ('definitions',), + ast.OperationDefinition: ('name', 'variable_definitions', 'directives', 'selection_set'), + ast.VariableDefinition: ('variable', 'type', 'default_value'), + ast.Variable: ('name',), + ast.SelectionSet: ('selections',), + ast.Field: ('alias', 'name', 'arguments', 'directives', 'selection_set'), + ast.Argument: ('name', 'value'), + + ast.FragmentSpread: ('name', 'directives'), + ast.InlineFragment: ('type_condition', 'directives', 'selection_set'), + ast.FragmentDefinition: ('name', 'type_condition', 'directives', 'selection_set'), + + ast.IntValue: (), + ast.FloatValue: (), + ast.StringValue: (), + ast.BooleanValue: (), + ast.EnumValue: (), + ast.ListValue: ('values',), + ast.ObjectValue: ('fields',), + ast.ObjectField: ('name', 'value'), + + ast.Directive: ('name', 'arguments'), + + ast.NamedType: ('name',), + ast.ListType: ('type',), + ast.NonNullType: ('type',), + + ast.SchemaDefinition: ('directives', 'operation_types',), + ast.OperationTypeDefinition: ('type',), + + ast.ScalarTypeDefinition: ('name', 'directives',), + ast.ObjectTypeDefinition: ('name', 'interfaces', 'directives', 'fields'), + ast.FieldDefinition: ('name', 'arguments', 'directives', 'type'), + ast.InputValueDefinition: ('name', 'type', 'directives', 'default_value'), + ast.InterfaceTypeDefinition: ('name', 'directives', 'fields'), + ast.UnionTypeDefinition: ('name', 'directives', 'types'), + ast.EnumTypeDefinition: ('name', 'directives', 'values'), + ast.EnumValueDefinition: ('name', 'directives',), + ast.InputObjectTypeDefinition: ('name', 'directives', 'fields'), + + ast.TypeExtensionDefinition: ('definition',), + + ast.DirectiveDefinition: ('name', 'arguments', 'locations'), +} + +AST_KIND_TO_TYPE = {c.__name__: c for c in QUERY_DOCUMENT_KEYS.keys()} + + +class VisitorMeta(type): + + def __new__(cls, name, bases, attrs): + enter_handlers = {} + leave_handlers = {} + + for base in bases: + if hasattr(base, '_enter_handlers'): + enter_handlers.update(base._enter_handlers) + + if hasattr(base, '_leave_handlers'): + leave_handlers.update(base._leave_handlers) + + for attr, val in attrs.items(): + if attr.startswith('enter_'): + ast_kind = attr[6:] + ast_type = AST_KIND_TO_TYPE.get(ast_kind) + enter_handlers[ast_type] = val + + elif attr.startswith('leave_'): + ast_kind = attr[6:] + ast_type = AST_KIND_TO_TYPE.get(ast_kind) + leave_handlers[ast_type] = val + + attrs['_enter_handlers'] = enter_handlers + attrs['_leave_handlers'] = leave_handlers + attrs['_get_enter_handler'] = enter_handlers.get + attrs['_get_leave_handler'] = leave_handlers.get + return super(VisitorMeta, cls).__new__(cls, name, bases, attrs) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b916e87d0c7734c1a148403b0f97d9ce947a2428 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/cached_property.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/cached_property.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80bcc944e9bdaff71ec1214e2972d4228c971883 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/cached_property.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/default_ordered_dict.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/default_ordered_dict.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56c98b9451c6eb155521e6db6f00afa90c9b9c86 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/default_ordered_dict.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/ordereddict.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/ordereddict.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a266a7f04ec07c2cde83f1a3da655cc297a373a3 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/ordereddict.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/pair_set.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/pair_set.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c512f85783ef21c29b091f54ed83f96e74c3300b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/pair_set.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/version.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/version.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d5e134508d3b58a17169f0fc8d659e9bdd0782d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/__pycache__/version.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py new file mode 100644 index 0000000000000000000000000000000000000000..b5db6d48a9300e45cd68ff5a80f5d6085d5d4fca --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/cached_property.py @@ -0,0 +1,17 @@ +class cached_property(object): + """ A property that is only computed once per instance and then replaces + itself with an ordinary attribute. Deleting the attribute resets the + property. + + Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 + """ + + def __init__(self, func): + self.__doc__ = getattr(func, '__doc__') + self.func = func + + def __get__(self, obj, cls): + if obj is None: + return self + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..6c34936d4adff92254e4625fe12a55285378e79b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/contain_subset.py @@ -0,0 +1,28 @@ +obj = (dict, list, tuple) + + +def contain_subset(expected, actual): + t_actual = type(actual) + t_expected = type(expected) + actual_is_dict = issubclass(t_actual, dict) + expected_is_dict = issubclass(t_expected, dict) + both_dicts = actual_is_dict and expected_is_dict + if not(both_dicts) and not(issubclass(t_actual, t_expected) or issubclass(t_expected, t_actual)): + return False + if not isinstance(expected, obj) or expected is None: + return expected == actual + if expected and not actual: + return False + if isinstance(expected, list): + aa = actual[:] + return all([any([contain_subset(exp, act) for act in aa]) for exp in expected]) + for key in expected.keys(): + eo = expected[key] + ao = actual.get(key) + if isinstance(eo, obj) and eo is not None and ao is not None: + if not contain_subset(eo, ao): + return False + continue + if ao != eo: + return False + return True diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..e82a1be12fa1b876271c880abdd8391c316bca2c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/default_ordered_dict.py @@ -0,0 +1,40 @@ +import copy +from collections import OrderedDict + + +class DefaultOrderedDict(OrderedDict): + __slots__ = 'default_factory', + + # Source: http://stackoverflow.com/a/6190500/562769 + def __init__(self, default_factory=None, *a, **kw): + if default_factory is not None and not callable(default_factory): + raise TypeError('first argument must be callable') + + OrderedDict.__init__(self, *a, **kw) + self.default_factory = default_factory + + def __missing__(self, key): + if self.default_factory is None: + raise KeyError(key) + self[key] = value = self.default_factory() + return value + + def __reduce__(self): + if self.default_factory is None: + args = tuple() + else: + args = self.default_factory, + return type(self), args, None, None, iter(self.items()) + + def copy(self): + return self.__copy__() + + def __copy__(self): + return type(self)(self.default_factory, self) + + def __deepcopy__(self, memo): + return self.__class__(self.default_factory, copy.deepcopy(list(self.items()))) + + def __repr__(self): + return 'DefaultOrderedDict(%s, %s)' % (self.default_factory, + OrderedDict.__repr__(self)[19:-1]) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py new file mode 100644 index 0000000000000000000000000000000000000000..ff341a4c86b8857fabf22fe2ef644b470e9586c6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/ordereddict.py @@ -0,0 +1,8 @@ +try: + # Try to load the Cython performant OrderedDict (C) + # as is more performant than collections.OrderedDict (Python) + from cyordereddict import OrderedDict +except ImportError: + from collections import OrderedDict + +__all__ = ['OrderedDict'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py new file mode 100644 index 0000000000000000000000000000000000000000..0af547a9d93aa3a2a6af356bd3d25d56104aa977 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/pair_set.py @@ -0,0 +1,43 @@ +class PairSet(object): + __slots__ = '_data', + + def __init__(self): + self._data = {} + + def __contains__(self, item): + return self.has(item[0], item[1], item[2]) + + def __str__(self): + return str(self._data) + + def __repr__(self): + return str(self._data) + + def has(self, a, b, are_mutually_exclusive): + first = self._data.get(a) + result = first and first.get(b) + if result is None: + return False + + # are_mutually_exclusive being false is a superset of being true, + # hence if we want to know if this PairSet "has" these two with no + # exclusivity, we have to ensure it was added as such. + if not are_mutually_exclusive: + return not result + + return True + + def add(self, a, b, are_mutually_exclusive): + _pair_set_add(self._data, a, b, are_mutually_exclusive) + _pair_set_add(self._data, b, a, are_mutually_exclusive) + return self + + +def _pair_set_add(data, a, b, are_mutually_exclusive): + sub_dict = data.get(a) + + if not sub_dict: + sub_dict = {} + data[a] = sub_dict + + sub_dict[b] = are_mutually_exclusive diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py new file mode 100644 index 0000000000000000000000000000000000000000..614df9f570dc8c92101cc0126de6ef49c3841615 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/pyutils/version.py @@ -0,0 +1,78 @@ +from __future__ import unicode_literals + +import datetime +import os +import subprocess + + +def get_version(version=None): + "Returns a PEP 440-compliant version number from VERSION." + version = get_complete_version(version) + + # Now build the two parts of the version number: + # main = X.Y[.Z] + # sub = .devN - for pre-alpha releases + # | {a|b|rc}N - for alpha, beta, and rc releases + + main = get_main_version(version) + + sub = '' + if version[3] == 'alpha' and version[4] == 0: + git_changeset = get_git_changeset() + if git_changeset: + sub = '.dev%s' % git_changeset + else: + sub = '.dev' + elif version[3] != 'final': + mapping = {'alpha': 'a', 'beta': 'b', 'rc': 'rc'} + sub = mapping[version[3]] + str(version[4]) + + return str(main + sub) + + +def get_main_version(version=None): + "Returns main version (X.Y[.Z]) from VERSION." + version = get_complete_version(version) + parts = 2 if version[2] == 0 else 3 + return '.'.join(str(x) for x in version[:parts]) + + +def get_complete_version(version=None): + """Returns a tuple of the graphql version. If version argument is non-empty, + then checks for correctness of the tuple provided. + """ + if version is None: + from wandb_graphql import VERSION as version + else: + assert len(version) == 5 + assert version[3] in ('alpha', 'beta', 'rc', 'final') + + return version + + +def get_docs_version(version=None): + version = get_complete_version(version) + if version[3] != 'final': + return 'dev' + else: + return '%d.%d' % version[:2] + + +def get_git_changeset(): + """Returns a numeric identifier of the latest git changeset. + The result is the UTC timestamp of the changeset in YYYYMMDDHHMMSS format. + This value isn't guaranteed to be unique, but collisions are very unlikely, + so it's sufficient for generating the development version numbers. + """ + repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + try: + git_log = subprocess.Popen( + 'git log --pretty=format:%ct --quiet -1 HEAD', + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=True, cwd=repo_dir, universal_newlines=True, + ) + timestamp = git_log.communicate()[0] + timestamp = datetime.datetime.utcfromtimestamp(int(timestamp)) + except: + return None + return timestamp.strftime('%Y%m%d%H%M%S') diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..153c1b5e333723de1379049da0c1d0250e94dfc1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__init__.py @@ -0,0 +1,67 @@ +# flake8: noqa +from .definition import ( # no import order + GraphQLScalarType, + GraphQLObjectType, + GraphQLField, + GraphQLArgument, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLEnumType, + GraphQLEnumValue, + GraphQLInputObjectType, + GraphQLInputObjectField, + GraphQLList, + GraphQLNonNull, + get_named_type, + is_abstract_type, + is_composite_type, + is_input_type, + is_leaf_type, + is_type, + get_nullable_type, + is_output_type +) +from .directives import ( + # "Enum" of Directive locations + DirectiveLocation, + + # Directive definition + GraphQLDirective, + + # Built-in directives defined by the Spec + specified_directives, + GraphQLSkipDirective, + GraphQLIncludeDirective, + GraphQLDeprecatedDirective, + + # Constant Deprecation Reason + DEFAULT_DEPRECATION_REASON, +) +from .scalars import ( # no import order + GraphQLInt, + GraphQLFloat, + GraphQLString, + GraphQLBoolean, + GraphQLID, +) +from .schema import GraphQLSchema + +from .introspection import ( + # "Enum" of Type Kinds + TypeKind, + + # GraphQL Types for introspection. + __Schema, + __Directive, + __DirectiveLocation, + __Type, + __Field, + __InputValue, + __EnumValue, + __TypeKind, + + # Meta-field definitions. + SchemaMetaFieldDef, + TypeMetaFieldDef, + TypeNameMetaFieldDef +) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf0db08a3b4459bd55287c12eee028f79679ef99 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/definition.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/definition.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b577b22f9fe759907cea292bf254e425ca95de6 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/definition.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/directives.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/directives.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb804b8c18e191553b57cf88f4b9c93faa4398a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/directives.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/introspection.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/introspection.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0ab0299689f1bd04d673726091fa6b4a345610f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/introspection.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/scalars.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/scalars.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a84751ce9e7738e8a3cfeff977dbe5acdb6c572 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/scalars.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/schema.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/schema.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1bb473cc53df979bb16137d6b42571425653d8b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/schema.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/typemap.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/typemap.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88977563d11de8aec6c8ccdbf2ee003fe28ab3a5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/__pycache__/typemap.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0c4b4c182372fdf5b411dd4f1337c796ec6d14 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/definition.py @@ -0,0 +1,619 @@ +from collections.abc import Mapping, Hashable +import collections +import copy + +from ..language import ast +from ..pyutils.cached_property import cached_property +from ..pyutils.ordereddict import OrderedDict +from ..utils.assert_valid_name import assert_valid_name + + +def is_type(type): + return isinstance(type, ( + GraphQLScalarType, + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLEnumType, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull + )) + + +def is_input_type(type): + named_type = get_named_type(type) + return isinstance(named_type, ( + GraphQLScalarType, + GraphQLEnumType, + GraphQLInputObjectType, + )) + + +def is_output_type(type): + named_type = get_named_type(type) + return isinstance(named_type, ( + GraphQLScalarType, + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + GraphQLEnumType + )) + + +def is_leaf_type(type): + return isinstance(type, ( + GraphQLScalarType, + GraphQLEnumType, + )) + + +def is_composite_type(type): + named_type = get_named_type(type) + return isinstance(named_type, ( + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + )) + + +def is_abstract_type(type): + return isinstance(type, ( + GraphQLInterfaceType, + GraphQLUnionType + )) + + +def get_nullable_type(type): + if isinstance(type, GraphQLNonNull): + return type.of_type + return type + + +def get_named_type(type): + unmodified_type = type + while isinstance(unmodified_type, (GraphQLList, GraphQLNonNull)): + unmodified_type = unmodified_type.of_type + + return unmodified_type + + +class GraphQLType(object): + __slots__ = 'name', + + def __str__(self): + return self.name + + def is_same_type(self, other): + return self.__class__ is other.__class__ and self.name == other.name + + +def none_func(x): + None + + +class GraphQLScalarType(GraphQLType): + """Scalar Type Definition + + The leaf values of any request and input values to arguments are + Scalars (or Enums) and are defined with a name and a series of coercion + functions used to ensure validity. + + Example: + + def coerce_odd(value): + if value % 2 == 1: + return value + return None + + OddType = GraphQLScalarType(name='Odd', serialize=coerce_odd) + """ + + __slots__ = 'name', 'description', 'serialize', 'parse_value', 'parse_literal' + + def __init__(self, name, description=None, serialize=None, parse_value=None, parse_literal=None): + assert name, 'Type must be named.' + assert_valid_name(name) + self.name = name + self.description = description + + assert callable(serialize), ( + '{} must provide "serialize" function. If this custom Scalar is ' + 'also used as an input type, ensure "parse_value" and "parse_literal" ' + 'functions are also provided.' + ).format(self) + + if parse_value is not None or parse_literal is not None: + assert callable(parse_value) and callable(parse_literal), ( + '{} must provide both "parse_value" and "parse_literal" functions.'.format(self) + ) + + self.serialize = serialize + self.parse_value = parse_value or none_func + self.parse_literal = parse_literal or none_func + + def __str__(self): + return self.name + + +class GraphQLObjectType(GraphQLType): + """Object Type Definition + + Almost all of the GraphQL types you define will be object types. + Object types have a name, but most importantly describe their fields. + + Example: + + AddressType = GraphQLObjectType('Address', { + 'street': GraphQLField(GraphQLString), + 'number': GraphQLField(GraphQLInt), + 'formatted': GraphQLField(GraphQLString, + resolver=lambda obj, args, context, info: obj.number + ' ' + obj.street), + }) + + When two types need to refer to each other, or a type needs to refer to + itself in a field, you can use a static method to supply the fields + lazily. + + Example: + + PersonType = GraphQLObjectType('Person', lambda: { + 'name': GraphQLField(GraphQLString), + 'bestFriend': GraphQLField(PersonType) + }) + """ + def __init__(self, name, fields, interfaces=None, is_type_of=None, description=None): + assert name, 'Type must be named.' + assert_valid_name(name) + self.name = name + self.description = description + + if is_type_of is not None: + assert callable(is_type_of), '{} must provide "is_type_of" as a function.'.format(self) + + self.is_type_of = is_type_of + self._fields = fields + self._provided_interfaces = interfaces + self._interfaces = None + + @cached_property + def fields(self): + return define_field_map(self, self._fields) + + @cached_property + def interfaces(self): + return define_interfaces(self, self._provided_interfaces) + + +def define_field_map(type, field_map): + if callable(field_map): + field_map = field_map() + + assert isinstance(field_map, Mapping) and len(field_map) > 0, ( + '{} fields must be a mapping (dict / OrderedDict) with field names as keys or a ' + 'function which returns such a mapping.' + ).format(type) + + for field_name, field in field_map.items(): + assert_valid_name(field_name) + field_args = getattr(field, 'args', None) + + if field_args: + assert isinstance(field_args, Mapping), ( + '{}.{} args must be a mapping (dict / OrderedDict) with argument names as keys.'.format(type, + field_name) + ) + + for arg_name, arg in field_args.items(): + assert_valid_name(arg_name) + + return OrderedDict(field_map) + + +def define_interfaces(type, interfaces): + if callable(interfaces): + interfaces = interfaces() + + if interfaces is None: + interfaces = [] + + assert isinstance(interfaces, (list, tuple)), ( + '{} interfaces must be a list/tuple or a function which returns a list/tuple.'.format(type) + ) + + for interface in interfaces: + assert isinstance(interface, GraphQLInterfaceType), ( + '{} may only implement Interface types, it cannot implement: {}.'.format(type, interface) + ) + + if not callable(interface.resolve_type): + assert callable(type.is_type_of), ( + 'Interface Type {} does not provide a "resolve_type" function ' + 'and implementing Type {} does not provide a "is_type_of" ' + 'function. There is no way to resolve this implementing type ' + 'during execution.' + ).format(interface, type) + + return interfaces + + +class GraphQLField(object): + __slots__ = 'type', 'args', 'resolver', 'deprecation_reason', 'description' + + def __init__(self, type, args=None, resolver=None, deprecation_reason=None, description=None): + self.type = type + self.args = args or OrderedDict() + self.resolver = resolver + self.deprecation_reason = deprecation_reason + self.description = description + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, GraphQLField) and + self.type == other.type and + self.args == other.args and + self.resolver == other.resolver and + self.deprecation_reason == other.deprecation_reason and + self.description == other.description + ) + ) + + def __hash__(self): + return id(self) + + +class GraphQLArgument(object): + __slots__ = 'type', 'default_value', 'description', 'out_name' + + def __init__(self, type, default_value=None, description=None, out_name=None): + self.type = type + self.default_value = default_value + self.description = description + self.out_name = out_name + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, GraphQLArgument) and + self.type == other.type and + self.default_value == other.default_value and + self.description == other.description and + self.out_name == other.out_name + ) + ) + + def __hash__(self): + return id(self) + + +class GraphQLInterfaceType(GraphQLType): + """Interface Type Definition + + When a field can return one of a heterogeneous set of types, a Interface type is used to describe what types are possible, + what fields are in common across all types, as well as a function to determine which type is actually used when the field is resolved. + + Example: + + EntityType = GraphQLInterfaceType( + name='Entity', + fields={ + 'name': GraphQLField(GraphQLString), + }) + """ + + def __init__(self, name, fields=None, resolve_type=None, description=None): + assert name, 'Type must be named.' + assert_valid_name(name) + self.name = name + self.description = description + + if resolve_type is not None: + assert callable(resolve_type), '{} must provide "resolve_type" as a function.'.format(self) + + self.resolve_type = resolve_type + self._fields = fields + + @cached_property + def fields(self): + return define_field_map(self, self._fields) + + +class GraphQLUnionType(GraphQLType): + """Union Type Definition + + When a field can return one of a heterogeneous set of types, a Union type is used to describe what types are possible + as well as providing a function to determine which type is actually used when the field is resolved. + + Example: + + class PetType(GraphQLUnionType): + name = 'Pet' + types = [DogType, CatType] + + def resolve_type(self, value): + if isinstance(value, Dog): + return DogType() + if isinstance(value, Cat): + return CatType() + """ + + def __init__(self, name, types=None, resolve_type=None, description=None): + assert name, 'Type must be named.' + assert_valid_name(name) + self.name = name + self.description = description + + if resolve_type is not None: + assert callable(resolve_type), '{} must provide "resolve_type" as a function.'.format(self) + + self.resolve_type = resolve_type + self._types = types + + @cached_property + def types(self): + return define_types(self, self._types) + + +def define_types(union_type, types): + if callable(types): + types = types() + + assert isinstance(types, (list, tuple)) and len( + types) > 0, 'Must provide types for Union {}.'.format(union_type.name) + has_resolve_type_fn = callable(union_type.resolve_type) + + for type in types: + assert isinstance(type, GraphQLObjectType), ( + '{} may only contain Object types, it cannot contain: {}.'.format(union_type, type) + ) + + if not has_resolve_type_fn: + assert callable(type.is_type_of), ( + 'Union Type {} does not provide a "resolve_type" function ' + 'and possible Type {} does not provide a "is_type_of" ' + 'function. There is no way to resolve this possible type ' + 'during execution.' + ).format(union_type, type) + + return types + + +class GraphQLEnumType(GraphQLType): + """Enum Type Definition + + Some leaf values of requests and input values are Enums. GraphQL serializes Enum values as strings, + however internally Enums can be represented by any kind of type, often integers. + + Example: + + RGBType = GraphQLEnumType( + name='RGB', + values=OrderedDict([ + ('RED', GraphQLEnumValue(0)), + ('GREEN', GraphQLEnumValue(1)), + ('BLUE', GraphQLEnumValue(2)) + ]) + ) + + Note: If a value is not provided in a definition, the name of the enum value will be used as it's internal value. + """ + + def __init__(self, name, values, description=None): + assert name, 'Type must provide name.' + assert_valid_name(name) + self.name = name + self.description = description + + self.values = define_enum_values(self, values) + + def serialize(self, value): + if isinstance(value, Hashable): + enum_value = self._value_lookup.get(value) + + if enum_value: + return enum_value.name + + return None + + def parse_value(self, value): + if isinstance(value, Hashable): + enum_value = self._name_lookup.get(value) + + if enum_value: + return enum_value.value + + return None + + def parse_literal(self, value_ast): + if isinstance(value_ast, ast.EnumValue): + enum_value = self._name_lookup.get(value_ast.value) + + if enum_value: + return enum_value.value + + @cached_property + def _value_lookup(self): + return {value.value: value for value in self.values} + + @cached_property + def _name_lookup(self): + return {value.name: value for value in self.values} + + +def define_enum_values(type, value_map): + assert isinstance(value_map, Mapping) and len(value_map) > 0, ( + '{} values must be a mapping (dict / OrderedDict) with value names as keys.'.format(type) + ) + + values = [] + if not isinstance(value_map, (collections.OrderedDict, OrderedDict)): + value_map = OrderedDict(sorted(list(value_map.items()))) + + for value_name, value in value_map.items(): + assert_valid_name(value_name) + assert isinstance(value, GraphQLEnumValue), ( + '{}.{} must be an instance of GraphQLEnumValue, but got: {}'.format(type, value_name, value) + ) + value = copy.copy(value) + value.name = value_name + if value.value is None: + value.value = value_name + + values.append(value) + + return values + + +class GraphQLEnumValue(object): + __slots__ = 'name', 'value', 'deprecation_reason', 'description' + + def __init__(self, value=None, deprecation_reason=None, description=None, name=None): + self.name = name + self.value = value + self.deprecation_reason = deprecation_reason + self.description = description + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, GraphQLEnumValue) and + self.name == other.name and + self.value == other.value and + self.deprecation_reason == other.deprecation_reason and + self.description == other.description + ) + ) + + +class GraphQLInputObjectType(GraphQLType): + """Input Object Type Definition + + An input object defines a structured collection of fields which may be + supplied to a field argument. + + Using `NonNull` will ensure that a value must be provided by the query + + Example: + + NonNullFloat = GraphQLNonNull(GraphQLFloat()) + + class GeoPoint(GraphQLInputObjectType): + name = 'GeoPoint' + fields = { + 'lat': GraphQLInputObjectField(NonNullFloat), + 'lon': GraphQLInputObjectField(NonNullFloat), + 'alt': GraphQLInputObjectField(GraphQLFloat(), + default_value=0) + } + """ + def __init__(self, name, fields, description=None): + assert name, 'Type must be named.' + self.name = name + self.description = description + + self._fields = fields + + @cached_property + def fields(self): + return self._define_field_map() + + def _define_field_map(self): + fields = self._fields + if callable(fields): + fields = fields() + + assert isinstance(fields, Mapping) and len(fields) > 0, ( + '{} fields must be a mapping (dict / OrderedDict) with field names as keys or a ' + 'function which returns such a mapping.' + ).format(self) + if not isinstance(fields, (collections.OrderedDict, OrderedDict)): + fields = OrderedDict(sorted(list(fields.items()))) + + for field_name, field in fields.items(): + assert_valid_name(field_name) + + return fields + + +class GraphQLInputObjectField(object): + __slots__ = 'type', 'default_value', 'description', 'out_name' + + def __init__(self, type, default_value=None, description=None, out_name=None): + self.type = type + self.default_value = default_value + self.description = description + self.out_name = out_name + + def __eq__(self, other): + return ( + self is other or ( + isinstance(other, GraphQLInputObjectField) and + self.type == other.type and + self.description == other.description and + self.out_name == other.out_name + ) + ) + + +class GraphQLList(GraphQLType): + """List Modifier + + A list is a kind of type marker, a wrapping type which points to another + type. Lists are often created within the context of defining the fields + of an object type. + + Example: + + class PersonType(GraphQLObjectType): + name = 'Person' + + def get_fields(self): + return { + 'parents': GraphQLField(GraphQLList(PersonType())), + 'children': GraphQLField(GraphQLList(PersonType())), + } + """ + __slots__ = 'of_type', + + def __init__(self, type): + assert is_type(type), 'Can only create List of a GraphQLType but got: {}.'.format(type) + self.of_type = type + + def __str__(self): + return '[' + str(self.of_type) + ']' + + def is_same_type(self, other): + return isinstance(other, GraphQLList) and self.of_type.is_same_type(other.of_type) + + +class GraphQLNonNull(GraphQLType): + """Non-Null Modifier + + A non-null is a kind of type marker, a wrapping type which points to another type. Non-null types enforce that their values are never null + and can ensure an error is raised if this ever occurs during a request. It is useful for fields which you can make a strong guarantee on + non-nullability, for example usually the id field of a database row will never be null. + + Example: + + class RowType(GraphQLObjectType): + name = 'Row' + fields = { + 'id': GraphQLField(type=GraphQLNonNull(GraphQLString())) + } + + Note: the enforcement of non-nullability occurs within the executor. + """ + __slots__ = 'of_type', + + def __init__(self, type): + assert is_type(type) and not isinstance(type, GraphQLNonNull), ( + 'Can only create NonNull of a Nullable GraphQLType but got: {}.'.format(type) + ) + self.of_type = type + + def __str__(self): + return str(self.of_type) + '!' + + def is_same_type(self, other): + return isinstance(other, GraphQLNonNull) and self.of_type.is_same_type(other.of_type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..c566b4d5d8ebdc43a12ae4648f01074eab0b1ae8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/directives.py @@ -0,0 +1,132 @@ +from collections.abc import Iterable, Mapping + +from ..pyutils.ordereddict import OrderedDict +from ..utils.assert_valid_name import assert_valid_name +from .definition import GraphQLArgument, GraphQLNonNull, is_input_type +from .scalars import GraphQLBoolean, GraphQLString + + +class DirectiveLocation(object): + # Operations + QUERY = 'QUERY' + MUTATION = 'MUTATION' + SUBSCRIPTION = 'SUBSCRIPTION' + FIELD = 'FIELD' + FRAGMENT_DEFINITION = 'FRAGMENT_DEFINITION' + FRAGMENT_SPREAD = 'FRAGMENT_SPREAD' + INLINE_FRAGMENT = 'INLINE_FRAGMENT' + + # Schema Definitions + SCHEMA = 'SCHEMA' + SCALAR = 'SCALAR' + OBJECT = 'OBJECT' + FIELD_DEFINITION = 'FIELD_DEFINITION' + ARGUMENT_DEFINITION = 'ARGUMENT_DEFINITION' + INTERFACE = 'INTERFACE' + UNION = 'UNION' + ENUM = 'ENUM' + ENUM_VALUE = 'ENUM_VALUE' + INPUT_OBJECT = 'INPUT_OBJECT' + INPUT_FIELD_DEFINITION = 'INPUT_FIELD_DEFINITION' + + OPERATION_LOCATIONS = [ + QUERY, + MUTATION, + SUBSCRIPTION + ] + + FRAGMENT_LOCATIONS = [ + FRAGMENT_DEFINITION, + FRAGMENT_SPREAD, + INLINE_FRAGMENT + ] + + FIELD_LOCATIONS = [ + FIELD + ] + + +class GraphQLDirective(object): + __slots__ = 'name', 'args', 'description', 'locations' + + def __init__(self, name, description=None, args=None, locations=None): + assert name, 'Directive must be named.' + assert_valid_name(name) + assert isinstance(locations, Iterable), 'Must provide locations for directive.' + + self.name = name + self.description = description + self.locations = locations + + if args: + assert isinstance(args, Mapping), '{} args must be a dict with argument names as keys.'.format(name) + for arg_name, _arg in args.items(): + assert_valid_name(arg_name) + assert is_input_type(_arg.type), '{}({}) argument type must be Input Type but got {}.'.format( + name, + arg_name, + _arg.type) + self.args = args or OrderedDict() + + +"""Used to conditionally include fields or fragments.""" +GraphQLIncludeDirective = GraphQLDirective( + name='include', + description='Directs the executor to include this field or fragment only when the `if` argument is true.', + args={ + 'if': GraphQLArgument( + type=GraphQLNonNull(GraphQLBoolean), + description='Included when true.', + ), + }, + locations=[ + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ] +) + +"""Used to conditionally skip (exclude) fields or fragments.""" +GraphQLSkipDirective = GraphQLDirective( + name='skip', + description='Directs the executor to skip this field or fragment when the `if` argument is true.', + args={ + 'if': GraphQLArgument( + type=GraphQLNonNull(GraphQLBoolean), + description='Skipped when true.', + ), + }, + locations=[ + DirectiveLocation.FIELD, + DirectiveLocation.FRAGMENT_SPREAD, + DirectiveLocation.INLINE_FRAGMENT, + ] +) + +"""Constant string used for default reason for a deprecation.""" +DEFAULT_DEPRECATION_REASON = 'No longer supported' + +"""Used to declare element of a GraphQL schema as deprecated.""" +GraphQLDeprecatedDirective = GraphQLDirective( + name='deprecated', + description='Marks an element of a GraphQL schema as no longer supported.', + args={ + 'reason': GraphQLArgument( + type=GraphQLString, + description=('Explains why this element was deprecated, usually also including a suggestion for how to' + 'access supported similar data. Formatted in [Markdown]' + '(https://daringfireball.net/projects/markdown/).'), + default_value=DEFAULT_DEPRECATION_REASON + ), + }, + locations=[ + DirectiveLocation.FIELD_DEFINITION, + DirectiveLocation.ENUM_VALUE, + ] +) + +specified_directives = [ + GraphQLIncludeDirective, + GraphQLSkipDirective, + GraphQLDeprecatedDirective +] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py new file mode 100644 index 0000000000000000000000000000000000000000..b2732fc82b0888598e82b06b91ecebc2e16980df --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/introspection.py @@ -0,0 +1,440 @@ +from collections import OrderedDict, namedtuple + +from ..language.printer import print_ast +from ..utils.ast_from_value import ast_from_value +from .definition import (GraphQLArgument, GraphQLEnumType, GraphQLEnumValue, + GraphQLField, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, GraphQLNonNull, + GraphQLObjectType, GraphQLScalarType, + GraphQLUnionType) +from .directives import DirectiveLocation +from .scalars import GraphQLBoolean, GraphQLString + +InputField = namedtuple('InputField', ['name', 'description', 'type', 'default_value']) +Field = namedtuple('Field', ['name', 'type', 'description', 'args', 'deprecation_reason']) + + +def input_fields_to_list(input_fields): + fields = [] + for field_name, field in input_fields.items(): + fields.append(InputField( + name=field_name, + description=field.description, + type=field.type, + default_value=field.default_value)) + return fields + + +__Schema = GraphQLObjectType( + '__Schema', + description='A GraphQL Schema defines the capabilities of a GraphQL server. It ' + 'exposes all available types and directives on the server, as well as ' + 'the entry points for query, mutation and subscription operations.', + fields=lambda: OrderedDict([ + ('types', GraphQLField( + description='A list of all types supported by this server.', + type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__Type))), + resolver=lambda schema, *_: schema.get_type_map().values(), + )), + ('queryType', GraphQLField( + description='The type that query operations will be rooted at.', + type=GraphQLNonNull(__Type), + resolver=lambda schema, *_: schema.get_query_type(), + )), + ('mutationType', GraphQLField( + description='If this server supports mutation, the type that ' + 'mutation operations will be rooted at.', + type=__Type, + resolver=lambda schema, *_: schema.get_mutation_type(), + )), + ('subscriptionType', GraphQLField( + description='If this server support subscription, the type ' + 'that subscription operations will be rooted at.', + type=__Type, + resolver=lambda schema, *_: schema.get_subscription_type(), + )), + ('directives', GraphQLField( + description='A list of all directives supported by this server.', + type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__Directive))), + resolver=lambda schema, *_: schema.get_directives(), + )), + ])) + +_on_operation_locations = set(DirectiveLocation.OPERATION_LOCATIONS) +_on_fragment_locations = set(DirectiveLocation.FRAGMENT_LOCATIONS) +_on_field_locations = set(DirectiveLocation.FIELD_LOCATIONS) + +__Directive = GraphQLObjectType( + '__Directive', + description='A Directive provides a way to describe alternate runtime execution and ' + 'type validation behavior in a GraphQL document.' + '\n\nIn some cases, you need to provide options to alter GraphQL\'s ' + 'execution behavior in ways field arguments will not suffice, such as ' + 'conditionally including or skipping a field. Directives provide this by ' + 'describing additional information to the executor.', + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLNonNull(GraphQLString))), + ('description', GraphQLField(GraphQLString)), + ('locations', GraphQLField( + type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__DirectiveLocation))), + )), + ('args', GraphQLField( + type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))), + resolver=lambda directive, *args: input_fields_to_list(directive.args), + )), + ('onOperation', GraphQLField( + type=GraphQLNonNull(GraphQLBoolean), + deprecation_reason='Use `locations`.', + resolver=lambda directive, *args: set(directive.locations) & _on_operation_locations, + )), + ('onFragment', GraphQLField( + type=GraphQLNonNull(GraphQLBoolean), + deprecation_reason='Use `locations`.', + resolver=lambda directive, *args: set(directive.locations) & _on_fragment_locations, + )), + ('onField', GraphQLField( + type=GraphQLNonNull(GraphQLBoolean), + deprecation_reason='Use `locations`.', + resolver=lambda directive, *args: set(directive.locations) & _on_field_locations, + )) + ])) + +__DirectiveLocation = GraphQLEnumType( + '__DirectiveLocation', + description=( + 'A Directive can be adjacent to many parts of the GraphQL language, a ' + + '__DirectiveLocation describes one such possible adjacencies.' + ), + values=OrderedDict([ + ('QUERY', GraphQLEnumValue( + DirectiveLocation.QUERY, + description='Location adjacent to a query operation.' + )), + ('MUTATION', GraphQLEnumValue( + DirectiveLocation.MUTATION, + description='Location adjacent to a mutation operation.' + )), + ('SUBSCRIPTION', GraphQLEnumValue( + DirectiveLocation.SUBSCRIPTION, + description='Location adjacent to a subscription operation.' + )), + ('FIELD', GraphQLEnumValue( + DirectiveLocation.FIELD, + description='Location adjacent to a field.' + )), + ('FRAGMENT_DEFINITION', GraphQLEnumValue( + DirectiveLocation.FRAGMENT_DEFINITION, + description='Location adjacent to a fragment definition.' + )), + ('FRAGMENT_SPREAD', GraphQLEnumValue( + DirectiveLocation.FRAGMENT_SPREAD, + description='Location adjacent to a fragment spread.' + )), + ('INLINE_FRAGMENT', GraphQLEnumValue( + DirectiveLocation.INLINE_FRAGMENT, + description='Location adjacent to an inline fragment.' + )), + ('SCHEMA', GraphQLEnumValue( + DirectiveLocation.SCHEMA, + description='Location adjacent to a schema definition.' + )), + ('SCALAR', GraphQLEnumValue( + DirectiveLocation.SCALAR, + description='Location adjacent to a scalar definition.' + )), + ('OBJECT', GraphQLEnumValue( + DirectiveLocation.OBJECT, + description='Location adjacent to an object definition.' + )), + ('FIELD_DEFINITION', GraphQLEnumValue( + DirectiveLocation.FIELD_DEFINITION, + description='Location adjacent to a field definition.' + )), + ('ARGUMENT_DEFINITION', GraphQLEnumValue( + DirectiveLocation.ARGUMENT_DEFINITION, + description='Location adjacent to an argument definition.' + )), + ('INTERFACE', GraphQLEnumValue( + DirectiveLocation.INTERFACE, + description='Location adjacent to an interface definition.' + )), + ('UNION', GraphQLEnumValue( + DirectiveLocation.UNION, + description='Location adjacent to a union definition.' + )), + ('ENUM', GraphQLEnumValue( + DirectiveLocation.ENUM, + description='Location adjacent to an enum definition.' + )), + ('ENUM_VALUE', GraphQLEnumValue( + DirectiveLocation.ENUM_VALUE, + description='Location adjacent to an enum value definition.' + )), + ('INPUT_OBJECT', GraphQLEnumValue( + DirectiveLocation.INPUT_OBJECT, + description='Location adjacent to an input object definition.' + )), + ('INPUT_FIELD_DEFINITION', GraphQLEnumValue( + DirectiveLocation.INPUT_FIELD_DEFINITION, + description='Location adjacent to an input object field definition.' + )), + ])) + + +class TypeKind(object): + SCALAR = 'SCALAR' + OBJECT = 'OBJECT' + INTERFACE = 'INTERFACE' + UNION = 'UNION' + ENUM = 'ENUM' + INPUT_OBJECT = 'INPUT_OBJECT' + LIST = 'LIST' + NON_NULL = 'NON_NULL' + + +class TypeFieldResolvers(object): + _kinds = ( + (GraphQLScalarType, TypeKind.SCALAR), + (GraphQLObjectType, TypeKind.OBJECT), + (GraphQLInterfaceType, TypeKind.INTERFACE), + (GraphQLUnionType, TypeKind.UNION), + (GraphQLEnumType, TypeKind.ENUM), + (GraphQLInputObjectType, TypeKind.INPUT_OBJECT), + (GraphQLList, TypeKind.LIST), + (GraphQLNonNull, TypeKind.NON_NULL), + ) + + @classmethod + def kind(cls, type, *_): + for klass, kind in cls._kinds: + if isinstance(type, klass): + return kind + + raise Exception('Unknown kind of type: {}'.format(type)) + + @staticmethod + def fields(type, args, *_): + if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType)): + fields = [] + include_deprecated = args.get('includeDeprecated') + for field_name, field in type.fields.items(): + if field.deprecation_reason and not include_deprecated: + continue + fields.append(Field( + name=field_name, + description=field.description, + type=field.type, + args=field.args, + deprecation_reason=field.deprecation_reason + )) + return fields + return None + + @staticmethod + def interfaces(type, *_): + if isinstance(type, GraphQLObjectType): + return type.interfaces + + @staticmethod + def possible_types(type, args, context, info): + if isinstance(type, (GraphQLInterfaceType, GraphQLUnionType)): + return info.schema.get_possible_types(type) + + @staticmethod + def enum_values(type, args, *_): + if isinstance(type, GraphQLEnumType): + values = type.values + if not args.get('includeDeprecated'): + values = [v for v in values if not v.deprecation_reason] + + return values + + @staticmethod + def input_fields(type, *_): + if isinstance(type, GraphQLInputObjectType): + return input_fields_to_list(type.fields) + + +__Type = GraphQLObjectType( + '__Type', + description='The fundamental unit of any GraphQL Schema is the type. There are ' + 'many kinds of types in GraphQL as represented by the `__TypeKind` enum.' + '\n\nDepending on the kind of a type, certain fields describe ' + 'information about that type. Scalar types provide no information ' + 'beyond a name and description, while Enum types provide their values. ' + 'Object and Interface types provide the fields they describe. Abstract ' + 'types, Union and Interface, provide the Object types possible ' + 'at runtime. List and NonNull types compose other types.', + fields=lambda: OrderedDict([ + ('kind', GraphQLField( + type=GraphQLNonNull(__TypeKind), + resolver=TypeFieldResolvers.kind + )), + ('name', GraphQLField(GraphQLString)), + ('description', GraphQLField(GraphQLString)), + ('fields', GraphQLField( + type=GraphQLList(GraphQLNonNull(__Field)), + args={ + 'includeDeprecated': GraphQLArgument( + GraphQLBoolean, + default_value=False + ) + }, + resolver=TypeFieldResolvers.fields + )), + ('interfaces', GraphQLField( + type=GraphQLList(GraphQLNonNull(__Type)), + resolver=TypeFieldResolvers.interfaces + )), + ('possibleTypes', GraphQLField( + type=GraphQLList(GraphQLNonNull(__Type)), + resolver=TypeFieldResolvers.possible_types + )), + ('enumValues', GraphQLField( + type=GraphQLList(GraphQLNonNull(__EnumValue)), + args={ + 'includeDeprecated': GraphQLArgument( + GraphQLBoolean, + default_value=False + ) + }, + resolver=TypeFieldResolvers.enum_values + )), + ('inputFields', GraphQLField( + type=GraphQLList(GraphQLNonNull(__InputValue)), + resolver=TypeFieldResolvers.input_fields + )), + ('ofType', GraphQLField( + type=__Type, + resolver=lambda type, *_: getattr(type, 'of_type', None) + )), + ])) + +__Field = GraphQLObjectType( + '__Field', + description='Object and Interface types are described by a list of Fields, each of ' + 'which has a name, potentially a list of arguments, and a return type.', + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLNonNull(GraphQLString))), + ('description', GraphQLField(GraphQLString)), + ('args', GraphQLField( + type=GraphQLNonNull(GraphQLList(GraphQLNonNull(__InputValue))), + resolver=lambda field, *_: input_fields_to_list(field.args) + )), + ('type', GraphQLField(GraphQLNonNull(__Type))), + ('isDeprecated', GraphQLField( + type=GraphQLNonNull(GraphQLBoolean), + resolver=lambda field, *_: bool(field.deprecation_reason) + )), + ('deprecationReason', GraphQLField( + type=GraphQLString, + resolver=lambda field, *_: field.deprecation_reason + )) + ]) +) + +__InputValue = GraphQLObjectType( + '__InputValue', + description='Arguments provided to Fields or Directives and the input fields of an ' + 'InputObject are represented as Input Values which describe their type ' + 'and optionally a default value.', + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLNonNull(GraphQLString))), + ('description', GraphQLField(GraphQLString)), + ('type', GraphQLField(GraphQLNonNull(__Type))), + ('defaultValue', GraphQLField( + type=GraphQLString, + resolver=lambda input_val, *_: + None if input_val.default_value is None + else print_ast(ast_from_value(input_val.default_value, input_val)) + )) + ])) + +__EnumValue = GraphQLObjectType( + '__EnumValue', + description='One possible value for a given Enum. Enum values are unique values, not ' + 'a placeholder for a string or numeric value. However an Enum value is ' + 'returned in a JSON response as a string.', + fields=lambda: OrderedDict([ + ('name', GraphQLField(GraphQLNonNull(GraphQLString))), + ('description', GraphQLField(GraphQLString)), + ('isDeprecated', GraphQLField( + type=GraphQLNonNull(GraphQLBoolean), + resolver=lambda field, *_: bool(field.deprecation_reason) + )), + ('deprecationReason', GraphQLField( + type=GraphQLString, + resolver=lambda enum_value, *_: enum_value.deprecation_reason, + )) + ])) + +__TypeKind = GraphQLEnumType( + '__TypeKind', + description='An enum describing what kind of type a given `__Type` is', + values=OrderedDict([ + ('SCALAR', GraphQLEnumValue( + TypeKind.SCALAR, + description='Indicates this type is a scalar.' + )), + ('OBJECT', GraphQLEnumValue( + TypeKind.OBJECT, + description='Indicates this type is an object. ' + '`fields` and `interfaces` are valid fields.' + )), + ('INTERFACE', GraphQLEnumValue( + TypeKind.INTERFACE, + description='Indicates this type is an interface. ' + '`fields` and `possibleTypes` are valid fields.' + )), + ('UNION', GraphQLEnumValue( + TypeKind.UNION, + description='Indicates this type is a union. ' + '`possibleTypes` is a valid field.' + )), + ('ENUM', GraphQLEnumValue( + TypeKind.ENUM, + description='Indicates this type is an enum. ' + '`enumValues` is a valid field.' + )), + ('INPUT_OBJECT', GraphQLEnumValue( + TypeKind.INPUT_OBJECT, + description='Indicates this type is an input object. ' + '`inputFields` is a valid field.' + )), + ('LIST', GraphQLEnumValue( + TypeKind.LIST, + description='Indicates this type is a list. ' + '`ofType` is a valid field.' + )), + ('NON_NULL', GraphQLEnumValue( + TypeKind.NON_NULL, + description='Indicates this type is a non-null. ' + '`ofType` is a valid field.' + )), + ])) + +IntrospectionSchema = __Schema + +SchemaMetaFieldDef = GraphQLField( + # name='__schema', + type=GraphQLNonNull(__Schema), + description='Access the current type schema of this server.', + resolver=lambda source, args, context, info: info.schema, + args={} +) + +TypeMetaFieldDef = GraphQLField( + type=__Type, + # name='__type', + description='Request the type information of a single type.', + args={'name': GraphQLArgument(GraphQLNonNull(GraphQLString))}, + resolver=lambda source, args, context, info: info.schema.get_type(args['name']) +) + +TypeNameMetaFieldDef = GraphQLField( + type=GraphQLNonNull(GraphQLString), + # name='__typename', + description='The name of the current Object type at runtime.', + resolver=lambda source, args, context, info: info.parent_type.name, + args={} +) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py new file mode 100644 index 0000000000000000000000000000000000000000..62955b53d3f89cda049a07ff2b3b269b290aaff2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/scalars.py @@ -0,0 +1,131 @@ +from ..language.ast import BooleanValue, FloatValue, IntValue, StringValue +from .definition import GraphQLScalarType + +# As per the GraphQL Spec, Integers are only treated as valid when a valid +# 32-bit signed integer, providing the broadest support across platforms. +# +# n.b. JavaScript's integers are safe between -(2^53 - 1) and 2^53 - 1 because +# they are internally represented as IEEE 754 doubles. +MAX_INT = 2147483647 +MIN_INT = -2147483648 + + +def coerce_int(value): + if isinstance(value, int): + num = value + else: + try: + num = int(value) + except ValueError: + num = int(float(value)) + if MIN_INT <= num <= MAX_INT: + return num + raise Exception(( + "Int cannot represent non 32-bit signed integer value: {}" + ).format(value)) + + +def parse_int_literal(ast): + if isinstance(ast, IntValue): + num = int(ast.value) + if MIN_INT <= num <= MAX_INT: + return num + + +GraphQLInt = GraphQLScalarType( + name='Int', + description='The `Int` scalar type represents non-fractional signed whole numeric ' + 'values. Int can represent values between -(2^53 - 1) and 2^53 - 1 since ' + 'represented in JSON as double-precision floating point numbers specified' + 'by [IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point).', + serialize=coerce_int, + parse_value=coerce_int, + parse_literal=parse_int_literal) + + +def coerce_float(value): + if isinstance(value, float): + return value + return float(value) + + +def parse_float_literal(ast): + if isinstance(ast, (FloatValue, IntValue)): + return float(ast.value) + return None + + +GraphQLFloat = GraphQLScalarType( + name='Float', + description='The `Float` scalar type represents signed double-precision fractional ' + 'values as specified by ' + '[IEEE 754](http://en.wikipedia.org/wiki/IEEE_floating_point). ', + serialize=coerce_float, + parse_value=coerce_float, + parse_literal=parse_float_literal) + + +def coerce_string(value): + if isinstance(value, str): + return value + + if isinstance(value, bool): + return u'true' if value else u'false' + + return str(value) + + +def coerce_str(value): + if isinstance(value, str): + return value + + return str(value) + + +def parse_string_literal(ast): + if isinstance(ast, StringValue): + return ast.value + + return None + + +GraphQLString = GraphQLScalarType( + name='String', + description='The `String` scalar type represents textual data, represented as UTF-8 ' + 'character sequences. The String type is most often used by GraphQL to ' + 'represent free-form human-readable text.', + serialize=coerce_string, + parse_value=coerce_string, + parse_literal=parse_string_literal) + + +def parse_boolean_literal(ast): + if isinstance(ast, BooleanValue): + return ast.value + return None + + +GraphQLBoolean = GraphQLScalarType( + name='Boolean', + description='The `Boolean` scalar type represents `true` or `false`.', + serialize=bool, + parse_value=bool, + parse_literal=parse_boolean_literal) + + +def parse_id_literal(ast): + if isinstance(ast, (StringValue, IntValue)): + return ast.value + return None + + +GraphQLID = GraphQLScalarType( + name='ID', + description='The `ID` scalar type represents a unique identifier, often used to ' + 'refetch an object or as key for a cache. The ID type appears in a JSON ' + 'response as a String; however, it is not intended to be human-readable. ' + 'When expected as an input type, any string (such as `"4"`) or integer ' + '(such as `4`) input value will be accepted as an ID.', + serialize=coerce_str, + parse_value=coerce_str, + parse_literal=parse_id_literal) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..7f29b8fe176d1777edd4ba2a971a659169fcc543 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/schema.py @@ -0,0 +1,100 @@ +from collections.abc import Iterable + +from .definition import GraphQLObjectType +from .directives import GraphQLDirective, specified_directives +from .introspection import IntrospectionSchema +from .typemap import GraphQLTypeMap + + +class GraphQLSchema(object): + """Schema Definition + + A Schema is created by supplying the root types of each type of operation, query and mutation (optional). + A schema definition is then supplied to the validator and executor. + + Example: + + MyAppSchema = GraphQLSchema( + query=MyAppQueryRootType, + mutation=MyAppMutationRootType, + ) + + Note: If an array of `directives` are provided to GraphQLSchema, that will be + the exact list of directives represented and allowed. If `directives` is not + provided then a default set of the specified directives (e.g. @include and + @skip) will be used. If you wish to provide *additional* directives to these + specified directives, you must explicitly declare them. Example: + + MyAppSchema = GraphQLSchema( + ... + directives=specified_directives.extend([MyCustomerDirective]), + ) + """ + __slots__ = '_query', '_mutation', '_subscription', '_type_map', '_directives', '_implementations', '_possible_type_map' + + def __init__(self, query, mutation=None, subscription=None, directives=None, types=None): + assert isinstance(query, GraphQLObjectType), 'Schema query must be Object Type but got: {}.'.format(query) + if mutation: + assert isinstance(mutation, GraphQLObjectType), \ + 'Schema mutation must be Object Type but got: {}.'.format(mutation) + + if subscription: + assert isinstance(subscription, GraphQLObjectType), \ + 'Schema subscription must be Object Type but got: {}.'.format(subscription) + + if types: + assert isinstance(types, Iterable), \ + 'Schema types must be iterable if provided but got: {}.'.format(types) + + self._query = query + self._mutation = mutation + self._subscription = subscription + if directives is None: + directives = specified_directives + + assert all(isinstance(d, GraphQLDirective) for d in directives), \ + 'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format( + directives + ) + self._directives = directives + + initial_types = [ + query, + mutation, + subscription, + IntrospectionSchema + ] + if types: + initial_types += types + self._type_map = GraphQLTypeMap(initial_types) + + def get_query_type(self): + return self._query + + def get_mutation_type(self): + return self._mutation + + def get_subscription_type(self): + return self._subscription + + def get_type_map(self): + return self._type_map + + def get_type(self, name): + return self._type_map.get(name) + + def get_directives(self): + return self._directives + + def get_directive(self, name): + for directive in self.get_directives(): + if directive.name == name: + return directive + + return None + + def get_possible_types(self, abstract_type): + return self._type_map.get_possible_types(abstract_type) + + def is_possible_type(self, abstract_type, possible_type): + return self._type_map.is_possible_type(abstract_type, possible_type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py new file mode 100644 index 0000000000000000000000000000000000000000..12733a661371c44ee63f7edaeca9eb27c04c2a80 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/type/typemap.py @@ -0,0 +1,145 @@ +from collections import OrderedDict, defaultdict +from collections.abc import Sequence +from functools import reduce + +from ..utils.type_comparators import is_equal_type, is_type_sub_type_of +from .definition import (GraphQLArgument, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, GraphQLNonNull, + GraphQLObjectType, GraphQLUnionType, is_input_type, + is_output_type) + + +class GraphQLTypeMap(OrderedDict): + + def __init__(self, types): + super(GraphQLTypeMap, self).__init__() + self.update(reduce(self.reducer, types, OrderedDict())) + self._possible_type_map = defaultdict(set) + + # Keep track of all implementations by interface name. + self._implementations = {} + for gql_type in self.values(): + if isinstance(gql_type, GraphQLObjectType): + for interface in gql_type.interfaces: + self._implementations.setdefault(interface.name, []).append(gql_type) + + # Enforce correct interface implementations. + for type in self.values(): + if isinstance(type, GraphQLObjectType): + for interface in type.interfaces: + self.assert_object_implements_interface(self, type, interface) + + def get_possible_types(self, abstract_type): + if isinstance(abstract_type, GraphQLUnionType): + return abstract_type.types + assert isinstance(abstract_type, GraphQLInterfaceType) + return self._implementations.get(abstract_type.name, None) + + def is_possible_type(self, abstract_type, possible_type): + possible_types = self.get_possible_types(abstract_type) + assert isinstance(possible_types, Sequence), ( + 'Could not find possible implementing types for ${} in ' + + 'schema. Check that schema.types is defined and is an array of' + + 'all possible types in the schema.' + ).format(abstract_type) + + if not self._possible_type_map[abstract_type.name]: + self._possible_type_map[abstract_type.name].update([p.name for p in possible_types]) + + return possible_type.name in self._possible_type_map[abstract_type.name] + + @classmethod + def reducer(cls, map, type): + if not type: + return map + + if isinstance(type, GraphQLList) or isinstance(type, GraphQLNonNull): + return cls.reducer(map, type.of_type) + + if type.name in map: + assert map[type.name] == type, ( + 'Schema must contain unique named types but contains multiple types named "{}".' + ).format(type.name) + + return map + + map[type.name] = type + + reduced_map = map + + if isinstance(type, (GraphQLUnionType)): + for t in type.types: + reduced_map = cls.reducer(reduced_map, t) + + if isinstance(type, GraphQLObjectType): + for t in type.interfaces: + reduced_map = cls.reducer(reduced_map, t) + + if isinstance(type, (GraphQLObjectType, GraphQLInterfaceType, GraphQLInputObjectType)): + field_map = type.fields + type_is_input = isinstance(type, GraphQLInputObjectType) + for field_name, field in field_map.items(): + if type_is_input: + assert isinstance(field, GraphQLInputObjectField), ( + '{}.{} must be an instance of GraphQLInputObjectField.'.format(type, field_name) + ) + assert is_input_type(field.type), ( + '{}.{} field type must be Input Type but got: {}.'.format(type, field_name, field.type) + ) + else: + assert isinstance(field, (GraphQLField, GraphQLField)), ( + '{}.{} must be an instance of GraphQLField.'.format(type, field_name) + ) + assert is_output_type(field.type), ( + '{}.{} field type must be Output Type but got: {}.'.format(type, field_name, field.type) + ) + for arg_name, arg in field.args.items(): + assert isinstance(arg, (GraphQLArgument, GraphQLArgument)), ( + '{}.{}({}:) argument must be an instance of GraphQLArgument.'.format(type, field_name, arg_name) + ) + assert is_input_type(arg.type), ( + '{}.{}({}:) argument type must be Input Type but got: {}.'.format(type, field_name, arg_name, + arg.type) + ) + reduced_map = cls.reducer(reduced_map, arg.type) + + reduced_map = cls.reducer(reduced_map, getattr(field, 'type', None)) + + return reduced_map + + @classmethod + def assert_object_implements_interface(cls, schema, object, interface): + object_field_map = object.fields + interface_field_map = interface.fields + + for field_name, interface_field in interface_field_map.items(): + object_field = object_field_map.get(field_name) + + assert object_field, '"{}" expects field "{}" but "{}" does not provide it.'.format( + interface, field_name, object + ) + + assert is_type_sub_type_of(schema, object_field.type, interface_field.type), ( + '{}.{} expects type "{}" but {}.{} provides type "{}".' + ).format(interface, field_name, interface_field.type, object, field_name, object_field.type) + + for arg_name, interface_arg in interface_field.args.items(): + object_arg = object_field.args.get(arg_name) + + assert object_arg, ( + '{}.{} expects argument "{}" but {}.{} does not provide it.' + ).format(interface, field_name, arg_name, object, field_name) + + assert is_equal_type(interface_arg.type, object_arg.type), ( + '{}.{}({}:) expects type "{}" but {}.{}({}:) provides type "{}".' + ).format(interface, field_name, arg_name, interface_arg.type, object, field_name, arg_name, object_arg.type) + + for arg_name, object_arg in object_field.args.items(): + interface_arg = interface_field.args.get(arg_name) + if not interface_arg: + assert not isinstance(object_arg.type, GraphQLNonNull), ( + '{}.{}({}:) is of required type ' + '"{}" but is not also provided by the ' + 'interface {}.{}.' + ).format(object, field_name, arg_name, object_arg.type, interface, field_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..710954b442e229231ea28fd3ff22de78810dc044 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/assert_valid_name.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/assert_valid_name.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..701804830a60160b4760648440af49224914c4e2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/assert_valid_name.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/ast_from_value.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/ast_from_value.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02b016976fadc0d8ab234a1e765664eb51ec66e5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/ast_from_value.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/base.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f93cd976ad12f4195230fd124c95534f4614b199 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/base.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/build_ast_schema.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/build_ast_schema.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fe2af095bc84187cc6eb6634fba7b323bca2f9b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/build_ast_schema.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/build_client_schema.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/build_client_schema.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b8988ecabad8b2b26484b026e366e189bfb6cf Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/build_client_schema.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/concat_ast.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/concat_ast.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d664a06fac2e44d2184b27259ee93e04a7cb22c5 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/concat_ast.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/extend_schema.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/extend_schema.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01834c2eb84eaa2094b4768fcbce3f3d20492f29 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/extend_schema.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/get_field_def.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/get_field_def.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11bb34f042320bbcda940c26d29e27590e8f00dc Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/get_field_def.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/get_operation_ast.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/get_operation_ast.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f68d22a9563b8cc91b62aa8678eef6ca6f42fbcf Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/get_operation_ast.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/introspection_query.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/introspection_query.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e7c5d0a32c5996b1f72095bcb03b7da20f6fe8e Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/introspection_query.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/is_valid_literal_value.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/is_valid_literal_value.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cdecd5a4f3b99ed191e259ada61e74e3cc8f65d Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/is_valid_literal_value.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/is_valid_value.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/is_valid_value.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36afb82126203cc7e7b22609c09d1af93cba0e6b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/is_valid_value.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/quoted_or_list.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/quoted_or_list.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab4c1c25d5d301cd669b625aba2258abeabdb3ba Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/quoted_or_list.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/schema_printer.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/schema_printer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2763bc2c607d4cba03a09550b3c8cf434608cbd Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/schema_printer.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/suggestion_list.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/suggestion_list.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5274a6ad00e8ca38ebb78bb25fac93bd45c8f9cd Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/suggestion_list.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_comparators.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_comparators.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2cf4c88da6a624ca6cb1482f5a0e512819101b2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_comparators.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_from_ast.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_from_ast.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994b7d45c0e66257ed9a1f635aab06a8e94ad207 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_from_ast.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_info.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_info.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6bc40aa2e79c0bc2f0dcae5234939f752b6bdd8 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/type_info.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/value_from_ast.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/value_from_ast.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb1a2f04591485839a73632216b43baed8841798 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/__pycache__/value_from_ast.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py new file mode 100644 index 0000000000000000000000000000000000000000..40afe59622232f4fc82e0ebeba623a138a02f4c6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/assert_valid_name.py @@ -0,0 +1,9 @@ +import re + +NAME_PATTERN = r'^[_a-zA-Z][_a-zA-Z0-9]*$' +COMPILED_NAME_PATTERN = re.compile(NAME_PATTERN) + + +def assert_valid_name(name): + '''Helper to assert that provided names are valid.''' + assert COMPILED_NAME_PATTERN.match(name), 'Names must match /{}/ but "{}" does not.'.format(NAME_PATTERN, name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py new file mode 100644 index 0000000000000000000000000000000000000000..1355ef565184335ffe7f483da8be6234e766e27d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_from_value.py @@ -0,0 +1,65 @@ +import json +import re +import sys + +from ..language import ast +from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, + GraphQLList, GraphQLNonNull) +from ..type.scalars import GraphQLFloat + + +def ast_from_value(value, type=None): + if isinstance(type, GraphQLNonNull): + return ast_from_value(value, type.of_type) + + if value is None: + return None + + if isinstance(value, list): + item_type = type.of_type if isinstance(type, GraphQLList) else None + return ast.ListValue([ast_from_value(item, item_type) for item in value]) + + elif isinstance(type, GraphQLList): + return ast_from_value(value, type.of_type) + + if isinstance(value, bool): + return ast.BooleanValue(value) + + if isinstance(value, (int, float)): + string_num = str(value) + int_value = int(value) + is_int_value = string_num.isdigit() + + if is_int_value or (int_value == value and value < sys.maxsize): + if type == GraphQLFloat: + return ast.FloatValue(str(float(value))) + + return ast.IntValue(str(int(value))) + + return ast.FloatValue(string_num) + + if isinstance(value, str): + if isinstance(type, GraphQLEnumType) and re.match(r'^[_a-zA-Z][_a-zA-Z0-9]*$', value): + return ast.EnumValue(value) + + return ast.StringValue(json.dumps(value)[1:-1]) + + assert isinstance(value, dict) + + fields = [] + is_graph_ql_input_object_type = isinstance(type, GraphQLInputObjectType) + + for field_name, field_value in value.items(): + field_type = None + if is_graph_ql_input_object_type: + field_def = type.fields.get(field_name) + field_type = field_def and field_def.type + + field_value = ast_from_value(field_value, field_type) + if field_value: + fields.append(ast.ObjectField( + ast.Name(field_name), + field_value + )) + + return ast.ObjectValue(fields) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py new file mode 100644 index 0000000000000000000000000000000000000000..8c307eead92e9a533a01f36b80c7187d060df9ae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_code.py @@ -0,0 +1,49 @@ +from ..language.ast import Node +from ..language.parser import Loc + + +def ast_to_code(ast, indent=0): + """ + Converts an ast into a python code representation of the AST. + """ + code = [] + + def append(line): + code.append((' ' * indent) + line) + + if isinstance(ast, Node): + append('ast.{}('.format(ast.__class__.__name__)) + indent += 1 + for i, k in enumerate(ast._fields, 1): + v = getattr(ast, k) + append('{}={},'.format( + k, + ast_to_code(v, indent), + )) + if ast.loc: + append('loc={}'.format(ast_to_code(ast.loc, indent))) + + indent -= 1 + append(')') + + elif isinstance(ast, Loc): + append('loc({}, {})'.format(ast.start, ast.end)) + + elif isinstance(ast, list): + if ast: + append('[') + indent += 1 + + for i, it in enumerate(ast, 1): + is_last = i == len(ast) + append(ast_to_code(it, indent) + (',' if not is_last else '')) + + indent -= 1 + append(']') + else: + append('[]') + + else: + append(repr(ast)) + + return '\n'.join(code).strip() diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b5cac5a134902ef98a7caeea8a4cc802f22cc8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/ast_to_dict.py @@ -0,0 +1,24 @@ +from ..language.ast import Node + + +def ast_to_dict(node, include_loc=False): + if isinstance(node, Node): + d = { + 'kind': node.__class__.__name__ + } + if hasattr(node, '_fields'): + for field in node._fields: + d[field] = ast_to_dict(getattr(node, field), include_loc) + + if include_loc and hasattr(node, 'loc') and node.loc: + d['loc'] = { + 'start': node.loc.start, + 'end': node.loc.end + } + + return d + + elif isinstance(node, list): + return [ast_to_dict(item, include_loc) for item in node] + + return node diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8958539b88c2637596e1ca03cc2195632376b7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/base.py @@ -0,0 +1,75 @@ +""" + Base GraphQL utilities + isort:skip_file +""" + +# The GraphQL query recommended for a full schema introspection. +from .introspection_query import introspection_query + +# Gets the target Operation from a Document +from .get_operation_ast import get_operation_ast + +# Build a GraphQLSchema from an introspection result. +from .build_client_schema import build_client_schema + +# Build a GraphQLSchema from a parsed GraphQL Schema language AST. +from .build_ast_schema import build_ast_schema + +# Extends an existing GraphQLSchema from a parsed GraphQL Schema language AST. +from .extend_schema import extend_schema + +# Print a GraphQLSchema to GraphQL Schema language. +from .schema_printer import print_schema, print_introspection_schema + +# Create a GraphQLType from a GraphQL language AST. +from .type_from_ast import type_from_ast + +# Create a JavaScript value from a GraphQL language AST. +from .value_from_ast import value_from_ast + +# Create a GraphQL language AST from a JavaScript value. +from .ast_from_value import ast_from_value + +# A helper to use within recursive-descent visitors which need to be aware of +# the GraphQL type system. +from .type_info import TypeInfo + +# Determine if JavaScript values adhere to a GraphQL type. +from .is_valid_value import is_valid_value + +# Determine if AST values adhere to a GraphQL type. +from .is_valid_literal_value import is_valid_literal_value + +# Concatenates multiple AST together. +from .concat_ast import concat_ast + +# Comparators for types +from .type_comparators import ( + is_equal_type, + is_type_sub_type_of, + do_types_overlap +) + +# Asserts that a string is a valid GraphQL name +from .assert_valid_name import assert_valid_name + +__all__ = [ + 'introspection_query', + 'get_operation_ast', + 'build_client_schema', + 'build_ast_schema', + 'extend_schema', + 'print_introspection_schema', + 'print_schema', + 'type_from_ast', + 'value_from_ast', + 'ast_from_value', + 'TypeInfo', + 'is_valid_value', + 'is_valid_literal_value', + 'concat_ast', + 'do_types_overlap', + 'is_equal_type', + 'is_type_sub_type_of', + 'assert_valid_name', +] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..e0c632b24bce9fe37227dc7b631c5ddcae90ee1d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_ast_schema.py @@ -0,0 +1,291 @@ +from ..execution.values import get_argument_values +from ..language import ast +from ..pyutils.ordereddict import OrderedDict +from ..type import (GraphQLArgument, GraphQLBoolean, + GraphQLDeprecatedDirective, GraphQLDirective, + GraphQLEnumType, GraphQLEnumValue, GraphQLField, + GraphQLFloat, GraphQLID, GraphQLIncludeDirective, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLSkipDirective, GraphQLString, + GraphQLUnionType) +from ..type.introspection import (__Directive, __DirectiveLocation, + __EnumValue, __Field, __InputValue, __Schema, + __Type, __TypeKind) +from ..utils.value_from_ast import value_from_ast + + +def _build_wrapped_type(inner_type, input_type_ast): + if isinstance(input_type_ast, ast.ListType): + return GraphQLList(_build_wrapped_type(inner_type, input_type_ast.type)) + + if isinstance(input_type_ast, ast.NonNullType): + return GraphQLNonNull(_build_wrapped_type(inner_type, input_type_ast.type)) + + return inner_type + + +def _get_inner_type_name(type_ast): + if isinstance(type_ast, (ast.ListType, ast.NonNullType)): + return _get_inner_type_name(type_ast.type) + + return type_ast.name.value + + +def _get_named_type_ast(type_ast): + named_type = type_ast + while isinstance(named_type, (ast.ListType, ast.NonNullType)): + named_type = named_type.type + + return named_type + + +def _false(*_): + return False + + +def _none(*_): + return None + + +def build_ast_schema(document): + assert isinstance(document, ast.Document), 'must pass in Document ast.' + + schema_def = None + + type_asts = ( + ast.ScalarTypeDefinition, + ast.ObjectTypeDefinition, + ast.InterfaceTypeDefinition, + ast.EnumTypeDefinition, + ast.UnionTypeDefinition, + ast.InputObjectTypeDefinition, + ) + + type_defs = [] + directive_defs = [] + + for d in document.definitions: + if isinstance(d, ast.SchemaDefinition): + if schema_def: + raise Exception('Must provide only one schema definition.') + schema_def = d + if isinstance(d, type_asts): + type_defs.append(d) + elif isinstance(d, ast.DirectiveDefinition): + directive_defs.append(d) + + if not schema_def: + raise Exception('Must provide a schema definition.') + + query_type_name = None + mutation_type_name = None + subscription_type_name = None + + for operation_type in schema_def.operation_types: + type_name = operation_type.type.name.value + if operation_type.operation == 'query': + if query_type_name: + raise Exception('Must provide only one query type in schema.') + query_type_name = type_name + elif operation_type.operation == 'mutation': + if mutation_type_name: + raise Exception('Must provide only one mutation type in schema.') + mutation_type_name = type_name + elif operation_type.operation == 'subscription': + if subscription_type_name: + raise Exception('Must provide only one subscription type in schema.') + subscription_type_name = type_name + + if not query_type_name: + raise Exception('Must provide schema definition with query type.') + + ast_map = {d.name.value: d for d in type_defs} + + if query_type_name not in ast_map: + raise Exception('Specified query type "{}" not found in document.'.format(query_type_name)) + + if mutation_type_name and mutation_type_name not in ast_map: + raise Exception('Specified mutation type "{}" not found in document.'.format(mutation_type_name)) + + if subscription_type_name and subscription_type_name not in ast_map: + raise Exception('Specified subscription type "{}" not found in document.'.format(subscription_type_name)) + + inner_type_map = OrderedDict([ + ('String', GraphQLString), + ('Int', GraphQLInt), + ('Float', GraphQLFloat), + ('Boolean', GraphQLBoolean), + ('ID', GraphQLID), + ('__Schema', __Schema), + ('__Directive', __Directive), + ('__DirectiveLocation', __DirectiveLocation), + ('__Type', __Type), + ('__Field', __Field), + ('__InputValue', __InputValue), + ('__EnumValue', __EnumValue), + ('__TypeKind', __TypeKind), + ]) + + def get_directive(directive_ast): + return GraphQLDirective( + name=directive_ast.name.value, + locations=[node.value for node in directive_ast.locations], + args=make_input_values(directive_ast.arguments, GraphQLArgument), + ) + + def get_object_type(type_ast): + type = type_def_named(type_ast.name.value) + assert isinstance(type, GraphQLObjectType), 'AST must provide object type' + return type + + def produce_type_def(type_ast): + type_name = _get_named_type_ast(type_ast).name.value + type_def = type_def_named(type_name) + return _build_wrapped_type(type_def, type_ast) + + def type_def_named(type_name): + if type_name in inner_type_map: + return inner_type_map[type_name] + + if type_name not in ast_map: + raise Exception('Type "{}" not found in document'.format(type_name)) + + inner_type_def = make_schema_def(ast_map[type_name]) + if not inner_type_def: + raise Exception('Nothing constructed for "{}".'.format(type_name)) + + inner_type_map[type_name] = inner_type_def + return inner_type_def + + def make_schema_def(definition): + if not definition: + raise Exception('def must be defined.') + + handler = _schema_def_handlers.get(type(definition)) + if not handler: + raise Exception('Type kind "{}" not supported.'.format(type(definition).__name__)) + + return handler(definition) + + def make_type_def(definition): + return GraphQLObjectType( + name=definition.name.value, + fields=lambda: make_field_def_map(definition), + interfaces=make_implemented_interfaces(definition) + ) + + def make_field_def_map(definition): + return OrderedDict( + (f.name.value, GraphQLField( + type=produce_type_def(f.type), + args=make_input_values(f.arguments, GraphQLArgument), + deprecation_reason=get_deprecation_reason(f.directives), + )) + for f in definition.fields + ) + + def make_implemented_interfaces(definition): + return [produce_type_def(i) for i in definition.interfaces] + + def make_input_values(values, cls): + return OrderedDict( + (value.name.value, cls( + type=produce_type_def(value.type), + default_value=value_from_ast(value.default_value, produce_type_def(value.type)) + )) + for value in values + ) + + def make_interface_def(definition): + return GraphQLInterfaceType( + name=definition.name.value, + resolve_type=_none, + fields=lambda: make_field_def_map(definition) + ) + + def make_enum_def(definition): + values = OrderedDict((v.name.value, GraphQLEnumValue(deprecation_reason=get_deprecation_reason(v.directives))) + for v in definition.values) + return GraphQLEnumType( + name=definition.name.value, + values=values + ) + + def make_union_def(definition): + return GraphQLUnionType( + name=definition.name.value, + resolve_type=_none, + types=[produce_type_def(t) for t in definition.types] + ) + + def make_scalar_def(definition): + return GraphQLScalarType( + name=definition.name.value, + serialize=_none, + # Validation calls the parse functions to determine if a literal value is correct. + # Returning none, however would cause the scalar to fail validation. Returning false, + # will cause them to pass. + parse_literal=_false, + parse_value=_false + ) + + def make_input_object_def(definition): + return GraphQLInputObjectType( + name=definition.name.value, + fields=make_input_values(definition.fields, GraphQLInputObjectField) + ) + + _schema_def_handlers = { + ast.ObjectTypeDefinition: make_type_def, + ast.InterfaceTypeDefinition: make_interface_def, + ast.EnumTypeDefinition: make_enum_def, + ast.UnionTypeDefinition: make_union_def, + ast.ScalarTypeDefinition: make_scalar_def, + ast.InputObjectTypeDefinition: make_input_object_def + } + types = [type_def_named(definition.name.value) for definition in type_defs] + directives = [get_directive(d) for d in directive_defs] + + # If specified directive were not explicitly declared, add them. + find_skip_directive = (directive.name for directive in directives if directive.name == 'skip') + find_include_directive = (directive.name for directive in directives if directive.name == 'include') + find_deprecated_directive = (directive.name for directive in directives if directive.name == 'deprecated') + + if not next(find_skip_directive, None): + directives.append(GraphQLSkipDirective) + + if not next(find_include_directive, None): + directives.append(GraphQLIncludeDirective) + + if not next(find_deprecated_directive, None): + directives.append(GraphQLDeprecatedDirective) + + schema_kwargs = {'query': get_object_type(ast_map[query_type_name])} + + if mutation_type_name: + schema_kwargs['mutation'] = get_object_type(ast_map[mutation_type_name]) + + if subscription_type_name: + schema_kwargs['subscription'] = get_object_type(ast_map[subscription_type_name]) + + if directive_defs: + schema_kwargs['directives'] = directives + + if types: + schema_kwargs['types'] = types + + return GraphQLSchema(**schema_kwargs) + + +def get_deprecation_reason(directives): + deprecated_ast = next((directive for directive in directives + if directive.name.value == GraphQLDeprecatedDirective.name), + None) + + if deprecated_ast: + args = get_argument_values(GraphQLDeprecatedDirective.args, deprecated_ast.arguments) + return args['reason'] + else: + return None diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..116b77715eef61327bd5856165ab4f9f67e42627 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/build_client_schema.py @@ -0,0 +1,250 @@ +from ..language.parser import parse_value +from ..pyutils.ordereddict import OrderedDict +from ..type import (GraphQLArgument, GraphQLBoolean, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, GraphQLFloat, GraphQLID, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLScalarType, + GraphQLSchema, GraphQLString, GraphQLUnionType, + is_input_type, is_output_type) +from ..type.directives import DirectiveLocation, GraphQLDirective +from ..type.introspection import (TypeKind, __Directive, __DirectiveLocation, + __EnumValue, __Field, __InputValue, __Schema, + __Type, __TypeKind) +from .value_from_ast import value_from_ast + + +def _false(*_): + return False + + +def _none(*_): + return None + + +def no_execution(*args): + raise Exception('Client Schema cannot be used for execution.') + + +def build_client_schema(introspection): + schema_introspection = introspection['__schema'] + + type_introspection_map = {t['name']: t for t in schema_introspection['types']} + + type_def_cache = { + 'String': GraphQLString, + 'Int': GraphQLInt, + 'Float': GraphQLFloat, + 'Boolean': GraphQLBoolean, + 'ID': GraphQLID, + '__Schema': __Schema, + '__Directive': __Directive, + '__DirectiveLocation': __DirectiveLocation, + '__Type': __Type, + '__Field': __Field, + '__InputValue': __InputValue, + '__EnumValue': __EnumValue, + '__TypeKind': __TypeKind, + + } + + def get_type(type_ref): + kind = type_ref.get('kind') + + if kind == TypeKind.LIST: + item_ref = type_ref.get('ofType') + + if not item_ref: + raise Exception('Decorated type deeper than introspection query.') + + return GraphQLList(get_type(item_ref)) + + elif kind == TypeKind.NON_NULL: + nullable_ref = type_ref.get('ofType') + if not nullable_ref: + raise Exception('Decorated type deeper than introspection query.') + + return GraphQLNonNull(get_type(nullable_ref)) + + return get_named_type(type_ref['name']) + + def get_named_type(type_name): + if type_name in type_def_cache: + return type_def_cache[type_name] + + type_introspection = type_introspection_map.get(type_name) + if not type_introspection: + raise Exception( + 'Invalid or incomplete schema, unknown type: {}. Ensure that a full introspection query ' + 'is used in order to build a client schema.'.format(type_name) + ) + + type_def = type_def_cache[type_name] = build_type(type_introspection) + return type_def + + def get_input_type(type_ref): + input_type = get_type(type_ref) + assert is_input_type(input_type), 'Introspection must provide input type for arguments.' + return input_type + + def get_output_type(type_ref): + output_type = get_type(type_ref) + assert is_output_type(output_type), 'Introspection must provide output type for fields.' + return output_type + + def get_object_type(type_ref): + object_type = get_type(type_ref) + assert isinstance(object_type, GraphQLObjectType), 'Introspection must provide object type for possibleTypes.' + return object_type + + def get_interface_type(type_ref): + interface_type = get_type(type_ref) + assert isinstance(interface_type, GraphQLInterfaceType), \ + 'Introspection must provide interface type for interfaces.' + return interface_type + + def build_type(type): + type_kind = type.get('kind') + handler = type_builders.get(type_kind) + if not handler: + raise Exception( + 'Invalid or incomplete schema, unknown kind: {}. Ensure that a full introspection query ' + 'is used in order to build a client schema.'.format(type_kind) + ) + + return handler(type) + + def build_scalar_def(scalar_introspection): + return GraphQLScalarType( + name=scalar_introspection['name'], + description=scalar_introspection.get('description'), + serialize=_none, + parse_value=_false, + parse_literal=_false + ) + + def build_object_def(object_introspection): + return GraphQLObjectType( + name=object_introspection['name'], + description=object_introspection.get('description'), + interfaces=[get_interface_type(i) for i in object_introspection.get('interfaces', [])], + fields=lambda: build_field_def_map(object_introspection) + ) + + def build_interface_def(interface_introspection): + return GraphQLInterfaceType( + name=interface_introspection['name'], + description=interface_introspection.get('description'), + fields=lambda: build_field_def_map(interface_introspection), + resolve_type=no_execution + ) + + def build_union_def(union_introspection): + return GraphQLUnionType( + name=union_introspection['name'], + description=union_introspection.get('description'), + types=[get_object_type(t) for t in union_introspection.get('possibleTypes', [])], + resolve_type=no_execution + ) + + def build_enum_def(enum_introspection): + return GraphQLEnumType( + name=enum_introspection['name'], + description=enum_introspection.get('description'), + values=OrderedDict([(value_introspection['name'], + GraphQLEnumValue(description=value_introspection.get('description'), + deprecation_reason=value_introspection.get('deprecationReason'))) + for value_introspection in enum_introspection.get('enumValues', []) + ]) + ) + + def build_input_object_def(input_object_introspection): + return GraphQLInputObjectType( + name=input_object_introspection['name'], + description=input_object_introspection.get('description'), + fields=lambda: build_input_value_def_map( + input_object_introspection.get('inputFields'), GraphQLInputObjectField + ) + ) + + type_builders = { + TypeKind.SCALAR: build_scalar_def, + TypeKind.OBJECT: build_object_def, + TypeKind.INTERFACE: build_interface_def, + TypeKind.UNION: build_union_def, + TypeKind.ENUM: build_enum_def, + TypeKind.INPUT_OBJECT: build_input_object_def + } + + def build_field_def_map(type_introspection): + return OrderedDict([ + (f['name'], GraphQLField( + type=get_output_type(f['type']), + description=f.get('description'), + resolver=no_execution, + deprecation_reason=f.get('deprecationReason'), + args=build_input_value_def_map(f.get('args'), GraphQLArgument))) + for f in type_introspection.get('fields', []) + ]) + + def build_default_value(f): + default_value = f.get('defaultValue') + if default_value is None: + return None + + return value_from_ast(parse_value(default_value), get_input_type(f['type'])) + + def build_input_value_def_map(input_value_introspection, argument_type): + return OrderedDict([ + (f['name'], build_input_value(f, argument_type)) for f in input_value_introspection + ]) + + def build_input_value(input_value_introspection, argument_type): + input_value = argument_type( + description=input_value_introspection['description'], + type=get_input_type(input_value_introspection['type']), + default_value=build_default_value(input_value_introspection) + ) + return input_value + + def build_directive(directive_introspection): + # Support deprecated `on****` fields for building `locations`, as this + # is used by GraphiQL which may need to support outdated servers. + locations = list(directive_introspection.get('locations', [])) + if not locations: + locations = [] + if directive_introspection.get('onField', False): + locations += list(DirectiveLocation.FIELD_LOCATIONS) + if directive_introspection.get('onOperation', False): + locations += list(DirectiveLocation.OPERATION_LOCATIONS) + if directive_introspection.get('onFragment', False): + locations += list(DirectiveLocation.FRAGMENT_LOCATIONS) + + return GraphQLDirective( + name=directive_introspection['name'], + description=directive_introspection.get('description'), + # TODO: {} ? + args=build_input_value_def_map(directive_introspection.get('args', {}), GraphQLArgument), + locations=locations + ) + + # Iterate through all types, getting the type definition for each, ensuring + # that any type not directly referenced by a field will get created. + types = [get_named_type(type_introspection_name) for type_introspection_name in type_introspection_map.keys()] + + query_type = get_object_type(schema_introspection['queryType']) + mutation_type = get_object_type( + schema_introspection['mutationType']) if schema_introspection.get('mutationType') else None + subscription_type = get_object_type(schema_introspection['subscriptionType']) if \ + schema_introspection.get('subscriptionType') else None + + directives = [build_directive(d) for d in schema_introspection['directives']] \ + if schema_introspection['directives'] else [] + + return GraphQLSchema( + query=query_type, + mutation=mutation_type, + subscription=subscription_type, + directives=directives, + types=types + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..9abebe9245f4d34ef5af35d5cf47d412d0d29398 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/concat_ast.py @@ -0,0 +1,9 @@ +import itertools + +from ..language.ast import Document + + +def concat_ast(asts): + return Document(definitions=list(itertools.chain.from_iterable( + document.definitions for document in asts + ))) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bd3451e6e879a6f1cb192e0a9acbf085cf8706 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/extend_schema.py @@ -0,0 +1,357 @@ +from collections import defaultdict + +from ..error import GraphQLError +from ..language import ast +from ..pyutils.ordereddict import OrderedDict +from ..type.definition import (GraphQLArgument, GraphQLEnumType, + GraphQLEnumValue, GraphQLField, + GraphQLInputObjectField, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLScalarType, GraphQLUnionType) +from ..type.introspection import (__Directive, __DirectiveLocation, + __EnumValue, __Field, __InputValue, __Schema, + __Type, __TypeKind) +from ..type.scalars import (GraphQLBoolean, GraphQLFloat, GraphQLID, + GraphQLInt, GraphQLString) +from ..type.schema import GraphQLSchema +from .value_from_ast import value_from_ast + + +def extend_schema(schema, documentAST=None): + """Produces a new schema given an existing schema and a document which may + contain GraphQL type extensions and definitions. The original schema will + remain unaltered. + + Because a schema represents a graph of references, a schema cannot be + extended without effectively making an entire copy. We do not know until it's + too late if subgraphs remain unchanged. + + This algorithm copies the provided schema, applying extensions while + producing the copy. The original schema remains unaltered.""" + + assert isinstance( + schema, GraphQLSchema), 'Must provide valid GraphQLSchema' + assert documentAST and isinstance( + documentAST, ast.Document), 'Must provide valid Document AST' + + # Collect the type definitions and extensions found in the document. + type_definition_map = {} + type_extensions_map = defaultdict(list) + + for _def in documentAST.definitions: + if isinstance(_def, ( + ast.ObjectTypeDefinition, + ast.InterfaceTypeDefinition, + ast.EnumTypeDefinition, + ast.UnionTypeDefinition, + ast.ScalarTypeDefinition, + ast.InputObjectTypeDefinition, + )): + # Sanity check that none of the defined types conflict with the + # schema's existing types. + type_name = _def.name.value + if schema.get_type(type_name): + raise GraphQLError( + ('Type "{}" already exists in the schema. It cannot also ' + + 'be defined in this type definition.').format(type_name), + [_def] + ) + + type_definition_map[type_name] = _def + elif isinstance(_def, ast.TypeExtensionDefinition): + # Sanity check that this type extension exists within the + # schema's existing types. + extended_type_name = _def.definition.name.value + existing_type = schema.get_type(extended_type_name) + if not existing_type: + raise GraphQLError( + ('Cannot extend type "{}" because it does not ' + + 'exist in the existing schema.').format(extended_type_name), + [_def.definition] + ) + if not isinstance(existing_type, GraphQLObjectType): + raise GraphQLError( + 'Cannot extend non-object type "{}".'.format( + extended_type_name), + [_def.definition] + ) + + type_extensions_map[extended_type_name].append(_def) + + # Below are functions used for producing this schema that have closed over + # this scope and have access to the schema, cache, and newly defined types. + + def get_type_from_def(type_def): + type = _get_named_type(type_def.name) + assert type, 'Invalid schema' + return type + + def get_type_from_AST(astNode): + type = _get_named_type(astNode.name.value) + if not type: + raise GraphQLError( + ('Unknown type: "{}". Ensure that this type exists ' + + 'either in the original schema, or is added in a type definition.').format( + astNode.name.value), + [astNode] + ) + return type + + # Given a name, returns a type from either the existing schema or an + # added type. + def _get_named_type(typeName): + cached_type_def = type_def_cache.get(typeName) + if cached_type_def: + return cached_type_def + + existing_type = schema.get_type(typeName) + if existing_type: + type_def = extend_type(existing_type) + type_def_cache[typeName] = type_def + return type_def + + type_ast = type_definition_map.get(typeName) + if type_ast: + type_def = build_type(type_ast) + type_def_cache[typeName] = type_def + return type_def + + # Given a type's introspection result, construct the correct + # GraphQLType instance. + def extend_type(type): + if isinstance(type, GraphQLObjectType): + return extend_object_type(type) + if isinstance(type, GraphQLInterfaceType): + return extend_interface_type(type) + if isinstance(type, GraphQLUnionType): + return extend_union_type(type) + return type + + def extend_object_type(type): + return GraphQLObjectType( + name=type.name, + description=type.description, + interfaces=lambda: extend_implemented_interfaces(type), + fields=lambda: extend_field_map(type), + ) + + def extend_interface_type(type): + return GraphQLInterfaceType( + name=type.name, + description=type.description, + fields=lambda: extend_field_map(type), + resolve_type=cannot_execute_client_schema, + ) + + def extend_union_type(type): + return GraphQLUnionType( + name=type.name, + description=type.description, + types=list(map(get_type_from_def, type.types)), + resolve_type=cannot_execute_client_schema, + ) + + def extend_implemented_interfaces(type): + interfaces = list(map(get_type_from_def, type.interfaces)) + + # If there are any extensions to the interfaces, apply those here. + extensions = type_extensions_map[type.name] + for extension in extensions: + for namedType in extension.definition.interfaces: + interface_name = namedType.name.value + if any([_def.name == interface_name for _def in interfaces]): + raise GraphQLError( + ('Type "{}" already implements "{}". ' + + 'It cannot also be implemented in this type extension.').format( + type.name, interface_name), + [namedType] + ) + interfaces.append(get_type_from_AST(namedType)) + + return interfaces + + def extend_field_map(type): + new_field_map = OrderedDict() + old_field_map = type.fields + for field_name, field in old_field_map.items(): + new_field_map[field_name] = GraphQLField( + extend_field_type(field.type), + description=field.description, + deprecation_reason=field.deprecation_reason, + args=field.args, + resolver=cannot_execute_client_schema, + ) + + # If there are any extensions to the fields, apply those here. + extensions = type_extensions_map[type.name] + for extension in extensions: + for field in extension.definition.fields: + field_name = field.name.value + if field_name in old_field_map: + raise GraphQLError( + ('Field "{}.{}" already exists in the ' + + 'schema. It cannot also be defined in this type extension.').format( + type.name, field_name), + [field] + ) + new_field_map[field_name] = GraphQLField( + build_field_type(field.type), + args=build_input_values(field.arguments), + resolver=cannot_execute_client_schema, + ) + + return new_field_map + + def extend_field_type(type): + if isinstance(type, GraphQLList): + return GraphQLList(extend_field_type(type.of_type)) + if isinstance(type, GraphQLNonNull): + return GraphQLNonNull(extend_field_type(type.of_type)) + return get_type_from_def(type) + + def build_type(type_ast): + _type_build = { + ast.ObjectTypeDefinition: build_object_type, + ast.InterfaceTypeDefinition: build_interface_type, + ast.UnionTypeDefinition: build_union_type, + ast.ScalarTypeDefinition: build_scalar_type, + ast.EnumTypeDefinition: build_enum_type, + ast.InputObjectTypeDefinition: build_input_object_type + } + func = _type_build.get(type(type_ast)) + if func: + return func(type_ast) + + def build_object_type(type_ast): + return GraphQLObjectType( + type_ast.name.value, + interfaces=lambda: build_implemented_interfaces(type_ast), + fields=lambda: build_field_map(type_ast), + ) + + def build_interface_type(type_ast): + return GraphQLInterfaceType( + type_ast.name.value, + fields=lambda: build_field_map(type_ast), + resolve_type=cannot_execute_client_schema, + ) + + def build_union_type(type_ast): + return GraphQLUnionType( + type_ast.name.value, + types=list(map(get_type_from_AST, type_ast.types)), + resolve_type=cannot_execute_client_schema, + ) + + def build_scalar_type(type_ast): + return GraphQLScalarType( + type_ast.name.value, + serialize=lambda *args, **kwargs: None, + # Note: validation calls the parse functions to determine if a + # literal value is correct. Returning null would cause use of custom + # scalars to always fail validation. Returning false causes them to + # always pass validation. + parse_value=lambda *args, **kwargs: False, + parse_literal=lambda *args, **kwargs: False, + ) + + def build_enum_type(type_ast): + return GraphQLEnumType( + type_ast.name.value, + values={v.name.value: GraphQLEnumValue() for v in type_ast.values}, + ) + + def build_input_object_type(type_ast): + return GraphQLInputObjectType( + type_ast.name.value, + fields=lambda: build_input_values( + type_ast.fields, GraphQLInputObjectField), + ) + + def build_implemented_interfaces(type_ast): + return list(map(get_type_from_AST, type_ast.interfaces)) + + def build_field_map(type_ast): + return { + field.name.value: GraphQLField( + build_field_type(field.type), + args=build_input_values(field.arguments), + resolver=cannot_execute_client_schema, + ) for field in type_ast.fields + } + + def build_input_values(values, input_type=GraphQLArgument): + input_values = OrderedDict() + for value in values: + type = build_field_type(value.type) + input_values[value.name.value] = input_type( + type, + default_value=value_from_ast(value.default_value, type) + ) + return input_values + + def build_field_type(type_ast): + if isinstance(type_ast, ast.ListType): + return GraphQLList(build_field_type(type_ast.type)) + if isinstance(type_ast, ast.NonNullType): + return GraphQLNonNull(build_field_type(type_ast.type)) + return get_type_from_AST(type_ast) + + # If this document contains no new types, then return the same unmodified + # GraphQLSchema instance. + if not type_extensions_map and not type_definition_map: + return schema + + # A cache to use to store the actual GraphQLType definition objects by name. + # Initialize to the GraphQL built in scalars and introspection types. All + # functions below are inline so that this type def cache is within the scope + # of the closure. + + type_def_cache = { + 'String': GraphQLString, + 'Int': GraphQLInt, + 'Float': GraphQLFloat, + 'Boolean': GraphQLBoolean, + 'ID': GraphQLID, + '__Schema': __Schema, + '__Directive': __Directive, + '__DirectiveLocation': __DirectiveLocation, + '__Type': __Type, + '__Field': __Field, + '__InputValue': __InputValue, + '__EnumValue': __EnumValue, + '__TypeKind': __TypeKind, + } + + # Get the root Query, Mutation, and Subscription types. + query_type = get_type_from_def(schema.get_query_type()) + + existing_mutation_type = schema.get_mutation_type() + mutationType = existing_mutation_type and get_type_from_def( + existing_mutation_type) or None + + existing_subscription_type = schema.get_subscription_type() + subscription_type = existing_subscription_type and get_type_from_def( + existing_subscription_type) or None + + # Iterate through all types, getting the type definition for each, ensuring + # that any type not directly referenced by a field will get created. + types = [get_type_from_def(_def) for _def in schema.get_type_map().values()] + + # Do the same with new types, appending to the list of defined types. + types += [get_type_from_AST(_def) for _def in type_definition_map.values()] + + # Then produce and return a Schema with these types. + return GraphQLSchema( + query=query_type, + mutation=mutationType, + subscription=subscription_type, + # Copy directives. + directives=schema.get_directives(), + types=types + ) + + +def cannot_execute_client_schema(*args, **kwargs): + raise Exception('Client Schema cannot be used for execution.') diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbf4e18fb8d0c64681ead618fbde00cd900f1b6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_field_def.py @@ -0,0 +1,27 @@ +from ..type.definition import (GraphQLInterfaceType, GraphQLObjectType, + GraphQLUnionType) +from ..type.introspection import (SchemaMetaFieldDef, TypeMetaFieldDef, + TypeNameMetaFieldDef) + + +def get_field_def(schema, parent_type, field_ast): + """Not exactly the same as the executor's definition of get_field_def, in this + statically evaluated environment we do not always have an Object type, + and need to handle Interface and Union types.""" + name = field_ast.name.value + if name == '__schema' and schema.get_query_type() == parent_type: + return SchemaMetaFieldDef + + elif name == '__type' and schema.get_query_type() == parent_type: + return TypeMetaFieldDef + + elif name == '__typename' and \ + isinstance(parent_type, ( + GraphQLObjectType, + GraphQLInterfaceType, + GraphQLUnionType, + )): + return TypeNameMetaFieldDef + + elif isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)): + return parent_type.fields.get(name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..899e907ceee8631ba1a31d10718dbcd7dade680d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/get_operation_ast.py @@ -0,0 +1,21 @@ +from ..language import ast + + +def get_operation_ast(document_ast, operation_name=None): + operation = None + + for definition in document_ast.definitions: + if isinstance(definition, ast.OperationDefinition): + if not operation_name: + # If no operation name is provided, only return an Operation if it is the only one present in the + # document. This means that if we've encountered a second operation as we were iterating over the + # definitions in the document, there are more than one Operation defined, and we should return None. + if operation: + return None + + operation = definition + + elif definition.name and definition.name.value == operation_name: + return definition + + return operation diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py new file mode 100644 index 0000000000000000000000000000000000000000..2b87ec13cc0f2bdf728262ed4b39833b982605c3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/introspection_query.py @@ -0,0 +1,90 @@ +introspection_query = ''' + query IntrospectionQuery { + __schema { + queryType { name } + mutationType { name } + subscriptionType { name } + types { + ...FullType + } + directives { + name + description + locations + args { + ...InputValue + } + } + } + } + fragment FullType on __Type { + kind + name + description + fields(includeDeprecated: true) { + name + description + args { + ...InputValue + } + type { + ...TypeRef + } + isDeprecated + deprecationReason + } + inputFields { + ...InputValue + } + interfaces { + ...TypeRef + } + enumValues(includeDeprecated: true) { + name + description + isDeprecated + deprecationReason + } + possibleTypes { + ...TypeRef + } + } + fragment InputValue on __InputValue { + name + description + type { ...TypeRef } + defaultValue + } + fragment TypeRef on __Type { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + ofType { + kind + name + } + } + } + } + } + } + } + } +''' diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py new file mode 100644 index 0000000000000000000000000000000000000000..d329d3b29f252789954ddbb31e64c39aaf42d3a7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_literal_value.py @@ -0,0 +1,67 @@ +from ..language import ast +from ..language.printer import print_ast +from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, + GraphQLList, GraphQLNonNull, GraphQLScalarType) + +_empty_list = [] + + +def is_valid_literal_value(type, value_ast): + if isinstance(type, GraphQLNonNull): + of_type = type.of_type + if not value_ast: + return [u'Expected "{}", found null.'.format(type)] + + return is_valid_literal_value(of_type, value_ast) + + if not value_ast: + return _empty_list + + if isinstance(value_ast, ast.Variable): + return _empty_list + + if isinstance(type, GraphQLList): + item_type = type.of_type + if isinstance(value_ast, ast.ListValue): + errors = [] + + for i, item_ast in enumerate(value_ast.values): + item_errors = is_valid_literal_value(item_type, item_ast) + for error in item_errors: + errors.append(u'In element #{}: {}'.format(i, error)) + + return errors + + return is_valid_literal_value(item_type, value_ast) + + if isinstance(type, GraphQLInputObjectType): + if not isinstance(value_ast, ast.ObjectValue): + return [u'Expected "{}", found not an object.'.format(type)] + + fields = type.fields + field_asts = value_ast.fields + + errors = [] + for provided_field_ast in field_asts: + if provided_field_ast.name.value not in fields: + errors.append(u'In field "{}": Unknown field.'.format(provided_field_ast.name.value)) + + field_ast_map = {field_ast.name.value: field_ast for field_ast in field_asts} + + def get_field_ast_value(field_name): + if field_name in field_ast_map: + return field_ast_map[field_name].value + + for field_name, field in fields.items(): + subfield_errors = is_valid_literal_value(field.type, get_field_ast_value(field_name)) + errors.extend(u'In field "{}": {}'.format(field_name, e) for e in subfield_errors) + + return errors + + assert isinstance(type, (GraphQLScalarType, GraphQLEnumType)), 'Must be input type' + + parse_result = type.parse_literal(value_ast) + if parse_result is None: + return [u'Expected type "{}", found {}.'.format(type.name, print_ast(value_ast))] + + return _empty_list diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py new file mode 100644 index 0000000000000000000000000000000000000000..eddce4f5d9dc119776fd11dd8a1971e34c3063d4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/is_valid_value.py @@ -0,0 +1,66 @@ +""" + Implementation of isValidJSValue from graphql.s +""" + +from collections.abc import Iterable, Mapping +import json + +from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, + GraphQLNonNull, GraphQLScalarType) + +_empty_list = [] + + +def is_valid_value(value, type): + """Given a type and any value, return True if that value is valid.""" + if isinstance(type, GraphQLNonNull): + of_type = type.of_type + if value is None: + return [u'Expected "{}", found null.'.format(type)] + + return is_valid_value(value, of_type) + + if value is None: + return _empty_list + + if isinstance(type, GraphQLList): + item_type = type.of_type + if not isinstance(value, str) and isinstance(value, Iterable): + errors = [] + for i, item in enumerate(value): + item_errors = is_valid_value(item, item_type) + for error in item_errors: + errors.append(u'In element #{}: {}'.format(i, error)) + + return errors + + else: + return is_valid_value(value, item_type) + + if isinstance(type, GraphQLInputObjectType): + if not isinstance(value, Mapping): + return [u'Expected "{}", found not an object.'.format(type)] + + fields = type.fields + errors = [] + + for provided_field in sorted(value.keys()): + if provided_field not in fields: + errors.append(u'In field "{}": Unknown field.'.format(provided_field)) + + for field_name, field in fields.items(): + subfield_errors = is_valid_value(value.get(field_name), field.type) + errors.extend(u'In field "{}": {}'.format(field_name, e) for e in subfield_errors) + + return errors + + assert isinstance(type, (GraphQLScalarType, GraphQLEnumType)), \ + 'Must be input type' + + # Scalar/Enum input checks to ensure the type can parse the value to + # a non-null value. + parse_result = type.parse_value(value) + if parse_result is None: + return [u'Expected type "{}", found {}.'.format(type, json.dumps(value))] + + return _empty_list diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py new file mode 100644 index 0000000000000000000000000000000000000000..9f98bcd8bf97d4392dede1b936cdbef2a6e5dddf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/quoted_or_list.py @@ -0,0 +1,21 @@ +import functools + +MAX_LENGTH = 5 + + +def quoted_or_list(items): + '''Given [ A, B, C ] return '"A", "B" or "C"'.''' + selected = items[:MAX_LENGTH] + quoted_items = ('"{}"'.format(t) for t in selected) + + def quoted_or_text(text, quoted_and_index): + index = quoted_and_index[0] + quoted_item = quoted_and_index[1] + text += ((', ' if len(selected) > 2 and not index == len(selected) - 1 else ' ') + + ('or ' if index == len(selected) - 1 else '') + + quoted_item) + return text + + enumerated_items = enumerate(quoted_items) + first_item = next(enumerated_items)[1] + return functools.reduce(quoted_or_text, enumerated_items, first_item) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py new file mode 100644 index 0000000000000000000000000000000000000000..168a17ecc304afbd4cda4daaad9495fba476b9ae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/schema_printer.py @@ -0,0 +1,168 @@ +from ..language.printer import print_ast +from ..type.definition import (GraphQLEnumType, GraphQLInputObjectType, + GraphQLInterfaceType, GraphQLObjectType, + GraphQLScalarType, GraphQLUnionType) +from ..type.directives import DEFAULT_DEPRECATION_REASON +from .ast_from_value import ast_from_value + + +def print_schema(schema): + return _print_filtered_schema(schema, lambda n: not(is_spec_directive(n)), _is_defined_type) + + +def print_introspection_schema(schema): + return _print_filtered_schema(schema, is_spec_directive, _is_introspection_type) + + +def is_spec_directive(directive_name): + return directive_name in ('skip', 'include', 'deprecated') + + +def _is_defined_type(typename): + return not _is_introspection_type(typename) and not _is_builtin_scalar(typename) + + +def _is_introspection_type(typename): + return typename.startswith('__') + + +_builtin_scalars = frozenset(['String', 'Boolean', 'Int', 'Float', 'ID']) + + +def _is_builtin_scalar(typename): + return typename in _builtin_scalars + + +def _print_filtered_schema(schema, directive_filter, type_filter): + return '\n\n'.join([ + _print_schema_definition(schema) + ] + [ + _print_directive(directive) + for directive in schema.get_directives() + if directive_filter(directive.name) + ] + [ + _print_type(type) + for typename, type in sorted(schema.get_type_map().items()) + if type_filter(typename) + ]) + '\n' + + +def _print_schema_definition(schema): + operation_types = [] + + query_type = schema.get_query_type() + if query_type: + operation_types.append(' query: {}'.format(query_type)) + + mutation_type = schema.get_mutation_type() + if mutation_type: + operation_types.append(' mutation: {}'.format(mutation_type)) + + subscription_type = schema.get_subscription_type() + if subscription_type: + operation_types.append(' subscription: {}'.format(subscription_type)) + + return 'schema {{\n{}\n}}'.format('\n'.join(operation_types)) + + +def _print_type(type): + if isinstance(type, GraphQLScalarType): + return _print_scalar(type) + + elif isinstance(type, GraphQLObjectType): + return _print_object(type) + + elif isinstance(type, GraphQLInterfaceType): + return _print_interface(type) + + elif isinstance(type, GraphQLUnionType): + return _print_union(type) + + elif isinstance(type, GraphQLEnumType): + return _print_enum(type) + + assert isinstance(type, GraphQLInputObjectType) + return _print_input_object(type) + + +def _print_scalar(type): + return 'scalar {}'.format(type.name) + + +def _print_object(type): + interfaces = type.interfaces + implemented_interfaces = \ + ' implements {}'.format(', '.join(i.name for i in interfaces)) if interfaces else '' + + return ( + 'type {}{} {{\n' + '{}\n' + '}}' + ).format(type.name, implemented_interfaces, _print_fields(type)) + + +def _print_interface(type): + return ( + 'interface {} {{\n' + '{}\n' + '}}' + ).format(type.name, _print_fields(type)) + + +def _print_union(type): + return 'union {} = {}'.format(type.name, ' | '.join(str(t) for t in type.types)) + + +def _print_enum(type): + return ( + 'enum {} {{\n' + '{}\n' + '}}' + ).format(type.name, '\n'.join(' ' + v.name + _print_deprecated(v) for v in type.values)) + + +def _print_input_object(type): + return ( + 'input {} {{\n' + '{}\n' + '}}' + ).format(type.name, '\n'.join(' ' + _print_input_value(name, field) for name, field in type.fields.items())) + + +def _print_fields(type): + return '\n'.join(' {}{}: {}{}'.format(f_name, _print_args(f), f.type, _print_deprecated(f)) + for f_name, f in type.fields.items()) + + +def _print_deprecated(field_or_enum_value): + reason = field_or_enum_value.deprecation_reason + + if reason is None: + return '' + elif reason in ('', DEFAULT_DEPRECATION_REASON): + return ' @deprecated' + else: + return ' @deprecated(reason: {})'.format(print_ast(ast_from_value(reason))) + + +def _print_args(field_or_directives): + if not field_or_directives.args: + return '' + + return '({})'.format(', '.join(_print_input_value(arg_name, arg) for arg_name, arg in field_or_directives.args.items())) + + +def _print_input_value(name, arg): + if arg.default_value is not None: + default_value = ' = ' + print_ast(ast_from_value(arg.default_value, arg.type)) + else: + default_value = '' + + return '{}: {}{}'.format(name, arg.type, default_value) + + +def _print_directive(directive): + return 'directive @{}{} on {}'.format(directive.name, _print_args(directive), ' | '.join(directive.locations)) + + +__all__ = ['print_schema', 'print_introspection_schema'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py new file mode 100644 index 0000000000000000000000000000000000000000..208f8e31e08f5940eefe40fa19aad86c00891cc6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/suggestion_list.py @@ -0,0 +1,56 @@ +from collections import OrderedDict + + +def suggestion_list(inp, options): + ''' + Given an invalid input string and a list of valid options, returns a filtered + list of valid options sorted based on their similarity with the input. + ''' + options_by_distance = OrderedDict() + input_threshold = len(inp) / 2 + + for option in options: + distance = lexical_distance(inp, option) + threshold = max(input_threshold, len(option) / 2, 1) + if distance <= threshold: + options_by_distance[option] = distance + + return sorted(list(options_by_distance.keys()), key=lambda k: options_by_distance[k]) + + +def lexical_distance(a, b): + ''' + Computes the lexical distance between strings A and B. + The "distance" between two strings is given by counting the minimum number + of edits needed to transform string A into string B. An edit can be an + insertion, deletion, or substitution of a single character, or a swap of two + adjacent characters. + This distance can be useful for detecting typos in input or sorting + @returns distance in number of edits + ''' + + d = [[i] for i in range(len(a) + 1)] or [] + d_len = len(d) or 1 + for i in range(d_len): + for j in range(1, len(b) + 1): + if i == 0: + d[i].append(j) + else: + d[i].append(0) + + for i in range(1, len(a) + 1): + for j in range(1, len(b) + 1): + cost = 0 if a[i - 1] == b[j - 1] else 1 + + d[i][j] = min( + d[i - 1][j] + 1, + d[i][j - 1] + 1, + d[i - 1][j - 1] + cost + ) + + if (i > 1 and j < 1 and + a[i - 1] == b[j - 2] and + a[i - 2] == b[j - 1]): + d[i][j] = min(d[i][j], d[i - 2][j - 2] + cost) + + return d[len(a)][len(b)] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py new file mode 100644 index 0000000000000000000000000000000000000000..93ebb0455761a16c5d90952c906c1aeac3134397 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_comparators.py @@ -0,0 +1,69 @@ +from ..type.definition import (GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + GraphQLUnionType, is_abstract_type) + + +def is_equal_type(type_a, type_b): + if type_a is type_b: + return True + + if isinstance(type_a, GraphQLNonNull) and isinstance(type_b, GraphQLNonNull): + return is_equal_type(type_a.of_type, type_b.of_type) + + if isinstance(type_a, GraphQLList) and isinstance(type_b, GraphQLList): + return is_equal_type(type_a.of_type, type_b.of_type) + + return False + + +def is_type_sub_type_of(schema, maybe_subtype, super_type): + if maybe_subtype is super_type: + return True + + if isinstance(super_type, GraphQLNonNull): + if isinstance(maybe_subtype, GraphQLNonNull): + return is_type_sub_type_of(schema, maybe_subtype.of_type, super_type.of_type) + return False + elif isinstance(maybe_subtype, GraphQLNonNull): + return is_type_sub_type_of(schema, maybe_subtype.of_type, super_type) + + if isinstance(super_type, GraphQLList): + if isinstance(maybe_subtype, GraphQLList): + return is_type_sub_type_of(schema, maybe_subtype.of_type, super_type.of_type) + return False + elif isinstance(maybe_subtype, GraphQLList): + return False + + if is_abstract_type(super_type) and isinstance( + maybe_subtype, GraphQLObjectType) and schema.is_possible_type( + super_type, maybe_subtype): + return True + + return False + + +def do_types_overlap(schema, t1, t2): + # print 'do_types_overlap', t1, t2 + if t1 == t2: + # print '1' + return True + + if isinstance(t1, (GraphQLInterfaceType, GraphQLUnionType)): + if isinstance(t2, (GraphQLInterfaceType, GraphQLUnionType)): + # If both types are abstract, then determine if there is any intersection + # between possible concrete types of each. + s = any([schema.is_possible_type(t2, type) for type in schema.get_possible_types(t1)]) + # print '2',s + return s + # Determine if the latter type is a possible concrete type of the former. + r = schema.is_possible_type(t1, t2) + # print '3', r + return r + + if isinstance(t2, (GraphQLInterfaceType, GraphQLUnionType)): + t = schema.is_possible_type(t2, t1) + # print '4', t + return t + + # print '5' + return False diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..8689f27adf0527887a1029d6a3730d42568244df --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_from_ast.py @@ -0,0 +1,21 @@ +from ..language import ast +from ..type.definition import GraphQLList, GraphQLNonNull + + +def type_from_ast(schema, input_type_ast): + if isinstance(input_type_ast, ast.ListType): + inner_type = type_from_ast(schema, input_type_ast.type) + if inner_type: + return GraphQLList(inner_type) + else: + return None + + if isinstance(input_type_ast, ast.NonNullType): + inner_type = type_from_ast(schema, input_type_ast.type) + if inner_type: + return GraphQLNonNull(inner_type) + else: + return None + + assert isinstance(input_type_ast, ast.NamedType), 'Must be a type name.' + return schema.get_type(input_type_ast.name.value) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py new file mode 100644 index 0000000000000000000000000000000000000000..6baf41760a71ae37be2d068e0390220af1d2b92f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/type_info.py @@ -0,0 +1,149 @@ +from ..language import visitor_meta +from ..type.definition import (GraphQLInputObjectType, GraphQLList, + get_named_type, get_nullable_type, + is_composite_type) +from .get_field_def import get_field_def +from .type_from_ast import type_from_ast + + +def pop(lst): + if lst: + lst.pop() + + +# noinspection PyPep8Naming +class TypeInfo(metaclass=visitor_meta.VisitorMeta): + __slots__ = '_schema', '_type_stack', '_parent_type_stack', '_input_type_stack', '_field_def_stack', '_directive', \ + '_argument', '_get_field_def_fn' + + def __init__(self, schema, get_field_def_fn=get_field_def): + self._schema = schema + self._type_stack = [] + self._parent_type_stack = [] + self._input_type_stack = [] + self._field_def_stack = [] + self._directive = None + self._argument = None + self._get_field_def_fn = get_field_def_fn + + def get_type(self): + if self._type_stack: + return self._type_stack[-1] + + def get_parent_type(self): + if self._parent_type_stack: + return self._parent_type_stack[-1] + + def get_input_type(self): + if self._input_type_stack: + return self._input_type_stack[-1] + + def get_field_def(self): + if self._field_def_stack: + return self._field_def_stack[-1] + + def get_directive(self): + return self._directive + + def get_argument(self): + return self._argument + + def leave(self, node): + method = self._get_leave_handler(type(node)) + if method: + return method(self) + + def enter(self, node): + method = self._get_enter_handler(type(node)) + if method: + return method(self, node) + + def enter_SelectionSet(self, node): + named_type = get_named_type(self.get_type()) + composite_type = None + if is_composite_type(named_type): + composite_type = named_type + self._parent_type_stack.append(composite_type) + + def enter_Field(self, node): + parent_type = self.get_parent_type() + field_def = None + if parent_type: + field_def = self._get_field_def_fn(self._schema, parent_type, node) + self._field_def_stack.append(field_def) + self._type_stack.append(field_def and field_def.type) + + def enter_Directive(self, node): + self._directive = self._schema.get_directive(node.name.value) + + def enter_OperationDefinition(self, node): + definition_type = None + if node.operation == 'query': + definition_type = self._schema.get_query_type() + elif node.operation == 'mutation': + definition_type = self._schema.get_mutation_type() + + self._type_stack.append(definition_type) + + def enter_InlineFragment(self, node): + type_condition_ast = node.type_condition + type = type_from_ast(self._schema, type_condition_ast) if type_condition_ast else self.get_type() + self._type_stack.append(type) + + enter_FragmentDefinition = enter_InlineFragment + + def enter_VariableDefinition(self, node): + self._input_type_stack.append(type_from_ast(self._schema, node.type)) + + def enter_Argument(self, node): + arg_def = None + arg_type = None + field_or_directive = self.get_directive() or self.get_field_def() + if field_or_directive: + arg_def = field_or_directive.args.get(node.name.value) + if arg_def: + arg_type = arg_def.type + self._argument = arg_def + self._input_type_stack.append(arg_type) + + def enter_ListValue(self, node): + list_type = get_nullable_type(self.get_input_type()) + self._input_type_stack.append( + list_type.of_type if isinstance(list_type, GraphQLList) else None + ) + + def enter_ObjectField(self, node): + object_type = get_named_type(self.get_input_type()) + field_type = None + if isinstance(object_type, GraphQLInputObjectType): + input_field = object_type.fields.get(node.name.value) + field_type = input_field.type if input_field else None + self._input_type_stack.append(field_type) + + def leave_SelectionSet(self): + pop(self._parent_type_stack) + + def leave_Field(self): + pop(self._field_def_stack) + pop(self._type_stack) + + def leave_Directive(self): + self._directive = None + + def leave_OperationDefinition(self): + pop(self._type_stack) + + leave_InlineFragment = leave_OperationDefinition + leave_FragmentDefinition = leave_OperationDefinition + + def leave_VariableDefinition(self): + pop(self._input_type_stack) + + def leave_Argument(self): + self._argument = None + pop(self._input_type_stack) + + def leave_ListType(self): + pop(self._input_type_stack) + + leave_ObjectField = leave_ListType diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py new file mode 100644 index 0000000000000000000000000000000000000000..ff7486bef40dc9687fb95d93333eda2c416855b0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/utils/value_from_ast.py @@ -0,0 +1,69 @@ +from ..language import ast +from ..type import (GraphQLEnumType, GraphQLInputObjectType, GraphQLList, + GraphQLNonNull, GraphQLScalarType) + + +def value_from_ast(value_ast, type, variables=None): + """Given a type and a value AST node known to match this type, build a + runtime value.""" + if isinstance(type, GraphQLNonNull): + # Note: we're not checking that the result of coerceValueAST is non-null. + # We're assuming that this query has been validated and the value used here is of the correct type. + return value_from_ast(value_ast, type.of_type, variables) + + if not value_ast: + return None + + if isinstance(value_ast, ast.Variable): + variable_name = value_ast.name.value + if not variables or variable_name not in variables: + return None + + # Note: we're not doing any checking that this variable is correct. We're assuming that this query + # has been validated and the variable usage here is of the correct type. + return variables[variable_name] + + if isinstance(type, GraphQLList): + item_type = type.of_type + if isinstance(value_ast, ast.ListValue): + return [value_from_ast(item_ast, item_type, variables) + for item_ast in value_ast.values] + + else: + return [value_from_ast(value_ast, item_type, variables)] + + if isinstance(type, GraphQLInputObjectType): + fields = type.fields + if not isinstance(value_ast, ast.ObjectValue): + return None + + field_asts = {} + + for field in value_ast.fields: + field_asts[field.name.value] = field + + obj = {} + for field_name, field in fields.items(): + field_ast = field_asts.get(field_name) + field_value_ast = None + + if field_ast: + field_value_ast = field_ast.value + + field_value = value_from_ast( + field_value_ast, field.type, variables + ) + if field_value is None: + field_value = field.default_value + + if field_value is not None: + # We use out_name as the output name for the + # dict if exists + obj[field.out_name or field_name] = field_value + + return obj + + assert isinstance(type, (GraphQLScalarType, GraphQLEnumType)), \ + 'Must be input type' + + return type.parse_literal(value_ast) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..893af3e790576776013bbd60f9db943d5d4094ba --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__init__.py @@ -0,0 +1,4 @@ +from .validation import validate +from .rules import specified_rules + +__all__ = ['validate', 'specified_rules'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2e2462ba49132d91dd6aa8b8b32943ecaad2161 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__pycache__/validation.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__pycache__/validation.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c168c70659f423ec3d27d4c983ac88fa3f7279e2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/__pycache__/validation.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ebd7d6ad8d9e7e49eaa2232d3050f087e923b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__init__.py @@ -0,0 +1,79 @@ +from .arguments_of_correct_type import ArgumentsOfCorrectType +from .default_values_of_correct_type import DefaultValuesOfCorrectType +from .fields_on_correct_type import FieldsOnCorrectType +from .fragments_on_composite_types import FragmentsOnCompositeTypes +from .known_argument_names import KnownArgumentNames +from .known_directives import KnownDirectives +from .known_fragment_names import KnownFragmentNames +from .known_type_names import KnownTypeNames +from .lone_anonymous_operation import LoneAnonymousOperation +from .no_fragment_cycles import NoFragmentCycles +from .no_undefined_variables import NoUndefinedVariables +from .no_unused_fragments import NoUnusedFragments +from .no_unused_variables import NoUnusedVariables +from .overlapping_fields_can_be_merged import OverlappingFieldsCanBeMerged +from .possible_fragment_spreads import PossibleFragmentSpreads +from .provided_non_null_arguments import ProvidedNonNullArguments +from .scalar_leafs import ScalarLeafs +from .unique_argument_names import UniqueArgumentNames +from .unique_fragment_names import UniqueFragmentNames +from .unique_input_field_names import UniqueInputFieldNames +from .unique_operation_names import UniqueOperationNames +from .unique_variable_names import UniqueVariableNames +from .variables_are_input_types import VariablesAreInputTypes +from .variables_in_allowed_position import VariablesInAllowedPosition + +specified_rules = [ + UniqueOperationNames, + LoneAnonymousOperation, + KnownTypeNames, + FragmentsOnCompositeTypes, + VariablesAreInputTypes, + ScalarLeafs, + FieldsOnCorrectType, + UniqueFragmentNames, + KnownFragmentNames, + NoUnusedFragments, + PossibleFragmentSpreads, + NoFragmentCycles, + NoUndefinedVariables, + NoUnusedVariables, + KnownDirectives, + KnownArgumentNames, + UniqueArgumentNames, + ArgumentsOfCorrectType, + ProvidedNonNullArguments, + DefaultValuesOfCorrectType, + VariablesInAllowedPosition, + OverlappingFieldsCanBeMerged, + UniqueInputFieldNames, + UniqueVariableNames +] + +__all__ = [ + 'ArgumentsOfCorrectType', + 'DefaultValuesOfCorrectType', + 'FieldsOnCorrectType', + 'FragmentsOnCompositeTypes', + 'KnownArgumentNames', + 'KnownDirectives', + 'KnownFragmentNames', + 'KnownTypeNames', + 'LoneAnonymousOperation', + 'NoFragmentCycles', + 'UniqueVariableNames', + 'NoUndefinedVariables', + 'NoUnusedFragments', + 'NoUnusedVariables', + 'OverlappingFieldsCanBeMerged', + 'PossibleFragmentSpreads', + 'ProvidedNonNullArguments', + 'ScalarLeafs', + 'UniqueArgumentNames', + 'UniqueFragmentNames', + 'UniqueInputFieldNames', + 'UniqueOperationNames', + 'VariablesAreInputTypes', + 'VariablesInAllowedPosition', + 'specified_rules' +] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e079a97f07bee14c402e50117b65bed0db07084a Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/arguments_of_correct_type.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/arguments_of_correct_type.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f06827a87bd6cb3b2a8068e5b3df672a3cfa89c Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/arguments_of_correct_type.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/base.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/base.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95bdbae98be270dbc4ca7d637ea01b98a786b132 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/base.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/default_values_of_correct_type.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/default_values_of_correct_type.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8422d1af19e90a3a42ced00a237cb0fb28d08fb7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/default_values_of_correct_type.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/fields_on_correct_type.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/fields_on_correct_type.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39183517070137f4f11bba255f2f1a2f2e79f525 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/fields_on_correct_type.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/fragments_on_composite_types.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/fragments_on_composite_types.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa2a9e6ad6b9902c9684ce88b34f9f0d0395efd2 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/fragments_on_composite_types.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_argument_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_argument_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..458cafb6e260eb4c15915f1dd4ca994ebc53b66b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_argument_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_directives.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_directives.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8de2709625c0098ec05d2580bd64236126a3450 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_directives.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_fragment_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_fragment_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..943dc8152467ed72748a72f6d8911262507df91f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_fragment_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_type_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_type_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4df4f70e05e8e07286fa0bc7180e684841446fde Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/known_type_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/lone_anonymous_operation.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/lone_anonymous_operation.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cf38d1bd757f8474ac7908d22101471f22e7103 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/lone_anonymous_operation.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_fragment_cycles.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_fragment_cycles.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0cf0fccfb6ef14fc3b7fd19268c1a5c44707692 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_fragment_cycles.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_undefined_variables.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_undefined_variables.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df242e5310ab9e51f6140858e445d227136d94b Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_undefined_variables.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_unused_fragments.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_unused_fragments.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..563b1dde8a77f9d5a7f1584a7858d702351278b4 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_unused_fragments.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_unused_variables.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_unused_variables.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2dd96862a975a458cf23dfce8ee51de9c911d3 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/no_unused_variables.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/overlapping_fields_can_be_merged.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/overlapping_fields_can_be_merged.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8377388ace99189a46066f7450ff0de485fddf95 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/overlapping_fields_can_be_merged.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/possible_fragment_spreads.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/possible_fragment_spreads.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc9d2a1ea45df9b068584e7df4b81cdd9d836762 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/possible_fragment_spreads.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/provided_non_null_arguments.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/provided_non_null_arguments.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df44554f8933c592c199a44615cf5a4928c6048f Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/provided_non_null_arguments.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/scalar_leafs.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/scalar_leafs.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..528f6f0425e4fb1028790d5abc268986c3a139b4 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/scalar_leafs.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_argument_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_argument_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fc5804ec792bb86fabeb4a2075c2b3d9fc71f3 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_argument_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_fragment_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_fragment_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af790de8a8ae37b1cf03c785482e958bb6ce65c7 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_fragment_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_input_field_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_input_field_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994cfa66444fb45a65e8ca4e1f2527372b629633 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_input_field_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_operation_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_operation_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f6eaca111e5571e55f56f5c7a16e45c7e664436 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_operation_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_variable_names.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_variable_names.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d942937474944572b0f9056f14284d43ee9ed35 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/unique_variable_names.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/variables_are_input_types.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/variables_are_input_types.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a173b1c2ba6a89080dbeda546e8dda8d76da04db Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/variables_are_input_types.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/variables_in_allowed_position.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/variables_in_allowed_position.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1264556c6c8691725a0b7b320a52323d51074a9 Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/__pycache__/variables_in_allowed_position.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py new file mode 100644 index 0000000000000000000000000000000000000000..011fae79b59d229c156afa812c78d427879a99d6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/arguments_of_correct_type.py @@ -0,0 +1,24 @@ +from ...error import GraphQLError +from ...language.printer import print_ast +from ...utils.is_valid_literal_value import is_valid_literal_value +from .base import ValidationRule + + +class ArgumentsOfCorrectType(ValidationRule): + + def enter_Argument(self, node, key, parent, path, ancestors): + arg_def = self.context.get_argument() + if arg_def: + errors = is_valid_literal_value(arg_def.type, node.value) + if errors: + self.context.report_error(GraphQLError( + self.bad_value_message(node.name.value, arg_def.type, + print_ast(node.value), errors), + [node.value] + )) + return False + + @staticmethod + def bad_value_message(arg_name, type, value, verbose_errors): + message = (u'\n' + u'\n'.join(verbose_errors)) if verbose_errors else '' + return 'Argument "{}" has invalid value {}.{}'.format(arg_name, value, message) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..43bb53b7493ad3711fb0d7d99e195e200cfa51c6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/base.py @@ -0,0 +1,8 @@ +from ...language.visitor import Visitor + + +class ValidationRule(Visitor): + __slots__ = 'context', + + def __init__(self, context): + self.context = context diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6346b46573289f43040425938cb327007e2393 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/default_values_of_correct_type.py @@ -0,0 +1,44 @@ +from ...error import GraphQLError +from ...language.printer import print_ast +from ...type.definition import GraphQLNonNull +from ...utils.is_valid_literal_value import is_valid_literal_value +from .base import ValidationRule + + +class DefaultValuesOfCorrectType(ValidationRule): + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + name = node.variable.name.value + default_value = node.default_value + type = self.context.get_input_type() + + if isinstance(type, GraphQLNonNull) and default_value: + self.context.report_error(GraphQLError( + self.default_for_non_null_arg_message(name, type, type.of_type), + [default_value] + )) + + if type and default_value: + errors = is_valid_literal_value(type, default_value) + if errors: + self.context.report_error(GraphQLError( + self.bad_value_for_default_arg_message(name, type, print_ast(default_value), errors), + [default_value] + )) + return False + + def enter_SelectionSet(self, node, key, parent, path, ancestors): + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + return False + + @staticmethod + def default_for_non_null_arg_message(var_name, type, guess_type): + return u'Variable "${}" of type "{}" is required and will not use the default value. ' \ + u'Perhaps you meant to use type "{}".'.format(var_name, type, guess_type) + + @staticmethod + def bad_value_for_default_arg_message(var_name, type, value, verbose_errors): + message = (u'\n' + u'\n'.join(verbose_errors)) if verbose_errors else u'' + return u'Variable "${}" of type "{}" has invalid default value: {}.{}'.format(var_name, type, value, message) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py new file mode 100644 index 0000000000000000000000000000000000000000..55bb2221c9e4df059215227dcecacf9ec884e6e4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fields_on_correct_type.py @@ -0,0 +1,113 @@ +from collections import Counter + +from ...error import GraphQLError +from ...pyutils.ordereddict import OrderedDict +from ...type.definition import (GraphQLInterfaceType, GraphQLObjectType, + GraphQLUnionType) +from ...utils.quoted_or_list import quoted_or_list +from ...utils.suggestion_list import suggestion_list +from .base import ValidationRule + +try: + # Python 2 + from itertools import izip +except ImportError: + # Python 3 + izip = zip + + +def _undefined_field_message(field_name, type, suggested_types, + suggested_fields): + message = 'Cannot query field "{}" on type "{}".'.format(field_name, type) + + if suggested_types: + suggestions = quoted_or_list(suggested_types) + message += " Did you mean to use an inline fragment on {}?".format(suggestions) + elif suggested_fields: + suggestions = quoted_or_list(suggested_fields) + message += " Did you mean {}?".format(suggestions) + + return message + + +class OrderedCounter(Counter, OrderedDict): + pass + + +class FieldsOnCorrectType(ValidationRule): + '''Fields on correct type + + A GraphQL document is only valid if all fields selected are defined by the + parent type, or are an allowed meta field such as __typenamme + ''' + + def enter_Field(self, node, key, parent, path, ancestors): + parent_type = self.context.get_parent_type() + if not parent_type: + return + + field_def = self.context.get_field_def() + if not field_def: + # This field doesn't exist, lets look for suggestions. + schema = self.context.get_schema() + field_name = node.name.value + + # First determine if there are any suggested types to condition on. + suggested_type_names = get_suggested_type_names(schema, parent_type, field_name) + # if there are no suggested types perhaps it was a typo? + suggested_field_names = [] if suggested_type_names else get_suggested_field_names(schema, parent_type, field_name) + + # report an error including helpful suggestions. + self.context.report_error(GraphQLError( + _undefined_field_message(field_name, parent_type.name, suggested_type_names, suggested_field_names), + [node] + )) + + +def get_suggested_type_names(schema, output_type, field_name): + '''Go through all of the implementations of type, as well as the interfaces + that they implement. If any of those types include the provided field, + suggest them, sorted by how often the type is referenced, starting + with Interfaces.''' + + if isinstance(output_type, (GraphQLInterfaceType, GraphQLUnionType)): + suggested_object_types = [] + interface_usage_count = OrderedDict() + for possible_type in schema.get_possible_types(output_type): + if not possible_type.fields.get(field_name): + return + + # This object type defines this field. + suggested_object_types.append(possible_type.name) + + for possible_interface in possible_type.interfaces: + if not possible_interface.fields.get(field_name): + continue + + # This interface type defines this field. + interface_usage_count[possible_interface.name] = ( + interface_usage_count.get(possible_interface.name, 0) + 1) + + # Suggest interface types based on how common they are. + suggested_interface_types = sorted(list(interface_usage_count.keys()), key=lambda k: interface_usage_count[k], + reverse=True) + + # Suggest both interface and object types. + suggested_interface_types.extend(suggested_object_types) + return suggested_interface_types + + # Otherwise, must be an Object type, which does not have possible fields. + return [] + + +def get_suggested_field_names(schema, graphql_type, field_name): + '''For the field name provided, determine if there are any similar field names + that may be the result of a typo.''' + + if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLObjectType)): + possible_field_names = list(graphql_type.fields.keys()) + + return suggestion_list(field_name, possible_field_names) + + # Otherwise, must be a Union type, which does not define fields. + return [] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a95e247c08fa14846786520ad9a4720bce0dba32 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/fragments_on_composite_types.py @@ -0,0 +1,33 @@ +from ...error import GraphQLError +from ...language.printer import print_ast +from ...type.definition import is_composite_type +from .base import ValidationRule + + +class FragmentsOnCompositeTypes(ValidationRule): + + def enter_InlineFragment(self, node, key, parent, path, ancestors): + type = self.context.get_type() + + if node.type_condition and type and not is_composite_type(type): + self.context.report_error(GraphQLError( + self.inline_fragment_on_non_composite_error_message(print_ast(node.type_condition)), + [node.type_condition] + )) + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + type = self.context.get_type() + + if type and not is_composite_type(type): + self.context.report_error(GraphQLError( + self.fragment_on_non_composite_error_message(node.name.value, print_ast(node.type_condition)), + [node.type_condition] + )) + + @staticmethod + def inline_fragment_on_non_composite_error_message(type): + return 'Fragment cannot condition on non composite type "{}".'.format(type) + + @staticmethod + def fragment_on_non_composite_error_message(frag_name, type): + return 'Fragment "{}" cannot condition on non composite type "{}".'.format(frag_name, type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5860397ec2c58d9f45fee3e3eebdc0358fad80 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_argument_names.py @@ -0,0 +1,70 @@ +from ...error import GraphQLError +from ...language import ast +from ...utils.quoted_or_list import quoted_or_list +from ...utils.suggestion_list import suggestion_list +from .base import ValidationRule + + +def _unknown_arg_message(arg_name, field_name, type, suggested_args): + message = 'Unknown argument "{}" on field "{}" of type "{}".'.format(arg_name, field_name, type) + if suggested_args: + message += ' Did you mean {}?'.format(quoted_or_list(suggested_args)) + + return message + + +def _unknown_directive_arg_message(arg_name, directive_name, suggested_args): + message = 'Unknown argument "{}" on directive "@{}".'.format(arg_name, directive_name) + if suggested_args: + message += ' Did you mean {}?'.format(quoted_or_list(suggested_args)) + + return message + + +class KnownArgumentNames(ValidationRule): + + def enter_Argument(self, node, key, parent, path, ancestors): + argument_of = ancestors[-1] + + if isinstance(argument_of, ast.Field): + field_def = self.context.get_field_def() + if not field_def: + return + + field_arg_def = field_def.args.get(node.name.value) + + if not field_arg_def: + parent_type = self.context.get_parent_type() + assert parent_type + self.context.report_error(GraphQLError( + _unknown_arg_message( + node.name.value, + argument_of.name.value, + parent_type.name, + suggestion_list( + node.name.value, + (arg_name for arg_name in field_def.args.keys()) + ) + ), + [node] + )) + + elif isinstance(argument_of, ast.Directive): + directive = self.context.get_directive() + if not directive: + return + + directive_arg_def = directive.args.get(node.name.value) + + if not directive_arg_def: + self.context.report_error(GraphQLError( + _unknown_directive_arg_message( + node.name.value, + directive.name, + suggestion_list( + node.name.value, + (arg_name for arg_name in directive.args.keys()) + ) + ), + [node] + )) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py new file mode 100644 index 0000000000000000000000000000000000000000..f672105df350babee2dd9b03fb11b5609dded655 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_directives.py @@ -0,0 +1,97 @@ +from ...error import GraphQLError +from ...language import ast +from ...type.directives import DirectiveLocation +from .base import ValidationRule + + +class KnownDirectives(ValidationRule): + + def enter_Directive(self, node, key, parent, path, ancestors): + directive_def = next(( + definition for definition in self.context.get_schema().get_directives() + if definition.name == node.name.value + ), None) + + if not directive_def: + return self.context.report_error(GraphQLError( + self.unknown_directive_message(node.name.value), + [node] + )) + + candidate_location = get_directive_location_for_ast_path(ancestors) + if not candidate_location: + self.context.report_error(GraphQLError( + self.misplaced_directive_message(node.name.value, node.type), + [node] + )) + elif candidate_location not in directive_def.locations: + self.context.report_error(GraphQLError( + self.misplaced_directive_message(node.name.value, candidate_location), + [node] + )) + + @staticmethod + def unknown_directive_message(directive_name): + return 'Unknown directive "{}".'.format(directive_name) + + @staticmethod + def misplaced_directive_message(directive_name, location): + return 'Directive "{}" may not be used on "{}".'.format(directive_name, location) + + +_operation_definition_map = { + 'query': DirectiveLocation.QUERY, + 'mutation': DirectiveLocation.MUTATION, + 'subscription': DirectiveLocation.SUBSCRIPTION, +} + + +def get_directive_location_for_ast_path(ancestors): + applied_to = ancestors[-1] + if isinstance(applied_to, ast.OperationDefinition): + return _operation_definition_map.get(applied_to.operation) + + elif isinstance(applied_to, ast.Field): + return DirectiveLocation.FIELD + + elif isinstance(applied_to, ast.FragmentSpread): + return DirectiveLocation.FRAGMENT_SPREAD + + elif isinstance(applied_to, ast.InlineFragment): + return DirectiveLocation.INLINE_FRAGMENT + + elif isinstance(applied_to, ast.FragmentDefinition): + return DirectiveLocation.FRAGMENT_DEFINITION + + elif isinstance(applied_to, ast.SchemaDefinition): + return DirectiveLocation.SCHEMA + + elif isinstance(applied_to, ast.ScalarTypeDefinition): + return DirectiveLocation.SCALAR + + elif isinstance(applied_to, ast.ObjectTypeDefinition): + return DirectiveLocation.OBJECT + + elif isinstance(applied_to, ast.FieldDefinition): + return DirectiveLocation.FIELD_DEFINITION + + elif isinstance(applied_to, ast.InterfaceTypeDefinition): + return DirectiveLocation.INTERFACE + + elif isinstance(applied_to, ast.UnionTypeDefinition): + return DirectiveLocation.UNION + + elif isinstance(applied_to, ast.EnumTypeDefinition): + return DirectiveLocation.ENUM + + elif isinstance(applied_to, ast.EnumValueDefinition): + return DirectiveLocation.ENUM_VALUE + + elif isinstance(applied_to, ast.InputObjectTypeDefinition): + return DirectiveLocation.INPUT_OBJECT + + elif isinstance(applied_to, ast.InputValueDefinition): + parent_node = ancestors[-3] + return (DirectiveLocation.INPUT_FIELD_DEFINITION + if isinstance(parent_node, ast.InputObjectTypeDefinition) + else DirectiveLocation.ARGUMENT_DEFINITION) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7375e3c4c9d8b083b3873643114bbf7a8625b3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_fragment_names.py @@ -0,0 +1,19 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class KnownFragmentNames(ValidationRule): + + def enter_FragmentSpread(self, node, key, parent, path, ancestors): + fragment_name = node.name.value + fragment = self.context.get_fragment(fragment_name) + + if not fragment: + self.context.report_error(GraphQLError( + self.unknown_fragment_message(fragment_name), + [node.name] + )) + + @staticmethod + def unknown_fragment_message(fragment_name): + return 'Unknown fragment "{}".'.format(fragment_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d1369933238e797f3efced67c7a38c5dbea689 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/known_type_names.py @@ -0,0 +1,43 @@ +from ...error import GraphQLError +from ...utils.quoted_or_list import quoted_or_list +from ...utils.suggestion_list import suggestion_list +from .base import ValidationRule + + +def _unknown_type_message(type, suggested_types): + message = 'Unknown type "{}".'.format(type) + if suggested_types: + message += ' Perhaps you meant {}?'.format(quoted_or_list(suggested_types)) + + return message + + +class KnownTypeNames(ValidationRule): + + def enter_ObjectTypeDefinition(self, node, *args): + return False + + def enter_InterfaceTypeDefinition(self, node, *args): + return False + + def enter_UnionTypeDefinition(self, node, *args): + return False + + def enter_InputObjectTypeDefinition(self, node, *args): + return False + + def enter_NamedType(self, node, *args): + schema = self.context.get_schema() + type_name = node.name.value + type = schema.get_type(type_name) + + if not type: + self.context.report_error( + GraphQLError( + _unknown_type_message( + type_name, + suggestion_list(type_name, list(schema.get_type_map().keys())) + ), + [node] + ) + ) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..462a67f31174753baac5b421a0142c0c12c04e08 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/lone_anonymous_operation.py @@ -0,0 +1,23 @@ +from ...error import GraphQLError +from ...language import ast +from .base import ValidationRule + + +class LoneAnonymousOperation(ValidationRule): + __slots__ = 'operation_count', + + def __init__(self, context): + self.operation_count = 0 + super(LoneAnonymousOperation, self).__init__(context) + + def enter_Document(self, node, key, parent, path, ancestors): + self.operation_count = \ + sum(1 for definition in node.definitions if isinstance(definition, ast.OperationDefinition)) + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + if not node.name and self.operation_count > 1: + self.context.report_error(GraphQLError(self.anonymous_operation_not_alone_message(), [node])) + + @staticmethod + def anonymous_operation_not_alone_message(): + return 'This anonymous operation must be the only defined operation.' diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e0d79f0e4bc0130180a44ab40fa6102f8d68b3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_fragment_cycles.py @@ -0,0 +1,59 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class NoFragmentCycles(ValidationRule): + __slots__ = 'errors', 'visited_frags', 'spread_path', 'spread_path_index_by_name' + + def __init__(self, context): + super(NoFragmentCycles, self).__init__(context) + self.errors = [] + self.visited_frags = set() + self.spread_path = [] + self.spread_path_index_by_name = {} + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + if node.name.value not in self.visited_frags: + self.detect_cycle_recursive(node) + return False + + def detect_cycle_recursive(self, fragment): + fragment_name = fragment.name.value + self.visited_frags.add(fragment_name) + + spread_nodes = self.context.get_fragment_spreads(fragment.selection_set) + if not spread_nodes: + return + + self.spread_path_index_by_name[fragment_name] = len(self.spread_path) + + for spread_node in spread_nodes: + spread_name = spread_node.name.value + cycle_index = self.spread_path_index_by_name.get(spread_name) + + if cycle_index is None: + self.spread_path.append(spread_node) + if spread_name not in self.visited_frags: + spread_fragment = self.context.get_fragment(spread_name) + if spread_fragment: + self.detect_cycle_recursive(spread_fragment) + self.spread_path.pop() + else: + cycle_path = self.spread_path[cycle_index:] + self.context.report_error(GraphQLError( + self.cycle_error_message( + spread_name, + [s.name.value for s in cycle_path] + ), + cycle_path + [spread_node] + )) + + self.spread_path_index_by_name[fragment_name] = None + + @staticmethod + def cycle_error_message(fragment_name, spread_names): + via = ' via {}'.format(', '.join(spread_names)) if spread_names else '' + return 'Cannot spread fragment "{}" within itself{}.'.format(fragment_name, via) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..81c9c4953cb57cb91b30e2488693f350704012c3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_undefined_variables.py @@ -0,0 +1,36 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class NoUndefinedVariables(ValidationRule): + __slots__ = 'defined_variable_names', + + def __init__(self, context): + self.defined_variable_names = set() + super(NoUndefinedVariables, self).__init__(context) + + @staticmethod + def undefined_var_message(var_name, op_name=None): + if op_name: + return 'Variable "${}" is not defined by operation "{}".'.format( + var_name, op_name + ) + return 'Variable "${}" is not defined.'.format(var_name) + + def enter_OperationDefinition(self, operation, key, parent, path, ancestors): + self.defined_variable_names = set() + + def leave_OperationDefinition(self, operation, key, parent, path, ancestors): + usages = self.context.get_recursive_variable_usages(operation) + + for variable_usage in usages: + node = variable_usage.node + var_name = node.name.value + if var_name not in self.defined_variable_names: + self.context.report_error(GraphQLError( + self.undefined_var_message(var_name, operation.name and operation.name.value), + [node, operation] + )) + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + self.defined_variable_names.add(node.variable.name.value) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py new file mode 100644 index 0000000000000000000000000000000000000000..8d35f483612600227fa158ceb39eebd7eb8cdd6f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_fragments.py @@ -0,0 +1,38 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class NoUnusedFragments(ValidationRule): + __slots__ = 'fragment_definitions', 'operation_definitions', 'fragment_adjacencies', 'spread_names' + + def __init__(self, context): + super(NoUnusedFragments, self).__init__(context) + self.operation_definitions = [] + self.fragment_definitions = [] + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + self.operation_definitions.append(node) + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + self.fragment_definitions.append(node) + return False + + def leave_Document(self, node, key, parent, path, ancestors): + fragment_names_used = set() + + for operation in self.operation_definitions: + fragments = self.context.get_recursively_referenced_fragments(operation) + for fragment in fragments: + fragment_names_used.add(fragment.name.value) + + for fragment_definition in self.fragment_definitions: + if fragment_definition.name.value not in fragment_names_used: + self.context.report_error(GraphQLError( + self.unused_fragment_message(fragment_definition.name.value), + [fragment_definition] + )) + + @staticmethod + def unused_fragment_message(fragment_name): + return 'Fragment "{}" is never used.'.format(fragment_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py new file mode 100644 index 0000000000000000000000000000000000000000..a799e826472214019ab8b9037954165f4a299749 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/no_unused_variables.py @@ -0,0 +1,37 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class NoUnusedVariables(ValidationRule): + __slots__ = 'variable_definitions' + + def __init__(self, context): + self.variable_definitions = [] + super(NoUnusedVariables, self).__init__(context) + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + self.variable_definitions = [] + + def leave_OperationDefinition(self, operation, key, parent, path, ancestors): + variable_name_used = set() + usages = self.context.get_recursive_variable_usages(operation) + op_name = operation.name and operation.name.value or None + + for variable_usage in usages: + variable_name_used.add(variable_usage.node.name.value) + + for variable_definition in self.variable_definitions: + if variable_definition.variable.name.value not in variable_name_used: + self.context.report_error(GraphQLError( + self.unused_variable_message(variable_definition.variable.name.value, op_name), + [variable_definition] + )) + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + self.variable_definitions.append(node) + + @staticmethod + def unused_variable_message(variable_name, op_name): + if op_name: + return 'Variable "${}" is never used in operation "{}".'.format(variable_name, op_name) + return 'Variable "${}" is never used.'.format(variable_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7f51acfc431a563cf5c3f7aeef1f4e9bacedf9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/overlapping_fields_can_be_merged.py @@ -0,0 +1,529 @@ +import itertools +from collections import OrderedDict + +from ...error import GraphQLError +from ...language import ast +from ...language.printer import print_ast +from ...pyutils.pair_set import PairSet +from ...type.definition import (GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, + get_named_type, is_leaf_type) +from ...utils.type_comparators import is_equal_type +from ...utils.type_from_ast import type_from_ast +from .base import ValidationRule + + +class OverlappingFieldsCanBeMerged(ValidationRule): + __slots__ = ('_compared_fragments', '_cached_fields_and_fragment_names', ) + + def __init__(self, context): + super(OverlappingFieldsCanBeMerged, self).__init__(context) + # A memoization for when two fragments are compared "between" each other for + # conflicts. Two fragments may be compared many times, so memoizing this can + # dramatically improve the performance of this validator. + self._compared_fragments = PairSet() + + # A cache for the "field map" and list of fragment names found in any given + # selection set. Selection sets may be asked for this information multiple + # times, so this improves the performance of this validator. + self._cached_fields_and_fragment_names = {} + + def leave_SelectionSet(self, node, key, parent, path, ancestors): + # Note: we validate on the reverse traversal so deeper conflicts will be + # caught first, for correct calculation of mutual exclusivity and for + # clearer error messages. + # field_map = _collect_field_asts_and_defs( + # self.context, + # self.context.get_parent_type(), + # node + # ) + + # conflicts = _find_conflicts(self.context, False, field_map, self.compared_set) + conflicts = _find_conflicts_within_selection_set(self.context, self._cached_fields_and_fragment_names, + self._compared_fragments, self.context.get_parent_type(), + node) + + for (reason_name, reason), fields1, fields2 in conflicts: + self.context.report_error(GraphQLError( + self.fields_conflict_message(reason_name, reason), + list(fields1) + list(fields2) + )) + + @staticmethod + def same_type(type1, type2): + return is_equal_type(type1, type2) + # return type1.is_same_type(type2) + + @classmethod + def fields_conflict_message(cls, reason_name, reason): + return ( + 'Fields "{}" conflict because {}. ' + 'Use different aliases on the fields to fetch both if this was ' + 'intentional.' + ).format(reason_name, cls.reason_message(reason)) + + @classmethod + def reason_message(cls, reason): + if isinstance(reason, list): + return ' and '.join('subfields "{}" conflict because {}'.format(reason_name, cls.reason_message(sub_reason)) + for reason_name, sub_reason in reason) + + return reason + + +# Algorithm: +# +# Conflicts occur when two fields exist in a query which will produce the same +# response name, but represent differing values, thus creating a conflict. +# The algorithm below finds all conflicts via making a series of comparisons +# between fields. In order to compare as few fields as possible, this makes +# a series of comparisons "within" sets of fields and "between" sets of fields. +# +# Given any selection set, a collection produces both a set of fields by +# also including all inline fragments, as well as a list of fragments +# referenced by fragment spreads. +# +# A) Each selection set represented in the document first compares "within" its +# collected set of fields, finding any conflicts between every pair of +# overlapping fields. +# Note: This is the only time that a the fields "within" a set are compared +# to each other. After this only fields "between" sets are compared. +# +# B) Also, if any fragment is referenced in a selection set, then a +# comparison is made "between" the original set of fields and the +# referenced fragment. +# +# C) Also, if multiple fragments are referenced, then comparisons +# are made "between" each referenced fragment. +# +# D) When comparing "between" a set of fields and a referenced fragment, first +# a comparison is made between each field in the original set of fields and +# each field in the the referenced set of fields. +# +# E) Also, if any fragment is referenced in the referenced selection set, +# then a comparison is made "between" the original set of fields and the +# referenced fragment (recursively referring to step D). +# +# F) When comparing "between" two fragments, first a comparison is made between +# each field in the first referenced set of fields and each field in the the +# second referenced set of fields. +# +# G) Also, any fragments referenced by the first must be compared to the +# second, and any fragments referenced by the second must be compared to the +# first (recursively referring to step F). +# +# H) When comparing two fields, if both have selection sets, then a comparison +# is made "between" both selection sets, first comparing the set of fields in +# the first selection set with the set of fields in the second. +# +# I) Also, if any fragment is referenced in either selection set, then a +# comparison is made "between" the other set of fields and the +# referenced fragment. +# +# J) Also, if two fragments are referenced in both selection sets, then a +# comparison is made "between" the two fragments. + +def _find_conflicts_within_selection_set(context, cached_fields_and_fragment_names, compared_fragments, parent_type, + selection_set): + """Find all conflicts found "within" a selection set, including those found via spreading in fragments. + + Called when visiting each SelectionSet in the GraphQL Document. + """ + conflicts = [] + + field_map, fragment_names = _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, parent_type, + selection_set) + + # (A) Find all conflicts "within" the fields of this selection set. + # Note: this is the *only place* `collect_conflicts_within` is called. + _collect_conflicts_within( + context, + conflicts, + cached_fields_and_fragment_names, + compared_fragments, + field_map + ) + + # (B) Then collect conflicts between these fields and those represented by + # each spread fragment name found. + for i, fragment_name in enumerate(fragment_names): + _collect_conflicts_between_fields_and_fragment( + context, + conflicts, + cached_fields_and_fragment_names, + compared_fragments, + False, + field_map, + fragment_name, + ) + + # (C) Then compare this fragment with all other fragments found in this + # selection set to collect conflicts within fragments spread together. + # This compares each item in the list of fragment names to every other item + # in that same list (except for itself). + for other_fragment_name in fragment_names[i+1:]: + _collect_conflicts_between_fragments( + context, + conflicts, + cached_fields_and_fragment_names, + compared_fragments, + False, + fragment_name, + other_fragment_name, + ) + + return conflicts + + +def _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map, + fragment_name): + + fragment = context.get_fragment(fragment_name) + + if not fragment: + return None + + field_map2, fragment_names2 = _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, + fragment) + + # (D) First collect any conflicts between the provided collection of fields + # and the collection of fields represented by the given fragment. + _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, field_map, field_map2) + + # (E) Then collect any conflicts between the provided collection of fields + # and any fragment names found in the given fragment. + for fragment_name2 in fragment_names2: + _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map, + fragment_name2) + + +# Collect all conflicts found between two fragments, including via spreading in +# any nested fragments +def _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, fragment_name1, fragment_name2): + + fragment1 = context.get_fragment(fragment_name1) + fragment2 = context.get_fragment(fragment_name2) + + if not fragment1 or not fragment2: + return None + + # No need to compare a fragment to itself. + if fragment1 == fragment2: + return None + + # Memoize so two fragments are not compared for conflicts more than once. + if compared_fragments.has(fragment_name1, fragment_name2, are_mutually_exclusive): + return None + + compared_fragments.add(fragment_name1, fragment_name2, are_mutually_exclusive) + + field_map1, fragment_names1 = _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, + fragment1) + + field_map2, fragment_names2 = _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, + fragment2) + + # (F) First, collect all conflicts between these two collections of fields + # (not including any nested fragments) + _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, field_map1, field_map2) + + # (G) Then collect conflicts between the first fragment and any nested + # fragments spread in the second fragment. + for _fragment_name2 in fragment_names2: + _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, fragment_name1, _fragment_name2) + + # (G) Then collect conflicts between the second fragment and any nested + # fragments spread in the first fragment. + for _fragment_name1 in fragment_names1: + _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, _fragment_name1, fragment_name2) + + +def _find_conflicts_between_sub_selection_sets(context, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, parent_type1, selection_set1, + parent_type2, selection_set2): + """Find all conflicts found between two selection sets. + + Includes those found via spreading in fragments. Called when determining if conflicts exist + between the sub-fields of two overlapping fields. + """ + + conflicts = [] + + field_map1, fragment_names1 = _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, + parent_type1, selection_set1) + + field_map2, fragment_names2 = _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, + parent_type2, selection_set2) + + # (H) First, collect all conflicts between these two collections of field. + _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + are_mutually_exclusive, field_map1, field_map2) + + # (I) Then collect conflicts between the first collection of fields and + # those referenced by each fragment name associated with the second. + for fragment_name2 in fragment_names2: + _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map1, + fragment_name2) + + # (I) Then collect conflicts between the second collection of fields and + # those referenced by each fragment name associated with the first. + for fragment_name1 in fragment_names1: + _collect_conflicts_between_fields_and_fragment(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, field_map2, + fragment_name1) + + # (J) Also collect conflicts between any fragment names by the first and + # fragment names by the second. This compares each item in the first set of + # names to each item in the second set of names. + for fragment_name1 in fragment_names1: + for fragment_name2 in fragment_names2: + _collect_conflicts_between_fragments(context, conflicts, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, + fragment_name1, fragment_name2) + + return conflicts + + +def _collect_conflicts_within(context, conflicts, cached_fields_and_fragment_names, compared_fragments, field_map): + """Collect all Conflicts "within" one collection of fields.""" + + # field map is a keyed collection, where each key represents a response + # name and the value at that key is a list of all fields which provide that + # response name. For every response name, if there are multiple fields, they + # must be compared to find a potential conflict. + for response_name, fields in list(field_map.items()): + # This compares every field in the list to every other field in this list + # (except to itself). If the list only has one item, nothing needs to + # be compared. + for i, field in enumerate(fields): + for other_field in fields[i+1:]: + # within one collection is never mutually exclusive + conflict = _find_conflict(context, cached_fields_and_fragment_names, compared_fragments, False, + response_name, field, other_field) + if conflict: + conflicts.append(conflict) + + +def _collect_conflicts_between(context, conflicts, cached_fields_and_fragment_names, compared_fragments, + parent_fields_are_mutually_exclusive, field_map1, field_map2): + """Collect all Conflicts between two collections of fields. + + This is similar to, but different from the `collect_conflicts_within` function above. This check assumes that + `collect_conflicts_within` has already been called on each provided collection of fields. + This is true because this validator traverses each individual selection set. + """ + # A field map is a keyed collection, where each key represents a response + # name and the value at that key is a list of all fields which provide that + # response name. For any response name which appears in both provided field + # maps, each field from the first field map must be compared to every field + # in the second field map to find potential conflicts. + for response_name, fields1 in list(field_map1.items()): + fields2 = field_map2.get(response_name) + + if fields2: + for field1 in fields1: + for field2 in fields2: + conflict = _find_conflict(context, cached_fields_and_fragment_names, compared_fragments, + parent_fields_are_mutually_exclusive, response_name, field1, field2) + + if conflict: + conflicts.append(conflict) + + +def _find_conflict(context, cached_fields_and_fragment_names, compared_fragments, parent_fields_are_mutually_exclusive, + response_name, field1, field2): + """Determines if there is a conflict between two particular fields.""" + parent_type1, ast1, def1 = field1 + parent_type2, ast2, def2 = field2 + + # If it is known that two fields could not possibly apply at the same + # time, due to the parent types, then it is safe to permit them to diverge + # in aliased field or arguments used as they will not present any ambiguity + # by differing. + # It is known that two parent types could never overlap if they are + # different Object types. Interface or Union types might overlap - if not + # in the current state of the schema, then perhaps in some future version, + # thus may not safely diverge. + + are_mutually_exclusive = ( + parent_fields_are_mutually_exclusive or ( + parent_type1 != parent_type2 and + isinstance(parent_type1, GraphQLObjectType) and + isinstance(parent_type2, GraphQLObjectType) + ) + ) + + # The return type for each field. + type1 = def1 and def1.type + type2 = def2 and def2.type + + if not are_mutually_exclusive: + # Two aliases must refer to the same field. + name1 = ast1.name.value + name2 = ast2.name.value + + if name1 != name2: + return ( + (response_name, '{} and {} are different fields'.format(name1, name2)), + [ast1], + [ast2] + ) + + # Two field calls must have the same arguments. + if not _same_arguments(ast1.arguments, ast2.arguments): + return ( + (response_name, 'they have differing arguments'), + [ast1], + [ast2] + ) + + if type1 and type2 and do_types_conflict(type1, type2): + return ( + (response_name, 'they return conflicting types {} and {}'.format(type1, type2)), + [ast1], + [ast2] + ) + + # Collect and compare sub-fields. Use the same "visited fragment names" list + # for both collections so fields in a fragment reference are never + # compared to themselves. + selection_set1 = ast1.selection_set + selection_set2 = ast2.selection_set + + if selection_set1 and selection_set2: + conflicts = _find_conflicts_between_sub_selection_sets(context, cached_fields_and_fragment_names, + compared_fragments, are_mutually_exclusive, + get_named_type(type1), selection_set1, + get_named_type(type2), selection_set2) + + return _subfield_conflicts(conflicts, response_name, ast1, ast2) + + +def _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, parent_type, selection_set): + cached = cached_fields_and_fragment_names.get(selection_set) + + if not cached: + ast_and_defs = OrderedDict() + fragment_names = OrderedDict() + _collect_fields_and_fragment_names(context, parent_type, selection_set, ast_and_defs, fragment_names) + cached = [ast_and_defs, list(fragment_names.keys())] + cached_fields_and_fragment_names[selection_set] = cached + + return cached + + +def _get_referenced_fields_and_fragment_names(context, cached_fields_and_fragment_names, fragment): + """Given a reference to a fragment, return the represented collection of fields as well as a list of + nested fragment names referenced via fragment spreads.""" + + # Short-circuit building a type from the AST if possible. + cached = cached_fields_and_fragment_names.get(fragment.selection_set) + + if cached: + return cached + + fragment_type = type_from_ast(context.get_schema(), fragment.type_condition) + + return _get_fields_and_fragments_names(context, cached_fields_and_fragment_names, + fragment_type, fragment.selection_set) + + +def _collect_fields_and_fragment_names(context, parent_type, selection_set, ast_and_defs, fragment_names): + + for selection in selection_set.selections: + if isinstance(selection, ast.Field): + field_name = selection.name.value + if isinstance(parent_type, (GraphQLObjectType, GraphQLInterfaceType)): + field_def = parent_type.fields.get(field_name) + else: + field_def = None + + response_name = selection.alias.value if selection.alias else field_name + + if not ast_and_defs.get(response_name): + ast_and_defs[response_name] = [] + + ast_and_defs[response_name].append([parent_type, selection, field_def]) + + elif isinstance(selection, ast.FragmentSpread): + fragment_names[selection.name.value] = True + elif isinstance(selection, ast.InlineFragment): + type_condition = selection.type_condition + if type_condition: + inline_fragment_type = type_from_ast(context.get_schema(), selection.type_condition) + else: + inline_fragment_type = parent_type + + _collect_fields_and_fragment_names(context, inline_fragment_type, selection.selection_set, ast_and_defs, + fragment_names) + + +def _subfield_conflicts(conflicts, response_name, ast1, ast2): + """Given a series of Conflicts which occurred between two sub-fields, generate a single Conflict.""" + if conflicts: + return ( + (response_name, [conflict[0] for conflict in conflicts]), + tuple(itertools.chain([ast1], *[conflict[1] for conflict in conflicts])), + tuple(itertools.chain([ast2], *[conflict[2] for conflict in conflicts])) + ) + + +def do_types_conflict(type1, type2): + if isinstance(type1, GraphQLList): + if isinstance(type2, GraphQLList): + return do_types_conflict(type1.of_type, type2.of_type) + return True + + if isinstance(type2, GraphQLList): + if isinstance(type1, GraphQLList): + return do_types_conflict(type1.of_type, type2.of_type) + return True + + if isinstance(type1, GraphQLNonNull): + if isinstance(type2, GraphQLNonNull): + return do_types_conflict(type1.of_type, type2.of_type) + return True + + if isinstance(type2, GraphQLNonNull): + if isinstance(type1, GraphQLNonNull): + return do_types_conflict(type1.of_type, type2.of_type) + return True + + if is_leaf_type(type1) or is_leaf_type(type2): + return type1 != type2 + + return False + + +def _same_value(value1, value2): + return (not value1 and not value2) or print_ast(value1) == print_ast(value2) + + +def _same_arguments(arguments1, arguments2): + # Check to see if they are empty arguments or nones. If they are, we can + # bail out early. + if not (arguments1 or arguments2): + return True + + if len(arguments1) != len(arguments2): + return False + + arguments2_values_to_arg = {a.name.value: a for a in arguments2} + + for argument1 in arguments1: + argument2 = arguments2_values_to_arg.get(argument1.name.value) + if not argument2: + return False + + if not _same_value(argument1.value, argument2.value): + return False + + return True diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py new file mode 100644 index 0000000000000000000000000000000000000000..b9cc4165867f59e042cd6233340dfa40a568a568 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/possible_fragment_spreads.py @@ -0,0 +1,44 @@ +from ...error import GraphQLError +from ...utils.type_comparators import do_types_overlap +from ...utils.type_from_ast import type_from_ast +from .base import ValidationRule + + +class PossibleFragmentSpreads(ValidationRule): + + def enter_InlineFragment(self, node, key, parent, path, ancestors): + frag_type = self.context.get_type() + parent_type = self.context.get_parent_type() + schema = self.context.get_schema() + if frag_type and parent_type and not do_types_overlap(schema, frag_type, parent_type): + self.context.report_error(GraphQLError( + self.type_incompatible_anon_spread_message(parent_type, frag_type), + [node] + )) + + def enter_FragmentSpread(self, node, key, parent, path, ancestors): + frag_name = node.name.value + frag_type = self.get_fragment_type(self.context, frag_name) + parent_type = self.context.get_parent_type() + schema = self.context.get_schema() + if frag_type and parent_type and not do_types_overlap(schema, frag_type, parent_type): + self.context.report_error(GraphQLError( + self.type_incompatible_spread_message(frag_name, parent_type, frag_type), + [node] + )) + + @staticmethod + def get_fragment_type(context, name): + frag = context.get_fragment(name) + return frag and type_from_ast(context.get_schema(), frag.type_condition) + + @staticmethod + def type_incompatible_spread_message(frag_name, parent_type, frag_type): + return 'Fragment {} cannot be spread here as objects of type {} can never be of type {}'.format(frag_name, + parent_type, + frag_type) + + @staticmethod + def type_incompatible_anon_spread_message(parent_type, frag_type): + return 'Fragment cannot be spread here as objects of type {} can never be of type {}'.format(parent_type, + frag_type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..09507277d922740bbc2d16b8e1dbc6401b31fc4b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/provided_non_null_arguments.py @@ -0,0 +1,46 @@ +from ...error import GraphQLError +from ...type.definition import GraphQLNonNull +from .base import ValidationRule + + +class ProvidedNonNullArguments(ValidationRule): + + def leave_Field(self, node, key, parent, path, ancestors): + field_def = self.context.get_field_def() + if not field_def: + return False + + arg_asts = node.arguments or [] + arg_ast_map = {arg.name.value: arg for arg in arg_asts} + + for arg_name, arg_def in field_def.args.items(): + arg_ast = arg_ast_map.get(arg_name, None) + if not arg_ast and isinstance(arg_def.type, GraphQLNonNull): + self.context.report_error(GraphQLError( + self.missing_field_arg_message(node.name.value, arg_name, arg_def.type), + [node] + )) + + def leave_Directive(self, node, key, parent, path, ancestors): + directive_def = self.context.get_directive() + if not directive_def: + return False + + arg_asts = node.arguments or [] + arg_ast_map = {arg.name.value: arg for arg in arg_asts} + + for arg_name, arg_def in directive_def.args.items(): + arg_ast = arg_ast_map.get(arg_name, None) + if not arg_ast and isinstance(arg_def.type, GraphQLNonNull): + self.context.report_error(GraphQLError( + self.missing_directive_arg_message(node.name.value, arg_name, arg_def.type), + [node] + )) + + @staticmethod + def missing_field_arg_message(name, arg_name, type): + return 'Field "{}" argument "{}" of type "{}" is required but not provided.'.format(name, arg_name, type) + + @staticmethod + def missing_directive_arg_message(name, arg_name, type): + return 'Directive "{}" argument "{}" of type "{}" is required but not provided.'.format(name, arg_name, type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py new file mode 100644 index 0000000000000000000000000000000000000000..f03efbc7af3ab1116894e589cb52f87ae9813d3e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/scalar_leafs.py @@ -0,0 +1,33 @@ +from ...error import GraphQLError +from ...type.definition import get_named_type, is_leaf_type +from .base import ValidationRule + + +class ScalarLeafs(ValidationRule): + + def enter_Field(self, node, key, parent, path, ancestors): + type = self.context.get_type() + + if not type: + return + + if is_leaf_type(get_named_type(type)): + if node.selection_set: + self.context.report_error(GraphQLError( + self.no_subselection_allowed_message(node.name.value, type), + [node.selection_set] + )) + + elif not node.selection_set: + self.context.report_error(GraphQLError( + self.required_subselection_message(node.name.value, type), + [node] + )) + + @staticmethod + def no_subselection_allowed_message(field, type): + return 'Field "{}" of type "{}" must not have a sub selection.'.format(field, type) + + @staticmethod + def required_subselection_message(field, type): + return 'Field "{}" of type "{}" must have a sub selection.'.format(field, type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py new file mode 100644 index 0000000000000000000000000000000000000000..9e25c75c8e5cfbe6e4853f806de68e7f6bbd97a4 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_argument_names.py @@ -0,0 +1,32 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class UniqueArgumentNames(ValidationRule): + __slots__ = 'known_arg_names', + + def __init__(self, context): + super(UniqueArgumentNames, self).__init__(context) + self.known_arg_names = {} + + def enter_Field(self, node, key, parent, path, ancestors): + self.known_arg_names = {} + + def enter_Directive(self, node, key, parent, path, ancestors): + self.known_arg_names = {} + + def enter_Argument(self, node, key, parent, path, ancestors): + arg_name = node.name.value + + if arg_name in self.known_arg_names: + self.context.report_error(GraphQLError( + self.duplicate_arg_message(arg_name), + [self.known_arg_names[arg_name], node.name] + )) + else: + self.known_arg_names[arg_name] = node.name + return False + + @staticmethod + def duplicate_arg_message(field): + return 'There can only be one argument named "{}".'.format(field) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py new file mode 100644 index 0000000000000000000000000000000000000000..91de32714570b977bc96164f941af1a5f22764b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_fragment_names.py @@ -0,0 +1,28 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class UniqueFragmentNames(ValidationRule): + __slots__ = 'known_fragment_names', + + def __init__(self, context): + super(UniqueFragmentNames, self).__init__(context) + self.known_fragment_names = {} + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + fragment_name = node.name.value + if fragment_name in self.known_fragment_names: + self.context.report_error(GraphQLError( + self.duplicate_fragment_name_message(fragment_name), + [self.known_fragment_names[fragment_name], node.name] + )) + else: + self.known_fragment_names[fragment_name] = node.name + return False + + @staticmethod + def duplicate_fragment_name_message(field): + return 'There can only be one fragment named "{}".'.format(field) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py new file mode 100644 index 0000000000000000000000000000000000000000..6fe18bab577d8b1f81eebda4b60b1181fd9d93a3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_input_field_names.py @@ -0,0 +1,33 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class UniqueInputFieldNames(ValidationRule): + __slots__ = 'known_names', 'known_names_stack' + + def __init__(self, context): + super(UniqueInputFieldNames, self).__init__(context) + self.known_names = {} + self.known_names_stack = [] + + def enter_ObjectValue(self, node, key, parent, path, ancestors): + self.known_names_stack.append(self.known_names) + self.known_names = {} + + def leave_ObjectValue(self, node, key, parent, path, ancestors): + self.known_names = self.known_names_stack.pop() + + def enter_ObjectField(self, node, key, parent, path, ancestors): + field_name = node.name.value + if field_name in self.known_names: + self.context.report_error(GraphQLError( + self.duplicate_input_field_message(field_name), + [self.known_names[field_name], node.name] + )) + else: + self.known_names[field_name] = node.name + return False + + @staticmethod + def duplicate_input_field_message(field_name): + return 'There can only be one input field named "{}".'.format(field_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py new file mode 100644 index 0000000000000000000000000000000000000000..1fccbb9b2230fb708bed7ee42d5ff96ce4083b20 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_operation_names.py @@ -0,0 +1,31 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class UniqueOperationNames(ValidationRule): + __slots__ = 'known_operation_names', + + def __init__(self, context): + super(UniqueOperationNames, self).__init__(context) + self.known_operation_names = {} + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + operation_name = node.name + if not operation_name: + return + + if operation_name.value in self.known_operation_names: + self.context.report_error(GraphQLError( + self.duplicate_operation_name_message(operation_name.value), + [self.known_operation_names[operation_name.value], operation_name] + )) + else: + self.known_operation_names[operation_name.value] = operation_name + return False + + def enter_FragmentDefinition(self, node, key, parent, path, ancestors): + return False + + @staticmethod + def duplicate_operation_name_message(operation_name): + return 'There can only be one operation named "{}".'.format(operation_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py new file mode 100644 index 0000000000000000000000000000000000000000..f471e3464343d8d304e8d62341daba736a2c4ab7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/unique_variable_names.py @@ -0,0 +1,27 @@ +from ...error import GraphQLError +from .base import ValidationRule + + +class UniqueVariableNames(ValidationRule): + __slots__ = 'known_variable_names', + + def __init__(self, context): + super(UniqueVariableNames, self).__init__(context) + self.known_variable_names = {} + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + self.known_variable_names = {} + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + variable_name = node.variable.name.value + if variable_name in self.known_variable_names: + self.context.report_error(GraphQLError( + self.duplicate_variable_message(variable_name), + [self.known_variable_names[variable_name], node.variable.name] + )) + else: + self.known_variable_names[variable_name] = node.variable.name + + @staticmethod + def duplicate_variable_message(operation_name): + return 'There can be only one variable named "{}".'.format(operation_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py new file mode 100644 index 0000000000000000000000000000000000000000..f510fbba0ad1584ea5101e0869c5133f14dbb708 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_are_input_types.py @@ -0,0 +1,21 @@ +from ...error import GraphQLError +from ...language.printer import print_ast +from ...type.definition import is_input_type +from ...utils.type_from_ast import type_from_ast +from .base import ValidationRule + + +class VariablesAreInputTypes(ValidationRule): + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + type = type_from_ast(self.context.get_schema(), node.type) + + if type and not is_input_type(type): + self.context.report_error(GraphQLError( + self.non_input_type_on_variable_message(node.variable.name.value, print_ast(node.type)), + [node.type] + )) + + @staticmethod + def non_input_type_on_variable_message(variable_name, type_name): + return 'Variable "${}" cannot be non-input type "{}".'.format(variable_name, type_name) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py new file mode 100644 index 0000000000000000000000000000000000000000..4117e0f8217f3d7ddee4fc6850d4899a130c0602 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/rules/variables_in_allowed_position.py @@ -0,0 +1,53 @@ +from ...error import GraphQLError +from ...type.definition import GraphQLNonNull +from ...utils.type_comparators import is_type_sub_type_of +from ...utils.type_from_ast import type_from_ast +from .base import ValidationRule + + +class VariablesInAllowedPosition(ValidationRule): + __slots__ = 'var_def_map' + + def __init__(self, context): + super(VariablesInAllowedPosition, self).__init__(context) + self.var_def_map = {} + + def enter_OperationDefinition(self, node, key, parent, path, ancestors): + self.var_def_map = {} + + def leave_OperationDefinition(self, operation, key, parent, path, ancestors): + usages = self.context.get_recursive_variable_usages(operation) + + for usage in usages: + node = usage.node + type = usage.type + var_name = node.name.value + var_def = self.var_def_map.get(var_name) + if var_def and type: + # A var type is allowed if it is the same or more strict (e.g. is + # a subtype of) than the expected type. It can be more strict if + # the variable type is non-null when the expected type is nullable. + # If both are list types, the variable item type can be more strict + # than the expected item type (contravariant). + schema = self.context.get_schema() + var_type = type_from_ast(schema, var_def.type) + if var_type and not is_type_sub_type_of(schema, self.effective_type(var_type, var_def), type): + self.context.report_error(GraphQLError( + self.bad_var_pos_message(var_name, var_type, type), + [var_def, node] + )) + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + self.var_def_map[node.variable.name.value] = node + + @staticmethod + def effective_type(var_type, var_def): + if not var_def.default_value or isinstance(var_type, GraphQLNonNull): + return var_type + + return GraphQLNonNull(var_type) + + @staticmethod + def bad_var_pos_message(var_name, var_type, expected_type): + return 'Variable "{}" of type "{}" used in position expecting type "{}".'.format(var_name, var_type, + expected_type) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..3610dbf4812ff5c7971a326d2a4e68fd3805c6c7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/graphql-core-1.1/wandb_graphql/validation/validation.py @@ -0,0 +1,158 @@ +from ..language.ast import (FragmentDefinition, FragmentSpread, + OperationDefinition) +from ..language.visitor import ParallelVisitor, TypeInfoVisitor, Visitor, visit +from ..type import GraphQLSchema +from ..utils.type_info import TypeInfo +from .rules import specified_rules + + +def validate(schema, ast, rules=specified_rules): + assert schema, 'Must provide schema' + assert ast, 'Must provide document' + assert isinstance(schema, GraphQLSchema) + type_info = TypeInfo(schema) + return visit_using_rules(schema, type_info, ast, rules) + + +def visit_using_rules(schema, type_info, ast, rules): + context = ValidationContext(schema, ast, type_info) + visitors = [rule(context) for rule in rules] + visit(ast, TypeInfoVisitor(type_info, ParallelVisitor(visitors))) + return context.get_errors() + + +class VariableUsage(object): + __slots__ = 'node', 'type' + + def __init__(self, node, type): + self.node = node + self.type = type + + +class UsageVisitor(Visitor): + __slots__ = 'usages', 'type_info' + + def __init__(self, usages, type_info): + self.usages = usages + self.type_info = type_info + + def enter_VariableDefinition(self, node, key, parent, path, ancestors): + return False + + def enter_Variable(self, node, key, parent, path, ancestors): + usage = VariableUsage(node, type=self.type_info.get_input_type()) + self.usages.append(usage) + + +class ValidationContext(object): + __slots__ = ('_schema', '_ast', '_type_info', '_errors', '_fragments', '_fragment_spreads', + '_recursively_referenced_fragments', '_variable_usages', '_recursive_variable_usages') + + def __init__(self, schema, ast, type_info): + self._schema = schema + self._ast = ast + self._type_info = type_info + self._errors = [] + self._fragments = None + self._fragment_spreads = {} + self._recursively_referenced_fragments = {} + self._variable_usages = {} + self._recursive_variable_usages = {} + + def report_error(self, error): + self._errors.append(error) + + def get_errors(self): + return self._errors + + def get_schema(self): + return self._schema + + def get_variable_usages(self, node): + usages = self._variable_usages.get(node) + if usages is None: + usages = [] + sub_visitor = UsageVisitor(usages, self._type_info) + visit(node, TypeInfoVisitor(self._type_info, sub_visitor)) + self._variable_usages[node] = usages + + return usages + + def get_recursive_variable_usages(self, operation): + assert isinstance(operation, OperationDefinition) + usages = self._recursive_variable_usages.get(operation) + if usages is None: + usages = self.get_variable_usages(operation) + fragments = self.get_recursively_referenced_fragments(operation) + for fragment in fragments: + usages.extend(self.get_variable_usages(fragment)) + self._recursive_variable_usages[operation] = usages + + return usages + + def get_recursively_referenced_fragments(self, operation): + assert isinstance(operation, OperationDefinition) + fragments = self._recursively_referenced_fragments.get(operation) + if not fragments: + fragments = [] + collected_names = set() + nodes_to_visit = [operation.selection_set] + while nodes_to_visit: + node = nodes_to_visit.pop() + spreads = self.get_fragment_spreads(node) + for spread in spreads: + frag_name = spread.name.value + if frag_name not in collected_names: + collected_names.add(frag_name) + fragment = self.get_fragment(frag_name) + if fragment: + fragments.append(fragment) + nodes_to_visit.append(fragment.selection_set) + self._recursively_referenced_fragments[operation] = fragments + return fragments + + def get_fragment_spreads(self, node): + spreads = self._fragment_spreads.get(node) + if not spreads: + spreads = [] + sets_to_visit = [node] + while sets_to_visit: + _set = sets_to_visit.pop() + for selection in _set.selections: + if isinstance(selection, FragmentSpread): + spreads.append(selection) + elif selection.selection_set: + sets_to_visit.append(selection.selection_set) + + self._fragment_spreads[node] = spreads + return spreads + + def get_ast(self): + return self._ast + + def get_fragment(self, name): + fragments = self._fragments + if fragments is None: + self._fragments = fragments = {} + for statement in self.get_ast().definitions: + if isinstance(statement, FragmentDefinition): + fragments[statement.name.value] = statement + return fragments.get(name) + + def get_type(self): + return self._type_info.get_type() + + def get_parent_type(self): + return self._type_info.get_parent_type() + + def get_input_type(self): + return self._type_info.get_input_type() + + def get_field_def(self): + return self._type_info.get_field_def() + + def get_directive(self): + return self._type_info.get_directive() + + def get_argument(self): + return self._type_info.get_argument() diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/conftest.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..a36e6ba1ffed579764055782646373de3b9ed488 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/conftest.py @@ -0,0 +1,30 @@ +# Configuration for pytest to automatically collect types. +# Thanks to Guilherme Salgado. +import pytest + +try: + import pyannotate_runtime + PYANOTATE_PRESENT = True +except ImportError: + PYANOTATE_PRESENT = False + +if PYANOTATE_PRESENT: + def pytest_collection_finish(session): + """Handle the pytest collection finish hook: configure pyannotate. + Explicitly delay importing `collect_types` until all tests have + been collected. This gives gevent a chance to monkey patch the + world before importing pyannotate. + """ + from pyannotate_runtime import collect_types + collect_types.init_types_collection() + + @pytest.fixture(autouse=True) + def collect_types_fixture(): + from pyannotate_runtime import collect_types + collect_types.resume() + yield + collect_types.pause() + + def pytest_sessionfinish(session, exitstatus): + from pyannotate_runtime import collect_types + collect_types.dump_stats("type_info.json") diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/setup.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9c13695c9d197a77ecbb069ae0bf2dd33d8b5fef --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/setup.py @@ -0,0 +1,64 @@ +import sys +from setuptools import setup, find_packages + +if sys.version_info[0] < 3: + import __builtin__ as builtins +else: + import builtins + +builtins.__SETUP__ = True + +version = __import__("promise").get_version() + + +IS_PY3 = sys.hexversion >= 0x03000000 + +tests_require = [ + "pytest>=2.7.3", + "pytest-cov", + "coveralls", + "futures", + "pytest-benchmark", + "mock", +] +if IS_PY3: + tests_require += ["pytest-asyncio"] + + +setup( + name="promise", + version=version, + description="Promises/A+ implementation for Python", + long_description=open("README.rst").read(), + url="https://github.com/syrusakbary/promise", + download_url="https://github.com/syrusakbary/promise/releases", + author="Syrus Akbary", + author_email="me@syrusakbary.com", + license="MIT", + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Topic :: Software Development :: Libraries", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: Implementation :: PyPy", + "License :: OSI Approved :: MIT License", + ], + keywords="concurrent future deferred promise", + packages=find_packages(exclude=["tests"]), + # PEP-561: https://www.python.org/dev/peps/pep-0561/ + package_data={"promise": ["py.typed"]}, + extras_require={"test": tests_require}, + install_requires=[ + "typing>=3.6.4; python_version < '3.5'", + "six" + ], + tests_require=tests_require, +) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/conftest.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..26122664faf386b96dbcffc63dcb2ea41b57c049 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/conftest.py @@ -0,0 +1,8 @@ +from sys import version_info + +collect_ignore = [] +if version_info[:2] < (3, 4): + collect_ignore.append("test_awaitable.py") +if version_info[:2] < (3, 5): + collect_ignore.append("test_awaitable_35.py") + collect_ignore.append("test_dataloader_awaitable_35.py") diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_awaitable.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_awaitable.py new file mode 100644 index 0000000000000000000000000000000000000000..aad7f2ce41eb1420aa19fba3f0af656dc0668586 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_awaitable.py @@ -0,0 +1,32 @@ +from asyncio import coroutine +from pytest import mark +from time import sleep +from promise import Promise + + +@mark.asyncio +@coroutine +def test_await(): + yield from Promise.resolve(True) + + +@mark.asyncio +@coroutine +def test_await_time(): + def resolve_or_reject(resolve, reject): + sleep(.1) + resolve(True) + + p = Promise(resolve_or_reject) + assert p.get() is True + + +@mark.asyncio +@coroutine +def test_promise_coroutine(): + @coroutine + def my_coro(): + yield from Promise.resolve(True) + + promise = Promise.resolve(my_coro()) + assert isinstance(promise, Promise) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa70503a74554f63fc1bb86928ac3e66e7857f5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_awaitable_35.py @@ -0,0 +1,47 @@ +from asyncio import sleep, Future, wait, FIRST_COMPLETED +from pytest import mark +from promise import Promise, is_thenable + + +@mark.asyncio +async def test_await(): + assert await Promise.resolve(True) + + +@mark.asyncio +async def test_promisify_coroutine(): + async def my_coroutine(): + await sleep(.01) + return True + + assert await Promise.resolve(my_coroutine()) + + +@mark.asyncio +async def test_coroutine_is_thenable(): + async def my_coroutine(): + await sleep(.01) + return True + + assert is_thenable(my_coroutine()) + + +@mark.asyncio +async def test_promisify_future(): + future = Future() + future.set_result(True) + assert await Promise.resolve(future) + + +@mark.asyncio +async def test_await_in_safe_promise(): + async def inner(): + @Promise.safe + def x(): + promise = Promise.resolve(True).then(lambda x: x) + return promise + + return await x() + + result = await inner() + assert result == True diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_benchmark.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..eb30f24e2f587ab7d806013a3fd7af077abc1c2b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_benchmark.py @@ -0,0 +1,116 @@ +from pytest import raises +import time +from promise import Promise, promisify, is_thenable + + +def test_benchmark_promise_creation(benchmark): + @benchmark + def create_promise(): # unnecessary function call + p = Promise() + + +def test_benchmark_promise_resolve(benchmark): + def create_promise(): + return Promise.resolve(True) + + result = benchmark(create_promise).get() + assert result == True + + +def test_benchmark_is_thenable_basic_type(benchmark): + def create_promise(): + return is_thenable(True) + + result = benchmark(create_promise) + assert result == False + + +def test_benchmark_is_thenable_custom_type(benchmark): + class MyType(object): + pass + + my_type_instance = MyType() + + def create_promise(): + return is_thenable(my_type_instance) + + result = benchmark(create_promise) + assert result == False + + +def test_benchmark_promise_creation_with_resolve(benchmark): + do_resolve = lambda resolve, reject: resolve(True) + + def create_promise(): # unnecessary function call + p = Promise(do_resolve) + # p._wait() + return p + + result = benchmark(create_promise).get() + assert result == True + + +def test_benchmark_promise_creation_with_reject(benchmark): + do_resolve = lambda resolve, reject: reject(Exception("Error")) + + def create_promise(): # unnecessary function call + p = Promise(do_resolve) + # p._wait() + return p + + with raises(Exception) as exc_info: + result = benchmark(create_promise).get() + + assert str(exc_info.value) == "Error" + + +# def test_benchmark_promisify_promise(benchmark): +# instance = Promise() + +# def create_promise(): # unnecessary function call +# return promisify(instance) + +# result = benchmark(create_promise) + +# assert isinstance(result, Promise) + + +def test_benchmark_promisify_custom_type(benchmark): + class CustomThenable(object): + pass + # def then(self, resolve, reject): + # return resolve(True) + + instance = CustomThenable() + + def create_promise(): # unnecessary function call + return Promise.resolve(instance) + + result = benchmark(create_promise) + + assert isinstance(result, Promise) + assert result.get() == instance + + +def test_benchmark_promise_all(benchmark): + values = range(1000) + + def create_promise(): # unnecessary function call + return Promise.all(values) + + result = benchmark(create_promise) + + assert isinstance(result, Promise) + assert result.get() == list(range(1000)) + + +def test_benchmark_promise_all_promise(benchmark): + values = [Promise.resolve(i) for i in range(100000)] + + def create_promise(): # unnecessary function call + return Promise.all(values) + + result = benchmark(create_promise) + + assert isinstance(result, Promise) + assert result.get() == list(range(100000)) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_complex_threads.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_complex_threads.py new file mode 100644 index 0000000000000000000000000000000000000000..6cddfaac1e9900c9351c50eeaacbda6f93cef477 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_complex_threads.py @@ -0,0 +1,23 @@ +from time import sleep +from concurrent.futures import ThreadPoolExecutor +from promise import Promise +from operator import mul + +executor = ThreadPoolExecutor(max_workers=40000) + + +def promise_factorial(n): + if n < 2: + return 1 + sleep(.02) + a = executor.submit(promise_factorial, n - 1) + + def promise_then(r): + return mul(r, n) + + return Promise.resolve(a).then(promise_then) + + +def test_factorial(): + p = promise_factorial(10) + assert p.get() == 3628800 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_dataloader.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a8352fc6e3d5f7b6c71d949d87a59bb9ecdd79cd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_dataloader.py @@ -0,0 +1,452 @@ +from pytest import raises + +from promise import Promise, async_instance +from promise.dataloader import DataLoader + + +def id_loader(**options): + load_calls = [] + + resolve = options.pop("resolve", Promise.resolve) + + def fn(keys): + load_calls.append(keys) + return resolve(keys) + + identity_loader = DataLoader(fn, **options) + return identity_loader, load_calls + + +def test_build_a_simple_data_loader(): + def call_fn(keys): + return Promise.resolve(keys) + + identity_loader = DataLoader(call_fn) + + promise1 = identity_loader.load(1) + assert isinstance(promise1, Promise) + + value1 = promise1.get() + assert value1 == 1 + + +def test_supports_loading_multiple_keys_in_one_call(): + def call_fn(keys): + return Promise.resolve(keys) + + identity_loader = DataLoader(call_fn) + + promise_all = identity_loader.load_many([1, 2]) + assert isinstance(promise_all, Promise) + + values = promise_all.get() + assert values == [1, 2] + + promise_all = identity_loader.load_many([]) + assert isinstance(promise_all, Promise) + + values = promise_all.get() + assert values == [] + + +def test_batches_multiple_requests(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + promise1 = identity_loader.load(1) + promise2 = identity_loader.load(2) + + p = Promise.all([promise1, promise2]) + + value1, value2 = p.get() + + assert value1 == 1 + assert value2 == 2 + + assert load_calls == [[1, 2]] + + do().get() + + +def test_batches_multiple_requests_with_max_batch_sizes(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader(max_batch_size=2) + + promise1 = identity_loader.load(1) + promise2 = identity_loader.load(2) + promise3 = identity_loader.load(3) + + p = Promise.all([promise1, promise2, promise3]) + + value1, value2, value3 = p.get() + + assert value1 == 1 + assert value2 == 2 + assert value3 == 3 + + assert load_calls == [[1, 2], [3]] + + do().get() + + +def test_coalesces_identical_requests(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + promise1 = identity_loader.load(1) + promise2 = identity_loader.load(1) + + assert promise1 == promise2 + p = Promise.all([promise1, promise2]) + + value1, value2 = p.get() + + assert value1 == 1 + assert value2 == 1 + + assert load_calls == [[1]] + + do().get() + + +def test_caches_repeated_requests(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + a, b = Promise.all([identity_loader.load("A"), identity_loader.load("B")]).get() + + assert a == "A" + assert b == "B" + + assert load_calls == [["A", "B"]] + + a2, c = Promise.all( + [identity_loader.load("A"), identity_loader.load("C")] + ).get() + + assert a2 == "A" + assert c == "C" + + assert load_calls == [["A", "B"], ["C"]] + + a3, b2, c2 = Promise.all( + [ + identity_loader.load("A"), + identity_loader.load("B"), + identity_loader.load("C"), + ] + ).get() + + assert a3 == "A" + assert b2 == "B" + assert c2 == "C" + + assert load_calls == [["A", "B"], ["C"]] + + do().get() + + +def test_clears_single_value_in_loader(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + a, b = Promise.all([identity_loader.load("A"), identity_loader.load("B")]).get() + + assert a == "A" + assert b == "B" + + assert load_calls == [["A", "B"]] + + identity_loader.clear("A") + + a2, b2 = Promise.all( + [identity_loader.load("A"), identity_loader.load("B")] + ).get() + + assert a2 == "A" + assert b2 == "B" + + assert load_calls == [["A", "B"], ["A"]] + + do().get() + + +def test_clears_all_values_in_loader(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + a, b = Promise.all([identity_loader.load("A"), identity_loader.load("B")]).get() + + assert a == "A" + assert b == "B" + + assert load_calls == [["A", "B"]] + + identity_loader.clear_all() + + a2, b2 = Promise.all( + [identity_loader.load("A"), identity_loader.load("B")] + ).get() + + assert a2 == "A" + assert b2 == "B" + + assert load_calls == [["A", "B"], ["A", "B"]] + + do().get() + + +def test_does_not_replace_cache_map(): + @Promise.safe + def do(): + identity_loader, _ = id_loader() + a, b = Promise.all([identity_loader.load("A"), identity_loader.load("B")]).get() + + assert a == "A" + assert b == "B" + + cache_map = identity_loader._promise_cache + + identity_loader.clear_all() + + assert id(identity_loader._promise_cache) == id(cache_map) + + do().get() + + +def test_allows_priming_the_cache(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + identity_loader.prime("A", "A") + + a, b = Promise.all([identity_loader.load("A"), identity_loader.load("B")]).get() + + assert a == "A" + assert b == "B" + + assert load_calls == [["B"]] + + do().get() + + +def test_does_not_prime_keys_that_already_exist(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + identity_loader.prime("A", "X") + + a1 = identity_loader.load("A").get() + b1 = identity_loader.load("B").get() + + assert a1 == "X" + assert b1 == "B" + + identity_loader.prime("A", "Y") + identity_loader.prime("B", "Y") + + a2 = identity_loader.load("A").get() + b2 = identity_loader.load("B").get() + + assert a2 == "X" + assert b2 == "B" + + assert load_calls == [["B"]] + + do().get() + + +# Represents Errors + + +def test_resolves_to_error_to_indicate_failure(): + @Promise.safe + def do(): + def resolve(keys): + mapped_keys = [ + key if key % 2 == 0 else Exception("Odd: {}".format(key)) + for key in keys + ] + return Promise.resolve(mapped_keys) + + even_loader, load_calls = id_loader(resolve=resolve) + + with raises(Exception) as exc_info: + even_loader.load(1).get() + + assert str(exc_info.value) == "Odd: 1" + + value2 = even_loader.load(2).get() + assert value2 == 2 + assert load_calls == [[1], [2]] + + do().get() + + +def test_can_represent_failures_and_successes_simultaneously(): + @Promise.safe + def do(): + def resolve(keys): + mapped_keys = [ + key if key % 2 == 0 else Exception("Odd: {}".format(key)) + for key in keys + ] + return Promise.resolve(mapped_keys) + + even_loader, load_calls = id_loader(resolve=resolve) + + promise1 = even_loader.load(1) + promise2 = even_loader.load(2) + + with raises(Exception) as exc_info: + promise1.get() + + assert str(exc_info.value) == "Odd: 1" + value2 = promise2.get() + assert value2 == 2 + assert load_calls == [[1, 2]] + + do().get() + + +def test_caches_failed_fetches(): + @Promise.safe + def do(): + def resolve(keys): + mapped_keys = [Exception("Error: {}".format(key)) for key in keys] + return Promise.resolve(mapped_keys) + + error_loader, load_calls = id_loader(resolve=resolve) + + with raises(Exception) as exc_info: + error_loader.load(1).get() + + assert str(exc_info.value) == "Error: 1" + + with raises(Exception) as exc_info: + error_loader.load(1).get() + + assert str(exc_info.value) == "Error: 1" + + assert load_calls == [[1]] + + do().get() + + +def test_caches_failed_fetches(): + @Promise.safe + def do(): + identity_loader, load_calls = id_loader() + + identity_loader.prime(1, Exception("Error: 1")) + + with raises(Exception) as exc_info: + identity_loader.load(1).get() + + assert load_calls == [] + + do().get() + + +# It is resilient to job queue ordering +# def test_batches_loads_occuring_within_promises(): +# @Promise.safe +# def do(): +# identity_loader, load_calls = id_loader() +# values = Promise.all([ +# identity_loader.load('A'), +# Promise.resolve(None).then(lambda v: Promise.resolve(None)).then( +# lambda v: identity_loader.load('B') +# ) +# ]).get() + +# assert values == ['A', 'B'] +# assert load_calls == [['A', 'B']] + +# do().get() + + +def test_catches_error_if_loader_resolver_fails(): + @Promise.safe + def do(): + def do_resolve(x): + raise Exception("AOH!") + + a_loader, a_load_calls = id_loader(resolve=do_resolve) + + with raises(Exception) as exc_info: + a_loader.load("A1").get() + + assert str(exc_info.value) == "AOH!" + + do().get() + + +def test_can_call_a_loader_from_a_loader(): + @Promise.safe + def do(): + deep_loader, deep_load_calls = id_loader() + a_loader, a_load_calls = id_loader( + resolve=lambda keys: deep_loader.load(tuple(keys)) + ) + b_loader, b_load_calls = id_loader( + resolve=lambda keys: deep_loader.load(tuple(keys)) + ) + + a1, b1, a2, b2 = Promise.all( + [ + a_loader.load("A1"), + b_loader.load("B1"), + a_loader.load("A2"), + b_loader.load("B2"), + ] + ).get() + + assert a1 == "A1" + assert b1 == "B1" + assert a2 == "A2" + assert b2 == "B2" + + assert a_load_calls == [["A1", "A2"]] + assert b_load_calls == [["B1", "B2"]] + assert deep_load_calls == [[("A1", "A2"), ("B1", "B2")]] + + do().get() + + +def test_dataloader_clear_with_missing_key_works(): + @Promise.safe + def do(): + def do_resolve(x): + return x + + a_loader, a_load_calls = id_loader(resolve=do_resolve) + assert a_loader.clear("A1") == a_loader + + do().get() + + +def test_wrong_loader_return_type_does_not_block_async_instance(): + @Promise.safe + def do(): + def do_resolve(x): + return x + + a_loader, a_load_calls = id_loader(resolve=do_resolve) + + with raises(Exception): + a_loader.load("A1").get() + assert async_instance.have_drained_queues + with raises(Exception): + a_loader.load("A2").get() + assert async_instance.have_drained_queues + + do().get() diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py new file mode 100644 index 0000000000000000000000000000000000000000..88161dd595ea19b608ca27a76b03621457469ad2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_dataloader_awaitable_35.py @@ -0,0 +1,99 @@ +from pytest import mark +from promise import Promise +from promise.dataloader import DataLoader + + +def id_loader(**options): + load_calls = [] + + resolve = options.pop("resolve", Promise.resolve) + + def fn(keys): + load_calls.append(keys) + return resolve(keys) + + identity_loader = DataLoader(fn, **options) + return identity_loader, load_calls + + +@mark.asyncio +async def test_await_dataloader(): + identity_loader, load_calls = id_loader() + + async def load_multiple(identity_loader): + one = identity_loader.load("load1") + two = identity_loader.load("load2") + return await Promise.all([one, two]) + + result = await load_multiple(identity_loader) + assert result == ["load1", "load2"] + assert load_calls == [["load1"], ["load2"]] + + +@mark.asyncio +async def test_await_dataloader_safe_promise(): + identity_loader, load_calls = id_loader() + + @Promise.safe + async def load_multiple(identity_loader): + one = identity_loader.load("load1") + two = identity_loader.load("load2") + return await Promise.all([one, two]) + + result = await load_multiple(identity_loader) + assert result == ["load1", "load2"] + assert load_calls == [["load1"], ["load2"]] + + +@mark.asyncio +async def test_await_dataloader_individual(): + identity_loader, load_calls = id_loader() + + async def load_one_then_two(identity_loader): + one = await identity_loader.load("load1") + two = await identity_loader.load("load2") + return [one, two] + + result = await load_one_then_two(identity_loader) + assert result == ["load1", "load2"] + assert load_calls == [["load1"], ["load2"]] + + +@mark.asyncio +async def test_await_dataloader_individual_safe_promise(): + identity_loader, load_calls = id_loader() + + @Promise.safe + async def load_one_then_two(identity_loader): + one = await identity_loader.load("load1") + two = await identity_loader.load("load2") + return [one, two] + + result = await load_one_then_two(identity_loader) + assert result == ["load1", "load2"] + assert load_calls == [["load1"], ["load2"]] + + +@mark.asyncio +async def test_await_dataloader_two(): + identity_loader, load_calls = id_loader() + + async def load_one_then_two(identity_loader): + one = await identity_loader.load("load1") + two = await identity_loader.load("load2") + return (one, two) + + result12 = await Promise.all([load_one_then_two(identity_loader)]) + + +@mark.asyncio +async def test_await_dataloader_two_safe_promise(): + identity_loader, load_calls = id_loader() + + @Promise.safe + async def load_one_then_two(identity_loader): + one = await identity_loader.load("load1") + two = await identity_loader.load("load2") + return (one, two) + + result12 = await Promise.all([load_one_then_two(identity_loader)]) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_extra.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_extra.py new file mode 100644 index 0000000000000000000000000000000000000000..4a083718d134ede122d5a761c592318a9fa5748e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_extra.py @@ -0,0 +1,670 @@ +# This exercises some capabilities above and beyond +# the Promises/A+ test suite +from time import sleep +from pytest import raises, fixture + +from threading import Event +from promise import ( + Promise, + is_thenable, + promisify, + promise_for_dict as free_promise_for_dict, +) +from concurrent.futures import Future +from threading import Thread + +from .utils import assert_exception + + +class DelayedFulfill(Thread): + def __init__(self, d, p, v): + self.delay = d + self.promise = p + self.value = v + Thread.__init__(self) + + def run(self): + sleep(self.delay) + self.promise.do_resolve(self.value) + + +class DelayedRejection(Thread): + def __init__(self, d, p, r): + self.delay = d + self.promise = p + self.reason = r + Thread.__init__(self) + + def run(self): + sleep(self.delay) + self.promise.do_reject(self.reason) + + +class FakeThenPromise: + def __init__(self, raises=True): + self.raises = raises + + def then(self, s=None, f=None): + if self.raises: + raise Exception("FakeThenPromise raises in 'then'") + + +def df(value, dtime): + p = Promise() + t = DelayedFulfill(dtime, p, value) + t.start() + + return p + + +def dr(reason, dtime): + p = Promise() + t = DelayedRejection(dtime, p, reason) + t.start() + + return p + + +# Static methods +def test_fulfilled(): + p = Promise.fulfilled(4) + assert p.is_fulfilled + assert p.get() == 4 + + +def test_rejected(): + p = Promise.rejected(Exception("Static rejected")) + assert p.is_rejected + with raises(Exception) as exc_info: + p.get() + assert str(exc_info.value) == "Static rejected" + + +# Fulfill +def test_fulfill_self(): + p = Promise() + with raises(TypeError) as excinfo: + p.do_resolve(p) + p.get() + + +# Exceptions +def test_exceptions(): + def throws(v): + assert False + + p1 = Promise() + p1.then(throws) + p1.do_resolve(5) + + p2 = Promise() + p2.catch(throws) + p2.do_reject(Exception()) + + with raises(Exception) as excinfo: + p2.get() + + +def test_thrown_exceptions_have_stacktrace(): + def throws(v): + assert False + + p3 = Promise.resolve("a").then(throws) + with raises(AssertionError) as assert_exc: + p3.get() + + assert assert_exc.traceback[-1].path.strpath == __file__ + + +def test_thrown_exceptions_preserve_stacktrace(): + def throws(v): + assert False + + def after_throws(v): + pass + + p3 = Promise.resolve("a").then(throws).then(after_throws) + with raises(AssertionError) as assert_exc: + p3.get() + + assert assert_exc.traceback[-1].path.strpath == __file__ + + +# WAIT +# def test_wait_when(): +# p1 = df(5, 0.01) +# assert p1.is_pending +# p1._wait() +# assert p1.is_fulfilled + + +def test_wait_if(): + p1 = Promise() + p1.do_resolve(5) + p1._wait() + assert p1.is_fulfilled + + +# def test_wait_timeout(): +# p1 = df(5, 0.1) +# assert p1.is_pending +# with raises(Exception) as exc_info: +# p1._wait(timeout=0.05) +# assert str(exc_info.value) == "Timeout" +# assert p1.is_pending +# p1._wait() +# assert p1.is_fulfilled + + +# # GET +# def test_get_when(): +# p1 = df(5, 0.01) +# assert p1.is_pending +# v = p1.get() +# assert p1.is_fulfilled +# assert 5 == v + + +def test_get_if(): + p1 = Promise() + p1.do_resolve(5) + v = p1.get() + assert p1.is_fulfilled + assert 5 == v + + +# def test_get_timeout(): +# p1 = df(5, 0.1) +# assert p1.is_pending +# with raises(Exception) as exc_info: +# p1._wait(timeout=0.05) +# assert str(exc_info.value) == "Timeout" +# assert p1.is_pending +# v = p1.get() +# assert p1.is_fulfilled +# assert 5 == v + + +# Promise.all +def test_promise_all_when(): + p1 = Promise() + p2 = Promise() + pl = Promise.all([p1, p2]) + assert p1.is_pending + assert p2.is_pending + assert pl.is_pending + p1.do_resolve(5) + p1._wait() + assert p1.is_fulfilled + assert p2.is_pending + assert pl.is_pending + p2.do_resolve(10) + p2._wait() + pl._wait() + assert p1.is_fulfilled + assert p2.is_fulfilled + assert pl.is_fulfilled + assert 5 == p1.get() + assert 10 == p2.get() + assert 5 == pl.get()[0] + assert 10 == pl.get()[1] + + +def test_promise_all_when_mixed_promises(): + p1 = Promise() + p2 = Promise() + pl = Promise.all([p1, 32, p2, False, True]) + assert p1.is_pending + assert p2.is_pending + assert pl.is_pending + p1.do_resolve(5) + p1._wait() + assert p1.is_fulfilled + assert p2.is_pending + assert pl.is_pending + p2.do_resolve(10) + p2._wait() + pl._wait() + assert p1.is_fulfilled + assert p2.is_fulfilled + assert pl.is_fulfilled + assert 5 == p1.get() + assert 10 == p2.get() + assert pl.get() == [5, 32, 10, False, True] + + +def test_promise_all_when_if_no_promises(): + pl = Promise.all([10, 32, False, True]) + assert pl.get() == [10, 32, False, True] + + +def test_promise_all_if(): + p1 = Promise() + p2 = Promise() + pd1 = Promise.all([p1, p2]) + pd2 = Promise.all([p1]) + pd3 = Promise.all([]) + pd3._wait() + assert p1.is_pending + assert p2.is_pending + assert pd1.is_pending + assert pd2.is_pending + assert pd3.is_fulfilled + p1.do_resolve(5) + p1._wait() + pd2._wait() + assert p1.is_fulfilled + assert p2.is_pending + assert pd1.is_pending + assert pd2.is_fulfilled + p2.do_resolve(10) + p2._wait() + pd1._wait() + pd2._wait() + assert p1.is_fulfilled + assert p2.is_fulfilled + assert pd1.is_fulfilled + assert pd2.is_fulfilled + assert 5 == p1.get() + assert 10 == p2.get() + assert 5 == pd1.get()[0] + assert 5 == pd2.get()[0] + assert 10 == pd1.get()[1] + assert [] == pd3.get() + + +# promise_for_dict +@fixture(params=[Promise.for_dict, free_promise_for_dict]) +def promise_for_dict(request): + return request.param + + +def test_dict_promise_when(promise_for_dict): + p1 = Promise() + p2 = Promise() + d = {"a": p1, "b": p2} + pd1 = promise_for_dict(d) + pd2 = promise_for_dict({"a": p1}) + pd3 = promise_for_dict({}) + assert p1.is_pending + assert p2.is_pending + assert pd1.is_pending + assert pd2.is_pending + pd3._wait() + assert pd3.is_fulfilled + p1.do_resolve(5) + p1._wait() + pd2._wait() + assert p1.is_fulfilled + assert p2.is_pending + assert pd1.is_pending + assert pd2.is_fulfilled + p2.do_resolve(10) + p2._wait() + pd1._wait() + assert p1.is_fulfilled + assert p2.is_fulfilled + assert pd1.is_fulfilled + assert pd2.is_fulfilled + assert 5 == p1.get() + assert 10 == p2.get() + assert 5 == pd1.get()["a"] + assert 5 == pd2.get()["a"] + assert 10 == pd1.get()["b"] + assert {} == pd3.get() + + +def test_dict_promise_if(promise_for_dict): + p1 = Promise() + p2 = Promise() + d = {"a": p1, "b": p2} + pd = promise_for_dict(d) + assert p1.is_pending + assert p2.is_pending + assert pd.is_pending + p1.do_resolve(5) + p1._wait() + assert p1.is_fulfilled + assert p2.is_pending + assert pd.is_pending + p2.do_resolve(10) + p2._wait() + assert p1.is_fulfilled + assert p2.is_fulfilled + # pd._wait() + # assert pd.is_fulfilled + # assert 5 == p1.get() + # assert 10 == p2.get() + # assert 5 == pd.get()["a"] + # assert 10 == pd.get()["b"] + + +def test_done(): + counter = [0] + r = Promise() + + def inc(_): + counter[0] += 1 + + def dec(_): + counter[0] -= 1 + + def end(_): + r.do_resolve(None) + + p = Promise() + p.done(inc, dec) + p.done(inc, dec) + p.done(end) + p.do_resolve(4) + + Promise.wait(r) + assert counter[0] == 2 + + r = Promise() + + counter = [0] + p = Promise() + p.done(inc, dec) + p.done(inc, dec) + p.done(None, end) + p.do_reject(Exception()) + + Promise.wait(r) + assert counter[0] == -2 + + +def test_done_all(): + counter = [0] + + def inc(_): + counter[0] += 1 + + def dec(_): + counter[0] -= 1 + + p = Promise() + r = Promise() + p.done_all() + p.done_all([(inc, dec)]) + p.done_all( + [ + (inc, dec), + (inc, dec), + {"success": inc, "failure": dec}, + lambda _: r.do_resolve(None), + ] + ) + p.do_resolve(4) + Promise.wait(r) + assert counter[0] == 4 + + p = Promise() + r = Promise() + p.done_all() + p.done_all([inc]) + p.done_all([(inc, dec)]) + p.done_all( + [ + (inc, dec), + {"success": inc, "failure": dec}, + (None, lambda _: r.do_resolve(None)), + ] + ) + p.do_reject(Exception("Uh oh!")) + Promise.wait(r) + assert counter[0] == 1 + + +def test_then_all(): + p = Promise() + + handlers = [ + ((lambda x: x * x), (lambda r: 1)), + {"success": (lambda x: x + x), "failure": (lambda r: 2)}, + ] + + results = ( + p.then_all() + + p.then_all([lambda x: x]) + + p.then_all([(lambda x: x * x, lambda r: 1)]) + + p.then_all(handlers) + ) + + p.do_resolve(4) + + assert [r.get() for r in results] == [4, 16, 16, 8] + + p = Promise() + + handlers = [ + ((lambda x: x * x), (lambda r: 1)), + {"success": (lambda x: x + x), "failure": (lambda r: 2)}, + ] + + results = ( + p.then_all() + + p.then_all([(lambda x: x * x, lambda r: 1)]) + + p.then_all(handlers) + ) + + p.do_reject(Exception()) + + assert [r.get() for r in results] == [1, 1, 2] + + +def test_do_resolve(): + p1 = Promise(lambda resolve, reject: resolve(0)) + assert p1.get() == 0 + assert p1.is_fulfilled + + +def test_do_resolve_fail_on_call(): + def raises(resolve, reject): + raise Exception("Fails") + + p1 = Promise(raises) + assert not p1.is_fulfilled + assert str(p1.reason) == "Fails" + + +def test_catch(): + p1 = Promise(lambda resolve, reject: resolve(0)) + p2 = p1.then(lambda value: 1 / value).catch(lambda e: e).then(lambda e: type(e)) + assert p2.get() == ZeroDivisionError + assert p2.is_fulfilled + + +def test_is_thenable_promise(): + promise = Promise() + assert is_thenable(promise) + + +def test_is_thenable_then_object(): + promise = FakeThenPromise() + assert not is_thenable(promise) + + +def test_is_thenable_future(): + promise = Future() + assert is_thenable(promise) + + +def test_is_thenable_simple_object(): + assert not is_thenable(object()) + + +@fixture(params=[Promise.resolve]) +def resolve(request): + return request.param + + +def test_resolve_promise(resolve): + promise = Promise() + assert resolve(promise) == promise + + +def test_resolve_then_object(resolve): + promise = FakeThenPromise(raises=False) + p = resolve(promise) + assert isinstance(p, Promise) + + +def test_resolve_future(resolve): + future = Future() + promise = resolve(future) + assert promise.is_pending + future.set_result(1) + assert promise.get() == 1 + assert promise.is_fulfilled + + +def test_resolve_future_rejected(resolve): + future = Future() + promise = resolve(future) + assert promise.is_pending + future.set_exception(Exception("Future rejected")) + assert promise.is_rejected + assert_exception(promise.reason, Exception, "Future rejected") + + +def test_resolve_object(resolve): + val = object() + promised = resolve(val) + assert isinstance(promised, Promise) + assert promised.get() == val + + +def test_resolve_promise_subclass(): + class MyPromise(Promise): + pass + + p = Promise() + p.do_resolve(10) + m_p = MyPromise.resolve(p) + + assert isinstance(m_p, MyPromise) + assert m_p.get() == p.get() + + +def test_promise_repr_pending(): + promise = Promise() + assert repr(promise) == "".format(hex(id(promise))) + + +def test_promise_repr_pending(): + val = {1: 2} + promise = Promise.fulfilled(val) + promise._wait() + assert repr(promise) == "".format( + hex(id(promise)), repr(val) + ) + + +def test_promise_repr_fulfilled(): + val = {1: 2} + promise = Promise.fulfilled(val) + promise._wait() + assert repr(promise) == "".format( + hex(id(promise)), repr(val) + ) + + +def test_promise_repr_rejected(): + err = Exception("Error!") + promise = Promise.rejected(err) + promise._wait() + assert repr(promise) == "".format( + hex(id(promise)), repr(err) + ) + + +def test_promise_loop(): + def by_two(result): + return result * 2 + + def executor(resolve, reject): + resolve(Promise.resolve(1).then(lambda v: Promise.resolve(v).then(by_two))) + + p = Promise(executor) + assert p.get(.1) == 2 + + +def test_resolve_future_like(resolve): + class CustomThenable(object): + def add_done_callback(self, f): + f(True) + + def done(self): + return True + + def exception(self): + pass + + def result(self): + return True + + instance = CustomThenable() + + promise = resolve(instance) + assert promise.get() == True + + +def sum_function(a, b): + return a + b + + +def test_promisify_function_resolved(resolve): + promisified_func = promisify(sum_function) + + result = promisified_func(1, 2) + assert isinstance(result, Promise) + assert result.get() == 3 + + +def test_promisify_function_rejected(resolve): + promisified_func = promisify(sum_function) + + result = promisified_func(None, None) + assert isinstance(result, Promise) + with raises(Exception) as exc_info_promise: + result.get() + + with raises(Exception) as exc_info: + sum_function(None, None) + + assert str(exc_info_promise.value) == str(exc_info.value) + + +def test_promises_with_only_then(): + context = {"success": False} + error = RuntimeError("Ooops!") + promise1 = Promise( + lambda resolve, reject: context.update({"promise1_reject": reject}) + ) + promise2 = promise1.then(lambda x: None) + promise3 = promise1.then(lambda x: None) + context["promise1_reject"](error) + + promise2._wait() + promise3._wait() + assert promise2.reason == error + assert promise3.reason == error + + +def test_promises_promisify_still_works_but_deprecated_for_non_callables(): + x = promisify(1) + assert isinstance(x, Promise) + assert x.get() == 1 + + +# def test_promise_loop(): +# values = Promise.resolve([1, None, 2]) +# def on_error(error): +# error + +# def executor(resolve, reject): +# resolve(Promise.resolve(values).then(lambda values: Promise.all([Promise.resolve(values[0])]).catch(on_error))) + +# p = Promise(executor) +# assert p.get(.1) == 2 diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_issues.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_issues.py new file mode 100644 index 0000000000000000000000000000000000000000..91974090443f6600bcee908b7412dd41c8049b03 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_issues.py @@ -0,0 +1,132 @@ +# This tests reported issues in the Promise package +from concurrent.futures import ThreadPoolExecutor +from promise import Promise +import time +import weakref +import gc + +executor = ThreadPoolExecutor(max_workers=40000) + + +def test_issue_11(): + # https://github.com/syrusakbary/promise/issues/11 + def test(x): + def my(resolve, reject): + if x > 0: + resolve(x) + else: + reject(Exception(x)) + + return Promise(my) + + promise_resolved = test(42).then(lambda x: x) + assert promise_resolved.get() == 42 + + promise_rejected = test(-42).then(lambda x: x, lambda e: str(e)) + assert promise_rejected.get() == "-42" + + +def identity(x, wait): + if wait: + time.sleep(wait) + return x + + +def promise_with_wait(x, wait): + return Promise.resolve(identity(x, wait)) + + +def test_issue_9(): + no_wait = Promise.all( + [promise_with_wait(x1, None).then(lambda y: x1 * y) for x1 in (0, 1, 2, 3)] + ).get() + wait_a_bit = Promise.all( + [promise_with_wait(x2, 0.05).then(lambda y: x2 * y) for x2 in (0, 1, 2, 3)] + ).get() + wait_longer = Promise.all( + [promise_with_wait(x3, 0.1).then(lambda y: x3 * y) for x3 in (0, 1, 2, 3)] + ).get() + + assert no_wait == wait_a_bit + assert no_wait == wait_longer + + +@Promise.safe +def test_issue_9_safe(): + no_wait = Promise.all( + [promise_with_wait(x1, None).then(lambda y: x1 * y) for x1 in (0, 1, 2, 3)] + ).get() + wait_a_bit = Promise.all( + [promise_with_wait(x2, 0.05).then(lambda y: x2 * y) for x2 in (0, 1, 2, 3)] + ).get() + wait_longer = Promise.all( + [promise_with_wait(x3, 0.1).then(lambda y: x3 * y) for x3 in (0, 1, 2, 3)] + ).get() + + assert no_wait == [0, 3, 6, 9] + assert no_wait == wait_a_bit + assert no_wait == wait_longer + + +def test_issue_26(): + context = {"success": False} + promise1 = Promise( + lambda resolve, reject: context.update({"promise1_reject": reject}) + ) + promise1.then(lambda x: None) + promise1.then(lambda x: None) + context["promise1_reject"](RuntimeError("Ooops!")) + + promise2 = Promise( + lambda resolve, reject: context.update({"promise2_resolve": resolve}) + ) + promise3 = promise2.then(lambda x: context.update({"success": True})) + context["promise2_resolve"](None) + + # We wait so it works in asynchronous envs + promise3._wait(timeout=.1) + assert context["success"] + + +# def promise_in_executor(x, wait): +# return Promise.promisify(executor.submit(identity, x, wait)) + + +# @Promise.safe +# def test_issue_9_extra(): +# no_wait = Promise.all([promise_in_executor(x1, None).then(lambda y: x1*y) for x1 in (0,1,2,3)]).get() +# wait_a_bit = Promise.all([promise_in_executor(x2, 0.1).then(lambda y: x2*y) for x2 in (0,1,2,3)]).get() +# wait_longer = Promise.all([promise_in_executor(x3, 0.5).then(lambda y: x3*y) for x3 in (0,1,2,3)]).get() + +# assert no_wait == [0, 3, 6, 9] +# assert no_wait == wait_a_bit +# assert no_wait == wait_longer + + +def test_issue_33(): + def do(x): + v = Promise.resolve("ok").then(lambda x: x).get() + return v + + p = Promise.resolve(None).then(do) + assert p.get() == "ok" + + +def test_issue_75(): + def function_with_local_type(): + class A: + pass + + a = A() + assert a == Promise.resolve(a).get() + + return weakref.ref(A) + + weak_reference = function_with_local_type() + + # The local type 'A' from the function is still kept alive by reference cycles. + gc.collect() + + # Now the local type should have been garbage collected, + # such that the weak reference should be invalid. + assert not weak_reference() diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_promise_list.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_promise_list.py new file mode 100644 index 0000000000000000000000000000000000000000..e8dc35a99a92f083ce92552d6bc395f197475e93 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_promise_list.py @@ -0,0 +1,70 @@ +from pytest import raises + +from promise import Promise +from promise.promise_list import PromiseList + + +def all(promises): + return PromiseList(promises, Promise).promise + + +def test_empty_promises(): + all_promises = all([]) + assert all_promises.get() == [] + + +def test_bad_promises(): + all_promises = all(None) + + with raises(Exception) as exc_info: + all_promises.get() + + assert str(exc_info.value) == "PromiseList requires an iterable. Received None." + + +def test_promise_basic(): + all_promises = all([1, 2]) + assert all_promises.get() == [1, 2] + + +def test_promise_mixed(): + all_promises = all([1, 2, Promise.resolve(3)]) + assert all_promises.get() == [1, 2, 3] + + +def test_promise_rejected(): + e = Exception("Error") + all_promises = all([1, 2, Promise.reject(e)]) + + with raises(Exception) as exc_info: + all_promises.get() + + assert str(exc_info.value) == "Error" + + +def test_promise_reject_skip_all_other_values(): + e1 = Exception("Error1") + e2 = Exception("Error2") + p = Promise() + all_promises = all([1, Promise.reject(e1), Promise.reject(e2)]) + + with raises(Exception) as exc_info: + all_promises.get() + + assert str(exc_info.value) == "Error1" + + +def test_promise_lazy_promise(): + p = Promise() + all_promises = all([1, 2, p]) + assert not all_promises.is_fulfilled + p.do_resolve(3) + assert all_promises.get() == [1, 2, 3] + + +def test_promise_contained_promise(): + p = Promise() + all_promises = all([1, 2, Promise.resolve(None).then(lambda v: p)]) + assert not all_promises.is_fulfilled + p.do_resolve(3) + assert all_promises.get() == [1, 2, 3] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_thread_safety.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_thread_safety.py new file mode 100644 index 0000000000000000000000000000000000000000..ed55a84ff70b71f4a3b465a9c2cc5c9646f7c210 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/test_thread_safety.py @@ -0,0 +1,115 @@ +from promise import Promise +from promise.dataloader import DataLoader +import threading + + + +def test_promise_thread_safety(): + """ + Promise tasks should never be executed in a different thread from the one they are scheduled from, + unless the ThreadPoolExecutor is used. + + Here we assert that the pending promise tasks on thread 1 are not executed on thread 2 as thread 2 + resolves its own promise tasks. + """ + event_1 = threading.Event() + event_2 = threading.Event() + + assert_object = {'is_same_thread': True} + + def task_1(): + thread_name = threading.current_thread().getName() + + def then_1(value): + # Enqueue tasks to run later. + # This relies on the fact that `then` does not execute the function synchronously when called from + # within another `then` callback function. + promise = Promise.resolve(None).then(then_2) + assert promise.is_pending + event_1.set() # Unblock main thread + event_2.wait() # Wait for thread 2 + + def then_2(value): + assert_object['is_same_thread'] = (thread_name == threading.current_thread().getName()) + + promise = Promise.resolve(None).then(then_1) + + def task_2(): + promise = Promise.resolve(None).then(lambda v: None) + promise.get() # Drain task queue + event_2.set() # Unblock thread 1 + + thread_1 = threading.Thread(target=task_1) + thread_1.start() + + event_1.wait() # Wait for Thread 1 to enqueue promise tasks + + thread_2 = threading.Thread(target=task_2) + thread_2.start() + + for thread in (thread_1, thread_2): + thread.join() + + assert assert_object['is_same_thread'] + + +def test_dataloader_thread_safety(): + """ + Dataloader should only batch `load` calls that happened on the same thread. + + Here we assert that `load` calls on thread 2 are not batched on thread 1 as + thread 1 batches its own `load` calls. + """ + def load_many(keys): + thead_name = threading.current_thread().getName() + return Promise.resolve([thead_name for key in keys]) + + thread_name_loader = DataLoader(load_many) + + event_1 = threading.Event() + event_2 = threading.Event() + event_3 = threading.Event() + + assert_object = { + 'is_same_thread_1': True, + 'is_same_thread_2': True, + } + + def task_1(): + @Promise.safe + def do(): + promise = thread_name_loader.load(1) + event_1.set() + event_2.wait() # Wait for thread 2 to call `load` + assert_object['is_same_thread_1'] = ( + promise.get() == threading.current_thread().getName() + ) + event_3.set() # Unblock thread 2 + + do().get() + + def task_2(): + @Promise.safe + def do(): + promise = thread_name_loader.load(2) + event_2.set() + event_3.wait() # Wait for thread 1 to run `dispatch_queue_batch` + assert_object['is_same_thread_2'] = ( + promise.get() == threading.current_thread().getName() + ) + + do().get() + + thread_1 = threading.Thread(target=task_1) + thread_1.start() + + event_1.wait() # Wait for thread 1 to call `load` + + thread_2 = threading.Thread(target=task_2) + thread_2.start() + + for thread in (thread_1, thread_2): + thread.join() + + assert assert_object['is_same_thread_1'] + assert assert_object['is_same_thread_2'] diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/utils.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73dd1619c250ecb2b42680afffad6884ae97282f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/tests/utils.py @@ -0,0 +1,3 @@ +def assert_exception(exception, expected_exception_cls, expected_message): + assert isinstance(exception, expected_exception_cls) + assert str(exception) == expected_message diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/__pycache__/__init__.cpython-313.pyc b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f05bb82019cf6cf52e75b3c9ddf5c86f087ceddd Binary files /dev/null and b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/__pycache__/__init__.cpython-313.pyc differ diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/compat.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..21b091989858e0588a4dc5e1ad10baa677cd7e0b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/compat.py @@ -0,0 +1,32 @@ +try: + from inspect import iscoroutine +except ImportError: + + def iscoroutine(obj): # type: ignore + return False + + +try: + from asyncio import Future, ensure_future # type: ignore +except ImportError: + + class Future: # type: ignore + def __init__(self): + raise Exception("You need asyncio for using Futures") + + def set_result(self): + raise Exception("You need asyncio for using Futures") + + def set_exception(self): + raise Exception("You need asyncio for using Futures") + + def ensure_future(): # type: ignore + raise Exception("ensure_future needs asyncio for executing") + + +try: + from .iterate_promise import iterate_promise +except (SyntaxError, ImportError): + + def iterate_promise(promise): # type: ignore + raise Exception('You need "yield from" syntax for iterate in a Promise.') diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..cb1dd0dd472671749cd26c15bc2fa82cd6facb73 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/promise-2.3.0/wandb_promise/dataloader.py @@ -0,0 +1,326 @@ +from collections import namedtuple +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable +from functools import partial +from threading import local + +from .promise import Promise, async_instance, get_default_scheduler + +if False: + from typing import ( + Any, + List, + Sized, + Callable, + Optional, + Tuple, + Union, + Iterator, + Hashable, + ) # flake8: noqa + + +def get_chunks(iterable_obj, chunk_size=1): + # type: (List[Loader], int) -> Iterator + chunk_size = max(1, chunk_size) + return ( + iterable_obj[i : i + chunk_size] + for i in range(0, len(iterable_obj), chunk_size) + ) + + +Loader = namedtuple("Loader", "key,resolve,reject") + + +class DataLoader(local): + + batch = True + max_batch_size = None # type: int + cache = True + + def __init__( + self, + batch_load_fn=None, # type: Callable + batch=None, # type: Optional[Any] + max_batch_size=None, # type: Optional[int] + cache=None, # type: Optional[Any] + get_cache_key=None, # type: Optional[Any] + cache_map=None, # type: Optional[Any] + scheduler=None, # type: Optional[Any] + ): + # type: (...) -> None + + if batch_load_fn is not None: + self.batch_load_fn = batch_load_fn + + if not callable(self.batch_load_fn): + raise TypeError( + ( + "DataLoader must be have a batch_load_fn which accepts " + "List and returns Promise>, but got: {}." + ).format(batch_load_fn) + ) + + if batch is not None: + self.batch = batch + + if max_batch_size is not None: + self.max_batch_size = max_batch_size + + if cache is not None: + self.cache = cache + + self.get_cache_key = get_cache_key or (lambda x: x) + self._promise_cache = cache_map or {} + self._queue = [] # type: List[Loader] + self._scheduler = scheduler + + def load(self, key=None): + # type: (Hashable) -> Promise + """ + Loads a key, returning a `Promise` for the value represented by that key. + """ + if key is None: + raise TypeError( + ( + "The loader.load() function must be called with a value," + + "but got: {}." + ).format(key) + ) + + cache_key = self.get_cache_key(key) + + # If caching and there is a cache-hit, return cached Promise. + if self.cache: + cached_promise = self._promise_cache.get(cache_key) + if cached_promise: + return cached_promise + + # Otherwise, produce a new Promise for this value. + + promise = Promise(partial(self.do_resolve_reject, key)) # type: ignore + + # If caching, cache this promise. + if self.cache: + self._promise_cache[cache_key] = promise + + return promise + + def do_resolve_reject(self, key, resolve, reject): + # type: (Hashable, Callable, Callable) -> None + # Enqueue this Promise to be dispatched. + self._queue.append(Loader(key=key, resolve=resolve, reject=reject)) + # Determine if a dispatch of this queue should be scheduled. + # A single dispatch should be scheduled per queue at the time when the + # queue changes from "empty" to "full". + if len(self._queue) == 1: + if self.batch: + # If batching, schedule a task to dispatch the queue. + enqueue_post_promise_job(partial(dispatch_queue, self), self._scheduler) + else: + # Otherwise dispatch the (queue of one) immediately. + dispatch_queue(self) + + def load_many(self, keys): + # type: (Iterable[Hashable]) -> Promise + """ + Loads multiple keys, promising an array of values + + >>> a, b = await my_loader.load_many([ 'a', 'b' ]) + + This is equivalent to the more verbose: + + >>> a, b = await Promise.all([ + >>> my_loader.load('a'), + >>> my_loader.load('b') + >>> ]) + """ + if not isinstance(keys, Iterable): + raise TypeError( + ( + "The loader.loadMany() function must be called with Array " + + "but got: {}." + ).format(keys) + ) + + return Promise.all([self.load(key) for key in keys]) + + def clear(self, key): + # type: (Hashable) -> DataLoader + """ + Clears the value at `key` from the cache, if it exists. Returns itself for + method chaining. + """ + cache_key = self.get_cache_key(key) + self._promise_cache.pop(cache_key, None) + return self + + def clear_all(self): + # type: () -> DataLoader + """ + Clears the entire cache. To be used when some event results in unknown + invalidations across this particular `DataLoader`. Returns itself for + method chaining. + """ + self._promise_cache.clear() + return self + + def prime(self, key, value): + # type: (Hashable, Any) -> DataLoader + """ + Adds the provied key and value to the cache. If the key already exists, no + change is made. Returns itself for method chaining. + """ + cache_key = self.get_cache_key(key) + + # Only add the key if it does not already exist. + if cache_key not in self._promise_cache: + # Cache a rejected promise if the value is an Error, in order to match + # the behavior of load(key). + if isinstance(value, Exception): + promise = Promise.reject(value) + else: + promise = Promise.resolve(value) + + self._promise_cache[cache_key] = promise + + return self + + +# Private: Enqueue a Job to be executed after all "PromiseJobs" Jobs. +# +# ES6 JavaScript uses the concepts Job and JobQueue to schedule work to occur +# after the current execution context has completed: +# http://www.ecma-international.org/ecma-262/6.0/#sec-jobs-and-job-queues +# +# Node.js uses the `process.nextTick` mechanism to implement the concept of a +# Job, maintaining a global FIFO JobQueue for all Jobs, which is flushed after +# the current call stack ends. +# +# When calling `then` on a Promise, it enqueues a Job on a specific +# "PromiseJobs" JobQueue which is flushed in Node as a single Job on the +# global JobQueue. +# +# DataLoader batches all loads which occur in a single frame of execution, but +# should include in the batch all loads which occur during the flushing of the +# "PromiseJobs" JobQueue after that same execution frame. +# +# In order to avoid the DataLoader dispatch Job occuring before "PromiseJobs", +# A Promise Job is created with the sole purpose of enqueuing a global Job, +# ensuring that it always occurs after "PromiseJobs" ends. + +# Private: cached resolved Promise instance +cache = local() + +def enqueue_post_promise_job(fn, scheduler): + # type: (Callable, Any) -> None + global cache + if not hasattr(cache, 'resolved_promise'): + cache.resolved_promise = Promise.resolve(None) + if not scheduler: + scheduler = get_default_scheduler() + + def on_promise_resolve(v): + # type: (Any) -> None + async_instance.invoke(fn, scheduler) + + cache.resolved_promise.then(on_promise_resolve) + + +def dispatch_queue(loader): + # type: (DataLoader) -> None + """ + Given the current state of a Loader instance, perform a batch load + from its current queue. + """ + # Take the current loader queue, replacing it with an empty queue. + queue = loader._queue + loader._queue = [] + + # If a maxBatchSize was provided and the queue is longer, then segment the + # queue into multiple batches, otherwise treat the queue as a single batch. + max_batch_size = loader.max_batch_size + + if max_batch_size and max_batch_size < len(queue): + chunks = get_chunks(queue, max_batch_size) + for chunk in chunks: + dispatch_queue_batch(loader, chunk) + else: + dispatch_queue_batch(loader, queue) + + +def dispatch_queue_batch(loader, queue): + # type: (DataLoader, List[Loader]) -> None + # Collect all keys to be loaded in this dispatch + keys = [l.key for l in queue] + + # Call the provided batch_load_fn for this loader with the loader queue's keys. + try: + batch_promise = loader.batch_load_fn(keys) + except Exception as e: + failed_dispatch(loader, queue, e) + return None + + # Assert the expected response from batch_load_fn + if not batch_promise or not isinstance(batch_promise, Promise): + failed_dispatch( + loader, + queue, + TypeError( + ( + "DataLoader must be constructed with a function which accepts " + "Array and returns Promise>, but the function did " + "not return a Promise: {}." + ).format(batch_promise) + ), + ) + return None + + def batch_promise_resolved(values): + # type: (Sized) -> None + # Assert the expected resolution from batchLoadFn. + if not isinstance(values, Iterable): + raise TypeError( + ( + "DataLoader must be constructed with a function which accepts " + "Array and returns Promise>, but the function did " + "not return a Promise of an Array: {}." + ).format(values) + ) + + if len(values) != len(keys): + raise TypeError( + ( + "DataLoader must be constructed with a function which accepts " + "Array and returns Promise>, but the function did " + "not return a Promise of an Array of the same length as the Array " + "of keys." + "\n\nKeys:\n{}" + "\n\nValues:\n{}" + ).format(keys, values) + ) + + # Step through the values, resolving or rejecting each Promise in the + # loaded queue. + for l, value in zip(queue, values): + if isinstance(value, Exception): + l.reject(value) + else: + l.resolve(value) + + batch_promise.then(batch_promise_resolved).catch( + partial(failed_dispatch, loader, queue) + ) + + +def failed_dispatch(loader, queue, error): + # type: (DataLoader, Iterable[Loader], Exception) -> None + """ + Do not cache individual loads if the entire batch dispatch fails, + but still reject each request so they do not hang. + """ + for l in queue: + loader.clear(l.key) + l.reject(error) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/pygments/__init__.py b/.venv/lib/python3.13/site-packages/wandb/vendor/pygments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..394a85f2a4c1c7f84b47f2e17420c5a1ab761b00 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/pygments/__init__.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- +""" + Pygments + ~~~~~~~~ + + Pygments is a syntax highlighting package written in Python. + + It is a generic syntax highlighter for general use in all kinds of software + such as forum systems, wikis or other applications that need to prettify + source code. Highlights are: + + * a wide range of common languages and markup formats is supported + * special attention is paid to details, increasing quality by a fair amount + * support for new languages and formats are added easily + * a number of output formats, presently HTML, LaTeX, RTF, SVG, all image + formats that PIL supports, and ANSI sequences + * it is usable as a command-line tool and as a library + * ... and it highlights even Brainfuck! + + The `Pygments tip`_ is installable with ``easy_install Pygments==dev``. + + .. _Pygments tip: + http://bitbucket.org/birkenfeld/pygments-main/get/tip.zip#egg=Pygments-dev + + :copyright: Copyright 2006-2017 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" +import sys + +from pygments.util import StringIO, BytesIO + +__version__ = '2.2.0' +__docformat__ = 'restructuredtext' + +__all__ = ['lex', 'format', 'highlight'] + + +def lex(code, lexer): + """ + Lex ``code`` with ``lexer`` and return an iterable of tokens. + """ + try: + return lexer.get_tokens(code) + except TypeError as err: + if (isinstance(err.args[0], str) and + ('unbound method get_tokens' in err.args[0] or + 'missing 1 required positional argument' in err.args[0])): + raise TypeError('lex() argument must be a lexer instance, ' + 'not a class') + raise + + +def format(tokens, formatter, outfile=None): # pylint: disable=redefined-builtin + """ + Format a tokenlist ``tokens`` with the formatter ``formatter``. + + If ``outfile`` is given and a valid file object (an object + with a ``write`` method), the result will be written to it, otherwise + it is returned as a string. + """ + try: + if not outfile: + realoutfile = getattr(formatter, 'encoding', None) and BytesIO() or StringIO() + formatter.format(tokens, realoutfile) + return realoutfile.getvalue() + else: + formatter.format(tokens, outfile) + except TypeError as err: + if (isinstance(err.args[0], str) and + ('unbound method format' in err.args[0] or + 'missing 1 required positional argument' in err.args[0])): + raise TypeError('format() argument must be a formatter instance, ' + 'not a class') + raise + + +def highlight(code, lexer, formatter, outfile=None): + """ + Lex ``code`` with ``lexer`` and format it with the formatter ``formatter``. + + If ``outfile`` is given and a valid file object (an object + with a ``write`` method), the result will be written to it, otherwise + it is returned as a string. + """ + return format(lex(code, lexer), formatter, outfile) + + +if __name__ == '__main__': # pragma: no cover + from pygments.cmdline import main + sys.exit(main(sys.argv)) diff --git a/.venv/lib/python3.13/site-packages/wandb/vendor/pygments/cmdline.py b/.venv/lib/python3.13/site-packages/wandb/vendor/pygments/cmdline.py new file mode 100644 index 0000000000000000000000000000000000000000..5e1f39e2aa4c1ca05d2e8a5ea1700890460e24e1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/wandb/vendor/pygments/cmdline.py @@ -0,0 +1,568 @@ +# -*- coding: utf-8 -*- +""" + pygments.cmdline + ~~~~~~~~~~~~~~~~ + + Command line interface. + + :copyright: Copyright 2006-2017 by the Pygments team, see AUTHORS. + :license: BSD, see LICENSE for details. +""" + +from __future__ import print_function + +import sys +import getopt +from textwrap import dedent + +from pygments import __version__, highlight +from pygments.util import ClassNotFound, OptionError, docstring_headline, \ + guess_decode, guess_decode_from_terminal, terminal_encoding +from pygments.lexers import get_all_lexers, get_lexer_by_name, guess_lexer, \ + load_lexer_from_file, get_lexer_for_filename, find_lexer_class_for_filename +from pygments.lexers.special import TextLexer +from pygments.formatters.latex import LatexEmbeddedLexer, LatexFormatter +from pygments.formatters import get_all_formatters, get_formatter_by_name, \ + load_formatter_from_file, get_formatter_for_filename, find_formatter_class +from pygments.formatters.terminal import TerminalFormatter +from pygments.filters import get_all_filters, find_filter_class +from pygments.styles import get_all_styles, get_style_by_name + + +USAGE = """\ +Usage: %s [-l | -g] [-F [:]] [-f ] + [-O ] [-P ] [-s] [-v] [-x] [-o ] [] + + %s -S + + +

%(title)s

+ +''' + +DOC_HEADER_EXTERNALCSS = '''\ + + + + + %(title)s + + + + +

%(title)s

+ +''' + +DOC_FOOTER = '''\ + + +''' + + +class HtmlFormatter(Formatter): + r""" + Format tokens as HTML 4 ```` tags within a ``
`` tag, wrapped
+    in a ``
`` tag. The ``
``'s CSS class can be set by the `cssclass` + option. + + If the `linenos` option is set to ``"table"``, the ``
`` is
+    additionally wrapped inside a ```` which has one row and two
+    cells: one containing the line numbers and one containing the code.
+    Example:
+
+    .. sourcecode:: html
+
+        
+
+ + +
+
1
+            2
+
+
def foo(bar):
+              pass
+            
+
+ + (whitespace added to improve clarity). + + Wrapping can be disabled using the `nowrap` option. + + A list of lines can be specified using the `hl_lines` option to make these + lines highlighted (as of Pygments 0.11). + + With the `full` option, a complete HTML 4 document is output, including + the style definitions inside a ``